[mlir][sparse] Enhancing sparse=>sparse conversion.

Fixes: https://github.com/llvm/llvm-project/issues/51652

Depends On D122060

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D122061
This commit is contained in:
wren romano 2022-05-11 16:05:13 -07:00
parent e0c3b94c80
commit 8cb332406c
5 changed files with 539 additions and 28 deletions

View file

@ -355,6 +355,32 @@ static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc,
builder.create<memref::StoreOp>(loc, elemV, tensor, ivs);
}
/// Determine if the runtime library supports direct conversion to the
/// given target `dimTypes`.
static bool canUseDirectConversion(
ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes) {
bool alreadyCompressed = false;
for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) {
switch (dimTypes[r]) {
case SparseTensorEncodingAttr::DimLevelType::Compressed:
if (alreadyCompressed)
return false; // Multiple compressed dimensions not yet supported.
alreadyCompressed = true;
break;
case SparseTensorEncodingAttr::DimLevelType::Dense:
if (alreadyCompressed)
return false; // Dense after Compressed not yet supported.
break;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
// Although Singleton isn't generally supported yet, the direct
// conversion method doesn't have any particular problems with
// singleton after compressed.
break;
}
}
return true;
}
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
@ -492,21 +518,41 @@ public:
SmallVector<Value, 8> params;
ShapedType stp = srcType.cast<ShapedType>();
sizesFromPtr(rewriter, sizes, op, encSrc, stp, src);
// Set up encoding with right mix of src and dst so that the two
// method calls can share most parameters, while still providing
// the correct sparsity information to either of them.
auto enc = SparseTensorEncodingAttr::get(
op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src);
Value coo = genNewCall(rewriter, op, params);
params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = coo;
Value dst = genNewCall(rewriter, op, params);
genDelCOOCall(rewriter, op, stp.getElementType(), coo);
rewriter.replaceOp(op, dst);
bool useDirectConversion;
switch (options.sparseToSparseStrategy) {
case SparseToSparseConversionStrategy::kViaCOO:
useDirectConversion = false;
break;
case SparseToSparseConversionStrategy::kDirect:
useDirectConversion = true;
assert(canUseDirectConversion(encDst.getDimLevelType()) &&
"Unsupported target for direct sparse-to-sparse conversion");
break;
case SparseToSparseConversionStrategy::kAuto:
useDirectConversion = canUseDirectConversion(encDst.getDimLevelType());
break;
}
if (useDirectConversion) {
newParams(rewriter, params, op, stp, encDst, Action::kSparseToSparse,
sizes, src);
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
} else { // use via-COO conversion.
// Set up encoding with right mix of src and dst so that the two
// method calls can share most parameters, while still providing
// the correct sparsity information to either of them.
auto enc = SparseTensorEncodingAttr::get(
op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
newParams(rewriter, params, op, stp, enc, Action::kToCOO, sizes, src);
Value coo = genNewCall(rewriter, op, params);
params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = coo;
Value dst = genNewCall(rewriter, op, params);
genDelCOOCall(rewriter, op, stp.getElementType(), coo);
rewriter.replaceOp(op, dst);
}
return success();
}
if (!encDst && encSrc) {

View file

@ -84,6 +84,25 @@ static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
return lhs * rhs;
}
// TODO: adjust this so it can be used by `openSparseTensorCOO` too.
// That version doesn't have the permutation, and the `sizes` are
// a pointer/C-array rather than `std::vector`.
//
/// Asserts that the `sizes` (in target-order) under the `perm` (mapping
/// semantic-order to target-order) are a refinement of the desired `shape`
/// (in semantic-order).
///
/// Precondition: `perm` and `shape` must be valid for `rank`.
static inline void
assertPermutedSizesMatchShape(const std::vector<uint64_t> &sizes, uint64_t rank,
const uint64_t *perm, const uint64_t *shape) {
assert(perm && shape);
assert(rank == sizes.size() && "Rank mismatch");
for (uint64_t r = 0; r < rank; r++)
assert((shape[r] == 0 || shape[r] == sizes[perm[r]]) &&
"Dimension size mismatch");
}
/// A sparse tensor element in coordinate scheme (value and indices).
/// For example, a rank-1 vector element would look like
/// ({i}, a[i])
@ -215,6 +234,10 @@ private:
unsigned iteratorPos = 0;
};
// Forward.
template <typename V>
class SparseTensorEnumeratorBase;
/// Abstract base class for `SparseTensorStorage<P,I,V>`. This class
/// takes responsibility for all the `<P,I,V>`-independent aspects
/// of the tensor (e.g., shape, sparsity, permutation). In addition,
@ -274,6 +297,40 @@ public:
return (dimTypes[d] == DimLevelType::kCompressed);
}
/// Allocate a new enumerator.
virtual void newEnumerator(SparseTensorEnumeratorBase<double> **, uint64_t,
const uint64_t *) const {
fatal("enumf64");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<float> **, uint64_t,
const uint64_t *) const {
fatal("enumf32");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<int64_t> **, uint64_t,
const uint64_t *) const {
fatal("enumi64");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<int32_t> **, uint64_t,
const uint64_t *) const {
fatal("enumi32");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<int16_t> **, uint64_t,
const uint64_t *) const {
fatal("enumi16");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<int8_t> **, uint64_t,
const uint64_t *) const {
fatal("enumi8");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<complex64> **, uint64_t,
const uint64_t *) const {
fatal("enumc64");
}
virtual void newEnumerator(SparseTensorEnumeratorBase<complex32> **, uint64_t,
const uint64_t *) const {
fatal("enumc32");
}
/// Overhead storage.
virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
@ -368,6 +425,17 @@ class SparseTensorEnumerator;
/// and annotations to implement all required setup in a general manner.
template <typename P, typename I, typename V>
class SparseTensorStorage : public SparseTensorStorageBase {
/// Private constructor to share code between the other constructors.
/// Beware that the object is not necessarily guaranteed to be in a
/// valid state after this constructor alone; e.g., `isCompressedDim(d)`
/// doesn't entail `!(pointers[d].empty())`.
///
/// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
const DimLevelType *sparsity)
: SparseTensorStorageBase(szs, perm, sparsity), pointers(getRank()),
indices(getRank()), idx(getRank()) {}
public:
/// Constructs a sparse tensor storage scheme with the given dimensions,
/// permutation, and per-dimension dense/sparse annotations, using
@ -375,10 +443,8 @@ public:
///
/// Precondition: `perm` and `sparsity` must be valid for `szs.size()`.
SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
const DimLevelType *sparsity,
SparseTensorCOO<V> *coo = nullptr)
: SparseTensorStorageBase(szs, perm, sparsity), pointers(getRank()),
indices(getRank()), idx(getRank()) {
const DimLevelType *sparsity, SparseTensorCOO<V> *coo)
: SparseTensorStorage(szs, perm, sparsity) {
// Provide hints on capacity of pointers and indices.
// TODO: needs much fine-tuning based on actual sparsity; currently
// we reserve pointer/index space based on all previous dense
@ -414,6 +480,17 @@ public:
}
}
/// Constructs a sparse tensor storage scheme with the given dimensions,
/// permutation, and per-dimension dense/sparse annotations, using
/// the given sparse tensor for the initial contents.
///
/// Preconditions:
/// * `perm` and `sparsity` must be valid for `szs.size()`.
/// * The `tensor` must have the same value type `V`.
SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
const DimLevelType *sparsity,
const SparseTensorStorageBase &tensor);
~SparseTensorStorage() override = default;
/// Partially specialize these getter methods based on template types.
@ -478,21 +555,28 @@ public:
endPath(0);
}
void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
const uint64_t *perm) const override {
*out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
}
/// Returns this sparse tensor storage scheme as a new memory-resident
/// sparse tensor in coordinate scheme with the given dimension order.
///
/// Precondition: `perm` must be valid for `getRank()`.
SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
SparseTensorEnumerator<P, I, V> enumerator(*this, getRank(), perm);
SparseTensorEnumeratorBase<V> *enumerator;
newEnumerator(&enumerator, getRank(), perm);
SparseTensorCOO<V> *coo =
new SparseTensorCOO<V>(enumerator.permutedSizes(), values.size());
enumerator.forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
coo->add(ind, val);
});
// TODO: This assertion assumes there are no stored zeros,
// or if there are then that we don't filter them out.
// Cf., <https://github.com/llvm/llvm-project/issues/54179>
assert(coo->getElements().size() == values.size());
delete enumerator;
return coo;
}
@ -508,10 +592,8 @@ public:
const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
SparseTensorStorage<P, I, V> *n = nullptr;
if (coo) {
assert(coo->getRank() == rank && "Tensor rank mismatch");
const auto &coosz = coo->getSizes();
for (uint64_t r = 0; r < rank; r++)
assert(shape[r] == 0 || shape[r] == coosz[perm[r]]);
assertPermutedSizesMatchShape(coosz, rank, perm, shape);
n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
} else {
std::vector<uint64_t> permsz(rank);
@ -519,11 +601,34 @@ public:
assert(shape[r] > 0 && "Dimension size zero has trivial storage");
permsz[perm[r]] = shape[r];
}
n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity);
// We pass the null `coo` to ensure we select the intended constructor.
n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
}
return n;
}
/// Factory method. Constructs a sparse tensor storage scheme with
/// the given dimensions, permutation, and per-dimension dense/sparse
/// annotations, using the sparse tensor for the initial contents.
///
/// Preconditions:
/// * `shape`, `perm`, and `sparsity` must be valid for `rank`.
/// * The `tensor` must have the same value type `V`.
static SparseTensorStorage<P, I, V> *
newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
const DimLevelType *sparsity,
const SparseTensorStorageBase *source) {
assert(source && "Got nullptr for source");
SparseTensorEnumeratorBase<V> *enumerator;
source->newEnumerator(&enumerator, rank, perm);
const auto &permsz = enumerator->permutedSizes();
assertPermutedSizesMatchShape(permsz, rank, perm, shape);
auto *tensor =
new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
delete enumerator;
return tensor;
}
private:
/// Appends an arbitrary new position to `pointers[d]`. This method
/// checks that `pos` is representable in the `P` type; however, it
@ -561,6 +666,36 @@ private:
}
}
/// Writes the given coordinate to `indices[d][pos]`. This method
/// checks that `i` is representable in the `I` type; however, it
/// does not check that `i` is semantically valid (i.e., in bounds
/// for `sizes[d]` and not elsewhere occurring in the same segment).
void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
assert(isCompressedDim(d));
// Subscript assignment to `std::vector` requires that the `pos`-th
// entry has been initialized; thus we must be sure to check `size()`
// here, instead of `capacity()` as would be ideal.
assert(pos < indices[d].size() && "Index position is out of bounds");
assert(i <= std::numeric_limits<I>::max() &&
"Index value is too large for the I-type");
indices[d][pos] = static_cast<I>(i);
}
/// Computes the assembled-size associated with the `d`-th dimension,
/// given the assembled-size associated with the `(d-1)`-th dimension.
/// "Assembled-sizes" correspond to the (nominal) sizes of overhead
/// storage, as opposed to "dimension-sizes" which are the cardinality
/// of coordinates for that dimension.
///
/// Precondition: the `pointers[d]` array must be fully initialized
/// before calling this method.
uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
if (isCompressedDim(d))
return pointers[d][parentSz];
// else if dense:
return parentSz * getDimSizes()[d];
}
/// Initializes sparse tensor storage scheme from a memory-resident sparse
/// tensor in coordinate scheme. This method prepares the pointers and
/// indices arrays under the given per-dimension dense/sparse annotations.
@ -798,6 +933,206 @@ private:
}
};
/// Statistics regarding the number of nonzero subtensors in
/// a source tensor, for direct sparse=>sparse conversion a la
/// <https://arxiv.org/abs/2001.02609>.
///
/// N.B., this class stores references to the parameters passed to
/// the constructor; thus, objects of this class must not outlive
/// those parameters.
class SparseTensorNNZ {
public:
/// Allocate the statistics structure for the desired sizes and
/// sparsity (in the target tensor's storage-order). This constructor
/// does not actually populate the statistics, however; for that see
/// `initialize`.
///
/// Precondition: `szs` must not contain zeros.
SparseTensorNNZ(const std::vector<uint64_t> &szs,
const std::vector<DimLevelType> &sparsity)
: dimSizes(szs), dimTypes(sparsity), nnz(getRank()) {
assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
bool uncompressed = true;
uint64_t sz = 1; // the product of all `dimSizes` strictly less than `r`.
for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
switch (dimTypes[r]) {
case DimLevelType::kCompressed:
assert(uncompressed &&
"Multiple compressed layers not currently supported");
uncompressed = false;
nnz[r].resize(sz, 0); // Both allocate and zero-initialize.
break;
case DimLevelType::kDense:
assert(uncompressed &&
"Dense after compressed not currently supported");
break;
case DimLevelType::kSingleton:
// Singleton after Compressed causes no problems for allocating
// `nnz` nor for the yieldPos loop. This remains true even
// when adding support for multiple compressed dimensions or
// for dense-after-compressed.
break;
}
sz = checkedMul(sz, dimSizes[r]);
}
}
// We disallow copying to help avoid leaking the stored references.
SparseTensorNNZ(const SparseTensorNNZ &) = delete;
SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
/// Returns the rank of the target tensor.
uint64_t getRank() const { return dimSizes.size(); }
/// Enumerate the source tensor to fill in the statistics. The
/// enumerator should already incorporate the permutation (from
/// semantic-order to the target storage-order).
template <typename V>
void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
enumerator.forallElements(
[this](const std::vector<uint64_t> &ind, V) { add(ind); });
}
/// The type of callback functions which receive an nnz-statistic.
using NNZConsumer = const std::function<void(uint64_t)> &;
/// Lexicographically enumerates all indicies for dimensions strictly
/// less than `stopDim`, and passes their nnz statistic to the callback.
/// Since our use-case only requires the statistic not the coordinates
/// themselves, we do not bother to construct those coordinates.
void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
"Cannot look up non-compressed dimensions");
forallIndices(yield, stopDim, 0, 0);
}
private:
/// Adds a new element (i.e., increment its statistics). We use
/// a method rather than inlining into the lambda in `initialize`,
/// to avoid spurious templating over `V`. And this method is private
/// to avoid needing to re-assert validity of `ind` (which is guaranteed
/// by `forallElements`).
void add(const std::vector<uint64_t> &ind) {
uint64_t parentPos = 0;
for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
if (dimTypes[r] == DimLevelType::kCompressed)
nnz[r][parentPos]++;
parentPos = parentPos * dimSizes[r] + ind[r];
}
}
/// Recursive component of the public `forallIndices`.
void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
uint64_t d) const {
assert(d <= stopDim);
if (d == stopDim) {
assert(parentPos < nnz[d].size() && "Cursor is out of range");
yield(nnz[d][parentPos]);
} else {
const uint64_t sz = dimSizes[d];
const uint64_t pstart = parentPos * sz;
for (uint64_t i = 0; i < sz; i++)
forallIndices(yield, stopDim, pstart + i, d + 1);
}
}
// All of these are in the target storage-order.
const std::vector<uint64_t> &dimSizes;
const std::vector<DimLevelType> &dimTypes;
std::vector<std::vector<uint64_t>> nnz;
};
template <typename P, typename I, typename V>
SparseTensorStorage<P, I, V>::SparseTensorStorage(
const std::vector<uint64_t> &szs, const uint64_t *perm,
const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
: SparseTensorStorage(szs, perm, sparsity) {
SparseTensorEnumeratorBase<V> *enumerator;
tensor.newEnumerator(&enumerator, getRank(), perm);
{
// Initialize the statistics structure.
SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
nnz.initialize(*enumerator);
// Initialize "pointers" overhead (and allocate "indices", "values").
uint64_t parentSz = 1; // assembled-size (not dimension-size) of `r-1`.
for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
if (isCompressedDim(r)) {
pointers[r].reserve(parentSz + 1);
pointers[r].push_back(0);
uint64_t currentPos = 0;
nnz.forallIndices(r, [this, &currentPos, r](uint64_t n) {
currentPos += n;
appendPointer(r, currentPos);
});
assert(pointers[r].size() == parentSz + 1 &&
"Final pointers size doesn't match allocated size");
// That assertion entails `assembledSize(parentSz, r)`
// is now in a valid state. That is, `pointers[r][parentSz]`
// equals the present value of `currentPos`, which is the
// correct assembled-size for `indices[r]`.
}
// Update assembled-size for the next iteration.
parentSz = assembledSize(parentSz, r);
// Ideally we need only `indices[r].reserve(parentSz)`, however
// the `std::vector` implementation forces us to initialize it too.
// That is, in the yieldPos loop we need random-access assignment
// to `indices[r]`; however, `std::vector`'s subscript-assignment
// only allows assigning to already-initialized positions.
if (isCompressedDim(r))
indices[r].resize(parentSz, 0);
}
values.resize(parentSz, 0); // Both allocate and zero-initialize.
}
// The yieldPos loop
enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
uint64_t parentSz = 1, parentPos = 0;
for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
if (isCompressedDim(r)) {
// If `parentPos == parentSz` then it's valid as an array-lookup;
// however, it's semantically invalid here since that entry
// does not represent a segment of `indices[r]`. Moreover, that
// entry must be immutable for `assembledSize` to remain valid.
assert(parentPos < parentSz && "Pointers position is out of bounds");
const uint64_t currentPos = pointers[r][parentPos];
// This increment won't overflow the `P` type, since it can't
// exceed the original value of `pointers[r][parentPos+1]`
// which was already verified to be within bounds for `P`
// when it was written to the array.
pointers[r][parentPos]++;
writeIndex(r, currentPos, ind[r]);
parentPos = currentPos;
} else { // Dense dimension.
parentPos = parentPos * getDimSizes()[r] + ind[r];
}
parentSz = assembledSize(parentSz, r);
}
assert(parentPos < values.size() && "Value position is out of bounds");
values[parentPos] = val;
});
// No longer need the enumerator, so we'll delete it ASAP.
delete enumerator;
// The finalizeYieldPos loop
for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
if (isCompressedDim(r)) {
assert(parentSz == pointers[r].size() - 1 &&
"Actual pointers size doesn't match the expected size");
// Can't check all of them, but at least we can check the last one.
assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
"Pointers got corrupted");
// TODO: optimize this by using `memmove` or similar.
for (uint64_t n = 0; n < parentSz; n++) {
const uint64_t parentPos = parentSz - n;
pointers[r][parentPos] = pointers[r][parentPos - 1];
}
pointers[r][0] = 0;
}
parentSz = assembledSize(parentSz, r);
}
}
/// Helper to convert string to lower case.
static char *toLower(char *token) {
for (char *c = token; *c; c++)
@ -1088,6 +1423,11 @@ extern "C" {
delete coo; \
return tensor; \
} \
if (action == Action::kSparseToSparse) { \
auto *tensor = static_cast<SparseTensorStorageBase *>(ptr); \
return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm, \
sparsity, tensor); \
} \
if (action == Action::kEmptyCOO) \
return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm); \
coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm); \

