[MLIR] [Python] Add a method to clear live operations map
Introduce a method on PyMlirContext (and plumb it through to Python) to invalidate all of the operations in the live operations map and clear it. Since Python has no notion of private data, an end-developer could reach into some 3rd party API which uses the MLIR Python API (that is behaving correctly with regard to holding references) and grab a reference to an MLIR Python Operation, preventing it from being deconstructed out of the live operations map. This allows the API developer to clear the map when it calls C++ code which could delete operations, protecting itself from its users. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D123895
This commit is contained in:
parent
6db0afb44e
commit
6b0bed7ea5
|
@ -505,6 +505,14 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
|
|||
|
||||
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
|
||||
|
||||
size_t PyMlirContext::clearLiveOperations() {
|
||||
for (auto &op : liveOperations)
|
||||
op.second.second->setInvalid();
|
||||
size_t numInvalidated = liveOperations.size();
|
||||
liveOperations.clear();
|
||||
return numInvalidated;
|
||||
}
|
||||
|
||||
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
|
||||
|
||||
pybind11::object PyMlirContext::contextEnter() {
|
||||
|
@ -2208,6 +2216,7 @@ void mlir::python::populateIRCore(py::module &m) {
|
|||
return ref.releaseObject();
|
||||
})
|
||||
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
|
||||
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
|
||||
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
|
||||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
|
||||
&PyMlirContext::getCapsule)
|
||||
|
|
|
@ -201,6 +201,12 @@ public:
|
|||
/// Used for testing.
|
||||
size_t getLiveOperationCount();
|
||||
|
||||
/// Clears the live operations map, returning the number of entries which were
|
||||
/// invalidated. To be used as a safety mechanism so that API end-users can't
|
||||
/// corrupt by holding references they shouldn't have accessed in the first
|
||||
/// place.
|
||||
size_t clearLiveOperations();
|
||||
|
||||
/// Gets the count of live modules associated with this context.
|
||||
/// Used for testing.
|
||||
size_t getLiveModuleCount();
|
||||
|
@ -575,6 +581,9 @@ public:
|
|||
/// parent context's live operations map, and sets the valid bit false.
|
||||
void erase();
|
||||
|
||||
/// Invalidate the operation.
|
||||
void setInvalid() { valid = false; }
|
||||
|
||||
/// Clones this operation.
|
||||
pybind11::object clone(const pybind11::object &ip);
|
||||
|
||||
|
|
|
@ -104,6 +104,16 @@ def testModuleOperation():
|
|||
assert ctx._get_live_operation_count() == 1
|
||||
assert op1 is op2
|
||||
|
||||
# Test live operation clearing.
|
||||
op1 = module.operation
|
||||
assert ctx._get_live_operation_count() == 1
|
||||
num_invalidated = ctx._clear_live_operations()
|
||||
assert num_invalidated == 1
|
||||
assert ctx._get_live_operation_count() == 0
|
||||
op1 = None
|
||||
gc.collect()
|
||||
op1 = module.operation
|
||||
|
||||
# Ensure that if module is de-referenced, the operations are still valid.
|
||||
module = None
|
||||
gc.collect()
|
||||
|
|
Loading…
Reference in a new issue