Add AllReduceOp to GPU dialect with lowering to NVVM.

The reduction operation is currently fixed to "add", and the scope is fixed to "workgroup".

The implementation is currently limited to sizes that are multiple 32 (warp size) and no larger than 1024.

PiperOrigin-RevId: 271290265
This commit is contained in:
Christian Sigg 2019-09-26 00:17:13 -07:00 committed by A. Unique TensorFlower
parent 94298cea93
commit 116dac00ba
5 changed files with 223 additions and 2 deletions

View file

@ -59,4 +59,23 @@ def gpu_Return : GPU_Op<"return", [Terminator]>, Arguments<(ins)>,
let printer = [{ p << getOperationName(); }];
}
def gpu_AllReduce : GPU_Op<"all_reduce", [SameOperandsAndResultType]>,
Arguments<(ins AnyType)>, Results<(outs AnyType)> {
let summary = "Reduce values among workgroup.";
let description = [{
The "all_reduce" op reduces the value of every invocation across a local
workgroup.
For example,
```
%1 = gpu.all_reduce %0 : f32
```
computes the sum of each invocation's %0 value. The value of %1 is always
equal for all invocations of a local workgroup.
Either none or all invocations of a local workgroup need to execute this op
in convergence.
}];
}
#endif // GPU_OPS

View file

@ -103,6 +103,173 @@ public:
}
};
// Converts all_reduce op to LLVM/NVVM ops.
struct GPUAllReduceOpLowering : public LLVMOpLowering {
explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_)
: LLVMOpLowering(gpu::AllReduce::getOperationName(),
lowering_.getDialect()->getContext(), lowering_),
int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
Value *result = createBlockReduce(op->getLoc(), operands.front(), rewriter);
rewriter.replaceOp(op, {result});
return matchSuccess();
}
private:
// Creates an all_reduce across the local workgroup.
//
// First reduce the elements within a subgroup (i.e. warp). The first
// invocation of each subgroup writes the intermediate result to shared
// memory. After synchronizing the local workgroup, each subgroup reduces all
// values from shared memory.
//
// %warp_reduce = ... (see createWarpReduce)
// %buffer = llvm.mlir.addressof @reduce_buffer : !llvm<"[32 x float]*">
// %zero = llvm.mlir.constant(0 : i32) : !llvm.i32
// %lane_id = nvvm.read.ptx.sreg.laneid : !llvm.i32
// %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i32
// llvm.cond_br %is_first_lane, ^then, ^continue
// ^then:
// %warp_id = ... (see getWarpId)
// %store_dst = llvm.getelementptr %buffer[%zero, %warp_id]
// llvm.store %store_dst, %warp_reduce : !llvm.float
// llvm.br ^continue
// ^continue:
// nvvm.barrier0
// %load_src = llvm.getelementptr %buffer[%zero, %lane_id]
// %value = llvm.load %load_src : !llvm.float
// %result = ... (see createWarpReduce)
Value *createBlockReduce(Location loc, Value *operand,
ConversionPatternRewriter &rewriter) const {
auto type = operand->getType().cast<LLVM::LLVMType>();
Value *warpReduce = createWarpReduce(loc, operand, rewriter);
auto module = warpReduce->getDefiningOp()->getParentOfType<ModuleOp>();
assert(module && "op must belong to a module");
Value *sharedMemPtr =
createSharedMemoryArray(loc, module, type, kWarpSize, rewriter);
Value *zero = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(0u));
Value *laneId = rewriter.create<NVVM::LaneIdOp>(loc, int32Type);
Value *isFirstLane = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::eq, laneId, zero);
Block *currentBlock = rewriter.getInsertionBlock();
auto currentPoint = rewriter.getInsertionPoint();
Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
Block *continueBlock = rewriter.splitBlock(thenBlock, currentPoint);
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(
loc, llvm::makeArrayRef(isFirstLane),
ArrayRef<Block *>{thenBlock, continueBlock});
rewriter.setInsertionPointToEnd(thenBlock);
Value *warpId = getWarpId(loc, rewriter);
Value *storeDst = rewriter.create<LLVM::GEPOp>(
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, warpId}));
rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst);
rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value *>(),
llvm::makeArrayRef(continueBlock));
rewriter.setInsertionPointToStart(continueBlock);
rewriter.create<NVVM::Barrier0Op>(loc);
Value *loadSrc = rewriter.create<LLVM::GEPOp>(
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, laneId}));
Value *value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
Value *result = createWarpReduce(loc, value, rewriter);
return result;
}
// Creates an all_reduce across the subgroup. Creates a preamble
//
// %active_mask = llvm.mlir.constant(-1 : i32) : !llvm.i32
// %mask_and_clamp = llvm.mlir.constant(31 : i32) : !llvm.i32
//
// plus the accumulation for i = 1, 2, 4, 8, 16:
//
// %offset = llvm.mlir.constant(i : i32) : !llvm.i32
// %value = nvvm.shfl.sync.bfly
// %active_mask, %operand, %offset, %mask_and_clamp : !llvm.float
// %operand = llvm.fadd %operand, %value : !llvm.float
//
// Each invocation returns the same result.
//
// Note: this currently only supports reducing exactly 32 values.
Value *createWarpReduce(Location loc, Value *operand,
ConversionPatternRewriter &rewriter) const {
// TODO(csigg): Generalize to partial warps and other types of accumulation.
static_assert(kWarpSize == 32, "Only warp size of 32 is supported.");
auto activeMask = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(~0u));
auto maskAndClamp = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
auto resultType = operand->getType();
for (int i = 1; i < kWarpSize; i <<= 1) {
auto offset = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(i));
auto value = rewriter.create<NVVM::ShflBflyOp>(
loc, resultType, activeMask, operand, offset, maskAndClamp);
operand = rewriter.create<LLVM::FAddOp>(loc, resultType, operand, value);
}
return operand;
}
// Creates a global array stored in shared memory.
//
// llvm.mlir.global @reduce_buffer()
// {addr_space = 3 : i32} : !llvm<"[32 x float]">
//
Value *createSharedMemoryArray(Location loc, ModuleOp module,
LLVM::LLVMType elementType, int numElements,
ConversionPatternRewriter &rewriter) const {
OpBuilder builder(module.getBodyRegion());
auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements);
StringRef name = "reduce_buffer";
auto addrSpace =
builder.getNamedAttr("addr_space", builder.getI32IntegerAttr(3));
auto globalOp = builder.create<LLVM::GlobalOp>(
loc, arrayType.cast<LLVM::LLVMType>(),
/*isConstant=*/false, name, /*value=*/Attribute(),
llvm::makeArrayRef(addrSpace));
return rewriter.create<LLVM::AddressOfOp>(loc, globalOp);
}
// Returns the index of the subgroup within the local workgroup.
//
// %warp_size = llvm.mlir.constant(32 : i32) : !llvm.i32
// %thread_idx = nvvm.read.ptx.sreg.tid.x : !llvm.i32
// %warp_idx = llvm.sdiv %thread_idx, %warp_size : !llvm.i32
//
Value *getWarpId(Location loc, ConversionPatternRewriter &rewriter) const {
auto warpSize = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
auto threadIdx = getLinearThreadIndex(loc, rewriter);
return rewriter.create<LLVM::SDivOp>(loc, int32Type, threadIdx, warpSize);
}
Value *getLinearThreadIndex(Location loc,
ConversionPatternRewriter &rewriter) const {
// TODO(csigg): support 2- and 3-dimensional blocks.
return rewriter.create<NVVM::ThreadIdXOp>(loc, int32Type);
}
LLVM::LLVMType int32Type;
// TODO(csigg): Support other warp sizes.
static constexpr int kWarpSize = 32;
};
// A pass that replaces all occurences of GPU device operations with their
// corresponding NVVM equivalent.
//
@ -126,8 +293,8 @@ public:
GPUIndexIntrinsicOpLowering<gpu::BlockId, NVVM::BlockIdXOp,
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
GPUIndexIntrinsicOpLowering<gpu::GridDim, NVVM::GridDimXOp,
NVVM::GridDimYOp, NVVM::GridDimZOp>>(
converter);
NVVM::GridDimYOp, NVVM::GridDimZOp>,
GPUAllReduceOpLowering>(converter);
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();

View file

@ -32,6 +32,13 @@ module attributes {gpu.kernel_module} {
// CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
%gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
%one = constant 1.0 : f32
// TODO(csigg): Check full IR expansion once lowering has settled.
// CHECK: nvvm.shfl.sync.bfly
// CHECK: nvvm.barrier0
// CHECK: nvvm.shfl.sync.bfly
%result = "gpu.all_reduce"(%one) {scope = "workgroup", kernel = "add"} : (f32) -> (f32)
std.return
}
}

View file

@ -76,6 +76,9 @@ func @kernel_1(%arg0 : f32, %arg1 : memref<?xf32, 1>)
%gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index)
%gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
%one = constant 1.0 : f32
%sum = "gpu.all_reduce"(%one) : (f32) -> (f32)
"some_op"(%bIdX, %tIdX) : (index, index) -> ()
%42 = load %arg1[%bIdX] : memref<?xf32, 1>
return

View file

@ -0,0 +1,25 @@
// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext --entry-point-result=void | FileCheck %s
// CHECK: [8.128000e+03, 8.128000e+03, {{.*}}, 8.128000e+03, 8.128000e+03]
func @main() {
%arg = alloc() : memref<128xf32>
%dst = memref_cast %arg : memref<128xf32> to memref<?xf32>
%zero = constant 0 : i32
%one = constant 1 : index
%size = dim %dst, 0 : memref<?xf32>
call @mcuMemHostRegister(%dst, %zero) : (memref<?xf32>, i32) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %size, %block_y = %one, %block_z = %one)
args(%kernel_dst = %dst) : memref<?xf32> {
%idx = index_cast %tx : index to i32
%val = sitofp %idx : i32 to f32
%sum = "gpu.all_reduce"(%val) { op = "add" } : (f32) -> (f32)
store %sum, %kernel_dst[%tx] : memref<?xf32>
gpu.return
}
call @mcuPrintFloat(%dst) : (memref<?xf32>) -> ()
return
}
func @mcuMemHostRegister(%ptr : memref<?xf32>, %flags : i32)
func @mcuPrintFloat(%ptr : memref<?xf32>)