//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements a pass to generate NVVMIR operations for higher-level // GPU operations. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "../GPUCommon/IndexIntrinsicsOpLowering.h" using namespace mlir; namespace { // Converts all_reduce op to LLVM/NVVM ops. struct GPUAllReduceOpLowering : public LLVMOpLowering { explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_) : LLVMOpLowering(gpu::AllReduce::getOperationName(), lowering_.getDialect()->getContext(), lowering_), int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Value *result = createBlockReduce(op->getLoc(), operands.front(), rewriter); rewriter.replaceOp(op, {result}); return matchSuccess(); } private: // Creates an all_reduce across the block. // // First reduce the elements within a warp. The first thread of each warp // writes the intermediate result to shared memory. After synchronizing the // block, the first warp reduces the values from shared memory. The result // is broadcasted to all threads through shared memory. // // %warp_reduce = `createWarpReduce(%operand)` // %shared_mem_ptr = llvm.mlir.addressof @reduce_buffer // %zero = llvm.mlir.constant(0 : i32) : !llvm.i32 // %lane_id = nvvm.read.ptx.sreg.laneid : !llvm.i32 // %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i1 // %thread_idx = `getLinearThreadIndex()` : !llvm.i32 // llvm.cond_br %is_first_lane, ^then1, ^continue1 // ^then1: // %warp_id = `getWarpId()` // %store_dst = llvm.getelementptr %shared_mem_ptr[%zero, %warp_id] // llvm.store %store_dst, %warp_reduce // llvm.br ^continue1 // ^continue1: // nvvm.barrier0 // %num_warps = `getNumWarps()` : !llvm.i32 // %is_valid_warp = llvm.icmp "slt" %thread_idx, %num_warps // %result_ptr = llvm.getelementptr %shared_mem_ptr[%zero, %zero] // llvm.cond_br %is_first_lane, ^then2, ^continue2 // ^then2: // %load_src = llvm.getelementptr %shared_mem_ptr[%zero, %thread_idx] // %value = llvm.load %load_src // %result = `createWarpReduce(%value)` // llvm.store %result_ptr, %result // llvm.br ^continue2 // ^continue2: // nvvm.barrier0 // %result = llvm.load %result_ptr // return %result // Value *createBlockReduce(Location loc, Value *operand, ConversionPatternRewriter &rewriter) const { auto type = operand->getType().cast(); // Create shared memory array to store the warp reduction. auto module = operand->getDefiningOp()->getParentOfType(); assert(module && "op must belong to a module"); Value *sharedMemPtr = createSharedMemoryArray(loc, module, type, kWarpSize, rewriter); Value *zero = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(0u)); Value *laneId = rewriter.create(loc, int32Type); Value *isFirstLane = rewriter.create( loc, LLVM::ICmpPredicate::eq, laneId, zero); Value *threadIdx = getLinearThreadIndex(loc, rewriter); Value *blockSize = getBlockSize(loc, rewriter); Value *activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter); // Reduce elements within each warp to produce the intermediate results. Value *warpReduce = createWarpReduce(loc, activeWidth, laneId, operand, rewriter); // Write the intermediate results to shared memory, using the first lane of // each warp. createPredicatedBlock(loc, rewriter, isFirstLane, [&] { Value *warpId = getDivideByWarpSize(threadIdx, rewriter); Value *storeDst = rewriter.create( loc, type, sharedMemPtr, ArrayRef({zero, warpId})); rewriter.create(loc, warpReduce, storeDst); }); rewriter.create(loc); Value *numWarps = getNumWarps(loc, blockSize, rewriter); Value *isValidWarp = rewriter.create( loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps); Value *resultPtr = rewriter.create( loc, type, sharedMemPtr, ArrayRef({zero, zero})); // Use the first numWarps threads to reduce the intermediate results from // shared memory. The final result is written to shared memory again. createPredicatedBlock(loc, rewriter, isValidWarp, [&] { Value *loadSrc = rewriter.create( loc, type, sharedMemPtr, ArrayRef({zero, threadIdx})); Value *value = rewriter.create(loc, type, loadSrc); Value *result = createWarpReduce(loc, numWarps, laneId, value, rewriter); rewriter.create(loc, result, resultPtr); }); rewriter.create(loc); // Load and return result from shared memory. Value *result = rewriter.create(loc, type, resultPtr); return result; } // Creates an if-block skeleton and calls the two factories to generate the // ops in the `then` and `else` block.. // // llvm.cond_br %condition, ^then, ^continue // ^then: // %then_operands = `thenOpsFactory()` // llvm.br ^continue(%then_operands) // ^else: // %else_operands = `elseOpsFactory()` // llvm.br ^continue(%else_operands) // ^continue(%block_operands): // template void createIf(Location loc, ConversionPatternRewriter &rewriter, Value *condition, ThenOpsFactory &&thenOpsFactory, ElseOpsFactory &&elseOpsFactory) const { Block *currentBlock = rewriter.getInsertionBlock(); auto currentPoint = rewriter.getInsertionPoint(); Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(loc, llvm::makeArrayRef(condition), ArrayRef{thenBlock, elseBlock}); auto addBranch = [&](ArrayRef operands) { rewriter.create(loc, ArrayRef{}, llvm::makeArrayRef(continueBlock), llvm::makeArrayRef(operands)); }; rewriter.setInsertionPointToStart(thenBlock); auto thenOperands = thenOpsFactory(); addBranch(thenOperands); rewriter.setInsertionPointToStart(elseBlock); auto elseOperands = elseOpsFactory(); addBranch(elseOperands); assert(thenOperands.size() == elseOperands.size()); rewriter.setInsertionPointToStart(continueBlock); for (auto *operand : thenOperands) continueBlock->addArgument(operand->getType()); } // Shortcut for createIf with empty else block and no block operands. template void createPredicatedBlock(Location loc, ConversionPatternRewriter &rewriter, Value *condition, Factory &&predicatedOpsFactory) const { createIf( loc, rewriter, condition, [&] { predicatedOpsFactory(); return ArrayRef(); }, [&] { return ArrayRef(); }); } // Creates a reduction across the first activeWidth lanes of a warp. // The first lane returns the result, all others return values are undefined. Value *createWarpReduce(Location loc, Value *activeWidth, Value *laneId, Value *operand, ConversionPatternRewriter &rewriter) const { // TODO(csigg): Generalize to other types of accumulation. Value *warpSize = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); Value *maskAndClamp = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); Value *isPartialWarp = rewriter.create( loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize); auto type = operand->getType(); createIf( loc, rewriter, isPartialWarp, // Generate reduction over a (potentially) partial warp. [&] { Value *value = operand; Value *one = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(1)); // Bit mask of active lanes: `(1 << activeWidth) - 1`. Value *activeMask = rewriter.create( loc, int32Type, rewriter.create(loc, int32Type, one, activeWidth), one); // Bound of offsets which read from a lane within the active range. Value *offsetBound = rewriter.create(loc, activeWidth, laneId); // Repeatedly shuffle value from 'laneId + i' and accumulate if source // lane is within the active range. The first lane contains the final // result, all other lanes contain some undefined partial result. for (int i = 1; i < kWarpSize; i <<= 1) { Value *offset = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(i)); // ShflDownOp instead of ShflBflyOp would produce a scan. ShflBflyOp // also produces the correct reduction on lane 0 though. Value *shfl = rewriter.create( loc, type, activeMask, value, offset, maskAndClamp); // TODO(csigg): use the second result from the shuffle op instead. Value *isActiveSrcLane = rewriter.create( loc, LLVM::ICmpPredicate::slt, offset, offsetBound); // Skip the accumulation if the shuffle op read from a lane outside // of the active range. createIf( loc, rewriter, isActiveSrcLane, [&] { return llvm::SmallVector{ rewriter.create(loc, type, value, shfl)}; }, [&] { return llvm::makeArrayRef(value); }); value = rewriter.getInsertionBlock()->getArgument(0); } return llvm::SmallVector{value}; }, // Generate a reduction over the entire warp. This is a specialization // of the above reduction with unconditional accumulation. [&] { Value *value = operand; Value *activeMask = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(~0u)); for (int i = 1; i < kWarpSize; i <<= 1) { Value *offset = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(i)); Value *shfl = rewriter.create( loc, type, activeMask, value, offset, maskAndClamp); value = rewriter.create(loc, type, value, shfl); } return llvm::SmallVector{value}; }); return rewriter.getInsertionBlock()->getArgument(0); } // Creates a global array stored in shared memory. Value *createSharedMemoryArray(Location loc, ModuleOp module, LLVM::LLVMType elementType, int numElements, ConversionPatternRewriter &rewriter) const { OpBuilder builder(module.getBodyRegion()); auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); StringRef name = "reduce_buffer"; auto addrSpace = builder.getNamedAttr("addr_space", builder.getI32IntegerAttr(3)); auto globalOp = builder.create( loc, arrayType.cast(), /*isConstant=*/false, name, /*value=*/Attribute(), llvm::makeArrayRef(addrSpace)); return rewriter.create(loc, globalOp); } // Returns the index of the thread within the block. Value *getLinearThreadIndex(Location loc, ConversionPatternRewriter &rewriter) const { Value *dimX = rewriter.create(loc, int32Type); Value *dimY = rewriter.create(loc, int32Type); Value *idX = rewriter.create(loc, int32Type); Value *idY = rewriter.create(loc, int32Type); Value *idZ = rewriter.create(loc, int32Type); Value *tmp1 = rewriter.create(loc, int32Type, idZ, dimY); Value *tmp2 = rewriter.create(loc, int32Type, tmp1, idY); Value *tmp3 = rewriter.create(loc, int32Type, tmp2, dimX); return rewriter.create(loc, int32Type, tmp3, idX); } // Returns the number of threads in the block. Value *getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const { Value *dimX = rewriter.create(loc, int32Type); Value *dimY = rewriter.create(loc, int32Type); Value *dimZ = rewriter.create(loc, int32Type); Value *dimXY = rewriter.create(loc, int32Type, dimX, dimY); return rewriter.create(loc, int32Type, dimXY, dimZ); } // Returns the number of warps in the block. Value *getNumWarps(Location loc, Value *blockSize, ConversionPatternRewriter &rewriter) const { auto warpSizeMinusOne = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); auto biasedBlockSize = rewriter.create( loc, int32Type, blockSize, warpSizeMinusOne); return getDivideByWarpSize(biasedBlockSize, rewriter); } // Returns the number of active threads in the warp, not clamped to 32. Value *getActiveWidth(Location loc, Value *threadIdx, Value *blockSize, ConversionPatternRewriter &rewriter) const { Value *threadIdxMask = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(~(kWarpSize - 1))); Value *numThreadsWithSmallerWarpId = rewriter.create(loc, threadIdx, threadIdxMask); return rewriter.create(loc, blockSize, numThreadsWithSmallerWarpId); } // Returns value divided by the warp size (i.e. 32). Value *getDivideByWarpSize(Value *value, ConversionPatternRewriter &rewriter) const { auto loc = value->getLoc(); auto warpSize = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); return rewriter.create(loc, int32Type, value, warpSize); } LLVM::LLVMType int32Type; static constexpr int kWarpSize = 32; }; // A pass that replaces all occurrences of GPU device operations with their // corresponding NVVM equivalent. // // This pass only handles device code and is not meant to be run on GPU host // code. class LowerGpuOpsToNVVMOpsPass : public ModulePass { public: void runOnModule() override { ModuleOp m = getModule(); if (!m.getAttrOfType(gpu::GPUDialect::getKernelModuleAttrName())) return; OwningRewritePatternList patterns; LLVMTypeConverter converter(m.getContext()); populateStdToLLVMConversionPatterns(converter, patterns); patterns.insert< GPUIndexIntrinsicOpLowering, GPUIndexIntrinsicOpLowering, GPUIndexIntrinsicOpLowering, GPUIndexIntrinsicOpLowering, GPUAllReduceOpLowering>(converter); ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalDialect(); target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); if (failed(applyPartialConversion(m, target, patterns, &converter))) signalPassFailure(); } }; } // anonymous namespace std::unique_ptr> mlir::createLowerGpuOpsToNVVMOpsPass() { return std::make_unique(); } static PassRegistration pass("lower-gpu-ops-to-nvvm-ops", "Generate NVVM operations for gpu operations");