[MLIR][Shape] Allow shape.mul
to operate in indices
Differential Revision: https://reviews.llvm.org/D84437
This commit is contained in:
parent
cf42877812
commit
783a351785
|
@ -307,18 +307,25 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
|
|||
let results = (outs Shape_ShapeOrSizeType:$result);
|
||||
}
|
||||
|
||||
def Shape_MulOp : Shape_Op<"mul", [Commutative, SameOperandsAndResultType]> {
|
||||
let summary = "Multiplication of sizes";
|
||||
def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
|
||||
let summary = "Multiplication of sizes and indices";
|
||||
let description = [{
|
||||
Multiplies two valid sizes as follows:
|
||||
- lhs * rhs = unknown if either lhs or rhs unknown;
|
||||
- lhs * rhs = (int)lhs * (int)rhs if both known;
|
||||
Multiplies two sizes or indices. If either operand is an error it will be
|
||||
propagated to the result. The operands can be of type `size` or `index`. If
|
||||
at least one of the operands can hold an error, i.e. if it is of type `size`,
|
||||
then also the result must be of type `size`. If error propagation is not
|
||||
possible because both operands are of type `index` then the result must also
|
||||
be of type `index`.
|
||||
}];
|
||||
|
||||
let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs);
|
||||
let results = (outs Shape_SizeType:$result);
|
||||
let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs);
|
||||
let results = (outs Shape_SizeOrIndexType:$result);
|
||||
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict";
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs `:` type($lhs) `,` type($rhs) `->` type($result) attr-dict
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
|
||||
|
|
|
@ -28,6 +28,13 @@ static RankedTensorType getExtentTensorType(MLIRContext *ctx) {
|
|||
return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
|
||||
}
|
||||
|
||||
static bool isErrorPropagationPossible(ArrayRef<Type> operandTypes) {
|
||||
for (Type ty : operandTypes)
|
||||
if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>())
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
ShapeDialect::ShapeDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
|
@ -539,9 +546,7 @@ static LogicalResult verify(GetExtentOp op) {
|
|||
Type shapeTy = op.shape().getType();
|
||||
Type dimTy = op.dim().getType();
|
||||
Type extentTy = op.extent().getType();
|
||||
bool errorPropagationPossible =
|
||||
shapeTy.isa<ShapeType>() || dimTy.isa<SizeType>();
|
||||
if (errorPropagationPossible) {
|
||||
if (isErrorPropagationPossible({shapeTy, dimTy})) {
|
||||
if (!extentTy.isa<SizeType>())
|
||||
op.emitError()
|
||||
<< "if at least one of the operands can hold error values then the "
|
||||
|
@ -593,9 +598,8 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(shape::RankOp op) {
|
||||
Type argTy = op.shape().getType();
|
||||
Type resultTy = op.rank().getType();
|
||||
if (argTy.isa<ShapeType>() && !resultTy.isa<SizeType>())
|
||||
if (op.shape().getType().isa<ShapeType>() &&
|
||||
!op.rank().getType().isa<SizeType>())
|
||||
return op.emitOpError()
|
||||
<< "if operand is of type `shape` then the result must be of type "
|
||||
"`size` to propagate potential errors";
|
||||
|
@ -672,6 +676,25 @@ OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
|
|||
return builder.getIndexAttr(product.getLimitedValue());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(MulOp op) {
|
||||
Type resultTy = op.result().getType();
|
||||
if (isErrorPropagationPossible({op.lhs().getType(), op.rhs().getType()})) {
|
||||
if (!resultTy.isa<SizeType>())
|
||||
return op.emitOpError()
|
||||
<< "if at least one of the operands can hold error values then "
|
||||
"the result must be of type `size` to propagate them";
|
||||
} else {
|
||||
if (resultTy.isa<SizeType>())
|
||||
return op.emitError() << "if none of the operands can hold error values "
|
||||
"then the result must be of type `index`";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeOfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -685,15 +708,13 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
|
|||
}
|
||||
|
||||
static LogicalResult verify(ShapeOfOp op) {
|
||||
Type argTy = op.arg().getType();
|
||||
Type resultTy = op.result().getType();
|
||||
if (argTy.isa<ValueShapeType>()) {
|
||||
if (isErrorPropagationPossible(op.arg().getType())) {
|
||||
if (!resultTy.isa<ShapeType>())
|
||||
return op.emitOpError()
|
||||
<< "if operand is of type `value_shape` then the result must be "
|
||||
"of type `shape` to propagate potential error shapes";
|
||||
} else {
|
||||
assert(argTy.isa<ShapedType>());
|
||||
if (resultTy != getExtentTensorType(op.getContext()))
|
||||
return op.emitOpError() << "if operand is a shaped type then the result "
|
||||
"must be an extent tensor";
|
||||
|
|
|
@ -38,8 +38,8 @@ NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
|
|||
// Generate reduce operator.
|
||||
Block *body = reduce.getBody();
|
||||
OpBuilder b = OpBuilder::atBlockEnd(body);
|
||||
Value product =
|
||||
b.create<MulOp>(loc, body->getArgument(1), body->getArgument(2));
|
||||
Value product = b.create<MulOp>(loc, b.getType<SizeType>(),
|
||||
body->getArgument(1), body->getArgument(2));
|
||||
b.create<YieldOp>(loc, product);
|
||||
|
||||
rewriter.replaceOp(op, reduce.result());
|
||||
|
|
|
@ -24,10 +24,19 @@ func @shape_id(%shape : !shape.shape) -> !shape.shape {
|
|||
// CHECK-LABEL: @binary_ops
|
||||
// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index)
|
||||
func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) {
|
||||
// CHECK: addi %[[LHS]], %[[RHS]] : index
|
||||
%sum = "shape.add"(%lhs, %rhs) : (!shape.size, !shape.size) -> !shape.size
|
||||
// CHECK-NEXT: addi %[[LHS]], %[[RHS]] : index
|
||||
%product = shape.mul %lhs, %rhs
|
||||
// CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Lower binary ops.
|
||||
// CHECK-LABEL: @binary_ops
|
||||
// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index)
|
||||
func @binary_ops(%lhs : index, %rhs : index) {
|
||||
// CHECK: muli %[[LHS]], %[[RHS]] : index
|
||||
%product = shape.mul %lhs, %rhs : index, index -> index
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
|
|||
^bb0(%index: index, %dim: !shape.size):
|
||||
shape.yield %dim : !shape.size
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -18,6 +19,7 @@ func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
|
|||
: (!shape.size, !shape.size) -> !shape.size
|
||||
shape.yield %new_acc : !shape.size
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -28,6 +30,7 @@ func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
|
|||
^bb0(%index: index, %dim: f32, %lci: !shape.size):
|
||||
shape.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -38,6 +41,7 @@ func @reduce_op_arg1_wrong_type(%shape : tensor<?xindex>, %init : index) {
|
|||
^bb0(%index: index, %dim: f32, %lci: index):
|
||||
shape.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -48,6 +52,7 @@ func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
|
|||
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
|
||||
shape.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -58,6 +63,7 @@ func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
|
|||
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
|
||||
shape.yield %dim, %dim : !shape.size, !shape.size
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -69,6 +75,7 @@ func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) {
|
|||
%c0 = constant 1 : index
|
||||
shape.yield %c0 : index
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -85,6 +92,7 @@ func @shape_of(%value_arg : !shape.value_shape,
|
|||
%shaped_arg : tensor<?x3x4xf32>) {
|
||||
// expected-error@+1 {{if operand is of type `value_shape` then the result must be of type `shape` to propagate potential error shapes}}
|
||||
%0 = shape.shape_of %value_arg : !shape.value_shape -> tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -93,6 +101,7 @@ func @shape_of(%value_arg : !shape.value_shape,
|
|||
%shaped_arg : tensor<?x3x4xf32>) {
|
||||
// expected-error@+1 {{if operand is a shaped type then the result must be an extent tensor}}
|
||||
%1 = shape.shape_of %shaped_arg : tensor<?x3x4xf32> -> !shape.shape
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -100,6 +109,7 @@ func @shape_of(%value_arg : !shape.value_shape,
|
|||
func @rank(%arg : !shape.shape) {
|
||||
// expected-error@+1 {{if operand is of type `shape` then the result must be of type `size` to propagate potential errors}}
|
||||
%0 = shape.rank %arg : !shape.shape -> index
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -120,3 +130,19 @@ func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
|
|||
return %result : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @mul_error_free(%arg : index) -> !shape.size {
|
||||
// expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}}
|
||||
%result = shape.mul %arg, %arg : index, index -> !shape.size
|
||||
return %result : !shape.size
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index {
|
||||
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
|
||||
%result = shape.mul %lhs, %rhs : !shape.size, index -> index
|
||||
return %result : index
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
|
|||
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
|
||||
^bb0(%index : index, %extent : !shape.size, %acc : !shape.size):
|
||||
%acc_next = shape.mul %acc, %extent
|
||||
: !shape.size, !shape.size -> !shape.size
|
||||
shape.yield %acc_next : !shape.size
|
||||
}
|
||||
return %num_elements : !shape.size
|
||||
|
@ -19,7 +20,7 @@ func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index {
|
|||
%init = constant 1 : index
|
||||
%num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
|
||||
^bb0(%index : index, %extent : index, %acc : index):
|
||||
%acc_next = muli %acc, %extent : index
|
||||
%acc_next = shape.mul %acc, %extent : index, index -> index
|
||||
shape.yield %acc_next : index
|
||||
}
|
||||
return %num_elements : index
|
||||
|
@ -110,9 +111,13 @@ func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>,
|
|||
return
|
||||
}
|
||||
|
||||
func @test_mul(%lhs: !shape.size, %rhs: !shape.size) -> !shape.size {
|
||||
%product = shape.mul %lhs, %rhs
|
||||
return %product: !shape.size
|
||||
func @mul(%size_arg : !shape.size, %index_arg : index) {
|
||||
%size_prod = shape.mul %size_arg, %size_arg
|
||||
: !shape.size, !shape.size -> !shape.size
|
||||
%index_prod = shape.mul %index_arg, %index_arg : index, index -> index
|
||||
%mixed_prod = shape.mul %size_arg, %index_arg
|
||||
: !shape.size, index -> !shape.size
|
||||
return
|
||||
}
|
||||
|
||||
func @const_size() {
|
||||
|
|
Loading…
Reference in a new issue