[spirv] Add support for function calls.

Add spv.FunctionCall operation and (de)serialization.

Closes tensorflow/mlir#137

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/137 from denis0x0D:sandbox/function_call_op e2e6f07d21e7f23e8b44c7df8a8ab784f3356ce4
PiperOrigin-RevId: 269437167
This commit is contained in:
Denis Khalikov 2019-09-16 15:39:16 -07:00 committed by A. Unique TensorFlower
parent 9619ba10d4
commit 8a34d5d18c
7 changed files with 419 additions and 14 deletions

View file

@ -100,6 +100,7 @@ def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite",
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>;
def SPV_OC_OpFunctionCall : I32EnumAttrCase<"OpFunctionCall", 57>;
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
@ -161,13 +162,13 @@ def SPV_OpcodeAttr :
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd,
SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul,
SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem,
SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract,
SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect,
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
@ -1113,7 +1114,7 @@ def SPV_SamplerUseAttr:
// Check that an op can only be used within the scope of a FuncOp.
def InFunctionScope : PredOpTrait<
"op must appear in a 'func' block",
CPred<"llvm::isa_and_nonnull<FuncOp>($_op.getParentOp())">>;
CPred<"($_op.getParentOfType<FuncOp>())">>;
// Check that an op can only be used within the scope of a SPIR-V ModuleOp.
def InModuleScope : PredOpTrait<

View file

