[MLIR][Shape] Generalze shape.const_shape to extent tensors

The operation `shape.const_shape` was used for constants of type shape only.
We can now also use it to create constant extent tensors.

Differential Revision: https://reviews.llvm.org/D84157
This commit is contained in:
Frederik Gossen 2020-07-24 08:05:26 +00:00
parent f7ffb122d0
commit 14d3cef012
4 changed files with 90 additions and 73 deletions

View file

@ -96,18 +96,20 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
}
def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
let summary = "Creates a constant of !shape.shape type";
let summary = "Creates a constant shape or extent tensor";
let description = [{
Creates a !shape.shape with rank given by the length of `shape` and with
dimension sizes given by the values of `shape`.
Creates a constant shape or extent tensor. The individual extents are given
as the `shape` attribute. The number of these values equals the shape's
rank.
```mlir
%0 = shape.const_shape []
%1 = shape.const_shape [1, 2, 3]
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
%2 = shape.const_shape [4, 5, 6] : tensor<?xindex>
```
}];
let arguments = (ins IndexElementsAttr:$shape);
let results = (outs Shape_ShapeType:$result);
let results = (outs Shape_ShapeOrExtentTensorType:$result);
// TODO: Move this to main so that all shape ops implement these.
let printer = [{ return ::print(p, *this); }];

View file

@ -23,6 +23,11 @@ namespace {
#include "ShapeCanonicalization.inc"
}
static RankedTensorType getExtentTensorType(OpBuilder &builder) {
return RankedTensorType::get({ShapedType::kDynamicSize},
builder.getIndexType());
}
ShapeDialect::ShapeDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
@ -40,12 +45,12 @@ ShapeDialect::ShapeDialect(MLIRContext *context)
Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto shapeType = type.dyn_cast<ShapeType>())
if (type.isa<ShapeType>() || type == getExtentTensorType(builder))
return builder.create<ConstShapeOp>(loc, type,
value.cast<DenseIntElementsAttr>());
if (auto sizeType = type.dyn_cast<SizeType>())
if (type.isa<SizeType>())
return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
if (auto witnessType = type.dyn_cast<WitnessType>())
if (type.isa<WitnessType>())
return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
return nullptr;
}
@ -290,7 +295,8 @@ static void print(OpAsmPrinter &p, ConstShapeOp &op) {
p << "[";
interleaveComma(op.shape().getValues<int64_t>(), p,
[&](int64_t i) { p << i; });
p << "]";
p << "] : ";
p.printType(op.getType());
}
static ParseResult parseConstShapeOp(OpAsmParser &parser,
@ -316,8 +322,10 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
}
Builder &builder = parser.getBuilder();
result.addAttribute("shape", builder.getIndexTensorAttr(ints));
result.types.push_back(ShapeType::get(builder.getContext()));
Type resultTy;
if (parser.parseColonType(resultTy))
return failure();
result.types.push_back(resultTy);
return success();
}

View file

