[MLIR] [Python] Add a method to clear live operations map

Introduce a method on PyMlirContext (and plumb it through to Python) to
invalidate all of the operations in the live operations map and clear
it. Since Python has no notion of private data, an end-developer could
reach into some 3rd party API which uses the MLIR Python API (that is
behaving correctly with regard to holding references) and grab a
reference to an MLIR Python Operation, preventing it from being
deconstructed out of the live operations map. This allows the API
developer to clear the map when it calls C++ code which could delete
operations, protecting itself from its users.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D123895
This commit is contained in:
John Demme 2022-04-19 15:03:15 -07:00
parent 6db0afb44e
commit 6b0bed7ea5
3 changed files with 28 additions and 0 deletions

View file

@ -505,6 +505,14 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
size_t PyMlirContext::clearLiveOperations() {
for (auto &op : liveOperations)
op.second.second->setInvalid();
size_t numInvalidated = liveOperations.size();
liveOperations.clear();
return numInvalidated;
}
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
pybind11::object PyMlirContext::contextEnter() {
@ -2208,6 +2216,7 @@ void mlir::python::populateIRCore(py::module &m) {
return ref.releaseObject();
})
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyMlirContext::getCapsule)

View file

@ -201,6 +201,12 @@ public:
/// Used for testing.
size_t getLiveOperationCount();
/// Clears the live operations map, returning the number of entries which were
/// invalidated. To be used as a safety mechanism so that API end-users can't
/// corrupt by holding references they shouldn't have accessed in the first
/// place.
size_t clearLiveOperations();
/// Gets the count of live modules associated with this context.
/// Used for testing.
size_t getLiveModuleCount();
@ -575,6 +581,9 @@ public:
/// parent context's live operations map, and sets the valid bit false.
void erase();
/// Invalidate the operation.
void setInvalid() { valid = false; }
/// Clones this operation.
pybind11::object clone(const pybind11::object &ip);

View file

@ -104,6 +104,16 @@ def testModuleOperation():
assert ctx._get_live_operation_count() == 1
assert op1 is op2
# Test live operation clearing.
op1 = module.operation
assert ctx._get_live_operation_count() == 1
num_invalidated = ctx._clear_live_operations()
assert num_invalidated == 1
assert ctx._get_live_operation_count() == 0
op1 = None
gc.collect()
op1 = module.operation
# Ensure that if module is de-referenced, the operations are still valid.
module = None
gc.collect()