[mlir][sparse] add int64 storage type to sparse tensor runtime support library

This format was missing from the support library. Although there are some
subtleties reading in an external format for int64 as double, there is no
good reason to omit support for this data type form the support library.

Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D106016
This commit is contained in:
Aart Bik 2021-07-15 11:06:40 -07:00
parent d774b4aa5e
commit afc760ef35
2 changed files with 32 additions and 10 deletions

View file

@ -27,7 +27,19 @@ using namespace mlir::sparse_tensor;
namespace {
/// Returns internal type encoding for overhead storage.
/// Internal encoding of primary storage. Keep this enum consistent
/// with the equivalent enum in the sparse runtime support library.
enum PrimaryTypeEnum : uint64_t {
kF64 = 1,
kF32 = 2,
kI64 = 3,
kI32 = 4,
kI16 = 5,
kI8 = 6
};
/// Returns internal type encoding for overhead storage. Keep these
/// values consistent with the sparse runtime support library.
static unsigned getOverheadTypeEncoding(unsigned width) {
switch (width) {
default:
@ -41,7 +53,8 @@ static unsigned getOverheadTypeEncoding(unsigned width) {
}
}
/// Returns internal dimension level type encoding.
/// Returns internal dimension level type encoding. Keep these
/// values consistent with the sparse runtime support library.
static unsigned
getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
switch (dlt) {
@ -159,15 +172,17 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
unsigned primary;
if (eltType.isF64())
primary = 1;
primary = kF64;
else if (eltType.isF32())
primary = 2;
primary = kF32;
else if (eltType.isInteger(64))
primary = kI64;
else if (eltType.isInteger(32))
primary = 3;
primary = kI32;
else if (eltType.isInteger(16))
primary = 4;
primary = kI16;
else if (eltType.isInteger(8))
primary = 5;
primary = kI8;
else
return failure();
params.push_back(
@ -256,6 +271,8 @@ public:
name = "sparseValuesF64";
else if (eltType.isF32())
name = "sparseValuesF32";
else if (eltType.isInteger(64))
name = "sparseValuesI64";
else if (eltType.isInteger(32))
name = "sparseValuesI32";
else if (eltType.isInteger(16))

View file

@ -129,6 +129,7 @@ public:
// Primary storage.
virtual void getValues(std::vector<double> **) { fatal("valf64"); }
virtual void getValues(std::vector<float> **) { fatal("valf32"); }
virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
@ -437,6 +438,7 @@ TEMPLATE(MemRef1DU64, uint64_t);
TEMPLATE(MemRef1DU32, uint32_t);
TEMPLATE(MemRef1DU16, uint16_t);
TEMPLATE(MemRef1DU8, uint8_t);
TEMPLATE(MemRef1DI64, int64_t);
TEMPLATE(MemRef1DI32, int32_t);
TEMPLATE(MemRef1DI16, int16_t);
TEMPLATE(MemRef1DI8, int8_t);
@ -448,9 +450,10 @@ enum OverheadTypeEnum : uint64_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
enum PrimaryTypeEnum : uint64_t {
kF64 = 1,
kF32 = 2,
kI32 = 3,
kI16 = 4,
kI8 = 5
kI64 = 3,
kI32 = 4,
kI16 = 5,
kI8 = 6
};
void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata,
@ -499,6 +502,7 @@ void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata,
CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
// Integral matrices with same overhead storage.
CASE(kU64, kU64, kI64, uint64_t, uint64_t, int64_t);
CASE(kU64, kU64, kI32, uint64_t, uint64_t, int32_t);
CASE(kU64, kU64, kI16, uint64_t, uint64_t, int16_t);
CASE(kU64, kU64, kI8, uint64_t, uint64_t, int8_t);
@ -535,6 +539,7 @@ IMPL2(MemRef1DU16, sparseIndices16, uint16_t, getIndices)
IMPL2(MemRef1DU8, sparseIndices8, uint8_t, getIndices)
IMPL1(MemRef1DF64, sparseValuesF64, double, getValues)
IMPL1(MemRef1DF32, sparseValuesF32, float, getValues)
IMPL1(MemRef1DI64, sparseValuesI64, int64_t, getValues)
IMPL1(MemRef1DI32, sparseValuesI32, int32_t, getValues)
IMPL1(MemRef1DI16, sparseValuesI16, int16_t, getValues)
IMPL1(MemRef1DI8, sparseValuesI8, int8_t, getValues)