View file

@ -1,4 +1,10 @@
// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
// First use with `kViaCOO` for sparse2sparse conversion (the old way).
// RUN: mlir-opt %s --sparse-tensor-conversion="s2s-strategy=1" \
// RUN: --canonicalize --cse | FileCheck %s
//
// Now again with `kAuto` (the new default).
// RUN: mlir-opt %s --sparse-tensor-conversion="s2s-strategy=0" \
// RUN: --canonicalize --cse | FileCheck %s -check-prefix=CHECKAUTO
#SparseVector = #sparse_tensor.encoding<{
dimLevelType = ["compressed"]
@ -210,6 +216,17 @@ func.func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVecto
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
// CHECK: call @delSparseTensorCOOF32(%[[C]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
// CHECKAUTO-LABEL: func @sparse_convert_1d_ss(
// CHECKAUTO-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECKAUTO-DAG: %[[SparseToSparse:.*]] = arith.constant 3 : i32
// CHECKAUTO-DAG: %[[P:.*]] = memref.alloca() : memref<1xi8>
// CHECKAUTO-DAG: %[[Q:.*]] = memref.alloca() : memref<1xindex>
// CHECKAUTO-DAG: %[[R:.*]] = memref.alloca() : memref<1xindex>
// CHECKAUTO-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<1xi8> to memref<?xi8>
// CHECKAUTO-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<1xindex> to memref<?xindex>
// CHECKAUTO-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xindex> to memref<?xindex>
// CHECKAUTO: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[SparseToSparse]], %[[A]])
// CHECKAUTO: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
return %0 : tensor<?xf32, #SparseVector32>

