[mlir][sparse] add int64 storage type to sparse tensor runtime support library
This format was missing from the support library. Although there are some subtleties reading in an external format for int64 as double, there is no good reason to omit support for this data type form the support library. Reviewed By: gussmith23 Differential Revision: https://reviews.llvm.org/D106016
This commit is contained in:
parent
d774b4aa5e
commit
afc760ef35
|
@ -27,7 +27,19 @@ using namespace mlir::sparse_tensor;
|
|||
|
||||
namespace {
|
||||
|
||||
/// Returns internal type encoding for overhead storage.
|
||||
/// Internal encoding of primary storage. Keep this enum consistent
|
||||
/// with the equivalent enum in the sparse runtime support library.
|
||||
enum PrimaryTypeEnum : uint64_t {
|
||||
kF64 = 1,
|
||||
kF32 = 2,
|
||||
kI64 = 3,
|
||||
kI32 = 4,
|
||||
kI16 = 5,
|
||||
kI8 = 6
|
||||
};
|
||||
|
||||
/// Returns internal type encoding for overhead storage. Keep these
|
||||
/// values consistent with the sparse runtime support library.
|
||||
static unsigned getOverheadTypeEncoding(unsigned width) {
|
||||
switch (width) {
|
||||
default:
|
||||
|
@ -41,7 +53,8 @@ static unsigned getOverheadTypeEncoding(unsigned width) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns internal dimension level type encoding.
|
||||
/// Returns internal dimension level type encoding. Keep these
|
||||
/// values consistent with the sparse runtime support library.
|
||||
static unsigned
|
||||
getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
|
||||
switch (dlt) {
|
||||
|
@ -159,15 +172,17 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
|
|||
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
|
||||
unsigned primary;
|
||||
if (eltType.isF64())
|
||||
primary = 1;
|
||||
primary = kF64;
|
||||
else if (eltType.isF32())
|
||||
primary = 2;
|
||||
primary = kF32;
|
||||
else if (eltType.isInteger(64))
|
||||
primary = kI64;
|
||||
else if (eltType.isInteger(32))
|
||||
primary = 3;
|
||||
primary = kI32;
|
||||
else if (eltType.isInteger(16))
|
||||
primary = 4;
|
||||
primary = kI16;
|
||||
else if (eltType.isInteger(8))
|
||||
primary = 5;
|
||||
primary = kI8;
|
||||
else
|
||||
return failure();
|
||||
params.push_back(
|
||||
|
@ -256,6 +271,8 @@ public:
|
|||
name = "sparseValuesF64";
|
||||
else if (eltType.isF32())
|
||||
name = "sparseValuesF32";
|
||||
else if (eltType.isInteger(64))
|
||||
name = "sparseValuesI64";
|
||||
else if (eltType.isInteger(32))
|
||||
name = "sparseValuesI32";
|
||||
else if (eltType.isInteger(16))
|
||||
|
|
|
@ -129,6 +129,7 @@ public:
|
|||
// Primary storage.
|
||||
virtual void getValues(std::vector<double> **) { fatal("valf64"); }
|
||||
virtual void getValues(std::vector<float> **) { fatal("valf32"); }
|
||||
virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
|
||||
virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
|
||||
virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
|
||||
virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
|
||||
|
@ -437,6 +438,7 @@ TEMPLATE(MemRef1DU64, uint64_t);
|
|||
TEMPLATE(MemRef1DU32, uint32_t);
|
||||
TEMPLATE(MemRef1DU16, uint16_t);
|
||||
TEMPLATE(MemRef1DU8, uint8_t);
|
||||
TEMPLATE(MemRef1DI64, int64_t);
|
||||
TEMPLATE(MemRef1DI32, int32_t);
|
||||
TEMPLATE(MemRef1DI16, int16_t);
|
||||
TEMPLATE(MemRef1DI8, int8_t);
|
||||
|
@ -448,9 +450,10 @@ enum OverheadTypeEnum : uint64_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
|
|||
enum PrimaryTypeEnum : uint64_t {
|
||||
kF64 = 1,
|
||||
kF32 = 2,
|
||||
kI32 = 3,
|
||||
kI16 = 4,
|
||||
kI8 = 5
|
||||
kI64 = 3,
|
||||
kI32 = 4,
|
||||
kI16 = 5,
|
||||
kI8 = 6
|
||||
};
|
||||
|
||||
void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata,
|
||||
|
@ -499,6 +502,7 @@ void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata,
|
|||
CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
|
||||
|
||||
// Integral matrices with same overhead storage.
|
||||
CASE(kU64, kU64, kI64, uint64_t, uint64_t, int64_t);
|
||||
CASE(kU64, kU64, kI32, uint64_t, uint64_t, int32_t);
|
||||
CASE(kU64, kU64, kI16, uint64_t, uint64_t, int16_t);
|
||||
CASE(kU64, kU64, kI8, uint64_t, uint64_t, int8_t);
|
||||
|
@ -535,6 +539,7 @@ IMPL2(MemRef1DU16, sparseIndices16, uint16_t, getIndices)
|
|||
IMPL2(MemRef1DU8, sparseIndices8, uint8_t, getIndices)
|
||||
IMPL1(MemRef1DF64, sparseValuesF64, double, getValues)
|
||||
IMPL1(MemRef1DF32, sparseValuesF32, float, getValues)
|
||||
IMPL1(MemRef1DI64, sparseValuesI64, int64_t, getValues)
|
||||
IMPL1(MemRef1DI32, sparseValuesI32, int32_t, getValues)
|
||||
IMPL1(MemRef1DI16, sparseValuesI16, int16_t, getValues)
|
||||
IMPL1(MemRef1DI8, sparseValuesI8, int8_t, getValues)
|
||||
|
|
Loading…
Reference in a new issue