@ -151,6 +151,52 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
// -----
def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [InFunctionScope]> {
let summary = "Call a function.";
let description = [{
Result Type is the type of the return value of the function. It must be
the same as the Return Type operand of the Function Type operand of the
Function operand.
Function is an OpFunction instruction. This could be a forward
reference.
Argument N is the object to copy to parameter N of Function.
Note: A forward call is possible because there is no missing type
information: Result Type must match the Return Type of the function, and
the calling argument types must match the formal parameter types.
### Custom assembly form
``` {.ebnf}
function-call-op ::= `spv.FunctionCall` function-id `(` ssa-use-list `)`
`:` function-type
```
For example:
```
spv.FunctionCall @f_void(%arg0) : (i32) -> ()
%0 = spv.FunctionCall @f_iadd(%arg0, %arg1) : (i32, i32) -> i32
```
}];
let arguments = (ins
SymbolRefAttr:$callee,
Variadic<SPV_Type>:$arguments
);
let results = (outs
SPV_Optional<SPV_Type>:$result
);
let autogenSerialization = 0;
}
// -----
def SPV_LoopOp : SPV_Op<"loop"> {
let summary = "Define a structured loop.";

View file

@ -35,6 +35,7 @@ using namespace mlir;
// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kAlignmentAttrName[] = "alignment";
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
static constexpr const char kCallee[] = "callee";
static constexpr const char kDefaultValueAttrName[] = "default_value";
static constexpr const char kFnNameAttrName[] = "fn";
static constexpr const char kIndicesAttrName[] = "indices";
@ -912,6 +913,108 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) {
[&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); });
}
//===----------------------------------------------------------------------===//
// spv.FuncionCall
//===----------------------------------------------------------------------===//
static ParseResult parseFunctionCallOp(OpAsmParser *parser,
OperationState *state) {
SymbolRefAttr calleeAttr;
FunctionType type;
SmallVector<OpAsmParser::OperandType, 4> operands;
auto loc = parser->getNameLoc();
if (parser->parseAttribute(calleeAttr, kCallee, state->attributes) ||
parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
parser->parseColonType(type)) {
return failure();
}
auto funcType = type.dyn_cast<FunctionType>();
if (!funcType) {
return parser->emitError(loc, "expected function type, but provided ")
<< type;
}
if (funcType.getNumResults() > 1) {
return parser->emitError(loc, "expected callee function to have 0 or 1 "
"result, but provided ")
<< funcType.getNumResults();
}
return failure(parser->addTypesToList(funcType.getResults(), state->types) ||
parser->resolveOperands(operands, funcType.getInputs(), loc,
state->operands));
}
static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter *printer) {
SmallVector<Type, 4> argTypes(functionCallOp.getOperandTypes());
SmallVector<Type, 1> resultTypes(functionCallOp.getResultTypes());
Type functionType =
FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());
*printer << spirv::FunctionCallOp::getOperationName() << ' '
<< functionCallOp.getAttr(kCallee) << '(';
printer->printOperands(functionCallOp.arguments());
*printer << ") : " << functionType;
}
static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
auto fnName = functionCallOp.callee();
auto moduleOp = functionCallOp.getParentOfType<spirv::ModuleOp>();
if (!moduleOp) {
return functionCallOp.emitOpError(
"must appear in a function inside 'spv.module'");
}
auto funcOp = moduleOp.lookupSymbol<FuncOp>(fnName);
if (!funcOp) {
return functionCallOp.emitOpError("callee function '")
<< fnName << "' not found in 'spv.module'";
}
auto functionType = funcOp.getType();
if (functionCallOp.getNumResults() > 1) {
return functionCallOp.emitOpError(
"expected callee function to have 0 or 1 result, but provided ")
<< functionCallOp.getNumResults();
}
if (functionType.getNumInputs() != functionCallOp.getNumOperands()) {
return functionCallOp.emitOpError(
"has incorrect number of operands for callee: expected ")
<< functionType.getNumInputs() << ", but provided "
<< functionCallOp.getNumOperands();
}
for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
if (functionCallOp.getOperand(i)->getType() != functionType.getInput(i)) {
return functionCallOp.emitOpError(
"operand type mismatch: expected operand type ")
<< functionType.getInput(i) << ", but provided "
<< functionCallOp.getOperand(i)->getType()
<< " for operand number " << i;
}
}
if (functionType.getNumResults() != functionCallOp.getNumResults()) {
return functionCallOp.emitOpError(
"has incorrect number of results has for callee: expected ")
<< functionType.getNumResults() << ", but provided "
<< functionCallOp.getNumResults();
}
if (functionCallOp.getNumResults() &&
(functionCallOp.getResult(0)->getType() != functionType.getResult(0))) {
return functionCallOp.emitOpError("result type mismatch: expected ")
<< functionType.getResult(0) << ", but provided "
<< functionCallOp.getResult(0)->getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// spv.globalVariable
//===----------------------------------------------------------------------===//

View file

@ -128,6 +128,11 @@ private:
/// Gets the constant's attribute and type associated with the given <id>.
Optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
/// Returns a symbol to be used for the function name with the given
/// result <id>. This tries to use the function's OpName if
/// exists; otherwise creates one based on the <id>.
std::string getFunctionSymbol(uint32_t id);
/// Returns a symbol to be used for the specialization constant with the given
/// result <id>. This tries to use the specialization constant's OpName if
/// exists; otherwise creates one based on the <id>.
@ -637,10 +642,7 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
<< functionType << " and return type " << resultType << " specified";
}
std::string fnName = nameMap.lookup(operands[1]).str();
if (fnName.empty()) {
fnName = "spirv_fn_" + std::to_string(operands[2]);
}
std::string fnName = getFunctionSymbol(operands[1]);
auto funcOp = opBuilder.create<FuncOp>(unknownLoc, fnName, functionType,
ArrayRef<NamedAttribute>());
curFunction = funcMap[operands[1]] = funcOp;
@ -762,6 +764,14 @@ Optional<std::pair<Attribute, Type>> Deserializer::getConstant(uint32_t id) {
return constIt->getSecond();
}
std::string Deserializer::getFunctionSymbol(uint32_t id) {
auto funcName = nameMap.lookup(id).str();
if (funcName.empty()) {
funcName = "spirv_fn_" + std::to_string(id);
}
return funcName;
}
std::string Deserializer::getSpecConstantSymbol(uint32_t id) {
auto constName = nameMap.lookup(id).str();
if (constName.empty()) {
@ -1779,6 +1789,50 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
return success();
}
template <>
LogicalResult
Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
if (operands.size() < 3) {
return emitError(unknownLoc,
"OpFunctionCall must have at least 3 operands");
}
Type resultType = getType(operands[0]);
if (!resultType) {
return emitError(unknownLoc, "undefined result type from <id> ")
<< operands[0];
}
auto resultID = operands[1];
auto functionID = operands[2];
auto functionName = getFunctionSymbol(functionID);
llvm::SmallVector<Value *, 4> arguments;
for (auto operand : llvm::drop_begin(operands, 3)) {
auto *value = getValue(operand);
if (!value) {
return emitError(unknownLoc, "unknown <id> ")
<< operand << " used by OpFunctionCall";
}
arguments.push_back(value);
}
SmallVector<Type, 1> resultTypes;
if (!isVoidType(resultType)) {
resultTypes.push_back(resultType);
}
auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
unknownLoc, resultTypes, opBuilder.getSymbolRefAttr(functionName),
arguments);
if (!resultTypes.empty()) {
valueMap[resultID] = opFunctionCall.getResult(0);
}
return success();
}
// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
// various Deserializer::processOp<...>() specializations.
#define GET_DESERIALIZATION_FNS

