[mlir][shape] Further operand and result type generalization

Previous changes generalized some of the operands and results. Complete
a larger group of those to simplify progressive lowering. Also update
some of the declarative asm form due to generalization. Tried to keep it
mostly mechanical.
This commit is contained in:
Jacques Pienaar 2020-07-25 21:37:15 -07:00
parent 9162b70e51
commit 595d214f47
5 changed files with 87 additions and 39 deletions

View file

@ -86,11 +86,12 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
broadcastable output shape possible for the given inputs.
}];
let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs,
OptionalAttr<StrAttr>:$error);
let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
Shape_ShapeOrExtentTensorType:$rhs,
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeType:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict";
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
let hasFolder = 1;
}
@ -220,10 +221,10 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
If the shape represents an error, this op's behavior is undefined.
}];
let arguments = (ins Shape_ShapeType:$input);
let arguments = (ins Shape_ShapeOrExtentTensorType:$input);
let results = (outs IndexTensor:$result);
let assemblyFormat = "attr-dict $input `:` type($result)";
let assemblyFormat = "attr-dict $input `:` type($input) `->` type($result)";
let hasFolder = 1;
}
@ -342,6 +343,10 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
let results = (outs Shape_SizeOrIndexType:$result);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, Value shape">,
];
let assemblyFormat = "$shape `:` type($shape) `->` type($result) attr-dict";
let hasFolder = 1;
@ -412,23 +417,28 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict";
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, Value arg">
];
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
let summary = "Casts between index types of the shape and standard dialect";
let description = [{
Converts a `shape.size` to a standard index.
This operation and its inverse, `index_to_size`, facilitate index conversion
between the standard and the shape dialect.
The behavior is undefined for unknown and invalid arguments.
Converts a `shape.size` to a standard index. This operation and its
inverse, `index_to_size`, facilitate index conversion between the standard
and the shape dialect. The behavior is undefined for unknown and invalid
arguments.
}];
let arguments = (ins Shape_SizeType:$arg);
let arguments = (outs Shape_SizeOrIndexType:$arg);
let results = (outs Index:$result);
let assemblyFormat = "$arg attr-dict";
let assemblyFormat = "$arg attr-dict `:` type($arg)";
let hasFolder = 1;
let hasCanonicalizer = 1;
@ -490,7 +500,7 @@ def Shape_SplitAtOp : Shape_Op<"split_at", []> {
- `index` is in the range [-rank(operand),rank(operand)]
}];
let arguments = (ins Shape_ShapeType:$operand, I32:$index);
let arguments = (ins Shape_ShapeOrExtentTensorType:$operand, I32:$index);
let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail);
let hasFolder = 1;
}
@ -520,8 +530,7 @@ def Shape_ConcatOp : Shape_Op<"concat", []> {
// TODO: Move the code below and witnesses to a different file.
def Shape_AnyOp : Shape_Op<"any", [Commutative,
NoSideEffect,
SameOperandsAndResultType]> {
NoSideEffect]> {
let summary = "Return any combination of the input shapes";
let description = [{
This operation takes multiple input shapes or extent tensors and returns
@ -541,7 +550,6 @@ def Shape_AnyOp : Shape_Op<"any", [Commutative,
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs);
let results = (outs Shape_ShapeOrExtentTensorType:$result);
let assemblyFormat = "$inputs `:` type($result) attr-dict";
let hasFolder = 1;
}

View file

@ -674,6 +674,16 @@ OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
return builder.getIndexAttr(product.getLimitedValue());
}
void NumElementsOp::build(OpBuilder &builder, OperationState &result,
Value shape) {
if (shape.getType().isa<ShapedType>()) {
auto type = builder.getIndexType();
return build(builder, result, type, shape);
}
auto type = SizeType::get(builder.getContext());
return build(builder, result, type, shape);
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//
@ -702,6 +712,38 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
return builder.getIndexTensorAttr(type.getShape());
}
void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
if (arg.getType().isa<ShapedType>()) {
auto type = RankedTensorType::get({ShapedType::kDynamicSize},
builder.getIndexType());
return ShapeOfOp::build(builder, result, type, arg);
}
auto type = ShapeType::get(builder.getContext());
return ShapeOfOp::build(builder, result, type, arg);
}
namespace {
struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
if (!op.arg().getType().isa<ShapedType>())
return failure();
if (op.getType().isa<ShapedType>())
return failure();
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
return success();
}
};
} // namespace
void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
MLIRContext *context) {
patterns.insert<ShapeOfWithTensor>(context);
}
//===----------------------------------------------------------------------===//
// SizeToIndexOp
//===----------------------------------------------------------------------===//

