[MLIR][Shape] Generalze shape.const_shape
to extent tensors
The operation `shape.const_shape` was used for constants of type shape only. We can now also use it to create constant extent tensors. Differential Revision: https://reviews.llvm.org/D84157
This commit is contained in:
parent
f7ffb122d0
commit
14d3cef012
|
@ -96,18 +96,20 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
|
|||
}
|
||||
|
||||
def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
|
||||
let summary = "Creates a constant of !shape.shape type";
|
||||
let summary = "Creates a constant shape or extent tensor";
|
||||
let description = [{
|
||||
Creates a !shape.shape with rank given by the length of `shape` and with
|
||||
dimension sizes given by the values of `shape`.
|
||||
Creates a constant shape or extent tensor. The individual extents are given
|
||||
as the `shape` attribute. The number of these values equals the shape's
|
||||
rank.
|
||||
|
||||
```mlir
|
||||
%0 = shape.const_shape []
|
||||
%1 = shape.const_shape [1, 2, 3]
|
||||
%0 = shape.const_shape [] : !shape.shape
|
||||
%1 = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%2 = shape.const_shape [4, 5, 6] : tensor<?xindex>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins IndexElementsAttr:$shape);
|
||||
let results = (outs Shape_ShapeType:$result);
|
||||
let results = (outs Shape_ShapeOrExtentTensorType:$result);
|
||||
|
||||
// TODO: Move this to main so that all shape ops implement these.
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
|
|
|
@ -23,6 +23,11 @@ namespace {
|
|||
#include "ShapeCanonicalization.inc"
|
||||
}
|
||||
|
||||
static RankedTensorType getExtentTensorType(OpBuilder &builder) {
|
||||
return RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
builder.getIndexType());
|
||||
}
|
||||
|
||||
ShapeDialect::ShapeDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
|
@ -40,12 +45,12 @@ ShapeDialect::ShapeDialect(MLIRContext *context)
|
|||
Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
if (auto shapeType = type.dyn_cast<ShapeType>())
|
||||
if (type.isa<ShapeType>() || type == getExtentTensorType(builder))
|
||||
return builder.create<ConstShapeOp>(loc, type,
|
||||
value.cast<DenseIntElementsAttr>());
|
||||
if (auto sizeType = type.dyn_cast<SizeType>())
|
||||
if (type.isa<SizeType>())
|
||||
return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
|
||||
if (auto witnessType = type.dyn_cast<WitnessType>())
|
||||
if (type.isa<WitnessType>())
|
||||
return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -290,7 +295,8 @@ static void print(OpAsmPrinter &p, ConstShapeOp &op) {
|
|||
p << "[";
|
||||
interleaveComma(op.shape().getValues<int64_t>(), p,
|
||||
[&](int64_t i) { p << i; });
|
||||
p << "]";
|
||||
p << "] : ";
|
||||
p.printType(op.getType());
|
||||
}
|
||||
|
||||
static ParseResult parseConstShapeOp(OpAsmParser &parser,
|
||||
|
@ -316,8 +322,10 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
|
|||
}
|
||||
Builder &builder = parser.getBuilder();
|
||||
result.addAttribute("shape", builder.getIndexTensorAttr(ints));
|
||||
|
||||
result.types.push_back(ShapeType::get(builder.getContext()));
|
||||
Type resultTy;
|
||||
if (parser.parseColonType(resultTy))
|
||||
return failure();
|
||||
result.types.push_back(resultTy);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
|
||||
// CHECK: shape.const_shape [2, 3, 4]
|
||||
// CHECK: shape.const_shape [2, 3, 4] : !shape.shape
|
||||
%0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape
|
||||
return %0 : !shape.shape
|
||||
}
|
||||
|
@ -12,10 +12,10 @@ func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
|
|||
// Basic case.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() -> (!shape.shape, !shape.shape) {
|
||||
// CHECK: shape.const_shape [2, 3]
|
||||
// CHECK: shape.const_shape [4, 5]
|
||||
// CHECK: shape.const_shape [2, 3] : !shape.shape
|
||||
// CHECK: shape.const_shape [4, 5] : !shape.shape
|
||||
%c2 = constant 2 : i32
|
||||
%0 = shape.const_shape [2, 3, 4, 5]
|
||||
%0 = shape.const_shape [2, 3, 4, 5] : !shape.shape
|
||||
%head, %tail = "shape.split_at"(%0, %c2) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
|
||||
return %head, %tail : !shape.shape, !shape.shape
|
||||
|
||||
|
@ -26,10 +26,10 @@ func @f() -> (!shape.shape, !shape.shape) {
|
|||
// Negative split point.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() -> (!shape.shape, !shape.shape) {
|
||||
// CHECK: shape.const_shape [2, 3, 4]
|
||||
// CHECK: shape.const_shape [5]
|
||||
// CHECK: shape.const_shape [2, 3, 4] : !shape.shape
|
||||
// CHECK: shape.const_shape [5] : !shape.shape
|
||||
%c-1 = constant -1 : i32
|
||||
%0 = shape.const_shape [2, 3, 4, 5]
|
||||
%0 = shape.const_shape [2, 3, 4, 5] : !shape.shape
|
||||
%head, %tail = "shape.split_at"(%0, %c-1) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
|
||||
return %head, %tail : !shape.shape, !shape.shape
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ func @f() -> (!shape.shape, !shape.shape) {
|
|||
func @f() -> (!shape.shape, !shape.shape) {
|
||||
// CHECK: shape.split_at
|
||||
%c5 = constant 5 : i32
|
||||
%0 = shape.const_shape [2, 3, 4, 5]
|
||||
%0 = shape.const_shape [2, 3, 4, 5] : !shape.shape
|
||||
%head, %tail = "shape.split_at"(%0, %c5) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
|
||||
return %head, %tail : !shape.shape, !shape.shape
|
||||
}
|
||||
|
@ -51,9 +51,9 @@ func @f() -> (!shape.shape, !shape.shape) {
|
|||
// Basic case.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() -> !shape.shape {
|
||||
// CHECK: shape.const_shape [7, 2]
|
||||
%0 = shape.const_shape [1, 2]
|
||||
%1 = shape.const_shape [7, 1]
|
||||
// CHECK: shape.const_shape [7, 2] : !shape.shape
|
||||
%0 = shape.const_shape [1, 2] : !shape.shape
|
||||
%1 = shape.const_shape [7, 1] : !shape.shape
|
||||
%2 = shape.broadcast %0, %1
|
||||
return %2 : !shape.shape
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ func @f() -> !shape.shape {
|
|||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0 : !shape.shape) -> !shape.shape {
|
||||
// CHECK: return %arg0
|
||||
%0 = shape.const_shape []
|
||||
%0 = shape.const_shape [] : !shape.shape
|
||||
%1 = shape.broadcast %arg0, %0
|
||||
return %1 : !shape.shape
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
|
|||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0 : !shape.shape) -> !shape.shape {
|
||||
// CHECK: return %arg0
|
||||
%0 = shape.const_shape []
|
||||
%0 = shape.const_shape [] : !shape.shape
|
||||
%1 = shape.broadcast %0, %arg0
|
||||
return %1 : !shape.shape
|
||||
}
|
||||
|
@ -85,10 +85,10 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
|
|||
// Lhs is a scalar and rhs is constant.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() -> !shape.shape {
|
||||
// CHECK: %[[CST:.*]] = shape.const_shape [1, 2, 3]
|
||||
// CHECK: %[[CST:.*]] = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = shape.const_shape []
|
||||
%1 = shape.const_shape [1, 2, 3]
|
||||
%0 = shape.const_shape [] : !shape.shape
|
||||
%1 = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%2 = shape.broadcast %0, %1
|
||||
return %2 : !shape.shape
|
||||
}
|
||||
|
@ -99,8 +99,8 @@ func @f() -> !shape.shape {
|
|||
// CHECK-LABEL: func @f
|
||||
func @f() -> !shape.shape {
|
||||
// CHECK: shape.broadcast
|
||||
%0 = shape.const_shape [2]
|
||||
%1 = shape.const_shape [7]
|
||||
%0 = shape.const_shape [2] : !shape.shape
|
||||
%1 = shape.const_shape [7] : !shape.shape
|
||||
%2 = shape.broadcast %0, %1
|
||||
return %2 : !shape.shape
|
||||
}
|
||||
|
@ -110,9 +110,9 @@ func @f() -> !shape.shape {
|
|||
// Basic case.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() -> !shape.shape {
|
||||
// CHECK: shape.const_shape [0, 1, 2, 3]
|
||||
%lhs = shape.const_shape [0, 1]
|
||||
%rhs = shape.const_shape [2, 3]
|
||||
// CHECK: shape.const_shape [0, 1, 2, 3] : !shape.shape
|
||||
%lhs = shape.const_shape [0, 1] : !shape.shape
|
||||
%rhs = shape.const_shape [2, 3] : !shape.shape
|
||||
%0 = shape.concat %lhs, %rhs
|
||||
return %0 : !shape.shape
|
||||
}
|
||||
|
@ -123,7 +123,7 @@ func @f() -> !shape.shape {
|
|||
// CHECK-LABEL: func @f
|
||||
func @f() -> tensor<2xindex> {
|
||||
// CHECK: constant dense<[0, 1]> : tensor<2xindex>
|
||||
%cs = shape.const_shape [0, 1]
|
||||
%cs = shape.const_shape [0, 1] : !shape.shape
|
||||
%0 = shape.to_extent_tensor %cs : tensor<2xindex>
|
||||
return %0 : tensor<2xindex>
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ func @f() -> tensor<2xindex> {
|
|||
// Basic case.
|
||||
// CHECK-LABEL: func @f()
|
||||
func @f() -> !shape.shape {
|
||||
// CHECK: shape.const_shape [3, 5, 11]
|
||||
// CHECK: shape.const_shape [3, 5, 11] : !shape.shape
|
||||
%e0 = constant 3 : index
|
||||
%e1 = constant 5 : index
|
||||
%e2 = constant 11 : index
|
||||
|
@ -215,7 +215,7 @@ func @nonfoldable_index_to_size(%ci : index) -> !shape.size {
|
|||
// CHECK-LABEL: func @num_elements
|
||||
func @num_elements() -> !shape.size {
|
||||
// CHECK-NOT: shape.const_shape
|
||||
%shape = shape.const_shape [4, 5, 6]
|
||||
%shape = shape.const_shape [4, 5, 6] : !shape.shape
|
||||
// CHECK-NOT: shape.num_elements
|
||||
%num_elements = shape.num_elements %shape
|
||||
// CHECK: %[[NUM:.*]] = shape.const_size 120
|
||||
|
@ -239,7 +239,7 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
|
|||
// CHECK-LABEL: func @basic
|
||||
func @basic() -> !shape.size {
|
||||
// CHECK: shape.const_size 2
|
||||
%0 = shape.const_shape [0, 1, 2]
|
||||
%0 = shape.const_shape [0, 1, 2] : !shape.shape
|
||||
%c2 = shape.const_size 2
|
||||
%1 = shape.get_extent %0, %c2
|
||||
return %1 : !shape.size
|
||||
|
@ -252,7 +252,7 @@ func @basic() -> !shape.size {
|
|||
func @out_of_bounds() -> !shape.size {
|
||||
// CHECK: shape.const_shape
|
||||
// CHECK: shape.get_extent
|
||||
%0 = shape.const_shape [0, 1, 2]
|
||||
%0 = shape.const_shape [0, 1, 2] : !shape.shape
|
||||
%c3 = shape.const_size 3
|
||||
%1 = shape.get_extent %0, %c3
|
||||
return %1 : !shape.size
|
||||
|
@ -289,9 +289,9 @@ func @f() {
|
|||
// CHECK-NEXT: shape.const_witness true
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [0, 1]
|
||||
%cs1 = shape.const_shape [0, 1]
|
||||
%cs2 = shape.const_shape [0, 1]
|
||||
%cs0 = shape.const_shape [0, 1] : !shape.shape
|
||||
%cs1 = shape.const_shape [0, 1] : !shape.shape
|
||||
%cs2 = shape.const_shape [0, 1] : !shape.shape
|
||||
%0 = shape.cstr_eq %cs0, %cs1, %cs2
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
|
@ -306,8 +306,8 @@ func @f() {
|
|||
// CHECK-NEXT: shape.cstr_eq
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [0, 1]
|
||||
%cs1 = shape.const_shape [3, 1]
|
||||
%cs0 = shape.const_shape [0, 1] : !shape.shape
|
||||
%cs1 = shape.const_shape [3, 1] : !shape.shape
|
||||
%0 = shape.cstr_eq %cs0, %cs1
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
|
@ -367,7 +367,7 @@ func @f() {
|
|||
func @f(%arg0 : !shape.shape) -> !shape.shape {
|
||||
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape
|
||||
// CHECK-NEXT: return %[[CS]]
|
||||
%0 = shape.const_shape [2, 3, 4]
|
||||
%0 = shape.const_shape [2, 3, 4] : !shape.shape
|
||||
%1 = shape.any %0, %arg0
|
||||
return %1 : !shape.shape
|
||||
}
|
||||
|
@ -429,8 +429,8 @@ func @f() {
|
|||
// CHECK-NEXT: shape.const_witness true
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [3, 1]
|
||||
%cs1 = shape.const_shape [1, 5]
|
||||
%cs0 = shape.const_shape [3, 1] : !shape.shape
|
||||
%cs1 = shape.const_shape [1, 5] : !shape.shape
|
||||
%0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
|
@ -445,8 +445,8 @@ func @static_non_broadcastable() {
|
|||
// CHECK-NEXT: shape.cstr_broadcastable
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [1, 3]
|
||||
%cs1 = shape.const_shape [1, 5]
|
||||
%cs0 = shape.const_shape [1, 3] : !shape.shape
|
||||
%cs1 = shape.const_shape [1, 5] : !shape.shape
|
||||
%0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
|
@ -460,7 +460,7 @@ func @f(%arg0 : !shape.shape) {
|
|||
// CHECK-NEXT: shape.cstr_broadcastable
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [1,3]
|
||||
%cs0 = shape.const_shape [1, 3] : !shape.shape
|
||||
%0 = shape.cstr_broadcastable %arg0, %cs0 : !shape.shape, !shape.shape
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
|
@ -498,7 +498,7 @@ func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
|
|||
func @fold_rank() -> !shape.size {
|
||||
// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5
|
||||
// CHECK-DAG: return %[[RESULT]] : !shape.size
|
||||
%shape = shape.const_shape [3, 4, 5, 6, 7]
|
||||
%shape = shape.const_shape [3, 4, 5, 6, 7] : !shape.shape
|
||||
%rank = shape.rank %shape : !shape.shape
|
||||
return %rank : !shape.size
|
||||
}
|
||||
|
@ -571,7 +571,7 @@ func @cstr_broadcastable_scalar(%arg0 : tensor<?xf32>) {
|
|||
// CHECK-NEXT: shape.const_witness true
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%0 = shape.const_shape []
|
||||
%0 = shape.const_shape [] : !shape.shape
|
||||
%1 = shape.shape_of %arg0 : tensor<?xf32>
|
||||
%2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
|
||||
"consume.witness"(%2) : (!shape.witness) -> ()
|
||||
|
@ -617,9 +617,9 @@ func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<i
|
|||
func @shape_eq_fold_1() -> i1 {
|
||||
// CHECK: %[[RESULT:.*]] = constant true
|
||||
// CHECK: return %[[RESULT]] : i1
|
||||
%a = shape.const_shape [1, 2, 3]
|
||||
%b = shape.const_shape [1, 2, 3]
|
||||
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
|
||||
%a = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%b = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
%result = shape.shape_eq %a, %b : !shape.shape, tensor<?xindex>
|
||||
return %result : i1
|
||||
}
|
||||
|
||||
|
@ -630,9 +630,9 @@ func @shape_eq_fold_1() -> i1 {
|
|||
func @shape_eq_fold_0() -> i1 {
|
||||
// CHECK: %[[RESULT:.*]] = constant false
|
||||
// CHECK: return %[[RESULT]] : i1
|
||||
%a = shape.const_shape [1, 2, 3]
|
||||
%b = shape.const_shape [4, 5, 6]
|
||||
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
|
||||
%a = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
%b = shape.const_shape [4, 5, 6] : tensor<?xindex>
|
||||
%result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
|
||||
return %result : i1
|
||||
}
|
||||
|
||||
|
@ -643,8 +643,8 @@ func @shape_eq_fold_0() -> i1 {
|
|||
func @shape_eq_fold_0() -> i1 {
|
||||
// CHECK: %[[RESULT:.*]] = constant false
|
||||
// CHECK: return %[[RESULT]] : i1
|
||||
%a = shape.const_shape [1, 2, 3, 4, 5, 6]
|
||||
%b = shape.const_shape [1, 2, 3]
|
||||
%a = shape.const_shape [1, 2, 3, 4, 5, 6] : !shape.shape
|
||||
%b = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
|
||||
return %result : i1
|
||||
}
|
||||
|
@ -658,7 +658,7 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
|
|||
// CHECK: %[[B:.*]] = shape.const_shape [4, 5, 6]
|
||||
// CHECK: %[[RESULT:.*]] = shape.shape_eq %[[A]], %[[B]] : !shape.shape, !shape.shape
|
||||
// CHECK: return %[[RESULT]] : i1
|
||||
%b = shape.const_shape [4, 5, 6]
|
||||
%b = shape.const_shape [4, 5, 6] : !shape.shape
|
||||
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
|
||||
return %result : i1
|
||||
}
|
||||
|
|
|
@ -33,48 +33,55 @@ func @test_shape_num_elements_unknown() {
|
|||
return
|
||||
}
|
||||
|
||||
func @const_shape() {
|
||||
%0 = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%1 = shape.const_shape [4, 5, 6] : tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
||||
func @test_shape_num_elements_fixed() {
|
||||
%0 = shape.const_shape [1, 57, 92]
|
||||
%0 = shape.const_shape [1, 57, 92] : !shape.shape
|
||||
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
|
||||
%3 = "shape.print"(%1) : (!shape.size) -> !shape.size
|
||||
return
|
||||
}
|
||||
|
||||
func @test_broadcast_fixed() {
|
||||
%0 = shape.const_shape [10, 1, 57, 92]
|
||||
%1 = shape.const_shape [4, 57, 92]
|
||||
%0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
|
||||
%1 = shape.const_shape [4, 57, 92] : !shape.shape
|
||||
%2 = shape.broadcast %0, %1
|
||||
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
||||
return
|
||||
}
|
||||
|
||||
func @test_shape_any_fixed() {
|
||||
%0 = shape.const_shape [4, 57, 92]
|
||||
%1 = shape.const_shape [4, 57, 92]
|
||||
%0 = shape.const_shape [4, 57, 92] : !shape.shape
|
||||
%1 = shape.const_shape [4, 57, 92] : !shape.shape
|
||||
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
||||
return
|
||||
}
|
||||
|
||||
func @test_shape_any_unknown() {
|
||||
%0 = shape.const_shape [4, -1, 92]
|
||||
%1 = shape.const_shape [-1, 57, 92]
|
||||
%0 = shape.const_shape [4, -1, 92] : !shape.shape
|
||||
%1 = shape.const_shape [-1, 57, 92] : !shape.shape
|
||||
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
||||
return
|
||||
}
|
||||
|
||||
func @test_shape_any_fixed_mismatch() {
|
||||
%0 = shape.const_shape [4, 57, 92]
|
||||
%1 = shape.const_shape [2, 57, 92]
|
||||
%0 = shape.const_shape [4, 57, 92] : !shape.shape
|
||||
%1 = shape.const_shape [2, 57, 92] : !shape.shape
|
||||
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
||||
return
|
||||
}
|
||||
|
||||
func @test_parse_const_shape() {
|
||||
%0 = shape.const_shape []
|
||||
%1 = shape.const_shape [1, 2, 3]
|
||||
%0 = shape.const_shape [] : !shape.shape
|
||||
%1 = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%2 = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -84,8 +91,8 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
|
|||
}
|
||||
|
||||
func @test_constraints() {
|
||||
%0 = shape.const_shape []
|
||||
%1 = shape.const_shape [1, 2, 3]
|
||||
%0 = shape.const_shape [] : !shape.shape
|
||||
%1 = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
|
||||
%w1 = shape.cstr_eq %0, %1
|
||||
%w2 = shape.const_witness true
|
||||
|
|
Loading…
Reference in a new issue