[mlir][sparse][capi][python] add sparse tensor passes

First set of "boilerplate" to get sparse tensor
passes available through CAPI and Python.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D102362
This commit is contained in:
Aart Bik 2021-05-12 14:51:16 -07:00
parent bd00106d1e
commit 58d12332a4
8 changed files with 89 additions and 0 deletions

View file

@ -74,4 +74,6 @@ mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr);
}
#endif
#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.h.inc"
#endif // MLIR_C_DIALECT_SPARSE_TENSOR_H

View file

@ -1,5 +1,7 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name SparseTensor)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix SparseTensor)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix SparseTensor)
add_public_tablegen_target(MLIRSparseTensorPassIncGen)
add_mlir_doc(Passes SparseTensorPasses ./ -gen-pass-doc)

View file

@ -33,6 +33,14 @@ add_mlir_python_extension(MLIRAsyncPassesBindingsPythonExtension _mlirAsyncPasse
)
add_dependencies(MLIRBindingsPythonExtension MLIRAsyncPassesBindingsPythonExtension)
add_mlir_python_extension(MLIRSparseTensorPassesBindingsPythonExtension _mlirSparseTensorPasses
INSTALL_DIR
python
SOURCES
SparseTensorPasses.cpp
)
add_dependencies(MLIRBindingsPythonExtension MLIRSparseTensorPassesBindingsPythonExtension)
add_mlir_python_extension(MLIRGPUPassesBindingsPythonExtension _mlirGPUPasses
INSTALL_DIR
python

View file

@ -0,0 +1,22 @@
//===- SparseTensorPasses.cpp - Pybind module for the SparseTensor passes -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/SparseTensor.h"
#include <pybind11/pybind11.h>
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
PYBIND11_MODULE(_mlirSparseTensorPasses, m) {
m.doc() = "MLIR SparseTensor Dialect Passes";
// Register all SparseTensor passes on load.
mlirRegisterSparseTensorPasses();
}

View file

@ -62,11 +62,13 @@ add_mlir_public_c_api_library(MLIRCAPIShape
add_mlir_public_c_api_library(MLIRCAPISparseTensor
SparseTensor.cpp
SparseTensorPasses.cpp
PARTIAL_SOURCES_INTENDED
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIRSparseTensor
MLIRSparseTensorTransforms
)
add_mlir_public_c_api_library(MLIRCAPIStandard

View file

@ -0,0 +1,26 @@
//===- SparseTensorPasses.cpp - C API for SparseTensor Dialect Passes -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/CAPI/Pass.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
// Must include the declarations as they carry important visibility attributes.
#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.h.inc"
using namespace mlir;
#ifdef __cplusplus
extern "C" {
#endif
#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.cpp.inc"
#ifdef __cplusplus
}
#endif

View file

@ -3,5 +3,10 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._cext_loader import _reexport_cext
from .._cext_loader import _load_extension
_reexport_cext("dialects.sparse_tensor", __name__)
_cextSparseTensorPasses = _load_extension("_mlirSparseTensorPasses")
del _reexport_cext
del _load_extension

View file

@ -0,0 +1,22 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects import sparse_tensor as st
def run(f):
print('\nTEST:', f.__name__)
f()
return f
# CHECK-LABEL: TEST: testSparseTensorPass
@run
def testSparseTensorPass():
with Context() as context:
PassManager.parse('sparsification')
PassManager.parse('sparse-tensor-conversion')
# CHECK: SUCCESS
print('SUCCESS')