Start emitting ptx module from compilation

This commit is contained in:
Andrzej Janik 2020-04-03 01:23:11 +02:00
parent e981e20aae
commit b8129aab20
6 changed files with 132 additions and 81 deletions

View file

@ -9,6 +9,7 @@ edition = "2018"
[dependencies]
lalrpop-util = "0.18.1"
regex = "1"
rspirv = "0.6"
[build-dependencies.lalrpop]
version = "0.18.1"

View file

@ -3,4 +3,16 @@ I'm convinced nobody actually uses parser generators in Rust:
* pest can't do parse actions, you have to convert your parse tree to ast manually
* lalrpop can't do comments
* and the day I wrote the line above it can
* antlr4rust is untried and requires java to build
* reports parsing errors as byte offsets
* if you want to skip parsing one of the alternatives functional design gets quite awkward
* antlr4rust is untried and requires java to build
* no library supports island grammars
What to emit?
* SPIR-V
* Better library support, easier to emit
* Can by optimized by IGC
* Can't do some things (not sure what exactly yet)
* But we can work around things with inline VISA
* VISA
* Quicker compilation

View file

@ -1,18 +1,39 @@
pub struct Module {
version: (u8, u8),
target: Target
pub struct Module<'a> {
pub version: (u8, u8),
pub functions: Vec<Function<'a>>
}
pub struct Target {
arch: String,
texturing: TexturingMode,
debug: bool,
f64_to_f32: bool
pub struct Function<'a> {
pub kernel: bool,
pub name: &'a str,
pub args: Vec<Argument>,
pub body: Vec<Statement<'a>>,
}
pub enum TexturingMode {
Unspecified,
Unified,
Independent
pub struct Argument {
}
pub enum Statement<'a> {
Label(&'a str),
Variable(Variable),
Instruction(Instruction)
}
pub struct Variable {
}
pub enum Instruction {
Ld,
Mov,
Mul,
Add,
Setp,
Not,
Bra,
Cvt,
Shl,
At,
Ret
}

View file

@ -1,9 +1,11 @@
#[macro_use]
extern crate lalrpop_util;
lalrpop_mod!(pub ptx);
lalrpop_mod!(ptx);
mod test;
mod spirv;
pub mod ast;
pub use ast::Module as Module;
pub use spirv::translate as to_spirv;

View file

@ -1,5 +1,6 @@
use std::str::FromStr;
use super::ast;
use crate::ast;
use std::convert::identity;
grammar;
@ -15,15 +16,21 @@ match {
_
}
pub Module: () = {
Version
Target
Directive*
pub Module: Option<ast::Module<'input>> = {
<v:Version> Target <f:Directive*> => v.map(|v| ast::Module { version: v, functions: f.into_iter().filter_map(identity).collect::<Vec<_>>() })
};
Version = {
".version" VersionNumber
};
Version: Option<(u8, u8)> = {
".version" <v:VersionNumber> => {
let dot = v.find('.').unwrap();
let major = v[..dot].parse::<u8>();
major.ok().and_then(|major| {
v[dot+1..].parse::<u8>().ok().map(|minor| {
(major, minor)
})
})
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-module-directives-target
Target = {
@ -38,19 +45,19 @@ TargetSpecifier = {
"map_f64_to_f32"
};
Directive : () = {
AddressSize,
Function,
File,
Section
Directive: Option<ast::Function<'input>> = {
AddressSize => None,
<f:Function> => Some(f),
File => None,
Section => None
};
AddressSize = {
".address_size" Num
};
Function: (bool, &'input str) = {
LinkingDirective* <is:IsKernel> <id:ID> "(" Comma<FunctionInput> ")" FunctionBody => (is, id)
Function: ast::Function<'input> = {
LinkingDirective* <kernel:IsKernel> <name:ID> "(" <args:Comma<FunctionInput>> ")" <body:FunctionBody> => ast::Function {<>}
};
LinkingDirective = {
@ -64,12 +71,12 @@ IsKernel: bool = {
".func" => false
};
FunctionInput = {
".param" Type ID
FunctionInput: ast::Argument = {
".param" Type ID => ast::Argument {}
};
FunctionBody = {
"{" Statement* "}"
FunctionBody: Vec<ast::Statement<'input>> = {
"{" <s:Statement*> "}" => { s.into_iter().filter_map(identity).collect() }
};
StateSpaceSpecifier = {
@ -95,14 +102,14 @@ BaseType = {
".f32", ".f64"
};
Statement: () = {
Label,
DebugDirective,
Variable ";",
Instruction ";"
Statement: Option<ast::Statement<'input>> = {
<l:Label> => Some(ast::Statement::Label(l)),
DebugDirective => None,
<v:Variable> ";" => Some(ast::Statement::Variable(v)),
<i:Instruction> ";" => Some(ast::Statement::Instruction(i))
};
DebugDirective = {
DebugDirective: () = {
DebugLocation
};
@ -111,12 +118,12 @@ DebugLocation = {
".loc" Num Num Num
};
Label = {
ID ":"
Label: &'input str = {
<id:ID> ":" => id
};
Variable = {
StateSpaceSpecifier Type VariableName
Variable: ast::Variable = {
StateSpaceSpecifier Type VariableName => ast::Variable {}
};
VariableName = {
@ -124,7 +131,7 @@ VariableName = {
ParametrizedID
};
Instruction: () = {
Instruction = {
InstLd,
InstMov,
InstMul,
@ -139,8 +146,8 @@ Instruction: () = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
InstLd = {
"ld" LdQualifier? LdStateSpace? LdCacheOperator? Vector? BaseType ID "," "[" ID "]"
InstLd: ast::Instruction = {
"ld" LdQualifier? LdStateSpace? LdCacheOperator? Vector? BaseType ID "," "[" ID "]" => ast::Instruction::Ld
};
LdQualifier: () = {
@ -171,8 +178,8 @@ LdCacheOperator = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
InstMov = {
"mov" MovType ID "," Operand
InstMov: ast::Instruction = {
"mov" MovType ID "," Operand => ast::Instruction::Mov
};
MovType = {
@ -186,12 +193,12 @@ MovType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
InstMul: () = {
"mul" MulIntControl? IntType ID "," Operand "," Operand,
"mul" RoundingMode? ".ftz"? ".sat"? ".f32" ID "," Operand "," Operand,
"mul" RoundingMode? ".f64" ID "," Operand "," Operand,
"mul" ".rn"? ".ftz"? ".sat"? ".f16" ID "," Operand "," Operand,
"mul" ".rn"? ".ftz"? ".sat"? ".f16x2" ID "," Operand "," Operand,
InstMul: ast::Instruction = {
"mul" MulIntControl? IntType ID "," Operand "," Operand => ast::Instruction::Mul,
"mul" RoundingMode? ".ftz"? ".sat"? ".f32" ID "," Operand "," Operand => ast::Instruction::Mul,
"mul" RoundingMode? ".f64" ID "," Operand "," Operand => ast::Instruction::Mul,
"mul" ".rn"? ".ftz"? ".sat"? ".f16" ID "," Operand "," Operand => ast::Instruction::Mul,
"mul" ".rn"? ".ftz"? ".sat"? ".f16x2" ID "," Operand "," Operand => ast::Instruction::Mul,
};
MulIntControl = {
@ -211,19 +218,19 @@ IntType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add
InstAdd: () = {
"add" IntType ID "," Operand "," Operand,
"add" ".sat" ".s32" ID "," Operand "," Operand,
"add" RoundingMode? ".ftz"? ".sat"? ".f32" ID "," Operand "," Operand,
"add" RoundingMode? ".f64" ID "," Operand "," Operand,
"add" ".rn"? ".ftz"? ".sat"? ".f16" ID "," Operand "," Operand,
"add" ".rn"? ".ftz"? ".sat"? ".f16x2" ID "," Operand "," Operand,
InstAdd: ast::Instruction = {
"add" IntType ID "," Operand "," Operand => ast::Instruction::Add,
"add" ".sat" ".s32" ID "," Operand "," Operand => ast::Instruction::Add,
"add" RoundingMode? ".ftz"? ".sat"? ".f32" ID "," Operand "," Operand => ast::Instruction::Add,
"add" RoundingMode? ".f64" ID "," Operand "," Operand => ast::Instruction::Add,
"add" ".rn"? ".ftz"? ".sat"? ".f16" ID "," Operand "," Operand => ast::Instruction::Add,
"add" ".rn"? ".ftz"? ".sat"? ".f16x2" ID "," Operand "," Operand => ast::Instruction::Add,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
InstSetp: () = {
"setp" SetpCmpOp ".ftz"? SetpType ID ("|" ID)? "," Operand "," Operand,
"setp" SetpCmpOp SetpBoolOp ".ftz"? SetpType ID ("|" ID)? "," Operand "," Operand "," "!"? ID
InstSetp: ast::Instruction = {
"setp" SetpCmpOp ".ftz"? SetpType ID ("|" ID)? "," Operand "," Operand => ast::Instruction::Setp,
"setp" SetpCmpOp SetpBoolOp ".ftz"? SetpType ID ("|" ID)? "," Operand "," Operand "," "!"? ID => ast::Instruction::Setp
};
SetpCmpOp = {
@ -243,8 +250,8 @@ SetpType = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not
InstNot: () = {
"not" NotType ID "," Operand
InstNot: ast::Instruction = {
"not" NotType ID "," Operand => ast::Instruction::Not
};
NotType = {
@ -252,18 +259,18 @@ NotType = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at
InstAt = {
"@" "!"? ID
InstAt: ast::Instruction = {
"@" "!"? ID => ast::Instruction::At
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra
InstBra = {
InstAt? "bra" ".uni"? ID
InstBra: ast::Instruction = {
InstAt? "bra" ".uni"? ID => ast::Instruction::Bra
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
InstCvt = {
"cvt" CvtRnd? ".ftz"? ".sat"? CvtType CvtType ID "," Operand
InstCvt: ast::Instruction = {
"cvt" CvtRnd? ".ftz"? ".sat"? CvtType CvtType ID "," Operand => ast::Instruction::Cvt
};
CvtRnd = {
@ -286,8 +293,8 @@ CvtType = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
InstShl = {
"shl" ShlType ID "," Operand "," Operand
InstShl: ast::Instruction = {
"shl" ShlType ID "," Operand "," Operand => ast::Instruction::Shl
};
ShlType = {
@ -295,8 +302,8 @@ ShlType = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
InstSt = {
"st" LdQualifier? StStateSpace? StCacheOperator? Vector? BaseType "[" ID "]" "," Operand
InstSt: ast::Instruction = {
"st" LdQualifier? StStateSpace? StCacheOperator? Vector? BaseType "[" ID "]" "," Operand => ast::Instruction::Shl
};
StStateSpace = {
@ -314,8 +321,8 @@ StCacheOperator = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
InstRet: () = {
"ret" ".uni"?
InstRet: ast::Instruction = {
"ret" ".uni"? => ast::Instruction::Ret
};
Operand: () = {
@ -384,7 +391,6 @@ Comma<T>: Vec<T> = {
String = r#""[^"]*""#;
VersionNumber = r"[0-9]+\.[0-9]+";
//Num: i128 = <s:r"[?:0x][0-9]+"> => i128::from_str(s).unwrap();
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers
ID: &'input str = <s:r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+"> => s;
DotID: &'input str = <s:r"\.[a-zA-Z][a-zA-Z0-9_$]*"> => s;

9
ptx/src/spirv.rs Normal file
View file

@ -0,0 +1,9 @@
use super::ast;
pub struct TranslateError {
}
pub fn translate(ast: ast::Module) -> Result<Vec<u32>, TranslateError> {
Ok(vec!())
}