View file

@ -0,0 +1,102 @@
// Force this file to use the kDirect method for sparse2sparse.
// RUN: mlir-opt %s --sparse-compiler="s2s-strategy=2" | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
#Tensor1 = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "dense", "compressed" ]
}>
// NOTE: dense after compressed is not currently supported for the target
// of direct-sparse2sparse conversion. (It's fine for the source though.)
#Tensor2 = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "compressed", "dense" ]
}>
#Tensor3 = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "dense", "compressed" ],
dimOrdering = affine_map<(i,j,k) -> (i,k,j)>
}>
module {
//
// Utilities for output and releasing memory.
//
func.func @dump(%arg0: tensor<2x3x4xf64>) {
%c0 = arith.constant 0 : index
%d0 = arith.constant -1.0 : f64
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %d0: tensor<2x3x4xf64>, vector<2x3x4xf64>
vector.print %0 : vector<2x3x4xf64>
return
}
func.func @dumpAndRelease_234(%arg0: tensor<2x3x4xf64>) {
call @dump(%arg0) : (tensor<2x3x4xf64>) -> ()
%1 = bufferization.to_memref %arg0 : memref<2x3x4xf64>
memref.dealloc %1 : memref<2x3x4xf64>
return
}
//
// Main driver.
//
func.func @entry() {
//
// Initialize a 3-dim dense tensor.
//
%src = arith.constant dense<[
[ [ 1.0, 2.0, 3.0, 4.0 ],
[ 5.0, 6.0, 7.0, 8.0 ],
[ 9.0, 10.0, 11.0, 12.0 ] ],
[ [ 13.0, 14.0, 15.0, 16.0 ],
[ 17.0, 18.0, 19.0, 20.0 ],
[ 21.0, 22.0, 23.0, 24.0 ] ]
]> : tensor<2x3x4xf64>
//
// Convert dense tensor directly to various sparse tensors.
//
%s1 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor1>
%s2 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor2>
%s3 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor3>
//
// Convert sparse tensor directly to another sparse format.
//
%t13 = sparse_tensor.convert %s1 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf64, #Tensor3>
%t21 = sparse_tensor.convert %s2 : tensor<2x3x4xf64, #Tensor2> to tensor<2x3x4xf64, #Tensor1>
%t23 = sparse_tensor.convert %s2 : tensor<2x3x4xf64, #Tensor2> to tensor<2x3x4xf64, #Tensor3>
%t31 = sparse_tensor.convert %s3 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf64, #Tensor1>
//
// Convert sparse tensor back to dense.
//
%d13 = sparse_tensor.convert %t13 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf64>
%d21 = sparse_tensor.convert %t21 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf64>
%d23 = sparse_tensor.convert %t23 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf64>
%d31 = sparse_tensor.convert %t31 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf64>
//
// Check round-trip equality. And release dense tensors.
//
// CHECK-COUNT-5: ( ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) ), ( ( 13, 14, 15, 16 ), ( 17, 18, 19, 20 ), ( 21, 22, 23, 24 ) ) )
call @dump(%src) : (tensor<2x3x4xf64>) -> ()
call @dumpAndRelease_234(%d13) : (tensor<2x3x4xf64>) -> ()
call @dumpAndRelease_234(%d21) : (tensor<2x3x4xf64>) -> ()
call @dumpAndRelease_234(%d23) : (tensor<2x3x4xf64>) -> ()
call @dumpAndRelease_234(%d31) : (tensor<2x3x4xf64>) -> ()
//
// Release sparse tensors.
//
sparse_tensor.release %t13 : tensor<2x3x4xf64, #Tensor3>
sparse_tensor.release %t21 : tensor<2x3x4xf64, #Tensor1>
sparse_tensor.release %t23 : tensor<2x3x4xf64, #Tensor3>
sparse_tensor.release %t31 : tensor<2x3x4xf64, #Tensor1>
sparse_tensor.release %s1 : tensor<2x3x4xf64, #Tensor1>
sparse_tensor.release %s2 : tensor<2x3x4xf64, #Tensor2>
sparse_tensor.release %s3 : tensor<2x3x4xf64, #Tensor3>
return
}
}

View file

@ -186,11 +186,17 @@ def main():
vec = 0
vl = 1
e = False
# Disable direct sparse2sparse conversion, because it doubles the time!
# TODO: While direct s2s is far too slow for per-commit testing,
# we should have some framework ensure that we run this test with
# `s2s=0` on a regular basis, to ensure that it does continue to work.
s2s = 1
sparsification_options = (
f'parallelization-strategy={par} '
f'vectorization-strategy={vec} '
f'vl={vl} '
f'enable-simd-index32={e}')
f'enable-simd-index32={e} '
f's2s-strategy={s2s}')
compiler = sparse_compiler.SparseCompiler(
options=sparsification_options, opt_level=0, shared_libs=[support_lib])
f64 = ir.F64Type.get()