[OpenMPIRBuilder] Detect and fix ambiguous InsertPoints for createSections.

Follow-up on D117226 for createSections.

Reviewed By: shraiysh

Differential Revision: https://reviews.llvm.org/D117835
This commit is contained in:
Michael Kruse 2022-04-04 20:13:02 -05:00
parent bb3980ae9f
commit c082ca16f1
4 changed files with 48 additions and 43 deletions

View file

@ -1064,6 +1064,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections(
const LocationDescription &Loc, InsertPointTy AllocaIP,
ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
if (!updateToLocation(Loc))
return Loc.IP;

View file

@ -4077,7 +4077,12 @@ TEST_F(OpenMPIRBuilderTest, CreateSectionsSimple) {
OMPBuilder.initialize();
F->setName("func");
IRBuilder<> Builder(BB);
BasicBlock *EnterBB = BasicBlock::Create(Ctx, "sections.enter", F);
Builder.CreateBr(EnterBB);
Builder.SetInsertPoint(EnterBB);
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
llvm::SmallVector<BodyGenCallbackTy, 4> SectionCBVector;
llvm::SmallVector<BasicBlock *, 4> CaseBBs;
@ -4232,7 +4237,11 @@ TEST_F(OpenMPIRBuilderTest, CreateSectionsNoWait) {
F->setName("func");
IRBuilder<> Builder(BB);
BasicBlock *EnterBB = BasicBlock::Create(Ctx, "sections.enter", F);
Builder.CreateBr(EnterBB);
Builder.SetInsertPoint(EnterBB);
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
F->getEntryBlock().getFirstInsertionPt());
llvm::SmallVector<BodyGenCallbackTy, 4> SectionCBVector;

View file

@ -72,6 +72,21 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
return allocaInsertPoint;
// Otherwise, insert to the entry block of the surrounding function.
// If the current IRBuilder InsertPoint is the function's entry, it cannot
// also be used for alloca insertion which would result in insertion order
// confusion. Create a new BasicBlock for the Builder and use the entry block
// for the allocs.
if (builder.GetInsertBlock() ==
&builder.GetInsertBlock()->getParent()->getEntryBlock()) {
assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
"Assuming end of basic block");
llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
builder.GetInsertBlock()->getNextNode());
builder.CreateBr(entryBB);
builder.SetInsertPoint(entryBB);
}
llvm::BasicBlock &funcEntryBlock =
builder.GetInsertBlock()->getParent()->getEntryBlock();
return llvm::OpenMPIRBuilder::InsertPointTy(
@ -255,23 +270,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
// TODO: Is the Parallel construct cancellable?
bool isCancellable = false;
// Ensure that the BasicBlock for the the parallel region is sparate from the
// function entry which we may need to insert allocas.
if (builder.GetInsertBlock() ==
&builder.GetInsertBlock()->getParent()->getEntryBlock()) {
assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
"Assuming end of basic block");
llvm::BasicBlock *entryBB =
llvm::BasicBlock::Create(builder.getContext(), "parallel.entry",
builder.GetInsertBlock()->getParent(),
builder.GetInsertBlock()->getNextNode());
builder.CreateBr(entryBB);
builder.SetInsertPoint(entryBB);
}
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
ompLoc, findAllocaInsertPoint(builder, moduleTranslation), bodyGenCB,
privCB, finiCB, ifCond, numThreads, pbKind, isCancellable));
ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind,
isCancellable));
return bodyGenStatus;
}
@ -522,7 +526,6 @@ convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
SmallVector<llvm::Value *> vecValues =
moduleTranslation.lookupValues(orderedOp.depend_vec_vars());
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
size_t indexVecValues = 0;
while (indexVecValues < vecValues.size()) {
SmallVector<llvm::Value *> storeValues;
@ -531,9 +534,11 @@ convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
storeValues.push_back(vecValues[indexVecValues]);
indexVecValues++;
}
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
ompLoc, findAllocaInsertPoint(builder, moduleTranslation), numLoops,
storeValues, ".cnt.addr", isDependSource));
ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
}
return success();
}
@ -634,10 +639,12 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
// called for variables which have destructors/finalizers.
auto finiCB = [&](InsertPointTy codeGenIP) {};
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSections(
ompLoc, findAllocaInsertPoint(builder, moduleTranslation), sectionCBs,
privCB, finiCB, false, sectionsOp.nowait()));
ompLoc, allocaIP, sectionCBs, privCB, finiCB, false,
sectionsOp.nowait()));
return bodyGenStatus;
}
@ -1104,7 +1111,6 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
// Convert values and types.
auto &innerOpList = opInst.region().front().getOperations();
@ -1164,17 +1170,10 @@ convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
// Handle ambiguous alloca, if any.
auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
if (allocaIP.getPoint() == ompLoc.IP.getPoint()) {
// Same point => split basic block and make them unambigous.
llvm::UnreachableInst *unreachableInst = builder.CreateUnreachable();
builder.SetInsertPoint(builder.GetInsertBlock()->splitBasicBlock(
unreachableInst, "alloca_split"));
ompLoc.IP = builder.saveIP();
unreachableInst->eraseFromParent();
}
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(ompBuilder->createAtomicUpdate(
ompLoc, findAllocaInsertPoint(builder, moduleTranslation), llvmAtomicX,
llvmExpr, atomicOrdering, binop, updateFn, isXBinopExpr));
ompLoc, allocaIP, llvmAtomicX, llvmExpr, atomicOrdering, binop, updateFn,
isXBinopExpr));
return updateGenStatus;
}
@ -1183,7 +1182,6 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
mlir::Value mlirExpr;
bool isXBinopExpr = false, isPostfixUpdate = false;
llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
@ -1262,20 +1260,13 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
"argument");
return moduleTranslation.lookupValue(yieldop.results()[0]);
};
// Handle ambiguous alloca, if any.
auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
if (allocaIP.getPoint() == ompLoc.IP.getPoint()) {
// Same point => split basic block and make them unambigous.
llvm::UnreachableInst *unreachableInst = builder.CreateUnreachable();
builder.SetInsertPoint(builder.GetInsertBlock()->splitBasicBlock(
unreachableInst, "alloca_split"));
ompLoc.IP = builder.saveIP();
unreachableInst->eraseFromParent();
}
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
builder.restoreIP(ompBuilder->createAtomicCapture(
ompLoc, findAllocaInsertPoint(builder, moduleTranslation), llvmAtomicX,
llvmAtomicV, llvmExpr, atomicOrdering, binop, updateFn, atomicUpdateOp,
isPostfixUpdate, isXBinopExpr));
ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr));
return updateGenStatus;
}

View file

@ -1854,6 +1854,9 @@ llvm.func @omp_sections_empty() -> () {
// CHECK-LABEL: @omp_sections_trivial
llvm.func @omp_sections_trivial() -> () {
// CHECK: br label %[[ENTRY:[a-zA-Z_.]+]]
// CHECK: [[ENTRY]]:
// CHECK: br label %[[PREHEADER:.*]]
// CHECK: [[PREHEADER]]: