[MLIR] Make gpu.launch implicitly capture uses of values defined above.

Summary:
In the original design, gpu.launch required explicit capture of uses
and passing them as operands to the gpu.launch operation. This was
motivated by infrastructure restrictions rather than design. This
change lifts the requirement and removes the concept of kernel
arguments from gpu.launch. Instead, the kernel outlining
transformation now does the explicit capturing.

This is a breaking change for users of gpu.launch.

Differential Revision: https://reviews.llvm.org/D73769
This commit is contained in:
Stephan Herhut 2020-01-31 10:29:29 +01:00
parent 2663a25fad
commit 283b5e733d
15 changed files with 84 additions and 381 deletions

View file

@ -335,42 +335,32 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func">,
let verifier = [{ return ::verify(*this); }];
}
def GPU_LaunchOp : GPU_Op<"launch", [IsolatedFromAbove]>,
def GPU_LaunchOp : GPU_Op<"launch">,
Arguments<(ins Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ,
Variadic<AnyType>:$operands)>,
Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ)>,
Results<(outs)> {
let summary = "GPU kernel launch operation";
let description = [{
Launch a kernel on the specified grid of thread blocks. The body of the
kernel is defined by the single region that this operation contains. The
operation takes at least six operands, with first three operands being grid
sizes along x,y,z dimensions, the following three arguments being block
sizes along x,y,z dimension, and the remaining operands are arguments of the
kernel. When a lower-dimensional kernel is required, unused sizes must be
explicitly set to `1`.
operation takes six operands, with first three operands being grid sizes
along x,y,z dimensions and the following three arguments being block sizes
along x,y,z dimension. When a lower-dimensional kernel is required,
unused sizes must be explicitly set to `1`.
The body region has at least _twelve_ arguments, grouped as follows:
The body region has _twelve_ arguments, grouped as follows:
- three arguments that contain block identifiers along x,y,z dimensions;
- three arguments that contain thread identifiers along x,y,z dimensions;
- operands of the `gpu.launch` operation as is, including six leading
operands for grid and block sizes.
Operations inside the body region, and any operations in the nested regions,
are _not_ allowed to use values defined outside the _body_ region, as if
this region was a function. If necessary, values must be passed as kernel
arguments into the body region. Nested regions inside the kernel body are
allowed to use values defined in their ancestor regions as long as they
don't cross the kernel body region boundary.
- operands of the `gpu.launch` operation as is (i.e. the operands for
grid and block sizes).
Syntax:
```
operation ::= `gpu.launch` `block` `(` ssa-id-list `)` `in` ssa-reassignment
`threads` `(` ssa-id-list `)` `in` ssa-reassignment
(`args` ssa-reassignment `:` type-list)?
region attr-dict?
ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
```
@ -379,32 +369,29 @@ def GPU_LaunchOp : GPU_Op<"launch", [IsolatedFromAbove]>,
```mlir
gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %0, %sz_by = %1, %sz_bz = %2)
threads(%tx, %ty, %tz) in (%sz_tx = %3, %sz_ty = %4, %sz_tz = %5)
args(%arg0 = %6, %arg1 = 7) : f32, memref<?xf32, 1> {
threads(%tx, %ty, %tz) in (%sz_tx = %3, %sz_ty = %4, %sz_tz = %5) {
// Block and thread identifiers, as well as block/grid sizes are
// immediately usable inside body region.
"some_op"(%bx, %tx) : (index, index) -> ()
%42 = load %arg1[%bx] : memref<?xf32, 1>
// Assuming %val1 is defined outside the gpu.launch region.
%42 = load %val1[%bx] : memref<?xf32, 1>
}
// Generic syntax explains how the pretty syntax maps to the IR structure.
"gpu.launch"(%cst, %cst, %c1, // Grid sizes.
%cst, %c1, %c1, // Block sizes.
%arg0, %arg1) // Actual arguments.
%cst, %c1, %c1) // Block sizes.
{/*attributes*/}
// All sizes and identifiers have "index" size.
: (index, index, index, index, index, index, f32, memref<?xf32, 1>)
-> () {
: (index, index, index, index, index, index) -> () {
// The operation passes block and thread identifiers, followed by grid and
// block sizes, followed by actual arguments to the entry block of the
// region.
// block sizes.
^bb0(%bx : index, %by : index, %bz : index,
%tx : index, %ty : index, %tz : index,
%num_bx : index, %num_by : index, %num_bz : index,
%num_tx : index, %num_ty : index, %num_tz : index,
%arg0 : f32, %arg1 : memref<?xf32, 1>):
%num_tx : index, %num_ty : index, %num_tz : index)
"some_op"(%bx, %tx) : (index, index) -> ()
%3 = "std.load"(%arg1, %bx) : (memref<?xf32, 1>, index) -> f32
%3 = "std.load"(%val1, %bx) : (memref<?xf32, 1>, index) -> f32
}
```
@ -422,12 +409,9 @@ def GPU_LaunchOp : GPU_Op<"launch", [IsolatedFromAbove]>,
let builders = [
OpBuilder<"Builder *builder, OperationState &result, Value gridSizeX,"
"Value gridSizeY, Value gridSizeZ, Value blockSizeX,"
"Value blockSizeY, Value blockSizeZ,"
"ValueRange operands">
"Value blockSizeY, Value blockSizeZ">
];
let hasCanonicalizer = 1;
let extraClassDeclaration = [{
/// Get the SSA values corresponding to kernel block identifiers.
KernelDim3 getBlockIds();
@ -437,26 +421,14 @@ def GPU_LaunchOp : GPU_Op<"launch", [IsolatedFromAbove]>,
KernelDim3 getGridSize();
/// Get the SSA values corresponding to kernel block size.
KernelDim3 getBlockSize();
/// Get the operand values passed as kernel arguments.
operand_range getKernelOperandValues();
/// Get the operand types passed as kernel arguments.
operand_type_range getKernelOperandTypes();
/// Get the SSA values passed as operands to specify the grid size.
KernelDim3 getGridSizeOperandValues();
/// Get the SSA values passed as operands to specify the block size.
KernelDim3 getBlockSizeOperandValues();
/// Get the SSA values of the kernel arguments.
iterator_range<Block::args_iterator> getKernelArguments();
/// Erase the `index`-th kernel argument. Both the entry block argument and
/// the operand will be dropped. The block argument must not have any uses.
void eraseKernelArgument(unsigned index);
static StringRef getBlocksKeyword() { return "blocks"; }
static StringRef getThreadsKeyword() { return "threads"; }
static StringRef getArgsKeyword() { return "args"; }
/// The number of launch configuration operands, placed at the leading
/// positions of the operand list.

View file

@ -357,32 +357,15 @@ static LogicalResult createLaunchFromOp(OpTy rootForOp,
workGroupSize3D[workGroupSize.index()] = workGroupSize.value();
}
// Get the values used within the region of the rootForOp but defined above
// it.
llvm::SetVector<Value> valuesToForwardSet;
getUsedValuesDefinedAbove(rootForOp.region(), rootForOp.region(),
valuesToForwardSet);
// Also add the values used for the lb, ub, and step of the rootForOp.
valuesToForwardSet.insert(rootForOp.getOperands().begin(),
rootForOp.getOperands().end());
auto valuesToForward = valuesToForwardSet.takeVector();
auto launchOp = builder.create<gpu::LaunchOp>(
rootForOp.getLoc(), numWorkGroups3D[0], numWorkGroups3D[1],
numWorkGroups3D[2], workGroupSize3D[0], workGroupSize3D[1],
workGroupSize3D[2], valuesToForward);
workGroupSize3D[2]);
if (failed(createLaunchBody(builder, rootForOp, launchOp,
numWorkGroups.size(), workGroupSizes.size()))) {
return failure();
}
// Replace values that are used within the region of the launchOp but are
// defined outside. They all are replaced with kernel arguments.
for (auto pair :
llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) {
Value from = std::get<0>(pair);
Value to = std::get<1>(pair);
replaceAllUsesInRegionWith(from, to, launchOp.body());
}
return success();
}
@ -411,24 +394,13 @@ void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp,
Value blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne;
// Create a launch op and move the body region of the innermost loop to the
// launch op. Pass the values defined outside the outermost loop and used
// inside the innermost loop and loop lower bounds as kernel data arguments.
// Still assuming perfect nesting so there are no values other than induction
// variables that are defined in one loop and used in deeper loops.
llvm::SetVector<Value> valuesToForwardSet;
getUsedValuesDefinedAbove(innermostForOp.region(), rootForOp.region(),
valuesToForwardSet);
auto valuesToForward = valuesToForwardSet.takeVector();
auto originallyForwardedValues = valuesToForward.size();
valuesToForward.insert(valuesToForward.end(), lbs.begin(), lbs.end());
valuesToForward.insert(valuesToForward.end(), steps.begin(), steps.end());
// launch op.
auto launchOp = builder.create<gpu::LaunchOp>(
rootForOp.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX,
blockSizeY, blockSizeZ, valuesToForward);
valuesToForward.resize(originallyForwardedValues);
blockSizeY, blockSizeZ);
// Replace the loop terminator (loops contain only a single block) with the
// gpu return and move the operations from the loop body block to the gpu
// gpu terminator and move the operations from the loop body block to the gpu
// launch body block. Do not move the entire block because of the difference
// in block arguments.
Operation &terminator = innermostForOp.getBody()->back();
@ -445,9 +417,8 @@ void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp,
// from 0 to N with step 1. Therefore, loop induction variables are replaced
// with (gpu-thread/block-id * S) + LB.
builder.setInsertionPointToStart(&launchOp.body().front());
auto lbArgumentIt = std::next(launchOp.getKernelArguments().begin(),
originallyForwardedValues);
auto stepArgumentIt = std::next(lbArgumentIt, lbs.size());
auto lbArgumentIt = lbs.begin();
auto stepArgumentIt = steps.begin();
for (auto en : llvm::enumerate(ivs)) {
Value id =
en.index() < numBlockDims
@ -460,22 +431,10 @@ void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp,
Value ivReplacement =
builder.create<AddIOp>(rootForOp.getLoc(), *lbArgumentIt, id);
en.value().replaceAllUsesWith(ivReplacement);
replaceAllUsesInRegionWith(steps[en.index()], *stepArgumentIt,
launchOp.body());
std::advance(lbArgumentIt, 1);
std::advance(stepArgumentIt, 1);
}
// Remap the values defined outside the body to use kernel arguments instead.
// The list of kernel arguments also contains the lower bounds for loops at
// trailing positions, make sure we don't touch those.
for (auto pair :
llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) {
Value from = std::get<0>(pair);
Value to = std::get<1>(pair);
replaceAllUsesInRegionWith(from, to, launchOp.body());
}
// We are done and can erase the original outermost loop.
rootForOp.erase();
}

View file

@ -196,11 +196,10 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) {
void LaunchOp::build(Builder *builder, OperationState &result, Value gridSizeX,
Value gridSizeY, Value gridSizeZ, Value blockSizeX,
Value blockSizeY, Value blockSizeZ, ValueRange operands) {
Value blockSizeY, Value blockSizeZ) {
// Add grid and block sizes as op operands, followed by the data operands.
result.addOperands(
{gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
result.addOperands(operands);
// Create a kernel body region with kNumConfigRegionAttributes + N arguments,
// where the first kNumConfigRegionAttributes arguments have `index` type and
@ -209,7 +208,6 @@ void LaunchOp::build(Builder *builder, OperationState &result, Value gridSizeX,
Block *body = new Block();
body->addArguments(
std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType()));
body->addArguments(llvm::to_vector<4>(operands.getTypes()));
kernelRegion->push_back(body);
}
@ -237,14 +235,6 @@ KernelDim3 LaunchOp::getBlockSize() {
return KernelDim3{args[9], args[10], args[11]};
}
LaunchOp::operand_range LaunchOp::getKernelOperandValues() {
return llvm::drop_begin(getOperands(), kNumConfigOperands);
}
LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() {
return llvm::drop_begin(getOperandTypes(), kNumConfigOperands);
}
KernelDim3 LaunchOp::getGridSizeOperandValues() {
return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
}
@ -253,11 +243,6 @@ KernelDim3 LaunchOp::getBlockSizeOperandValues() {
return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
}
iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() {
auto args = body().getBlocks().front().getArguments();
return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes);
}
static LogicalResult verify(LaunchOp op) {
// Kernel launch takes kNumConfigOperands leading operands for grid/block
// sizes and transforms them into kNumConfigRegionAttributes region arguments
@ -312,25 +297,6 @@ static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
printSizeAssignment(p, op.getBlockSize(), operands.slice(3, 3),
op.getThreadIds());
// From now on, the first kNumConfigOperands operands corresponding to grid
// and block sizes are irrelevant, so we can drop them.
operands = operands.drop_front(LaunchOp::kNumConfigOperands);
// Print the data argument remapping.
if (!op.body().empty() && !operands.empty()) {
p << ' ' << op.getArgsKeyword() << '(';
Block *entryBlock = &op.body().front();
interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
p << entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i)
<< " = " << operands[i];
});
p << ") ";
}
// Print the types of data arguments.
if (!operands.empty())
p << ": " << operands.getTypes();
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op.getAttrs());
}
@ -368,8 +334,7 @@ parseSizeAssignment(OpAsmParser &parser,
// Parses a Launch operation.
// operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
// (`args` ssa-reassignment `:` type-list)?
// region attr-dict?
// region attr-dict?
// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) {
// Sizes of the grid and block.
@ -402,103 +367,17 @@ static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) {
result.operands))
return failure();
// If kernel argument renaming segment is present, parse it. When present,
// the segment should have at least one element. If this segment is present,
// so is the trailing type list. Parse it as well and use the parsed types
// to resolve the operands passed to the kernel arguments.
SmallVector<Type, 4> dataTypes;
if (!parser.parseOptionalKeyword(LaunchOp::getArgsKeyword())) {
llvm::SMLoc argsLoc = parser.getCurrentLocation();
regionArgs.push_back({});
dataOperands.push_back({});
if (parser.parseLParen() || parser.parseRegionArgument(regionArgs.back()) ||
parser.parseEqual() || parser.parseOperand(dataOperands.back()))
return failure();
while (!parser.parseOptionalComma()) {
regionArgs.push_back({});
dataOperands.push_back({});
if (parser.parseRegionArgument(regionArgs.back()) ||
parser.parseEqual() || parser.parseOperand(dataOperands.back()))
return failure();
}
if (parser.parseRParen() || parser.parseColonTypeList(dataTypes) ||
parser.resolveOperands(dataOperands, dataTypes, argsLoc,
result.operands))
return failure();
}
// Introduce the body region and parse it. The region has
// kNumConfigRegionAttributes leading arguments that correspond to
// Introduce the body region and parse it. The region has
// kNumConfigRegionAttributes arguments that correspond to
// block/thread identifiers and grid/block sizes, all of the `index` type.
// Follow the actual kernel arguments.
Type index = parser.getBuilder().getIndexType();
dataTypes.insert(dataTypes.begin(), LaunchOp::kNumConfigRegionAttributes,
index);
SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
LaunchOp::kNumConfigRegionAttributes, index);
Region *body = result.addRegion();
return failure(parser.parseRegion(*body, regionArgs, dataTypes) ||
parser.parseOptionalAttrDict(result.attributes));
}
void LaunchOp::eraseKernelArgument(unsigned index) {
Block &entryBlock = body().front();
assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes &&
"kernel argument index overflow");
entryBlock.eraseArgument(kNumConfigRegionAttributes + index);
getOperation()->eraseOperand(kNumConfigOperands + index);
}
namespace {
// Clone any known constants passed as operands to the kernel into its body.
class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
using OpRewritePattern<LaunchOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(LaunchOp launchOp,
PatternRewriter &rewriter) const override {
rewriter.startRootUpdate(launchOp);
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&launchOp.body().front());
// Traverse operands passed to kernel and check if some of them are known
// constants. If so, clone the constant operation inside the kernel region
// and use it instead of passing the value from the parent region. Perform
// the traversal in the inverse order to simplify index arithmetics when
// dropping arguments.
auto operands = launchOp.getKernelOperandValues();
auto kernelArgs = launchOp.getKernelArguments();
bool found = false;
for (unsigned i = operands.size(); i > 0; --i) {
unsigned index = i - 1;
Value operand = operands[index];
if (!isa_and_nonnull<ConstantOp>(operand.getDefiningOp()))
continue;
found = true;
Value internalConstant =
rewriter.clone(*operand.getDefiningOp())->getResult(0);
Value kernelArg = *std::next(kernelArgs.begin(), index);
kernelArg.replaceAllUsesWith(internalConstant);
launchOp.eraseKernelArgument(index);
}
if (!found) {
rewriter.cancelRootUpdate(launchOp);
return matchFailure();
}
rewriter.finalizeRootUpdate(launchOp);
return matchSuccess();
}
};
} // end namespace
void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<PropagateConstantBounds>(context);
}
//===----------------------------------------------------------------------===//
// LaunchFuncOp
//===----------------------------------------------------------------------===//

