[mlir][sparse] Adding Action::kSparseToSparse for @newSparseTensor

This is work towards: https://github.com/llvm/llvm-project/issues/51652

This differential doesn't yet make use of the new kSparseToSparse, just introduces it.  The differential that finally makes use of them is D122061, which is the final differential in the chain that fixes bug 51652.

Depends On D122054

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D122055
This commit is contained in:
wren romano 2022-03-18 19:20:33 -07:00
parent d8beb2c33d
commit df948127ac
2 changed files with 11 additions and 10 deletions

View file

@ -50,9 +50,10 @@ enum class Action : uint32_t {
kEmpty = 0,
kFromFile = 1,
kFromCOO = 2,
kEmptyCOO = 3,
kToCOO = 4,
kToIterator = 5
kSparseToSparse = 3,
kEmptyCOO = 4,
kToCOO = 5,
kToIterator = 6
};
/// This enum mimics `SparseTensorEncodingAttr::DimLevelType` for

View file

@ -29,7 +29,7 @@
// CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<1xindex>
// CHECK-DAG: %[[zeroI32:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[ElemTp:.*]] = arith.constant 4 : i32
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[zeroI32]], %[[zeroI32]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref<?xindex>
@ -68,7 +68,7 @@ func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32>
// CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<1xindex>
// CHECK-DAG: %[[zeroI32:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[ElemTp:.*]] = arith.constant 4 : i32
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[zeroI32]], %[[zeroI32]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref<?xindex>
@ -110,7 +110,7 @@ func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<?xi32
// CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[I1]], %[[PermS]][%[[I1]]] : memref<2xindex>
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %{{.*}}, %{{.*}}, %{{.*}}, %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<2xindex> to memref<?xindex>
@ -154,7 +154,7 @@ func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64
// CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[I1]], %[[PermS]][%[[I1]]] : memref<2xindex>
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %{{.*}}, %{{.*}}, %{{.*}}, %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<2xindex> to memref<?xindex>
@ -198,7 +198,7 @@ func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tensor<?x
// CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[I1]], %[[PermS]][%[[I1]]] : memref<2xindex>
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %{{.*}}, %{{.*}}, %{{.*}}, %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<2xindex> to memref<?xindex>
@ -242,7 +242,7 @@ func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tensor<2x
// CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[I1]], %[[PermS]][%[[I1]]] : memref<2xindex>
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %{{.*}}, %{{.*}}, %{{.*}}, %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<2xindex> to memref<?xindex>
@ -290,7 +290,7 @@ func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x
// CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<3xindex>
// CHECK-DAG: memref.store %[[I1]], %[[PermS]][%[[I1]]] : memref<3xindex>
// CHECK-DAG: memref.store %[[I2]], %[[PermS]][%[[I2]]] : memref<3xindex>
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32
// CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 6 : i32
// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %{{.*}}, %{{.*}}, %{{.*}}, %[[ActionToIter]], %[[Arg]]) : (memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<3xindex>
// CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<3xindex> to memref<?xindex>