Add TensorRankOf for ranked tensor types with specific ranks

This commit adds `TensorRankOf<types, typeNames, ranks>` to specify ranked
tensor types with the specified types and ranks.  For example,
`TensorRankOf<[I32, F32], ["i32", "F32"], [0, 1]>` matches `tensor<i32>`,
`tensor<?xi32>`, `tensor<f32>`, or `tensor<?xf32>`.

PiperOrigin-RevId: 266461256
This commit is contained in:
Logan Chien 2019-08-30 14:53:28 -07:00 committed by A. Unique TensorFlower
parent 140757050b
commit 6b1d7f51ef
3 changed files with 74 additions and 18 deletions

View file

@ -385,15 +385,6 @@ class StaticShapeTensorOf<list<Type> allowedTypes>
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
// Whether a type is a ranked tensor type.
def HasRankPred : CPred<"$_self.cast<ShapedType>().hasRank()">;
// Whether a type is a ranked tensor type with one of the specified ranks.
class HasAnyRankOfPred<list<int> ranks> : And<[
HasRankPred,
Or<!foreach(rank, ranks,
CPred<"$_self.cast<ShapedType>().getRank() == " # rank>)>]>;
def I1Tensor : TensorOf<[I1]>;
def I8Tensor : TensorOf<[I8]>;
def I16Tensor : TensorOf<[I16]>;
@ -405,6 +396,27 @@ def F16Tensor : TensorOf<[F16]>;
def F32Tensor : TensorOf<[F32]>;
def F64Tensor : TensorOf<[F64]>;
// Whether a type is a ranked tensor type.
def HasRankPred : CPred<"$_self.cast<ShapedType>().hasRank()">;
// Whether a type is a ranked tensor type with one of the specified ranks.
class HasAnyRankOfPred<list<int> ranks> : And<[
HasRankPred,
Or<!foreach(rank, ranks,
CPred<"$_self.cast<ShapedType>().getRank() == " # rank>)>]>;
// Ranked tensor type with one of the specified types and ranks.
class TensorRankOf<list<Type> allowedTypes, list<int> ranks> :
Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
StrJoin<!foreach(rank, ranks, rank # "D"), "/">.result # " " #
TensorOf<allowedTypes>.description>;
class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>;
class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>;
class 2DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [2]>;
class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
// Memref type.
// Memrefs are blocks of data with fixed type and rank.

View file

@ -50,10 +50,19 @@ def TakesStaticMemRefOp : TEST_Op<"takes_static_memref"> {
let arguments = (ins AnyStaticShapeMemRef:$x);
}
def I32TensorRank0Or1Op : TEST_Op<"i32_tensor_rank_0_or_1"> {
def NDTensorOfOp : TEST_Op<"nd_tensor_of"> {
let arguments = (ins
Type<And<[I32Tensor.predicate, HasAnyRankOfPred<[0, 1]>]>,
"tensor<i32> or tensor<?xi32>">:$arg0
0DTensorOf<[F32]>:$arg0,
1DTensorOf<[F32]>:$arg1,
2DTensorOf<[I16]>:$arg2,
3DTensorOf<[I16]>:$arg3,
4DTensorOf<[I16]>:$arg4
);
}
def MultiTensorRankOf : TEST_Op<"multi_tensor_rank_of"> {
let arguments = (ins
TensorRankOf<[I8, I32, F32], [0, 1]>:$arg0
);
}

View file

@ -81,17 +81,52 @@ func @nested_tuple_multi_level_wrong_type() {
// -----
func @tensor_has_rank_0_or_1_success(%arg0: tensor<i32>, %arg1: tensor<5xi32>) {
"test.i32_tensor_rank_0_or_1"(%arg0) : (tensor<i32>) -> ()
"test.i32_tensor_rank_0_or_1"(%arg1) : (tensor<5xi32>) -> ()
func @nd_tensor_of_success(%arg0: tensor<f32>, %arg1: tensor<10xf32>, %arg2: tensor<20x30xi16>, %arg3: tensor<40x50x60xi16>, %arg4: tensor<70x80x90x100xi16>) {
"test.nd_tensor_of"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<f32>, tensor<10xf32>, tensor<20x30xi16>, tensor<40x50x60xi16>, tensor<70x80x90x100xi16>) -> ()
return
}
// -----
func @tensor_has_rank_0_or_1_wrong_type(%arg0: tensor<2x2xi32>) {
// expected-error @+1 {{test.i32_tensor_rank_0_or_1' op operand #0 must be tensor<i32> or tensor<?xi32>}}
"test.i32_tensor_rank_0_or_1"(%arg0) : (tensor<2x2xi32>) -> ()
func @nd_tensor_of_success_wrong_type_0d(%arg0: tensor<f32>, %arg1: tensor<10xf32>, %arg2: tensor<20x30xi16>, %arg3: tensor<40x50x60xi16>, %arg4: tensor<70x80x90x100xi32>) {
// expected-error @+1 {{'test.nd_tensor_of' op operand #0 must be 0D tensor of 32-bit float values}}
"test.nd_tensor_of"(%arg1, %arg1, %arg2, %arg3, %arg4) : (tensor<10xf32>, tensor<10xf32>, tensor<20x30xi16>, tensor<40x50x60xi16>, tensor<70x80x90x100xi32>) -> ()
return
}
// -----
func @nd_tensor_of_success_wrong_type_4d(%arg0: tensor<f32>, %arg1: tensor<10xf32>, %arg2: tensor<20x30xi16>, %arg3: tensor<40x50x60xi16>, %arg4: tensor<70x80x90x100xi32>) {
// expected-error @+1 {{'test.nd_tensor_of' op operand #4 must be 4D tensor of 16-bit integer values}}
"test.nd_tensor_of"(%arg0, %arg1, %arg2, %arg3, %arg3) : (tensor<f32>, tensor<10xf32>, tensor<20x30xi16>, tensor<40x50x60xi16>, tensor<40x50x60xi16>) -> ()
return
}
// -----
func @multi_tensor_rank_of_success(%arg0: tensor<i8>, %arg1: tensor<i32>, %arg2: tensor<f32>, %arg3: tensor<1xi8>, %arg4: tensor<1xi32>, %arg5: tensor<1xf32>) {
"test.multi_tensor_rank_of"(%arg0) : (tensor<i8>) -> ()
"test.multi_tensor_rank_of"(%arg1) : (tensor<i32>) -> ()
"test.multi_tensor_rank_of"(%arg2) : (tensor<f32>) -> ()
"test.multi_tensor_rank_of"(%arg3) : (tensor<1xi8>) -> ()
"test.multi_tensor_rank_of"(%arg4) : (tensor<1xi32>) -> ()
"test.multi_tensor_rank_of"(%arg5) : (tensor<1xf32>) -> ()
return
}
// -----
func @multi_tensor_rank_of_wrong_unranked_type(%arg0: tensor<2x2xi8>) {
// expected-error @+1 {{'test.multi_tensor_rank_of' op operand #0 must be 0D/1D tensor of 8-bit integer or 32-bit integer or 32-bit float values}}
"test.multi_tensor_rank_of"(%arg0) : (tensor<2x2xi8>) -> ()
return
}
// -----
func @multi_tensor_rank_of_wrong_element_type(%arg0: tensor<2xi16>) {
// expected-error @+1 {{'test.multi_tensor_rank_of' op operand #0 must be 0D/1D tensor of 8-bit integer or 32-bit integer or 32-bit float values}}
"test.multi_tensor_rank_of"(%arg0) : (tensor<2xi16>) -> ()
return
}