[mlir][python] 8b/16b DenseIntElements access

This extends dense attribute element access to support 8b and 16b ints.
Also extends the corresponding parts of the C api.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D117731
This commit is contained in:
Rahul Kayaith 2022-01-21 05:21:00 +00:00 committed by Mehdi Amini
parent 26167cae45
commit 308d8b8c66
5 changed files with 91 additions and 0 deletions

View file

@ -355,6 +355,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt8Get(
MlirType shapedType, intptr_t numElements, const uint8_t *elements);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt8Get(
MlirType shapedType, intptr_t numElements, const int8_t *elements);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt16Get(
MlirType shapedType, intptr_t numElements, const uint16_t *elements);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt16Get(
MlirType shapedType, intptr_t numElements, const int16_t *elements);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32Get(
MlirType shapedType, intptr_t numElements, const uint32_t *elements);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32Get(
@ -416,6 +420,10 @@ MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED uint8_t
mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos);
MLIR_CAPI_EXPORTED int16_t
mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos);
MLIR_CAPI_EXPORTED uint16_t
mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos);
MLIR_CAPI_EXPORTED int32_t
mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos);
MLIR_CAPI_EXPORTED uint32_t

View file

@ -673,6 +673,12 @@ public:
if (width == 1) {
return mlirDenseElementsAttrGetBoolValue(*this, pos);
}
if (width == 8) {
return mlirDenseElementsAttrGetUInt8Value(*this, pos);
}
if (width == 16) {
return mlirDenseElementsAttrGetUInt16Value(*this, pos);
}
if (width == 32) {
return mlirDenseElementsAttrGetUInt32Value(*this, pos);
}
@ -683,6 +689,12 @@ public:
if (width == 1) {
return mlirDenseElementsAttrGetBoolValue(*this, pos);
}
if (width == 8) {
return mlirDenseElementsAttrGetInt8Value(*this, pos);
}
if (width == 16) {
return mlirDenseElementsAttrGetInt16Value(*this, pos);
}
if (width == 32) {
return mlirDenseElementsAttrGetInt32Value(*this, pos);
}

View file

@ -426,6 +426,16 @@ MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
const int8_t *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType,
intptr_t numElements,
const uint16_t *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType,
intptr_t numElements,
const int16_t *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
intptr_t numElements,
const uint32_t *elements) {
@ -530,6 +540,12 @@ int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>()[pos];
}
int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseElementsAttr>().getValues<int16_t>()[pos];
}
uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseElementsAttr>().getValues<uint16_t>()[pos];
}
int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>()[pos];
}

View file

@ -904,6 +904,8 @@ int printBuiltinAttributes(MlirContext ctx) {
int bools[] = {0, 1};
uint8_t uints8[] = {0u, 1u};
int8_t ints8[] = {0, 1};
uint16_t uints16[] = {0u, 1u};
int16_t ints16[] = {0, 1};
uint32_t uints32[] = {0u, 1u};
int32_t ints32[] = {0, 1};
uint64_t uints64[] = {0u, 1u};
@ -921,6 +923,13 @@ int printBuiltinAttributes(MlirContext ctx) {
MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
2, ints8);
MlirAttribute uint16Elements = mlirDenseElementsAttrUInt16Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 16),
encoding),
2, uints16);
MlirAttribute int16Elements = mlirDenseElementsAttrInt16Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 16), encoding),
2, ints16);
MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
encoding),
@ -956,6 +965,8 @@ int printBuiltinAttributes(MlirContext ctx) {
if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
mlirDenseElementsAttrGetUInt16Value(uint16Elements, 1) != 1 ||
mlirDenseElementsAttrGetInt16Value(int16Elements, 1) != 1 ||
mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||

View file

@ -292,6 +292,50 @@ def testDenseIntAttr():
print(ShapedType(a.type).element_type)
# CHECK-LABEL: TEST: testDenseIntAttrGetItem
@run
def testDenseIntAttrGetItem():
def print_item(attr_asm):
attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
dtype = ShapedType(attr.type).element_type
try:
item = attr[0]
print(f"{dtype}:", item)
except TypeError as e:
print(f"{dtype}:", e)
with Context():
# CHECK: i1: 1
print_item("dense<true> : tensor<i1>")
# CHECK: i8: 123
print_item("dense<123> : tensor<i8>")
# CHECK: i16: 123
print_item("dense<123> : tensor<i16>")
# CHECK: i32: 123
print_item("dense<123> : tensor<i32>")
# CHECK: i64: 123
print_item("dense<123> : tensor<i64>")
# CHECK: ui8: 123
print_item("dense<123> : tensor<ui8>")
# CHECK: ui16: 123
print_item("dense<123> : tensor<ui16>")
# CHECK: ui32: 123
print_item("dense<123> : tensor<ui32>")
# CHECK: ui64: 123
print_item("dense<123> : tensor<ui64>")
# CHECK: si8: -123
print_item("dense<-123> : tensor<si8>")
# CHECK: si16: -123
print_item("dense<-123> : tensor<si16>")
# CHECK: si32: -123
print_item("dense<-123> : tensor<si32>")
# CHECK: si64: -123
print_item("dense<-123> : tensor<si64>")
# CHECK: i7: Unsupported integer type
print_item("dense<123> : tensor<i7>")
# CHECK-LABEL: TEST: testDenseFPAttr
@run
def testDenseFPAttr():