[mlir] Centralize handling of memref element types.

This also beefs up the test coverage:
- Make unranked memref testing consistent with ranked memrefs.
- Add testing for the invalid element type cases.

This is not quite NFC: index types are now allowed in unranked memrefs.

Differential Revision: https://reviews.llvm.org/D85541
This commit is contained in:
Sean Silva 2020-08-07 11:40:58 -07:00
parent a97dfdc30b
commit b0d76f454d
4 changed files with 21 additions and 6 deletions

View file

@ -426,6 +426,11 @@ class BaseMemRefType : public ShapedType {
public:
using ShapedType::ShapedType;
/// Return true if the specified element type is ok in a memref.
static bool isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() || type.isa<VectorType, ComplexType>();
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
};

View file

@ -408,9 +408,7 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
Optional<Location> location) {
auto *context = elementType.getContext();
// Check that memref is formed from allowed types.
if (!elementType.isIntOrIndexOrFloat() &&
!elementType.isa<VectorType, ComplexType>())
if (!BaseMemRefType::isValidElementType(elementType))
return emitOptionalError(location, "invalid memref element type"),
MemRefType();
@ -486,9 +484,7 @@ unsigned UnrankedMemRefType::getMemorySpace() const {
LogicalResult
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned memorySpace) {
// Check that memref is formed from allowed types.
if (!elementType.isIntOrFloat() &&
!elementType.isa<VectorType, ComplexType>())
if (!BaseMemRefType::isValidElementType(elementType))
return emitError(loc, "invalid memref element type");
return success();
}

View file

@ -17,6 +17,14 @@ func @nestedtensor(tensor<tensor<i8>>) -> () // expected-error {{invalid tensor
// -----
func @illegalmemrefelementtype(memref<?xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
// -----
func @illegalunrankedmemrefelementtype(memref<*xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
// -----
func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements must be int or float type}}
// -----

View file

@ -152,6 +152,12 @@ func @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
// CHECK: func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
// CHECK: func @unranked_memref_with_index_elems(memref<*xindex>)
func @unranked_memref_with_index_elems(memref<*xindex>)
// CHECK: func @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
func @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
// CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())