View file

@ -131,6 +131,10 @@ private:
return funcIDMap.lookup(fnName);
}
/// Gets the <id> for the function with the given name. Assigns the next
/// available <id> if the function haven't been deserialized.
uint32_t getOrCreateFunctionID(StringRef fnName);
void processCapability();
void processExtension();
@ -392,6 +396,15 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
// Module structure
//===----------------------------------------------------------------------===//
uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
auto funcID = funcIDMap.lookup(fnName);
if (!funcID) {
funcID = getNextID();
funcIDMap[fnName] = funcID;
}
return funcID;
}
void Serializer::processCapability() {
auto caps = module.getAttrOfType<ArrayAttr>("capabilities");
if (!caps)
@ -537,8 +550,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
return failure();
}
operands.push_back(resTypeID);
auto funcID = getNextID();
funcIDMap[op.getName()] = funcID;
auto funcID = getOrCreateFunctionID(op.getName());
operands.push_back(funcID);
// TODO : Support other function control options.
operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None));
@ -1461,6 +1473,37 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
operands);
}
template <>
LogicalResult
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
auto funcName = op.callee();
uint32_t resTypeID = 0;
llvm::SmallVector<Type, 1> resultTypes(op.getResultTypes());
if (failed(processType(op.getLoc(),
(resultTypes.empty() ? getVoidType() : resultTypes[0]),
resTypeID))) {
return failure();
}
auto funcID = getOrCreateFunctionID(funcName);
auto funcCallID = getNextID();
SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
for (auto *value : op.arguments()) {
auto valueID = findValueID(value);
assert(valueID && "cannot find a value for spv.FunctionCall");
operands.push_back(valueID);
}
if (!resultTypes.empty()) {
valueIDMap[op.getResult(0)] = funcCallID;
}
return encodeInstructionInto(functions, spirv::Opcode::OpFunctionCall,
operands);
}
// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
// various Serializer::processOp<...>() specializations.
#define GET_SERIALIZATION_FNS

View file

@ -0,0 +1,53 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
spv.module "Logical" "GLSL450" {
spv.globalVariable @var1 : !spv.ptr<!spv.array<4xf32>, Input>
func @fmain() -> i32 {
%0 = spv.constant 16 : i32
%1 = spv._address_of @var1 : !spv.ptr<!spv.array<4xf32>, Input>
// CHECK: {{%.*}} = spv.FunctionCall @f_0({{%.*}}) : (i32) -> i32
%3 = spv.FunctionCall @f_0(%0) : (i32) -> i32
// CHECK: spv.FunctionCall @f_1({{%.*}}, {{%.*}}) : (i32, !spv.ptr<!spv.array<4 x f32>, Input>) -> ()
spv.FunctionCall @f_1(%3, %1) : (i32, !spv.ptr<!spv.array<4xf32>, Input>) -> ()
// CHECK: {{%.*}} = spv.FunctionCall @f_2({{%.*}}) : (!spv.ptr<!spv.array<4 x f32>, Input>) -> !spv.ptr<!spv.array<4 x f32>, Input>
%4 = spv.FunctionCall @f_2(%1) : (!spv.ptr<!spv.array<4xf32>, Input>) -> !spv.ptr<!spv.array<4xf32>, Input>
spv.ReturnValue %3 : i32
}
func @f_0(%arg0 : i32) -> i32 {
spv.ReturnValue %arg0 : i32
}
func @f_1(%arg0 : i32, %arg1 : !spv.ptr<!spv.array<4xf32>, Input>) -> () {
spv.Return
}
func @f_2(%arg0 : !spv.ptr<!spv.array<4xf32>, Input>) -> !spv.ptr<!spv.array<4xf32>, Input> {
spv.ReturnValue %arg0 : !spv.ptr<!spv.array<4xf32>, Input>
}
func @f_loop_with_function_call(%count : i32) -> () {
%zero = spv.constant 0: i32
%var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
spv.loop {
spv.Branch ^header
^header:
%val0 = spv.Load "Function" %var : i32
%cmp = spv.SLessThan %val0, %count : i32
spv.BranchConditional %cmp, ^body, ^merge
^body:
spv.Branch ^continue
^continue:
// CHECK: spv.FunctionCall @f_inc({{%.*}}) : (!spv.ptr<i32, Function>) -> ()
spv.FunctionCall @f_inc(%var) : (!spv.ptr<i32, Function>) -> ()
spv.Branch ^header
^merge:
spv._merge
}
spv.Return
}
func @f_inc(%arg0 : !spv.ptr<i32, Function>) -> () {
%one = spv.constant 1 : i32
%0 = spv.Load "Function" %arg0 : i32
%1 = spv.IAdd %0, %one : i32
spv.Store "Function" %arg0, %1 : i32
spv.Return
}
}