View file

@ -17,6 +17,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"
using namespace mlir;
@ -100,13 +101,22 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc,
// Outline the `gpu.launch` operation body into a kernel function. Replace
// `gpu.terminator` operations by `gpu.return` in the generated function.
static gpu::GPUFuncOp outlineKernelFunc(gpu::LaunchOp launchOp) {
static gpu::GPUFuncOp outlineKernelFunc(gpu::LaunchOp launchOp,
llvm::SetVector<Value> &operands) {
Location loc = launchOp.getLoc();
// Create a builder with no insertion point, insertion will happen separately
// due to symbol table manipulation.
OpBuilder builder(launchOp.getContext());
SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
// Identify uses from values defined outside of the scope of the launch
// operation.
getUsedValuesDefinedAbove(launchOp.body(), operands);
SmallVector<Type, 4> kernelOperandTypes;
kernelOperandTypes.reserve(operands.size());
for (Value operand : operands) {
kernelOperandTypes.push_back(operand.getType());
}
FunctionType type =
FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
std::string kernelFuncName =
@ -116,6 +126,11 @@ static gpu::GPUFuncOp outlineKernelFunc(gpu::LaunchOp launchOp) {
builder.getUnitAttr());
outlinedFunc.body().takeBody(launchOp.body());
injectGpuIndexOperations(loc, outlinedFunc.body());
Block &entryBlock = outlinedFunc.body().front();
for (Value operand : operands) {
BlockArgument newArg = entryBlock.addArgument(operand.getType());
replaceAllUsesInRegionWith(operand, newArg, outlinedFunc.body());
}
outlinedFunc.walk([](gpu::TerminatorOp op) {
OpBuilder replacer(op);
replacer.create<gpu::ReturnOp>(op.getLoc());
@ -129,11 +144,12 @@ static gpu::GPUFuncOp outlineKernelFunc(gpu::LaunchOp launchOp) {
// `kernelFunc`. The kernel func contains the body of the `gpu.launch` with
// constant region arguments inlined.
static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp,
gpu::GPUFuncOp kernelFunc) {
gpu::GPUFuncOp kernelFunc,
ValueRange operands) {
OpBuilder builder(launchOp);
auto launchFuncOp = builder.create<gpu::LaunchFuncOp>(
launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
launchOp.getBlockSizeOperandValues(), launchOp.getKernelOperandValues());
launchOp.getBlockSizeOperandValues(), operands);
inlineBeneficiaryOps(kernelFunc, launchFuncOp);
launchOp.erase();
}
@ -158,7 +174,8 @@ public:
// Insert just after the function.
Block::iterator insertPt(func.getOperation()->getNextNode());
func.walk([&](gpu::LaunchOp op) {
gpu::GPUFuncOp outlinedFunc = outlineKernelFunc(op);
llvm::SetVector<Value> operands;
gpu::GPUFuncOp outlinedFunc = outlineKernelFunc(op, operands);
// Create nested module and insert outlinedFunc. The module will
// originally get the same name as the function, but may be renamed on
@ -167,7 +184,7 @@ public:
symbolTable.insert(kernelModule, insertPt);
// Potentially changes signature, pulling in constants.
convertToLaunchFuncOp(op, outlinedFunc);
convertToLaunchFuncOp(op, outlinedFunc, operands.getArrayRef());
modified = true;
});
}

View file

@ -8,16 +8,16 @@ module {
%1 = dim %arg0, 1 : memref<?x?xf32>
%c0 = constant 0 : index
%c1 = constant 1 : index
// CHECK: gpu.launch blocks([[ARG5:%.*]], [[ARG6:%.*]], [[ARG7:%.*]]) in ([[ARG11:%.*]] = {{%.*}}, [[ARG12:%.*]] = {{%.*}}, [[ARG13:%.*]] = {{%.*}}) threads([[ARG8:%.*]], [[ARG9:%.*]], [[ARG10:%.*]]) in ([[ARG14:%.*]] = {{%.*}}, [[ARG15:%.*]] = {{%.*}}, [[ARG16:%.*]] = {{%.*}}) args([[ARG17:%.*]] = [[ARG3]], [[ARG18:%.*]] = [[ARG4]], [[ARG19:%.*]] = [[ARG1]], [[ARG20:%.*]] = {{%.*}}, {{%.*}} = {{%.*}}, [[ARG22:%.*]] = [[ARG0]], [[ARG23:%.*]] = [[ARG2]]
// CHECK: [[TEMP1:%.*]] = muli [[ARG17]], [[ARG6]] : index
// CHECK: gpu.launch blocks([[ARG5:%.*]], [[ARG6:%.*]], [[ARG7:%.*]]) in ([[ARG11:%.*]] = {{%.*}}, [[ARG12:%.*]] = {{%.*}}, [[ARG13:%.*]] = {{%.*}}) threads([[ARG8:%.*]], [[ARG9:%.*]], [[ARG10:%.*]]) in ([[ARG14:%.*]] = {{%.*}}, [[ARG15:%.*]] = {{%.*}}, [[ARG16:%.*]] = {{%.*}})
// CHECK: [[TEMP1:%.*]] = muli [[ARG3]], [[ARG6]] : index
// CHECK: [[BLOCKLOOPYLB:%.*]] = addi {{%.*}}, [[TEMP1]] : index
// CHECK: [[BLOCKLOOPYSTEP:%.*]] = muli [[ARG17]], [[ARG12]] : index
// CHECK: [[BLOCKLOOPYSTEP:%.*]] = muli [[ARG3]], [[ARG12]] : index
// CHECK: loop.for [[BLOCKLOOPYIV:%.*]] = [[BLOCKLOOPYLB]] to {{%.*}} step [[BLOCKLOOPYSTEP]]
loop.for %iv1 = %c0 to %0 step %arg3 {
// CHECK: [[TEMP2:%.*]] = muli [[ARG18]], [[ARG5]] : index
// CHECK: [[TEMP2:%.*]] = muli [[ARG4]], [[ARG5]] : index
// CHECK: [[BLOCKLOOPXLB:%.*]] = addi {{%.*}}, [[TEMP2]] : index
// CHECK: [[BLOCKLOOPXSTEP:%.*]] = muli [[ARG18]], [[ARG11]] : index
// CHECK: [[BLOCKLOOPXSTEP:%.*]] = muli [[ARG4]], [[ARG11]] : index
// CHECK: loop.for [[BLOCKLOOPXIV:%.*]] = [[BLOCKLOOPXLB]] to {{%.*}} step [[BLOCKLOOPXSTEP]]
loop.for %iv2 = %c0 to %1 step %arg4 {
@ -27,7 +27,7 @@ module {
%2 = alloc(%arg3, %arg4) : memref<?x?xf32>
// Load transpose tile
// CHECK: [[TEMP3:%.*]] = muli [[ARG20]], [[ARG9:%.*]] : index
// CHECK: [[TEMP3:%.*]] = muli [[ARG20:%.*]], [[ARG9:%.*]] : index
// CHECK: [[THREADLOOP1YLB:%.*]] = addi {{%.*}}, [[TEMP3]] : index
// CHECK: [[THREADLOOP1YSTEP:%.*]] = muli [[ARG20]], [[ARG15]] : index
// CHECK: loop.for [[THREADLOOP1YIV:%.*]] = [[THREADLOOP1YLB]] to {{%.*}} step [[THREADLOOP1YSTEP]]
@ -41,7 +41,7 @@ module {
%10 = addi %iv1, %iv3 : index
// CHECK: [[INDEX1:%.*]] = addi [[BLOCKLOOPXIV]], [[THREADLOOP1XIV]] : index
%11 = addi %iv2, %iv4 : index
// CHECK: [[VAL1:%.*]] = load [[ARG19]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} : memref<?x?xf32>
// CHECK: [[VAL1:%.*]] = load [[ARG1]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} : memref<?x?xf32>
%12 = load %arg1[%11, %10] : memref<?x?xf32>
// CHECK: store [[VAL1]], [[SCRATCHSPACE:%.*]]{{\[}}[[THREADLOOP1XIV]], [[THREADLOOP1YIV]]{{\]}} : memref<?x?xf32>
store %12, %2[%iv4, %iv3] : memref<?x?xf32>
@ -67,10 +67,10 @@ module {
%14 = addi %iv2, %iv4 : index
// CHECK: {{%.*}} = load [[SCRATCHSPACE]]{{\[}}[[THREADLOOP2XIV]], [[THREADLOOP2YIV]]{{\]}} : memref<?x?xf32>
%15 = load %2[%iv4, %iv3] : memref<?x?xf32>
// CHECK: {{%.*}} = load [[ARG22]]{{\[}}[[INDEX3]], [[INDEX4]]{{\]}}
// CHECK: {{%.*}} = load [[ARG0]]{{\[}}[[INDEX3]], [[INDEX4]]{{\]}}
%16 = load %arg0[%13, %14] : memref<?x?xf32>
%17 = mulf %15, %16 : f32
// CHECK: store {{%.*}}, [[ARG23]]{{\[}}[[INDEX3]], [[INDEX4]]{{\]}}
// CHECK: store {{%.*}}, [[ARG2]]{{\[}}[[INDEX3]], [[INDEX4]]{{\]}}
store %17, %arg2[%13, %14] : memref<?x?xf32>
}
}
@ -80,4 +80,4 @@ module {
}
return
}
}
}

View file

@ -14,7 +14,6 @@ func @foo(%arg0: memref<?xf32>, %arg1 : index) {
// CHECK: gpu.launch
// CHECK-SAME: blocks
// CHECK-SAME: threads
// CHECK-SAME: args
// Replacements of loop induction variables. Take a product with the
// step and add the lower bound.

View file

@ -30,7 +30,6 @@ func @step_1(%A : memref<?x?x?x?xf32>, %B : memref<?x?x?x?xf32>) {
// CHECK-11: gpu.launch
// CHECK-11-SAME: blocks
// CHECK-11-SAME: threads
// CHECK-11-SAME: args
// Remapping of the loop induction variables.
// CHECK-11: %[[i:.*]] = addi %{{.*}}, %{{.*}} : index
@ -57,7 +56,6 @@ func @step_1(%A : memref<?x?x?x?xf32>, %B : memref<?x?x?x?xf32>) {
// CHECK-22: gpu.launch
// CHECK-22-SAME: blocks
// CHECK-22-SAME: threads
// CHECK-22-SAME: args
// Remapping of the loop induction variables in the last mapped loop.
// CHECK-22: %[[i:.*]] = addi %{{.*}}, %{{.*}} : index

View file

@ -1,28 +0,0 @@
// RUN: mlir-opt -pass-pipeline='func(canonicalize)' %s | FileCheck %s
// CHECK-LABEL: @propagate_constant
// CHECK-SAME: %[[arg1:.*]]: memref
func @propagate_constant(%arg1: memref<?xf32>) {
// The outer constant must be preserved because it still has uses.
// CHECK: %[[outer_cst:.*]] = constant 1
%c1 = constant 1 : index
// The constant must be dropped from the args list, but the memref should
// remain.
// CHECK: gpu.launch
// CHECK-SAME: args(%[[inner_arg:.*]] = %[[arg1]]) : memref
gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
args(%x = %c1, %y = %arg1) : index, memref<?xf32> {
// The constant is propagated into the kernel body and used.
// CHECK: %[[inner_cst:.*]] = constant 1
// CHECK: "foo"(%[[inner_cst]])
"foo"(%x) : (index) -> ()
// CHECK: "bar"(%[[inner_arg]])
"bar"(%y) : (memref<?xf32>) -> ()
gpu.terminator
}
return
}

View file

@ -1,7 +1,7 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
func @not_enough_sizes(%sz : index) {
// expected-error@+1 {{expected 6 or more operands}}
// expected-error@+1 {{expected 6 operands, but found 5}}
"gpu.launch"(%sz, %sz, %sz, %sz, %sz) ({
gpu.return
}) : (index, index, index, index, index) -> ()
@ -22,59 +22,6 @@ func @no_region_attrs(%sz : index) {
// -----
func @isolation_arg(%sz : index) {
// expected-note@+1 {{required by region isolation constraints}}
"gpu.launch"(%sz, %sz, %sz, %sz, %sz, %sz) ({
^bb1(%bx: index, %by: index, %bz: index,
%tx: index, %ty: index, %tz: index,
%szbx: index, %szby: index, %szbz: index,
%sztx: index, %szty: index, %sztz: index):
// expected-error@+1 {{using value defined outside the region}}
"use"(%sz) : (index) -> ()
gpu.return
}) : (index, index, index, index, index, index) -> ()
return
}
// -----
func @isolation_op(%sz : index) {
%val = "produce"() : () -> (index)
// expected-note@+1 {{required by region isolation constraints}}
"gpu.launch"(%sz, %sz, %sz, %sz, %sz, %sz) ({
^bb1(%bx: index, %by: index, %bz: index,
%tx: index, %ty: index, %tz: index,
%szbx: index, %szby: index, %szbz: index,
%sztx: index, %szty: index, %sztz: index):
// expected-error@+1 {{using value defined outside the region}}
"use"(%val) : (index) -> ()
gpu.return
}) : (index, index, index, index, index, index) -> ()
return
}
// -----
func @nested_isolation(%sz : index) {
// expected-note@+1 {{required by region isolation constraints}}
"gpu.launch"(%sz, %sz, %sz, %sz, %sz, %sz) ({
^bb1(%bx: index, %by: index, %bz: index,
%tx: index, %ty: index, %tz: index,
%szbx: index, %szby: index, %szbz: index,
%sztx: index, %szty: index, %sztz: index):
"region"() ({
"region"() ({
// expected-error@+1 {{using value defined outside the region}}
"use"(%sz) : (index) -> ()
}) : () -> ()
}) : () -> ()
gpu.return
}) : (index, index, index, index, index, index) -> ()
return
}
// -----
func @launch_requires_gpu_return(%sz : index) {
// @expected-note@+1 {{in 'gpu.launch' body region}}
gpu.launch blocks(%bx, %by, %bz) in (%sbx = %sz, %sby = %sz, %sbz = %sz)
@ -463,4 +410,4 @@ module {
gpu.return
}
}
}
}

View file

@ -15,45 +15,11 @@ module attributes {gpu.container_module} {
// CHECK-LABEL:func @args(%{{.*}}: index, %{{.*}}: index, %{{.*}}: f32, %{{.*}}: memref<?xf32, 1>) {
func @args(%blk : index, %thrd : index, %float : f32, %data : memref<?xf32,1>) {
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) args(%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) : f32, memref<?xf32, 1>
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}})
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %blk, %grid_y = %blk, %grid_z = %blk)
threads(%tx, %ty, %tz) in (%block_x = %thrd, %block_y = %thrd, %block_z = %thrd)
args(%kernel_arg0 = %float, %kernel_arg1 = %data) : f32, memref<?xf32, 1> {
// CHECK: gpu.terminator
gpu.terminator
}
return
}
// It is possible to use values passed into the region as arguments.
// CHECK-LABEL: func @passing_values
func @passing_values(%blk : index, %thrd : index, %float : f32, %data : memref<?xf32,1>) {
// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) args(%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) : f32, memref<?xf32, 1>
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %blk, %grid_y = %blk, %grid_z = %blk)
threads(%tx, %ty, %tz) in (%block_x = %thrd, %block_y = %thrd, %block_z = %thrd)
args(%kernel_arg0 = %float, %kernel_arg1 = %data) : f32, memref<?xf32, 1> {
// CHECK: "use"(%{{.*}})
"use"(%kernel_arg0): (f32) -> ()
// CHECK: gpu.terminator
gpu.terminator
}
return
}
// It is possible to use values defined in nested regions as long as they don't
// cross kernel launch region boundaries.
// CHECK-LABEL: func @nested_isolation
func @nested_isolation(%sz : index) {
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
"region"() ({
// CHECK: %{{.*}} = "produce"()
%val = "produce"() : () -> (index)
"region"() ({
// CHECK: "use"(%{{.*}})
"use"(%val) : (index) -> ()
}) : () -> ()
}) : () -> ()
threads(%tx, %ty, %tz) in (%block_x = %thrd, %block_y = %thrd, %block_z = %thrd) {
"use"(%float) : (f32) -> ()
"use"(%data) : (memref<?xf32,1>) -> ()
// CHECK: gpu.terminator
gpu.terminator
}

View file

@ -1,4 +1,4 @@
// RUN: mlir-opt -gpu-kernel-outlining -split-input-file -verify-diagnostics %s | FileCheck %s
// RUN: mlir-opt -gpu-kernel-outlining -split-input-file -verify-diagnostics %s | FileCheck %s -dump-input-on-failure
// CHECK: module attributes {gpu.container_module}
@ -26,11 +26,10 @@ func @launch() {
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY,
%grid_z = %gDimZ)
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY,
%block_z = %bDimZ)
args(%arg0 = %0, %arg1 = %1) : f32, memref<?xf32, 1> {
"use"(%arg0): (f32) -> ()
%block_z = %bDimZ) {
"use"(%0): (f32) -> ()
"some_op"(%bx, %block_x) : (index, index) -> ()
%42 = load %arg1[%tx] : memref<?xf32, 1>
%42 = load %1[%tx] : memref<?xf32, 1>
gpu.terminator
}
return
@ -96,9 +95,8 @@ func @extra_constants(%arg0 : memref<?xf32>) {
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst,
%grid_z = %cst)
threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst,
%block_z = %cst)
args(%kernel_arg0 = %cst2, %kernel_arg1 = %arg0, %kernel_arg2 = %cst3) : index, memref<?xf32>, index {
"use"(%kernel_arg0, %kernel_arg1, %kernel_arg2) : (index, memref<?xf32>, index) -> ()
%block_z = %cst) {
"use"(%cst2, %arg0, %cst3) : (index, memref<?xf32>, index) -> ()
gpu.terminator
}
return

