[mlir][bufferize] Use rewriter instead of replacing all uses directly

This is important for compatibility with DialectConversion.
This commit is contained in:
Matthias Springer 2022-02-12 02:31:00 +09:00
parent 541c9ba842
commit 9106d35b91

View file

@ -293,14 +293,13 @@ FailureOr<Value> BufferizationState::getBuffer(
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
Operation *op,
ValueRange values) {
assert(values.size() == op->getNumResults() &&
"expected one value per OpResult");
OpBuilder::InsertionGuard g(rewriter);
// Replace all OpResults with the given values.
SmallVector<Value> replacements;
for (OpResult opResult : op->getOpResults()) {
// Skip OpResult if it has no uses.
if (opResult.getUses().empty())
continue;
Value replacement = values[opResult.getResultNumber()];
if (opResult.getType().isa<TensorType>()) {
// The OpResult is a tensor. Such values are replaced with memrefs during
@ -315,10 +314,10 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
replacement = rewriter.create<bufferization::ToTensorOp>(
replacement.getLoc(), replacement);
}
opResult.replaceAllUsesWith(replacement);
replacements.push_back(replacement);
}
rewriter.eraseOp(op);
rewriter.replaceOp(op, replacements);
}
AlwaysCopyBufferizationState::AlwaysCopyBufferizationState(