[mlir][shape] Further operand and result type generalization

Previous changes generalized some of the operands and results. Complete
a larger group of those to simplify progressive lowering. Also update
some of the declarative asm form due to generalization. Tried to keep it
mostly mechanical.
This commit is contained in:
Jacques Pienaar 2020-07-25 21:37:15 -07:00
parent 9162b70e51
commit 595d214f47
5 changed files with 87 additions and 39 deletions

View file

@ -86,11 +86,12 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
broadcastable output shape possible for the given inputs. broadcastable output shape possible for the given inputs.
}]; }];
let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs, let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
Shape_ShapeOrExtentTensorType:$rhs,
OptionalAttr<StrAttr>:$error); OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeType:$result); let results = (outs Shape_ShapeType:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict"; let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
let hasFolder = 1; let hasFolder = 1;
} }
@ -220,10 +221,10 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
If the shape represents an error, this op's behavior is undefined. If the shape represents an error, this op's behavior is undefined.
}]; }];
let arguments = (ins Shape_ShapeType:$input); let arguments = (ins Shape_ShapeOrExtentTensorType:$input);
let results = (outs IndexTensor:$result); let results = (outs IndexTensor:$result);
let assemblyFormat = "attr-dict $input `:` type($result)"; let assemblyFormat = "attr-dict $input `:` type($input) `->` type($result)";
let hasFolder = 1; let hasFolder = 1;
} }
@ -342,6 +343,10 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
let results = (outs Shape_SizeOrIndexType:$result); let results = (outs Shape_SizeOrIndexType:$result);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, Value shape">,
];
let assemblyFormat = "$shape `:` type($shape) `->` type($result) attr-dict"; let assemblyFormat = "$shape `:` type($shape) `->` type($result) attr-dict";
let hasFolder = 1; let hasFolder = 1;
@ -412,23 +417,28 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict"; let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict";
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, Value arg">
];
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
} }
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> { def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
let summary = "Casts between index types of the shape and standard dialect"; let summary = "Casts between index types of the shape and standard dialect";
let description = [{ let description = [{
Converts a `shape.size` to a standard index. Converts a `shape.size` to a standard index. This operation and its
This operation and its inverse, `index_to_size`, facilitate index conversion inverse, `index_to_size`, facilitate index conversion between the standard
between the standard and the shape dialect. and the shape dialect. The behavior is undefined for unknown and invalid
The behavior is undefined for unknown and invalid arguments. arguments.
}]; }];
let arguments = (ins Shape_SizeType:$arg); let arguments = (outs Shape_SizeOrIndexType:$arg);
let results = (outs Index:$result); let results = (outs Index:$result);
let assemblyFormat = "$arg attr-dict"; let assemblyFormat = "$arg attr-dict `:` type($arg)";
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
@ -490,7 +500,7 @@ def Shape_SplitAtOp : Shape_Op<"split_at", []> {
- `index` is in the range [-rank(operand),rank(operand)] - `index` is in the range [-rank(operand),rank(operand)]
}]; }];
let arguments = (ins Shape_ShapeType:$operand, I32:$index); let arguments = (ins Shape_ShapeOrExtentTensorType:$operand, I32:$index);
let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail); let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail);
let hasFolder = 1; let hasFolder = 1;
} }
@ -520,8 +530,7 @@ def Shape_ConcatOp : Shape_Op<"concat", []> {
// TODO: Move the code below and witnesses to a different file. // TODO: Move the code below and witnesses to a different file.
def Shape_AnyOp : Shape_Op<"any", [Commutative, def Shape_AnyOp : Shape_Op<"any", [Commutative,
NoSideEffect, NoSideEffect]> {
SameOperandsAndResultType]> {
let summary = "Return any combination of the input shapes"; let summary = "Return any combination of the input shapes";
let description = [{ let description = [{
This operation takes multiple input shapes or extent tensors and returns This operation takes multiple input shapes or extent tensors and returns
@ -541,7 +550,6 @@ def Shape_AnyOp : Shape_Op<"any", [Commutative,
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs); let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs);
let results = (outs Shape_ShapeOrExtentTensorType:$result); let results = (outs Shape_ShapeOrExtentTensorType:$result);
let assemblyFormat = "$inputs `:` type($result) attr-dict";
let hasFolder = 1; let hasFolder = 1;
} }

View file

@ -674,6 +674,16 @@ OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
return builder.getIndexAttr(product.getLimitedValue()); return builder.getIndexAttr(product.getLimitedValue());
} }
void NumElementsOp::build(OpBuilder &builder, OperationState &result,
Value shape) {
if (shape.getType().isa<ShapedType>()) {
auto type = builder.getIndexType();
return build(builder, result, type, shape);
}
auto type = SizeType::get(builder.getContext());
return build(builder, result, type, shape);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MulOp // MulOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -702,6 +712,38 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
return builder.getIndexTensorAttr(type.getShape()); return builder.getIndexTensorAttr(type.getShape());
} }
void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
if (arg.getType().isa<ShapedType>()) {
auto type = RankedTensorType::get({ShapedType::kDynamicSize},
builder.getIndexType());
return ShapeOfOp::build(builder, result, type, arg);
}
auto type = ShapeType::get(builder.getContext());
return ShapeOfOp::build(builder, result, type, arg);
}
namespace {
struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
if (!op.arg().getType().isa<ShapedType>())
return failure();
if (op.getType().isa<ShapedType>())
return failure();
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
return success();
}
};
} // namespace
void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
MLIRContext *context) {
patterns.insert<ShapeOfWithTensor>(context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// SizeToIndexOp // SizeToIndexOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View file

@ -50,7 +50,6 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
// CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index // CHECK-DAG: %[[C3:.*]] = constant 3 : index
// CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex> // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
// CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor<?xindex>
%shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex> %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
return return
} }
@ -66,7 +65,6 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
// CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
// CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex> // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
// CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor<?xindex>
%shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex> %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
return return
} }
@ -120,7 +118,7 @@ func @any_of_three(%a : tensor<?xindex>,
%b : tensor<?xindex>, %b : tensor<?xindex>,
%c : tensor<?xindex>) -> tensor<?xindex> { %c : tensor<?xindex>) -> tensor<?xindex> {
// CHECK: return %[[A]] : tensor<?xindex> // CHECK: return %[[A]] : tensor<?xindex>
%result = shape.any %a, %b, %c : tensor<?xindex> %result = "shape.any"(%a, %b, %c) : (tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
return %result : tensor<?xindex> return %result : tensor<?xindex>
} }
@ -131,7 +129,7 @@ func @any_of_three(%a : tensor<?xindex>,
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex> // CHECK-SAME: (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex>
func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> { func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> {
// CHECK: return %[[A]] : tensor<?xindex> // CHECK: return %[[A]] : tensor<?xindex>
%result = shape.any %a : tensor<?xindex> %result = "shape.any"(%a) : (tensor<?xindex>) -> tensor<?xindex>
return %result : tensor<?xindex> return %result : tensor<?xindex>
} }

View file

@ -54,7 +54,7 @@ func @f() -> !shape.shape {
// CHECK: shape.const_shape [7, 2] : !shape.shape // CHECK: shape.const_shape [7, 2] : !shape.shape
%0 = shape.const_shape [1, 2] : !shape.shape %0 = shape.const_shape [1, 2] : !shape.shape
%1 = shape.const_shape [7, 1] : !shape.shape %1 = shape.const_shape [7, 1] : !shape.shape
%2 = shape.broadcast %0, %1 %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
return %2 : !shape.shape return %2 : !shape.shape
} }
@ -65,7 +65,7 @@ func @f() -> !shape.shape {
func @f(%arg0 : !shape.shape) -> !shape.shape { func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK: return %arg0 // CHECK: return %arg0
%0 = shape.const_shape [] : !shape.shape %0 = shape.const_shape [] : !shape.shape
%1 = shape.broadcast %arg0, %0 %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape
return %1 : !shape.shape return %1 : !shape.shape
} }
@ -76,7 +76,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
func @f(%arg0 : !shape.shape) -> !shape.shape { func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK: return %arg0 // CHECK: return %arg0
%0 = shape.const_shape [] : !shape.shape %0 = shape.const_shape [] : !shape.shape
%1 = shape.broadcast %0, %arg0 %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape
return %1 : !shape.shape return %1 : !shape.shape
} }
@ -89,7 +89,7 @@ func @f() -> !shape.shape {
// CHECK: return %[[CST]] // CHECK: return %[[CST]]
%0 = shape.const_shape [] : !shape.shape %0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape %1 = shape.const_shape [1, 2, 3] : !shape.shape
%2 = shape.broadcast %0, %1 %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
return %2 : !shape.shape return %2 : !shape.shape
} }
@ -101,7 +101,7 @@ func @f() -> !shape.shape {
// CHECK: shape.broadcast // CHECK: shape.broadcast
%0 = shape.const_shape [2] : !shape.shape %0 = shape.const_shape [2] : !shape.shape
%1 = shape.const_shape [7] : !shape.shape %1 = shape.const_shape [7] : !shape.shape
%2 = shape.broadcast %0, %1 %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
return %2 : !shape.shape return %2 : !shape.shape
} }
@ -124,7 +124,7 @@ func @f() -> !shape.shape {
func @f() -> tensor<2xindex> { func @f() -> tensor<2xindex> {
// CHECK: constant dense<[0, 1]> : tensor<2xindex> // CHECK: constant dense<[0, 1]> : tensor<2xindex>
%cs = shape.const_shape [0, 1] : !shape.shape %cs = shape.const_shape [0, 1] : !shape.shape
%0 = shape.to_extent_tensor %cs : tensor<2xindex> %0 = shape.to_extent_tensor %cs : !shape.shape -> tensor<2xindex>
return %0 : tensor<2xindex> return %0 : tensor<2xindex>
} }
@ -159,7 +159,7 @@ func @const_size_to_index() -> index {
// CHECK-NOT: shape.index_cast // CHECK-NOT: shape.index_cast
%cs = shape.const_size 123 %cs = shape.const_size 123
// CHECK: constant 123 : index // CHECK: constant 123 : index
%ci = shape.size_to_index %cs %ci = shape.size_to_index %cs : !shape.size
return %ci : index return %ci : index
} }
@ -185,7 +185,7 @@ func @const_index_to_size_to_index() -> index {
%cs0 = shape.index_to_size %ci0 %cs0 = shape.index_to_size %ci0
// CHECK: %[[CI:.*]] = constant 123 : index // CHECK: %[[CI:.*]] = constant 123 : index
// CHECK-NEXT: return %[[CI]] : index // CHECK-NEXT: return %[[CI]] : index
%ci1 = shape.size_to_index %cs0 %ci1 = shape.size_to_index %cs0 : !shape.size
return %ci1 : index return %ci1 : index
} }
@ -195,7 +195,7 @@ func @const_index_to_size_to_index() -> index {
// CHECK-LABEL: func @nonfoldable_size_to_index // CHECK-LABEL: func @nonfoldable_size_to_index
func @nonfoldable_size_to_index(%cs : !shape.size) -> index { func @nonfoldable_size_to_index(%cs : !shape.size) -> index {
// CHECK: shape.size_to_index // CHECK: shape.size_to_index
%ci = shape.size_to_index %cs %ci = shape.size_to_index %cs : !shape.size
return %ci : index return %ci : index
} }
@ -403,7 +403,7 @@ func @f(%arg : !shape.shape) -> !shape.shape {
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape // CHECK-NEXT: %[[CS:.*]] = shape.const_shape
// CHECK-NEXT: return %[[CS]] // CHECK-NEXT: return %[[CS]]
%0 = shape.const_shape [2, 3, 4] : !shape.shape %0 = shape.const_shape [2, 3, 4] : !shape.shape
%1 = shape.any %0, %arg : !shape.shape %1 = "shape.any"(%0, %arg) : (!shape.shape, !shape.shape) -> !shape.shape
return %1 : !shape.shape return %1 : !shape.shape
} }
@ -415,7 +415,7 @@ func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex> // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
// CHECK-NEXT: return %[[CS]] : tensor<?xindex> // CHECK-NEXT: return %[[CS]] : tensor<?xindex>
%0 = shape.const_shape [2, 3, 4] : tensor<?xindex> %0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
%1 = shape.any %0, %arg : tensor<?xindex> %1 = "shape.any"(%0, %arg) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
return %1 : tensor<?xindex> return %1 : tensor<?xindex>
} }
@ -424,9 +424,9 @@ func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
// Folding of any with partially constant operands is not yet implemented. // Folding of any with partially constant operands is not yet implemented.
// CHECK-LABEL: func @f // CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape { func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
// CHECK-NEXT: %[[CS:.*]] = shape.any // CHECK-NEXT: %[[CS:.*]] = "shape.any"
// CHECK-NEXT: return %[[CS]] // CHECK-NEXT: return %[[CS]]
%1 = shape.any %arg0, %arg1 : !shape.shape %1 = "shape.any"(%arg0, %arg1) : (!shape.shape, !shape.shape) -> !shape.shape
return %1 : !shape.shape return %1 : !shape.shape
} }
@ -619,7 +619,7 @@ func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> index {
func @index_to_size_to_index(%index : index) -> index { func @index_to_size_to_index(%index : index) -> index {
// CHECK: return %[[IDX]] : index // CHECK: return %[[IDX]] : index
%size = shape.index_to_size %index %size = shape.index_to_size %index
%result = shape.size_to_index %size %result = shape.size_to_index %size : !shape.size
return %result : index return %result : index
} }
@ -630,7 +630,7 @@ func @index_to_size_to_index(%index : index) -> index {
// CHECK-SAME: (%[[SIZE:.*]]: !shape.size) -> !shape.size // CHECK-SAME: (%[[SIZE:.*]]: !shape.size) -> !shape.size
func @size_to_index_to_size(%size : !shape.size) -> !shape.size { func @size_to_index_to_size(%size : !shape.size) -> !shape.size {
// CHECK: return %[[SIZE]] : !shape.size // CHECK: return %[[SIZE]] : !shape.size
%idx = shape.size_to_index %size %idx = shape.size_to_index %size : !shape.size
%result = shape.index_to_size %idx %result = shape.index_to_size %idx
return %result : !shape.size return %result : !shape.size
} }

View file

@ -49,7 +49,7 @@ func @test_shape_num_elements_fixed() {
func @test_broadcast_fixed() { func @test_broadcast_fixed() {
%0 = shape.const_shape [10, 1, 57, 92] : !shape.shape %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
%1 = shape.const_shape [4, 57, 92] : !shape.shape %1 = shape.const_shape [4, 57, 92] : !shape.shape
%2 = shape.broadcast %0, %1 %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return return
} }
@ -99,7 +99,7 @@ func @test_constraints() {
%w3 = shape.const_witness false %w3 = shape.const_witness false
%w4 = shape.assuming_all %w0, %w1, %w2, %w3 %w4 = shape.assuming_all %w0, %w1, %w2, %w3
shape.assuming %w4 -> !shape.shape { shape.assuming %w4 -> !shape.shape {
%2 = shape.any %0, %1 : !shape.shape %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
shape.assuming_yield %2 : !shape.shape shape.assuming_yield %2 : !shape.shape
} }
return return
@ -131,7 +131,7 @@ func @const_size() {
} }
func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> { func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> {
%0 = shape.to_extent_tensor %arg : tensor<3xindex> %0 = shape.to_extent_tensor %arg : !shape.shape -> tensor<3xindex>
return %0 : tensor<3xindex> return %0 : tensor<3xindex>
} }
@ -188,10 +188,10 @@ func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
func @any() { func @any() {
%0 = shape.const_shape [1, 2, 3] : !shape.shape %0 = shape.const_shape [1, 2, 3] : !shape.shape
%1 = shape.const_shape [4, 5, 6] : !shape.shape %1 = shape.const_shape [4, 5, 6] : !shape.shape
%2 = shape.any %0, %1 : !shape.shape %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = shape.const_shape [1, 2, 3] : tensor<?xindex> %3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
%4 = shape.const_shape [4, 5, 6] : tensor<?xindex> %4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
%5 = shape.any %3, %4 : tensor<?xindex> %5 = "shape.any"(%3, %4) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
return return
} }