[mlir] Async: check awaited operand error state after sync await

Previously only await inside the async function (coroutine after lowering to async runtime) would check the error state

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D109229
This commit is contained in:
Eugene Zhulenev 2021-09-03 05:27:30 -07:00
parent 2833a2edac
commit fd52b4357a
5 changed files with 32 additions and 12 deletions

View file

@ -525,10 +525,6 @@ void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
bool isGroup = type.isa<GroupType>();
bool isValue = type.isa<ValueType>();
// Drop reference after async token or group await (sync await)
if (auto await = dyn_cast<RuntimeAwaitOp>(op))
return (isToken || isGroup) ? -1 : 0;
// Drop reference after async token or group error check (coro await).
if (auto await = dyn_cast<RuntimeIsErrorOp>(op))
return (isToken || isGroup) ? -1 : 0;

View file

@ -397,10 +397,23 @@ public:
Location loc = op->getLoc();
Value operand = AwaitAdaptor(operands).operand();
Type i1 = rewriter.getI1Type();
// Inside regular functions we use the blocking wait operation to wait for
// the async object (token, value or group) to become available.
if (!isInCoroutine)
rewriter.create<RuntimeAwaitOp>(loc, operand);
if (!isInCoroutine) {
ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
builder.create<RuntimeAwaitOp>(loc, operand);
// Assert that the awaited operands is not in the error state.
Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
Value notError = builder.create<XOrOp>(
isError,
builder.create<ConstantOp>(loc, i1, builder.getIntegerAttr(i1, 1)));
builder.create<AssertOp>(notError,
"Awaited async operand is in error state");
}
// Inside the coroutine we convert await operation into coroutine suspension
// point, and resume execution asynchronously.
@ -430,8 +443,7 @@ public:
// Check if the awaited value is in the error state.
builder.setInsertionPointToStart(resume);
auto isError =
builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand);
auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
builder.create<CondBranchOp>(isError,
/*trueDest=*/setupSetErrorBlock(coro),
/*trueArgs=*/ArrayRef<Value>(),
@ -772,7 +784,8 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
});
return !walkResult.wasInterrupted();
});
runtimeTarget.addLegalOp<BranchOp, CondBranchOp>();
runtimeTarget
.addLegalOp<AssertOp, XOrOp, ConstantOp, BranchOp, CondBranchOp>();
// Assertions must be converted to runtime errors inside async functions.
runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {

View file

@ -24,6 +24,10 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
async.yield
}
// CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
// CHECK: %[[TRUE:.*]] = constant true
// CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1
// CHECK: assert %[[NOT_ERROR]]
// CHECK-NEXT: return
async.await %token : !async.token
return
@ -83,7 +87,10 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
async.yield
}
// CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
// CHECK-NEXT: return
// CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
// CHECK: %[[TRUE:.*]] = constant true
// CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1
// CHECK: assert %[[NOT_ERROR]]
async.await %token0 : !async.token
return
}

View file

@ -4,7 +4,7 @@
// CHECK: %[[TOKEN:.*]]: !async.token
func @token_await(%arg0: !async.token) {
// CHECK: async.runtime.await %[[TOKEN]]
// CHECK: async.runtime.drop_ref %[[TOKEN]] {count = 1 : i32}
// CHECK-NOT: async.runtime.drop_ref
async.runtime.await %arg0 : !async.token
return
}
@ -13,7 +13,7 @@ func @token_await(%arg0: !async.token) {
// CHECK: %[[GROUP:.*]]: !async.group
func @group_await(%arg0: !async.group) {
// CHECK: async.runtime.await %[[GROUP]]
// CHECK: async.runtime.drop_ref %[[GROUP]] {count = 1 : i32}
// CHECK-NOT: async.runtime.drop_ref
async.runtime.await %arg0 : !async.group
return
}

View file

@ -60,6 +60,10 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
async.yield
}
// CHECK: async.runtime.await %[[TOKEN]]
// CHECK: %[[IS_ERROR:.*]] = async.runtime.is_error %[[TOKEN]]
// CHECK: %[[TRUE:.*]] = constant true
// CHECK: %[[NOT_ERROR:.*]] = xor %[[IS_ERROR]], %[[TRUE]] : i1
// CHECK: assert %[[NOT_ERROR]]
// CHECK-NEXT: return
async.await %token0 : !async.token
return