From 78fb4f9d5dd95d26424919f2da184e1119ccb023 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 23 Feb 2022 14:08:51 -0500 Subject: [PATCH] [SCF][MemRef] Enable SCF.Parallel Lowering to use Scope Op As discussed in https://reviews.llvm.org/D119743 scf.parallel would continuously stack allocate since the alloca op was placd in the wsloop rather than the omp.parallel. This PR is the second stage of the fix for that problem. Specifically, we now introduce an alloca scope around the inlined body of the scf.parallel and enable a canonicalization to hoist the allocations to the surrounding allocation scope (e.g. omp.parallel). Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D120423 --- mlir/include/mlir/Conversion/Passes.td | 3 +- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 1 + .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 45 +++-- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 154 ++++++++++++++++++ .../Conversion/SCFToOpenMP/reductions.mlir | 4 +- .../Conversion/SCFToOpenMP/scf-to-openmp.mlir | 5 + mlir/test/Dialect/MemRef/canonicalize.mlir | 91 +++++++++++ 7 files changed, 277 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index adb97abf925d..09a5b9358b94 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -501,7 +501,8 @@ def ConvertSCFToOpenMP : Pass<"convert-scf-to-openmp", "ModuleOp"> { let summary = "Convert SCF parallel loop to OpenMP parallel + workshare " "constructs."; let constructor = "mlir::createConvertSCFToOpenMPPass()"; - let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect"]; + let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect", + "memref::MemRefDialect"]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 1ee0b866a00a..ae4ce7d4b090 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -274,6 +274,7 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope", let regions = (region SizedRegion<1>:$bodyRegion); let hasCustomAssemblyFormat = 1; let hasVerifier = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index a9e7759aa75e..c9c5017e0382 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -364,8 +365,6 @@ struct ParallelOpLowering : public OpRewritePattern { loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); SmallVector reductionVariables; reductionVariables.reserve(parallelOp.getNumReductions()); - Value token = rewriter.create( - loc, LLVM::LLVMPointerType::get(rewriter.getIntegerType(8))); for (Value init : parallelOp.getInitVals()) { assert((LLVM::isCompatibleType(init.getType()) || init.getType().isa()) && @@ -392,31 +391,31 @@ struct ParallelOpLowering : public OpRewritePattern { // Create the parallel wrapper. auto ompParallel = rewriter.create(loc); { + OpBuilder::InsertionGuard guard(rewriter); rewriter.createBlock(&ompParallel.region()); - // Replace SCF yield with OpenMP yield. { - OpBuilder::InsertionGuard innerGuard(rewriter); - rewriter.setInsertionPointToEnd(parallelOp.getBody()); - assert(llvm::hasSingleElement(parallelOp.getRegion()) && - "expected scf.parallel to have one block"); - rewriter.replaceOpWithNewOp( - parallelOp.getBody()->getTerminator(), ValueRange()); - } + auto scope = rewriter.create(parallelOp.getLoc(), + TypeRange()); + rewriter.create(loc); + OpBuilder::InsertionGuard allocaGuard(rewriter); + rewriter.createBlock(&scope.getBodyRegion()); + rewriter.setInsertionPointToStart(&scope.getBodyRegion().front()); - // Replace the loop. - auto loop = rewriter.create( - parallelOp.getLoc(), parallelOp.getLowerBound(), - parallelOp.getUpperBound(), parallelOp.getStep()); - rewriter.create(loc); + // Replace the loop. + auto loop = rewriter.create( + parallelOp.getLoc(), parallelOp.getLowerBound(), + parallelOp.getUpperBound(), parallelOp.getStep()); + rewriter.create(loc); - rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(), - loop.region().begin()); - if (!reductionVariables.empty()) { - loop.reductionsAttr( - ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); - loop.reduction_varsMutable().append(reductionVariables); + rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(), + loop.region().begin()); + if (!reductionVariables.empty()) { + loop.reductionsAttr( + ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); + loop.reduction_varsMutable().append(reductionVariables); + } } } @@ -429,7 +428,6 @@ struct ParallelOpLowering : public OpRewritePattern { } rewriter.replaceOp(parallelOp, results); - rewriter.create(loc, token); return success(); } }; @@ -438,7 +436,8 @@ struct ParallelOpLowering : public OpRewritePattern { static LogicalResult applyPatterns(ModuleOp module) { ConversionTarget target(*module.getContext()); target.addIllegalOp(); - target.addLegalDialect(); + target.addLegalDialect(); RewritePatternSet patterns(module.getContext()); patterns.add(module.getContext()); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 4af1f5a25ba1..1fc3ab2eb280 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -258,6 +259,159 @@ void AllocaScopeOp::getSuccessorRegions( regions.push_back(RegionSuccessor(&bodyRegion())); } +/// Given an operation, return whether this op is guaranteed to +/// allocate an AutomaticAllocationScopeResource +static bool isGuaranteedAutomaticAllocationScope(Operation *op) { + MemoryEffectOpInterface interface = dyn_cast(op); + if (!interface) + return false; + for (auto res : op->getResults()) { + if (auto effect = + interface.getEffectOnValue(res)) { + if (isa( + effect->getResource())) + return true; + } + } + return false; +} + +/// Given an operation, return whether this op could to +/// allocate an AutomaticAllocationScopeResource +static bool isPotentialAutomaticAllocationScope(Operation *op) { + MemoryEffectOpInterface interface = dyn_cast(op); + if (!interface) + return true; + for (auto res : op->getResults()) { + if (auto effect = + interface.getEffectOnValue(res)) { + if (isa( + effect->getResource())) + return true; + } + } + return false; +} + +/// Return whether this op is the last non terminating op +/// in a region. That is to say, it is in a one-block region +/// and is only followed by a terminator. This prevents +/// extending the lifetime of allocations. +static bool lastNonTerminatorInRegion(Operation *op) { + return op->getNextNode() == op->getBlock()->getTerminator() && + op->getParentRegion()->getBlocks().size() == 1; +} + +/// Inline an AllocaScopeOp if either the direct parent is an allocation scope +/// or it contains no allocation. +struct AllocaScopeInliner : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllocaScopeOp op, + PatternRewriter &rewriter) const override { + if (!op->getParentOp()->hasTrait()) { + bool hasPotentialAlloca = + op->walk([&](Operation *alloc) { + if (isPotentialAutomaticAllocationScope(alloc)) + return WalkResult::interrupt(); + return WalkResult::skip(); + }).wasInterrupted(); + if (hasPotentialAlloca) + return failure(); + } + + // Only apply to if this is this last non-terminator + // op in the block (lest lifetime be extended) of a one + // block region + if (!lastNonTerminatorInRegion(op)) + return failure(); + + Block *block = &op.getRegion().front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.mergeBlockBefore(block, op); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); + return success(); + } +}; + +/// Move allocations into an allocation scope, if it is legal to +/// move them (e.g. their operands are available at the location +/// the op would be moved to). +struct AllocaScopeHoister : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllocaScopeOp op, + PatternRewriter &rewriter) const override { + + if (!op->getParentWithTrait()) + return failure(); + + Operation *lastParentWithoutScope = op->getParentOp(); + + if (!lastParentWithoutScope || + lastParentWithoutScope->hasTrait()) + return failure(); + + // Only apply to if this is this last non-terminator + // op in the block (lest lifetime be extended) of a one + // block region + if (!lastNonTerminatorInRegion(op) || + !lastNonTerminatorInRegion(lastParentWithoutScope)) + return failure(); + + while (!lastParentWithoutScope->getParentOp() + ->hasTrait()) { + lastParentWithoutScope = lastParentWithoutScope->getParentOp(); + if (!lastParentWithoutScope || + !lastNonTerminatorInRegion(lastParentWithoutScope)) + return failure(); + } + Operation *scope = lastParentWithoutScope->getParentOp(); + assert(scope->hasTrait()); + + Region *containingRegion = nullptr; + for (auto &r : lastParentWithoutScope->getRegions()) { + if (r.isAncestor(op->getParentRegion())) { + assert(containingRegion == nullptr && + "only one region can contain the op"); + containingRegion = &r; + } + } + assert(containingRegion && "op must be contained in a region"); + + SmallVector toHoist; + op->walk([&](Operation *alloc) { + if (!isGuaranteedAutomaticAllocationScope(alloc)) + return WalkResult::skip(); + + // If any operand is not defined before the location of + // lastParentWithoutScope (i.e. where we would hoist to), skip. + if (llvm::any_of(alloc->getOperands(), [&](Value v) { + return containingRegion->isAncestor(v.getParentRegion()); + })) + return WalkResult::skip(); + toHoist.push_back(alloc); + return WalkResult::advance(); + }); + + if (!toHoist.size()) + return failure(); + rewriter.setInsertionPoint(lastParentWithoutScope); + for (auto op : toHoist) { + auto cloned = rewriter.clone(*op); + rewriter.replaceOp(op, cloned->getResults()); + } + return success(); + } +}; + +void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // AssumeAlignmentOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir index cb07089a3a68..3e8881ff1d97 100644 --- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir @@ -21,12 +21,12 @@ func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) { // CHECK: %[[CST:.*]] = arith.constant 0.0 // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 - // CHECK: llvm.intr.stacksave // CHECK: %[[BUF:.*]] = llvm.alloca %[[ONE]] x f32 // CHECK: llvm.store %[[CST]], %[[BUF]] %step = arith.constant 1 : index %zero = arith.constant 0.0 : f32 // CHECK: omp.parallel + // CHECK: memref.alloca_scope // CHECK: omp.wsloop // CHECK-SAME: reduction(@[[$REDF]] -> %[[BUF]] scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) @@ -43,7 +43,6 @@ func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index, } // CHECK: omp.terminator // CHECK: llvm.load %[[BUF]] - // CHECK: llvm.intr.stackrestore return } @@ -162,6 +161,7 @@ func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index, // CHECK: llvm.store %[[IONE]], %[[BUF2]] // CHECK: omp.parallel + // CHECK: memref.alloca_scope // CHECK: omp.wsloop // CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]] // CHECK-SAME: @[[$REDF2]] -> %[[BUF2]] diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir index 1507f927b9f0..0c16b85dde48 100644 --- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir @@ -4,6 +4,7 @@ func @parallel(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { // CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> () @@ -20,9 +21,11 @@ func @parallel(%arg0: index, %arg1: index, %arg2: index, func @nested_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { // CHECK: omp.parallel + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { // CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> () @@ -41,6 +44,7 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index, func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { // CHECK: "test.payload1"(%[[LVAR_AL1]]) : (index) -> () @@ -52,6 +56,7 @@ func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index, // CHECK: } // CHECK: omp.parallel { + // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { // CHECK: "test.payload2"(%[[LVAR_AL2]]) : (index) -> () diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 58083437ca47..12c5a529e901 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -552,3 +552,94 @@ func @self_copy(%m1: memref) { // CHECK-LABEL: func @self_copy // CHECK-NEXT: return + +// ----- + +func @scopeMerge() { + memref.alloca_scope { + %cnt = "test.count"() : () -> index + %a = memref.alloca(%cnt) : memref + "test.use"(%a) : (memref) -> () + } + return +} +// CHECK: func @scopeMerge() { +// CHECK-NOT: alloca_scope +// CHECK: %[[cnt:.+]] = "test.count"() : () -> index +// CHECK: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK: return + +func @scopeMerge2() { + "test.region"() ({ + memref.alloca_scope { + %cnt = "test.count"() : () -> index + %a = memref.alloca(%cnt) : memref + "test.use"(%a) : (memref) -> () + } + "test.terminator"() : () -> () + }) : () -> () + return +} + +// CHECK: func @scopeMerge2() { +// CHECK: "test.region"() ({ +// CHECK: memref.alloca_scope { +// CHECK: %[[cnt:.+]] = "test.count"() : () -> index +// CHECK: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK: } +// CHECK: "test.terminator"() : () -> () +// CHECK: }) : () -> () +// CHECK: return +// CHECK: } + +func @scopeMerge3() { + %cnt = "test.count"() : () -> index + "test.region"() ({ + memref.alloca_scope { + %a = memref.alloca(%cnt) : memref + "test.use"(%a) : (memref) -> () + } + "test.terminator"() : () -> () + }) : () -> () + return +} + +// CHECK: func @scopeMerge3() { +// CHECK: %[[cnt:.+]] = "test.count"() : () -> index +// CHECK: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK: "test.region"() ({ +// CHECK: memref.alloca_scope { +// CHECK: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK: } +// CHECK: "test.terminator"() : () -> () +// CHECK: }) : () -> () +// CHECK: return +// CHECK: } + +func @scopeMerge4() { + %cnt = "test.count"() : () -> index + "test.region"() ({ + memref.alloca_scope { + %a = memref.alloca(%cnt) : memref + "test.use"(%a) : (memref) -> () + } + "test.op"() : () -> () + "test.terminator"() : () -> () + }) : () -> () + return +} + +// CHECK: func @scopeMerge4() { +// CHECK: %[[cnt:.+]] = "test.count"() : () -> index +// CHECK: "test.region"() ({ +// CHECK: memref.alloca_scope { +// CHECK: %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref +// CHECK: "test.use"(%[[alloc]]) : (memref) -> () +// CHECK: } +// CHECK: "test.op"() : () -> () +// CHECK: "test.terminator"() : () -> () +// CHECK: }) : () -> () +// CHECK: return +// CHECK: }