View file

@ -144,6 +144,111 @@ func @weights_cannot_both_be_zero() -> () {
// -----
//===----------------------------------------------------------------------===//
// spv.FunctionCall
//===----------------------------------------------------------------------===//
spv.module "Logical" "GLSL450" {
func @fmain(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>, %arg2 : i32) -> i32 {
// CHECK: {{%.*}} = spv.FunctionCall @f_0({{%.*}}, {{%.*}}) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
%0 = spv.FunctionCall @f_0(%arg0, %arg1) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
// CHECK: spv.FunctionCall @f_1({{%.*}}, {{%.*}}) : (vector<4xf32>, vector<4xf32>) -> ()
spv.FunctionCall @f_1(%0, %arg1) : (vector<4xf32>, vector<4xf32>) -> ()
// CHECK: spv.FunctionCall @f_2() : () -> ()
spv.FunctionCall @f_2() : () -> ()
// CHECK: {{%.*}} = spv.FunctionCall @f_3({{%.*}}) : (i32) -> i32
%1 = spv.FunctionCall @f_3(%arg2) : (i32) -> i32
spv.ReturnValue %1 : i32
}
func @f_0(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> (vector<4xf32>) {
spv.ReturnValue %arg0 : vector<4xf32>
}
func @f_1(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> () {
spv.Return
}
func @f_2() -> () {
spv.Return
}
func @f_3(%arg0 : i32) -> (i32) {
spv.ReturnValue %arg0 : i32
}
}
// -----
spv.module "Logical" "GLSL450" {
func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () {
// expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}}
%0 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32)
spv.Return
}
}
// -----
spv.module "Logical" "GLSL450" {
func @f_result_type_mismatch(%arg0 : i32, %arg1 : i32) -> () {
// expected-error @+1 {{has incorrect number of results has for callee: expected 0, but provided 1}}
%1 = spv.FunctionCall @f_result_type_mismatch(%arg0, %arg0) : (i32, i32) -> (i32)
spv.Return
}
}
// -----
spv.module "Logical" "GLSL450" {
func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () {
// expected-error @+1 {{has incorrect number of operands for callee: expected 2, but provided 1}}
spv.FunctionCall @f_type_mismatch(%arg0) : (i32) -> ()
spv.Return
}
}
// -----
spv.module "Logical" "GLSL450" {
func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () {
%0 = spv.constant 2.0 : f32
// expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32' for operand number 1}}
spv.FunctionCall @f_type_mismatch(%arg0, %0) : (i32, f32) -> ()
spv.Return
}
}
// -----
spv.module "Logical" "GLSL450" {
func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> i32 {
// expected-error @+1 {{result type mismatch: expected 'i32', but provided 'f32'}}
%0 = spv.FunctionCall @f_type_mismatch(%arg0, %arg0) : (i32, i32) -> f32
spv.Return
}
}
// -----
spv.module "Logical" "GLSL450" {
func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 {
// expected-error @+1 {{op callee function 'f_undefined' not found in 'spv.module'}}
%0 = spv.FunctionCall @f_undefined(%arg0, %arg0) : (i32, i32) -> i32
spv.Return
}
}
// -----
func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 {
// expected-error @+1 {{must appear in a function inside 'spv.module'}}
%0 = spv.FunctionCall @f_foo(%arg0, %arg0) : (i32, i32) -> i32
spv.Return
}
// -----
//===----------------------------------------------------------------------===//
// spv.loop
//===----------------------------------------------------------------------===//