Add support for GLSL Binary ops, and use it to implement GLSL FMax.

A base class is added to implement all GLSL Binary operations and is
used to implement the FMax operation. The existing framework already
generates all the necessary (de)serialization code.

PiperOrigin-RevId: 271037166
This commit is contained in:
Mahesh Ravishankar 2019-09-24 19:41:31 -07:00 committed by A. Unique TensorFlower
parent cf00feed03
commit c5284fe85e
3 changed files with 92 additions and 8 deletions

View file

@ -37,9 +37,9 @@ class SPV_GLSLOp<string mnemonic, int opcode, list<OpTrait> traits = []> :
SPV_ExtInstOp<mnemonic, "GLSL", "GLSL.std.450", opcode, traits>;
// Base class for GLSL unary ops.
class SPV_GLSLUnaryOp<string mnemonic, Type resultType, Type operandType,
int opcode, list<OpTrait> traits = []> :
SPV_GLSLOp<mnemonic, opcode, traits> {
class SPV_GLSLUnaryOp<string mnemonic, int opcode, Type resultType,
Type operandType, list<OpTrait> traits = []> :
SPV_GLSLOp<mnemonic, opcode, !listconcat([NoSideEffect], traits)> {
let arguments = (ins
SPV_ScalarOrVectorOf<operandType>:$operand
@ -60,7 +60,34 @@ class SPV_GLSLUnaryOp<string mnemonic, Type resultType, Type operandType,
// the operand type.
class SPV_GLSLUnaryArithmaticOp<string mnemonic, int opcode, Type type,
list<OpTrait> traits = []> :
SPV_GLSLUnaryOp<mnemonic, type, type, opcode, traits>;
SPV_GLSLUnaryOp<mnemonic, opcode, type, type, traits>;
// Base class for GLSL binary ops.
class SPV_GLSLBinaryOp<string mnemonic, int opcode, Type resultType,
Type operandType, list<OpTrait> traits = []> :
SPV_GLSLOp<mnemonic, opcode, !listconcat([NoSideEffect], traits)> {
let arguments = (ins
SPV_ScalarOrVectorOf<operandType>:$lhs,
SPV_ScalarOrVectorOf<operandType>:$rhs
);
let results = (outs
SPV_ScalarOrVectorOf<resultType>:$result
);
let parser = [{ return impl::parseBinaryOp(parser, result); }];
let printer = [{ return impl::printBinaryOp(getOperation(), p); }];
let verifier = [{ return success(); }];
}
// Base class for GLSL Binary arithmatic ops where operand types and
// return type matches.
class SPV_GLSLBinaryArithmaticOp<string mnemonic, int opcode, Type type,
list<OpTrait> traits = []> :
SPV_GLSLBinaryOp<mnemonic, opcode, type, type, traits>;
// -----
@ -82,16 +109,47 @@ def SPV_GLSLExpOp : SPV_GLSLUnaryArithmaticOp<"Exp", 27, FloatOfWidths<[16, 32]>
restricted-float-scalar-vector-type ::=
restricted-float-scalar-type |
`vector<` integer-literal `x` restricted-float-scalar-type `>`
exp-op ::= ssa-id `=` `spv.glsl.Exp` ssa-use `:`
exp-op ::= ssa-id `=` `spv.GLSL.Exp` ssa-use `:`
restricted-float-scalar-vector-type
```
For example:
```
%2 = spv.glsl.Exp %0 : f32
%3 = spv.glsl.Exp %1 : vector<3xf16>
%2 = spv.GLSL.Exp %0 : f32
%3 = spv.GLSL.Exp %1 : vector<3xf16>
```
}];
}
// -----
def SPV_GLSLFMaxOp : SPV_GLSLBinaryArithmaticOp<"FMax", 40, SPV_Float> {
let summary = "Return maximum of two floating-point operands";
let description = [{
Result is y if x < y; otherwise result is x. Which operand is the
result is undefined if one of the operands is a NaN.
The operands must all be a scalar or vector whose component type
is floating-point.
Result Type and the type of all operands must be the same
type. Results are computed per component.
### Custom assembly format
``` {.ebnf}
fmax-op ::= ssa-id `=` `spv.GLSL.FMax` ssa-use `:`
restricted-float-scalar-vector-type
```
For example:
```
%2 = spv.GLSL.FMax %0, %1 : f32
%3 = spv.GLSL.FMax %0, %1 : vector<3xf16>
```
}];
}
// -----
#endif // SPIRV_GLSL_OPS

View file

@ -1,9 +1,11 @@
// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
spv.module "Logical" "GLSL450" {
func @fmul(%arg0 : f32) {
func @fmul(%arg0 : f32, %arg1 : f32) {
// CHECK: {{%.*}} = spv.GLSL.Exp {{%.*}} : f32
%0 = spv.GLSL.Exp %arg0 : f32
// CHECK: {{%.*}} = spv.GLSL.FMax {{%.*}}, {{%.*}} : f32
%1 = spv.GLSL.FMax %arg0, %arg1 : f32
spv.Return
}
}

View file

@ -47,3 +47,27 @@ func @exp(%arg0 : i32) -> () {
%2 = spv.GLSL.Exp %arg0 :
return
}
// -----
//===----------------------------------------------------------------------===//
// spv.GLSL.FMax
//===----------------------------------------------------------------------===//
func @fmax(%arg0 : f32, %arg1 : f32) -> () {
// CHECK: spv.GLSL.FMax {{%.*}}, {{%.*}} : f32
%2 = spv.GLSL.FMax %arg0, %arg1 : f32
return
}
func @fmaxvec(%arg0 : vector<3xf16>, %arg1 : vector<3xf16>) -> () {
// CHECK: spv.GLSL.FMax {{%.*}}, {{%.*}} : vector<3xf16>
%2 = spv.GLSL.FMax %arg0, %arg1 : vector<3xf16>
return
}
func @fmaxf64(%arg0 : f64, %arg1 : f64) -> () {
// CHECK: spv.GLSL.FMax {{%.*}}, {{%.*}} : f64
%2 = spv.GLSL.FMax %arg0, %arg1 : f64
return
}