@ -2,7 +2,7 @@
// CHECK-LABEL: func @f
func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
// CHECK: shape.const_shape [2, 3, 4]
// CHECK: shape.const_shape [2, 3, 4] : !shape.shape
%0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape
return %0 : !shape.shape
}
@ -12,10 +12,10 @@ func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
// Basic case.
// CHECK-LABEL: func @f
func @f() -> (!shape.shape, !shape.shape) {
// CHECK: shape.const_shape [2, 3]
// CHECK: shape.const_shape [4, 5]
// CHECK: shape.const_shape [2, 3] : !shape.shape
// CHECK: shape.const_shape [4, 5] : !shape.shape
%c2 = constant 2 : i32
%0 = shape.const_shape [2, 3, 4, 5]
%0 = shape.const_shape [2, 3, 4, 5] : !shape.shape
%head, %tail = "shape.split_at"(%0, %c2) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
return %head, %tail : !shape.shape, !shape.shape
@ -26,10 +26,10 @@ func @f() -> (!shape.shape, !shape.shape) {
// Negative split point.
// CHECK-LABEL: func @f
func @f() -> (!shape.shape, !shape.shape) {
// CHECK: shape.const_shape [2, 3, 4]
// CHECK: shape.const_shape [5]
// CHECK: shape.const_shape [2, 3, 4] : !shape.shape
// CHECK: shape.const_shape [5] : !shape.shape
%c-1 = constant -1 : i32
%0 = shape.const_shape [2, 3, 4, 5]
%0 = shape.const_shape [2, 3, 4, 5] : !shape.shape
%head, %tail = "shape.split_at"(%0, %c-1) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
return %head, %tail : !shape.shape, !shape.shape
}
@ -41,7 +41,7 @@ func @f() -> (!shape.shape, !shape.shape) {
func @f() -> (!shape.shape, !shape.shape) {
// CHECK: shape.split_at
%c5 = constant 5 : i32
%0 = shape.const_shape [2, 3, 4, 5]
%0 = shape.const_shape [2, 3, 4, 5] : !shape.shape
%head, %tail = "shape.split_at"(%0, %c5) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
return %head, %tail : !shape.shape, !shape.shape
}
@ -51,9 +51,9 @@ func @f() -> (!shape.shape, !shape.shape) {
// Basic case.
// CHECK-LABEL: func @f
func @f() -> !shape.shape {
// CHECK: shape.const_shape [7, 2]
%0 = shape.const_shape [1, 2]
%1 = shape.const_shape [7, 1]
// CHECK: shape.const_shape [7, 2] : !shape.shape
%0 = shape.const_shape [1, 2] : !shape.shape
%1 = shape.const_shape [7, 1] : !shape.shape
%2 = shape.broadcast %0, %1
return %2 : !shape.shape
}
@ -64,7 +64,7 @@ func @f() -> !shape.shape {
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK: return %arg0
%0 = shape.const_shape []
%0 = shape.const_shape [] : !shape.shape
%1 = shape.broadcast %arg0, %0
return %1 : !shape.shape
}
@ -75,7 +75,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK: return %arg0
%0 = shape.const_shape []
%0 = shape.const_shape [] : !shape.shape
%1 = shape.broadcast %0, %arg0
return %1 : !shape.shape
}
@ -85,10 +85,10 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
// Lhs is a scalar and rhs is constant.
// CHECK-LABEL: func @f
func @f() -> !shape.shape {
// CHECK: %[[CST:.*]] = shape.const_shape [1, 2, 3]
// CHECK: %[[CST:.*]] = shape.const_shape [1, 2, 3] : !shape.shape
// CHECK: return %[[CST]]
%0 = shape.const_shape []
%1 = shape.const_shape [1, 2, 3]
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
%2 = shape.broadcast %0, %1
return %2 : !shape.shape
}
@ -99,8 +99,8 @@ func @f() -> !shape.shape {
// CHECK-LABEL: func @f
func @f() -> !shape.shape {
// CHECK: shape.broadcast
%0 = shape.const_shape [2]
%1 = shape.const_shape [7]
%0 = shape.const_shape [2] : !shape.shape
%1 = shape.const_shape [7] : !shape.shape
%2 = shape.broadcast %0, %1
return %2 : !shape.shape
}
@ -110,9 +110,9 @@ func @f() -> !shape.shape {
// Basic case.
// CHECK-LABEL: func @f
func @f() -> !shape.shape {
// CHECK: shape.const_shape [0, 1, 2, 3]
%lhs = shape.const_shape [0, 1]
%rhs = shape.const_shape [2, 3]
// CHECK: shape.const_shape [0, 1, 2, 3] : !shape.shape
%lhs = shape.const_shape [0, 1] : !shape.shape
%rhs = shape.const_shape [2, 3] : !shape.shape
%0 = shape.concat %lhs, %rhs
return %0 : !shape.shape
}
@ -123,7 +123,7 @@ func @f() -> !shape.shape {
// CHECK-LABEL: func @f
func @f() -> tensor<2xindex> {
// CHECK: constant dense<[0, 1]> : tensor<2xindex>
%cs = shape.const_shape [0, 1]
%cs = shape.const_shape [0, 1] : !shape.shape
%0 = shape.to_extent_tensor %cs : tensor<2xindex>
return %0 : tensor<2xindex>
}
@ -133,7 +133,7 @@ func @f() -> tensor<2xindex> {
// Basic case.
// CHECK-LABEL: func @f()
func @f() -> !shape.shape {
// CHECK: shape.const_shape [3, 5, 11]
// CHECK: shape.const_shape [3, 5, 11] : !shape.shape
%e0 = constant 3 : index
%e1 = constant 5 : index
%e2 = constant 11 : index
@ -215,7 +215,7 @@ func @nonfoldable_index_to_size(%ci : index) -> !shape.size {
// CHECK-LABEL: func @num_elements
func @num_elements() -> !shape.size {
// CHECK-NOT: shape.const_shape
%shape = shape.const_shape [4, 5, 6]
%shape = shape.const_shape [4, 5, 6] : !shape.shape
// CHECK-NOT: shape.num_elements
%num_elements = shape.num_elements %shape
// CHECK: %[[NUM:.*]] = shape.const_size 120
@ -239,7 +239,7 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
// CHECK-LABEL: func @basic
func @basic() -> !shape.size {
// CHECK: shape.const_size 2
%0 = shape.const_shape [0, 1, 2]
%0 = shape.const_shape [0, 1, 2] : !shape.shape
%c2 = shape.const_size 2
%1 = shape.get_extent %0, %c2
return %1 : !shape.size
@ -252,7 +252,7 @@ func @basic() -> !shape.size {
func @out_of_bounds() -> !shape.size {
// CHECK: shape.const_shape
// CHECK: shape.get_extent
%0 = shape.const_shape [0, 1, 2]
%0 = shape.const_shape [0, 1, 2] : !shape.shape
%c3 = shape.const_size 3
%1 = shape.get_extent %0, %c3
return %1 : !shape.size
@ -289,9 +289,9 @@ func @f() {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [0, 1]
%cs1 = shape.const_shape [0, 1]
%cs2 = shape.const_shape [0, 1]
%cs0 = shape.const_shape [0, 1] : !shape.shape
%cs1 = shape.const_shape [0, 1] : !shape.shape
%cs2 = shape.const_shape [0, 1] : !shape.shape
%0 = shape.cstr_eq %cs0, %cs1, %cs2
"consume.witness"(%0) : (!shape.witness) -> ()
return
@ -306,8 +306,8 @@ func @f() {
// CHECK-NEXT: shape.cstr_eq
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [0, 1]
%cs1 = shape.const_shape [3, 1]
%cs0 = shape.const_shape [0, 1] : !shape.shape
%cs1 = shape.const_shape [3, 1] : !shape.shape
%0 = shape.cstr_eq %cs0, %cs1
"consume.witness"(%0) : (!shape.witness) -> ()
return
@ -367,7 +367,7 @@ func @f() {
func @f(%arg0 : !shape.shape) -> !shape.shape {
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape
// CHECK-NEXT: return %[[CS]]
%0 = shape.const_shape [2, 3, 4]
%0 = shape.const_shape [2, 3, 4] : !shape.shape
%1 = shape.any %0, %arg0
return %1 : !shape.shape
}
@ -429,8 +429,8 @@ func @f() {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [3, 1]
%cs1 = shape.const_shape [1, 5]
%cs0 = shape.const_shape [3, 1] : !shape.shape
%cs1 = shape.const_shape [1, 5] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
@ -445,8 +445,8 @@ func @static_non_broadcastable() {
// CHECK-NEXT: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [1, 3]
%cs1 = shape.const_shape [1, 5]
%cs0 = shape.const_shape [1, 3] : !shape.shape
%cs1 = shape.const_shape [1, 5] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
@ -460,7 +460,7 @@ func @f(%arg0 : !shape.shape) {
// CHECK-NEXT: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [1,3]
%cs0 = shape.const_shape [1, 3] : !shape.shape
%0 = shape.cstr_broadcastable %arg0, %cs0 : !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
@ -498,7 +498,7 @@ func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
func @fold_rank() -> !shape.size {
// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5
// CHECK-DAG: return %[[RESULT]] : !shape.size
%shape = shape.const_shape [3, 4, 5, 6, 7]
%shape = shape.const_shape [3, 4, 5, 6, 7] : !shape.shape
%rank = shape.rank %shape : !shape.shape
return %rank : !shape.size
}
@ -571,7 +571,7 @@ func @cstr_broadcastable_scalar(%arg0 : tensor<?xf32>) {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%0 = shape.const_shape []
%0 = shape.const_shape [] : !shape.shape
%1 = shape.shape_of %arg0 : tensor<?xf32>
%2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
"consume.witness"(%2) : (!shape.witness) -> ()
@ -617,9 +617,9 @@ func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<i
func @shape_eq_fold_1() -> i1 {
// CHECK: %[[RESULT:.*]] = constant true
// CHECK: return %[[RESULT]] : i1
%a = shape.const_shape [1, 2, 3]
%b = shape.const_shape [1, 2, 3]
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
%a = shape.const_shape [1, 2, 3] : !shape.shape
%b = shape.const_shape [1, 2, 3] : tensor<?xindex>
%result = shape.shape_eq %a, %b : !shape.shape, tensor<?xindex>
return %result : i1
}
@ -630,9 +630,9 @@ func @shape_eq_fold_1() -> i1 {
func @shape_eq_fold_0() -> i1 {
// CHECK: %[[RESULT:.*]] = constant false
// CHECK: return %[[RESULT]] : i1
%a = shape.const_shape [1, 2, 3]
%b = shape.const_shape [4, 5, 6]
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
%a = shape.const_shape [1, 2, 3] : tensor<?xindex>
%b = shape.const_shape [4, 5, 6] : tensor<?xindex>
%result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
return %result : i1
}
@ -643,8 +643,8 @@ func @shape_eq_fold_0() -> i1 {
func @shape_eq_fold_0() -> i1 {
// CHECK: %[[RESULT:.*]] = constant false
// CHECK: return %[[RESULT]] : i1
%a = shape.const_shape [1, 2, 3, 4, 5, 6]
%b = shape.const_shape [1, 2, 3]
%a = shape.const_shape [1, 2, 3, 4, 5, 6] : !shape.shape
%b = shape.const_shape [1, 2, 3] : !shape.shape
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
return %result : i1
}
@ -658,7 +658,7 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
// CHECK: %[[B:.*]] = shape.const_shape [4, 5, 6]
// CHECK: %[[RESULT:.*]] = shape.shape_eq %[[A]], %[[B]] : !shape.shape, !shape.shape
// CHECK: return %[[RESULT]] : i1
%b = shape.const_shape [4, 5, 6]
%b = shape.const_shape [4, 5, 6] : !shape.shape
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
return %result : i1
}

View file

@ -33,48 +33,55 @@ func @test_shape_num_elements_unknown() {
return
}
func @const_shape() {
%0 = shape.const_shape [1, 2, 3] : !shape.shape
%1 = shape.const_shape [4, 5, 6] : tensor<?xindex>
return
}
func @test_shape_num_elements_fixed() {
%0 = shape.const_shape [1, 57, 92]
%0 = shape.const_shape [1, 57, 92] : !shape.shape
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
%3 = "shape.print"(%1) : (!shape.size) -> !shape.size
return
}
func @test_broadcast_fixed() {
%0 = shape.const_shape [10, 1, 57, 92]
%1 = shape.const_shape [4, 57, 92]
%0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
%1 = shape.const_shape [4, 57, 92] : !shape.shape
%2 = shape.broadcast %0, %1
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_fixed() {
%0 = shape.const_shape [4, 57, 92]
%1 = shape.const_shape [4, 57, 92]
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.const_shape [4, 57, 92] : !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_unknown() {
%0 = shape.const_shape [4, -1, 92]
%1 = shape.const_shape [-1, 57, 92]
%0 = shape.const_shape [4, -1, 92] : !shape.shape
%1 = shape.const_shape [-1, 57, 92] : !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_fixed_mismatch() {
%0 = shape.const_shape [4, 57, 92]
%1 = shape.const_shape [2, 57, 92]
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.const_shape [2, 57, 92] : !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_parse_const_shape() {
%0 = shape.const_shape []
%1 = shape.const_shape [1, 2, 3]
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
%2 = shape.const_shape [1, 2, 3] : tensor<?xindex>
return
}
@ -84,8 +91,8 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
}
func @test_constraints() {
%0 = shape.const_shape []
%1 = shape.const_shape [1, 2, 3]
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
%w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
%w1 = shape.cstr_eq %0, %1
%w2 = shape.const_witness true