View file

@ -10,8 +10,7 @@ func @main() {
%sz = dim %dst, 0 : memref<?x?x?xf32>
call @mcuMemHostRegisterMemRef3dFloat(%dst) : (memref<?x?x?xf32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz)
args(%kernel_dst = %dst) : memref<?x?x?xf32> {
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz) {
%t0 = muli %tz, %block_y : index
%t1 = addi %ty, %t0 : index
%t2 = muli %t1, %block_x : index
@ -19,7 +18,7 @@ func @main() {
%t3 = index_cast %idx : index to i32
%val = sitofp %t3 : i32 to f32
%sum = "gpu.all_reduce"(%val) ({}) { op = "add" } : (f32) -> (f32)
store %sum, %kernel_dst[%tz, %ty, %tx] : memref<?x?x?xf32>
store %sum, %dst[%tz, %ty, %tx] : memref<?x?x?xf32>
gpu.terminator
}
%U = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>

View file

@ -8,8 +8,7 @@ func @main() {
%sx = dim %dst, 0 : memref<?xf32>
call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref<?xf32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one)
args(%kernel_dst = %dst) : memref<?xf32> {
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%val = index_cast %tx : index to i32
%xor = "gpu.all_reduce"(%val) ({
^bb(%lhs : i32, %rhs : i32):
@ -17,7 +16,7 @@ func @main() {
"gpu.yield"(%xor) : (i32) -> ()
}) : (i32) -> (i32)
%res = sitofp %xor : i32 to f32
store %res, %kernel_dst[%tx] : memref<?xf32>
store %res, %dst[%tx] : memref<?xf32>
gpu.terminator
}
%U = memref_cast %dst : memref<?xf32> to memref<*xf32>

View file

@ -4,9 +4,8 @@ func @other_func(%arg0 : f32, %arg1 : memref<?xf32>) {
%cst = constant 1 : index
%cst2 = dim %arg1, 0 : memref<?xf32>
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, %grid_z = %cst)
threads(%tx, %ty, %tz) in (%block_x = %cst2, %block_y = %cst, %block_z = %cst)
args(%kernel_arg0 = %arg0, %kernel_arg1 = %arg1) : f32, memref<?xf32> {
store %kernel_arg0, %kernel_arg1[%tx] : memref<?xf32>
threads(%tx, %ty, %tz) in (%block_x = %cst2, %block_y = %cst, %block_z = %cst) {
store %arg0, %arg1[%tx] : memref<?xf32>
gpu.terminator
}
return

View file

@ -8,8 +8,7 @@ func @main() {
%sx = dim %dst, 0 : memref<?xf32>
call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref<?xf32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one)
args(%kernel_dst = %dst) : memref<?xf32> {
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%t0 = index_cast %tx : index to i32
%val = sitofp %t0 : i32 to f32
%width = index_cast %block_x : index to i32
@ -20,7 +19,7 @@ func @main() {
%m1 = constant -1.0 : f32
br ^bb1(%m1 : f32)
^bb1(%value : f32):
store %value, %kernel_dst[%tx] : memref<?xf32>
store %value, %dst[%tx] : memref<?xf32>
gpu.terminator
}
%U = memref_cast %dst : memref<?xf32> to memref<*xf32>