diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 3c50a4f8b39f..7b676a2b0598 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -86,11 +86,12 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> { broadcastable output shape possible for the given inputs. }]; - let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs, - OptionalAttr:$error); + let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, + Shape_ShapeOrExtentTensorType:$rhs, + OptionalAttr:$error); let results = (outs Shape_ShapeType:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict"; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; let hasFolder = 1; } @@ -220,10 +221,10 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { If the shape represents an error, this op's behavior is undefined. }]; - let arguments = (ins Shape_ShapeType:$input); + let arguments = (ins Shape_ShapeOrExtentTensorType:$input); let results = (outs IndexTensor:$result); - let assemblyFormat = "attr-dict $input `:` type($result)"; + let assemblyFormat = "attr-dict $input `:` type($input) `->` type($result)"; let hasFolder = 1; } @@ -342,6 +343,10 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> { let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); let results = (outs Shape_SizeOrIndexType:$result); + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result, Value shape">, + ]; + let assemblyFormat = "$shape `:` type($shape) `->` type($result) attr-dict"; let hasFolder = 1; @@ -412,23 +417,28 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> { let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict"; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result, Value arg"> + ]; + let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; + let hasCanonicalizer = 1; let hasFolder = 1; } def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> { let summary = "Casts between index types of the shape and standard dialect"; let description = [{ - Converts a `shape.size` to a standard index. - This operation and its inverse, `index_to_size`, facilitate index conversion - between the standard and the shape dialect. - The behavior is undefined for unknown and invalid arguments. + Converts a `shape.size` to a standard index. This operation and its + inverse, `index_to_size`, facilitate index conversion between the standard + and the shape dialect. The behavior is undefined for unknown and invalid + arguments. }]; - let arguments = (ins Shape_SizeType:$arg); + let arguments = (outs Shape_SizeOrIndexType:$arg); let results = (outs Index:$result); - let assemblyFormat = "$arg attr-dict"; + let assemblyFormat = "$arg attr-dict `:` type($arg)"; let hasFolder = 1; let hasCanonicalizer = 1; @@ -490,7 +500,7 @@ def Shape_SplitAtOp : Shape_Op<"split_at", []> { - `index` is in the range [-rank(operand),rank(operand)] }]; - let arguments = (ins Shape_ShapeType:$operand, I32:$index); + let arguments = (ins Shape_ShapeOrExtentTensorType:$operand, I32:$index); let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail); let hasFolder = 1; } @@ -520,8 +530,7 @@ def Shape_ConcatOp : Shape_Op<"concat", []> { // TODO: Move the code below and witnesses to a different file. def Shape_AnyOp : Shape_Op<"any", [Commutative, - NoSideEffect, - SameOperandsAndResultType]> { + NoSideEffect]> { let summary = "Return any combination of the input shapes"; let description = [{ This operation takes multiple input shapes or extent tensors and returns @@ -541,7 +550,6 @@ def Shape_AnyOp : Shape_Op<"any", [Commutative, let arguments = (ins Variadic:$inputs); let results = (outs Shape_ShapeOrExtentTensorType:$result); - let assemblyFormat = "$inputs `:` type($result) attr-dict"; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 104ab46c5581..4887c87c1e5f 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -674,6 +674,16 @@ OpFoldResult NumElementsOp::fold(ArrayRef operands) { return builder.getIndexAttr(product.getLimitedValue()); } +void NumElementsOp::build(OpBuilder &builder, OperationState &result, + Value shape) { + if (shape.getType().isa()) { + auto type = builder.getIndexType(); + return build(builder, result, type, shape); + } + auto type = SizeType::get(builder.getContext()); + return build(builder, result, type, shape); +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// @@ -702,6 +712,38 @@ OpFoldResult ShapeOfOp::fold(ArrayRef) { return builder.getIndexTensorAttr(type.getShape()); } +void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { + if (arg.getType().isa()) { + auto type = RankedTensorType::get({ShapedType::kDynamicSize}, + builder.getIndexType()); + return ShapeOfOp::build(builder, result, type, arg); + } + auto type = ShapeType::get(builder.getContext()); + return ShapeOfOp::build(builder, result, type, arg); +} + +namespace { +struct ShapeOfWithTensor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::ShapeOfOp op, + PatternRewriter &rewriter) const override { + if (!op.arg().getType().isa()) + return failure(); + if (op.getType().isa()) + return failure(); + + rewriter.replaceOpWithNewOp(op.getOperation(), op.arg()); + return success(); + } +}; +} // namespace + +void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // SizeToIndexOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index 8236c6f27975..9336402d86da 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -50,7 +50,6 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) { // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C3:.*]] = constant 3 : index // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex> - // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor return } @@ -66,7 +65,6 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) { // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex> - // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor return } @@ -120,7 +118,7 @@ func @any_of_three(%a : tensor, %b : tensor, %c : tensor) -> tensor { // CHECK: return %[[A]] : tensor - %result = shape.any %a, %b, %c : tensor + %result = "shape.any"(%a, %b, %c) : (tensor, tensor, tensor) -> tensor return %result : tensor } @@ -131,7 +129,7 @@ func @any_of_three(%a : tensor, // CHECK-SAME: (%[[A:.*]]: tensor) -> tensor func @any_of_one(%a : tensor) -> tensor { // CHECK: return %[[A]] : tensor - %result = shape.any %a : tensor + %result = "shape.any"(%a) : (tensor) -> tensor return %result : tensor } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index e147fbeb81ac..5fe2ac108a69 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -54,7 +54,7 @@ func @f() -> !shape.shape { // CHECK: shape.const_shape [7, 2] : !shape.shape %0 = shape.const_shape [1, 2] : !shape.shape %1 = shape.const_shape [7, 1] : !shape.shape - %2 = shape.broadcast %0, %1 + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape return %2 : !shape.shape } @@ -65,7 +65,7 @@ func @f() -> !shape.shape { func @f(%arg0 : !shape.shape) -> !shape.shape { // CHECK: return %arg0 %0 = shape.const_shape [] : !shape.shape - %1 = shape.broadcast %arg0, %0 + %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape return %1 : !shape.shape } @@ -76,7 +76,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape { func @f(%arg0 : !shape.shape) -> !shape.shape { // CHECK: return %arg0 %0 = shape.const_shape [] : !shape.shape - %1 = shape.broadcast %0, %arg0 + %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape return %1 : !shape.shape } @@ -89,7 +89,7 @@ func @f() -> !shape.shape { // CHECK: return %[[CST]] %0 = shape.const_shape [] : !shape.shape %1 = shape.const_shape [1, 2, 3] : !shape.shape - %2 = shape.broadcast %0, %1 + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape return %2 : !shape.shape } @@ -101,7 +101,7 @@ func @f() -> !shape.shape { // CHECK: shape.broadcast %0 = shape.const_shape [2] : !shape.shape %1 = shape.const_shape [7] : !shape.shape - %2 = shape.broadcast %0, %1 + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape return %2 : !shape.shape } @@ -124,7 +124,7 @@ func @f() -> !shape.shape { func @f() -> tensor<2xindex> { // CHECK: constant dense<[0, 1]> : tensor<2xindex> %cs = shape.const_shape [0, 1] : !shape.shape - %0 = shape.to_extent_tensor %cs : tensor<2xindex> + %0 = shape.to_extent_tensor %cs : !shape.shape -> tensor<2xindex> return %0 : tensor<2xindex> } @@ -159,7 +159,7 @@ func @const_size_to_index() -> index { // CHECK-NOT: shape.index_cast %cs = shape.const_size 123 // CHECK: constant 123 : index - %ci = shape.size_to_index %cs + %ci = shape.size_to_index %cs : !shape.size return %ci : index } @@ -185,7 +185,7 @@ func @const_index_to_size_to_index() -> index { %cs0 = shape.index_to_size %ci0 // CHECK: %[[CI:.*]] = constant 123 : index // CHECK-NEXT: return %[[CI]] : index - %ci1 = shape.size_to_index %cs0 + %ci1 = shape.size_to_index %cs0 : !shape.size return %ci1 : index } @@ -195,7 +195,7 @@ func @const_index_to_size_to_index() -> index { // CHECK-LABEL: func @nonfoldable_size_to_index func @nonfoldable_size_to_index(%cs : !shape.size) -> index { // CHECK: shape.size_to_index - %ci = shape.size_to_index %cs + %ci = shape.size_to_index %cs : !shape.size return %ci : index } @@ -403,7 +403,7 @@ func @f(%arg : !shape.shape) -> !shape.shape { // CHECK-NEXT: %[[CS:.*]] = shape.const_shape // CHECK-NEXT: return %[[CS]] %0 = shape.const_shape [2, 3, 4] : !shape.shape - %1 = shape.any %0, %arg : !shape.shape + %1 = "shape.any"(%0, %arg) : (!shape.shape, !shape.shape) -> !shape.shape return %1 : !shape.shape } @@ -415,7 +415,7 @@ func @f(%arg : tensor) -> tensor { // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor // CHECK-NEXT: return %[[CS]] : tensor %0 = shape.const_shape [2, 3, 4] : tensor - %1 = shape.any %0, %arg : tensor + %1 = "shape.any"(%0, %arg) : (tensor, tensor) -> tensor return %1 : tensor } @@ -424,9 +424,9 @@ func @f(%arg : tensor) -> tensor { // Folding of any with partially constant operands is not yet implemented. // CHECK-LABEL: func @f func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape { - // CHECK-NEXT: %[[CS:.*]] = shape.any + // CHECK-NEXT: %[[CS:.*]] = "shape.any" // CHECK-NEXT: return %[[CS]] - %1 = shape.any %arg0, %arg1 : !shape.shape + %1 = "shape.any"(%arg0, %arg1) : (!shape.shape, !shape.shape) -> !shape.shape return %1 : !shape.shape } @@ -619,7 +619,7 @@ func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> index { func @index_to_size_to_index(%index : index) -> index { // CHECK: return %[[IDX]] : index %size = shape.index_to_size %index - %result = shape.size_to_index %size + %result = shape.size_to_index %size : !shape.size return %result : index } @@ -630,7 +630,7 @@ func @index_to_size_to_index(%index : index) -> index { // CHECK-SAME: (%[[SIZE:.*]]: !shape.size) -> !shape.size func @size_to_index_to_size(%size : !shape.size) -> !shape.size { // CHECK: return %[[SIZE]] : !shape.size - %idx = shape.size_to_index %size + %idx = shape.size_to_index %size : !shape.size %result = shape.index_to_size %idx return %result : !shape.size } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir index f57826097d34..87af623fe0f7 100644 --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -49,7 +49,7 @@ func @test_shape_num_elements_fixed() { func @test_broadcast_fixed() { %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape %1 = shape.const_shape [4, 57, 92] : !shape.shape - %2 = shape.broadcast %0, %1 + %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } @@ -99,7 +99,7 @@ func @test_constraints() { %w3 = shape.const_witness false %w4 = shape.assuming_all %w0, %w1, %w2, %w3 shape.assuming %w4 -> !shape.shape { - %2 = shape.any %0, %1 : !shape.shape + %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape shape.assuming_yield %2 : !shape.shape } return @@ -131,7 +131,7 @@ func @const_size() { } func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> { - %0 = shape.to_extent_tensor %arg : tensor<3xindex> + %0 = shape.to_extent_tensor %arg : !shape.shape -> tensor<3xindex> return %0 : tensor<3xindex> } @@ -188,10 +188,10 @@ func @get_extent_on_mixed_operands(%arg : tensor) -> !shape.size { func @any() { %0 = shape.const_shape [1, 2, 3] : !shape.shape %1 = shape.const_shape [4, 5, 6] : !shape.shape - %2 = shape.any %0, %1 : !shape.shape + %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = shape.const_shape [1, 2, 3] : tensor %4 = shape.const_shape [4, 5, 6] : tensor - %5 = shape.any %3, %4 : tensor + %5 = "shape.any"(%3, %4) : (tensor, tensor) -> tensor return }