View file

@ -50,7 +50,6 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
// CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
// CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor<?xindex>
%shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
return
}
@ -66,7 +65,6 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
// CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
// CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor<?xindex>
%shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
return
}
@ -120,7 +118,7 @@ func @any_of_three(%a : tensor<?xindex>,
%b : tensor<?xindex>,
%c : tensor<?xindex>) -> tensor<?xindex> {
// CHECK: return %[[A]] : tensor<?xindex>
%result = shape.any %a, %b, %c : tensor<?xindex>
%result = "shape.any"(%a, %b, %c) : (tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
return %result : tensor<?xindex>
}
@ -131,7 +129,7 @@ func @any_of_three(%a : tensor<?xindex>,
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex>
func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> {
// CHECK: return %[[A]] : tensor<?xindex>
%result = shape.any %a : tensor<?xindex>
%result = "shape.any"(%a) : (tensor<?xindex>) -> tensor<?xindex>
return %result : tensor<?xindex>
}

View file

@ -54,7 +54,7 @@ func @f() -> !shape.shape {
// CHECK: shape.const_shape [7, 2] : !shape.shape
%0 = shape.const_shape [1, 2] : !shape.shape
%1 = shape.const_shape [7, 1] : !shape.shape
%2 = shape.broadcast %0, %1
%2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
return %2 : !shape.shape
}
@ -65,7 +65,7 @@ func @f() -> !shape.shape {
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK: return %arg0
%0 = shape.const_shape [] : !shape.shape
%1 = shape.broadcast %arg0, %0
%1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape
return %1 : !shape.shape
}
@ -76,7 +76,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK: return %arg0
%0 = shape.const_shape [] : !shape.shape
%1 = shape.broadcast %0, %arg0
%1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape
return %1 : !shape.shape
}
@ -89,7 +89,7 @@ func @f() -> !shape.shape {
// CHECK: return %[[CST]]
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
%2 = shape.broadcast %0, %1
%2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
return %2 : !shape.shape
}
@ -101,7 +101,7 @@ func @f() -> !shape.shape {
// CHECK: shape.broadcast
%0 = shape.const_shape [2] : !shape.shape
%1 = shape.const_shape [7] : !shape.shape
%2 = shape.broadcast %0, %1
%2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
return %2 : !shape.shape
}
@ -124,7 +124,7 @@ func @f() -> !shape.shape {
func @f() -> tensor<2xindex> {
// CHECK: constant dense<[0, 1]> : tensor<2xindex>
%cs = shape.const_shape [0, 1] : !shape.shape
%0 = shape.to_extent_tensor %cs : tensor<2xindex>
%0 = shape.to_extent_tensor %cs : !shape.shape -> tensor<2xindex>
return %0 : tensor<2xindex>
}
@ -159,7 +159,7 @@ func @const_size_to_index() -> index {
// CHECK-NOT: shape.index_cast
%cs = shape.const_size 123
// CHECK: constant 123 : index
%ci = shape.size_to_index %cs
%ci = shape.size_to_index %cs : !shape.size
return %ci : index
}
@ -185,7 +185,7 @@ func @const_index_to_size_to_index() -> index {
%cs0 = shape.index_to_size %ci0
// CHECK: %[[CI:.*]] = constant 123 : index
// CHECK-NEXT: return %[[CI]] : index
%ci1 = shape.size_to_index %cs0
%ci1 = shape.size_to_index %cs0 : !shape.size
return %ci1 : index
}
@ -195,7 +195,7 @@ func @const_index_to_size_to_index() -> index {
// CHECK-LABEL: func @nonfoldable_size_to_index
func @nonfoldable_size_to_index(%cs : !shape.size) -> index {
// CHECK: shape.size_to_index
%ci = shape.size_to_index %cs
%ci = shape.size_to_index %cs : !shape.size
return %ci : index
}
@ -403,7 +403,7 @@ func @f(%arg : !shape.shape) -> !shape.shape {
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape
// CHECK-NEXT: return %[[CS]]
%0 = shape.const_shape [2, 3, 4] : !shape.shape
%1 = shape.any %0, %arg : !shape.shape
%1 = "shape.any"(%0, %arg) : (!shape.shape, !shape.shape) -> !shape.shape
return %1 : !shape.shape
}
@ -415,7 +415,7 @@ func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
// CHECK-NEXT: return %[[CS]] : tensor<?xindex>
%0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
%1 = shape.any %0, %arg : tensor<?xindex>
%1 = "shape.any"(%0, %arg) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
return %1 : tensor<?xindex>
}
@ -424,9 +424,9 @@ func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
// Folding of any with partially constant operands is not yet implemented.
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
// CHECK-NEXT: %[[CS:.*]] = shape.any
// CHECK-NEXT: %[[CS:.*]] = "shape.any"
// CHECK-NEXT: return %[[CS]]
%1 = shape.any %arg0, %arg1 : !shape.shape
%1 = "shape.any"(%arg0, %arg1) : (!shape.shape, !shape.shape) -> !shape.shape
return %1 : !shape.shape
}
@ -619,7 +619,7 @@ func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> index {
func @index_to_size_to_index(%index : index) -> index {
// CHECK: return %[[IDX]] : index
%size = shape.index_to_size %index
%result = shape.size_to_index %size
%result = shape.size_to_index %size : !shape.size
return %result : index
}
@ -630,7 +630,7 @@ func @index_to_size_to_index(%index : index) -> index {
// CHECK-SAME: (%[[SIZE:.*]]: !shape.size) -> !shape.size
func @size_to_index_to_size(%size : !shape.size) -> !shape.size {
// CHECK: return %[[SIZE]] : !shape.size
%idx = shape.size_to_index %size
%idx = shape.size_to_index %size : !shape.size
%result = shape.index_to_size %idx
return %result : !shape.size
}

View file

@ -49,7 +49,7 @@ func @test_shape_num_elements_fixed() {
func @test_broadcast_fixed() {
%0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
%1 = shape.const_shape [4, 57, 92] : !shape.shape
%2 = shape.broadcast %0, %1
%2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
@ -99,7 +99,7 @@ func @test_constraints() {
%w3 = shape.const_witness false
%w4 = shape.assuming_all %w0, %w1, %w2, %w3
shape.assuming %w4 -> !shape.shape {
%2 = shape.any %0, %1 : !shape.shape
%2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
shape.assuming_yield %2 : !shape.shape
}
return
@ -131,7 +131,7 @@ func @const_size() {
}
func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> {
%0 = shape.to_extent_tensor %arg : tensor<3xindex>
%0 = shape.to_extent_tensor %arg : !shape.shape -> tensor<3xindex>
return %0 : tensor<3xindex>
}
@ -188,10 +188,10 @@ func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
func @any() {
%0 = shape.const_shape [1, 2, 3] : !shape.shape
%1 = shape.const_shape [4, 5, 6] : !shape.shape
%2 = shape.any %0, %1 : !shape.shape
%2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
%4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
%5 = shape.any %3, %4 : tensor<?xindex>
%5 = "shape.any"(%3, %4) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
return
}