Finish up cleanup for PTX function support

This commit is contained in:
Andrzej Janik 2020-09-03 20:58:35 +02:00
parent de734305cf
commit bbb3a6c5cb
3 changed files with 83 additions and 104 deletions

View file

@ -8,7 +8,7 @@ use super::{context, device, module, Decuda, Encuda};
use std::mem;
use std::os::raw::{c_uint, c_ulong, c_ushort};
use std::{
ffi::{c_void, CStr, CString},
ffi::{c_void, CStr},
ptr, slice,
};

View file

@ -190,14 +190,17 @@ fn test_spvtxt_assert<'a>(
ptr::null_mut()
)
};
assert_eq!(result, spv_result_t::SPV_SUCCESS);
let raw_text = unsafe {
std::slice::from_raw_parts((*spv_text).str_ as *const u8, (*spv_text).length)
};
let spv_from_ptx_text = unsafe { str::from_utf8_unchecked(raw_text) };
// TODO: stop leaking kernel text
unsafe { spirv_tools::spvContextDestroy(spv_context) };
panic!(spv_from_ptx_text);
if result == spv_result_t::SPV_SUCCESS {
let raw_text = unsafe {
std::slice::from_raw_parts((*spv_text).str_ as *const u8, (*spv_text).length)
};
let spv_from_ptx_text = unsafe { str::from_utf8_unchecked(raw_text) };
// TODO: stop leaking kernel text
panic!(spv_from_ptx_text);
} else {
panic!(ptx_mod.disassemble());
}
}
unsafe { spirv_tools::spvContextDestroy(spv_context) };
Ok(())

View file

@ -155,7 +155,14 @@ impl TypeWordMap {
}
pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, dr::Error> {
let mut id_defs = GlobalStringIdResolver::new(1);
let ssa_functions = ast
.functions
.into_iter()
.map(|f| to_ssa_function(&mut id_defs, f))
.collect::<Vec<_>>();
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 3);
emit_capabilities(&mut builder);
@ -163,13 +170,8 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, dr::Error
let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
let mut id_defs = GlobalStringIdResolver::new(builder.id());
let ssa_functions = ast
.functions
.into_iter()
.map(|f| to_ssa_function(&mut id_defs, opencl_id, f))
.collect::<Vec<_>>();
for f in ssa_functions {
emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive, &*f.args)?;
emit_function_args(&mut builder, &mut map, &*f.args);
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.body)?;
builder.end_function()?;
@ -177,6 +179,31 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, dr::Error
Ok(builder.module())
}
fn emit_function_header(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
global: &GlobalStringIdResolver,
func_directive: ast::FunctionHeader<ExpandedArgParams>,
params: &[ast::Argument<ExpandedArgParams>],
) -> Result<(), dr::Error> {
let func_type = get_function_type(builder, map, params);
let (fn_id, ret_type) = match func_directive {
ast::FunctionHeader::Kernel(name) => {
let fn_id = global.get_id(name);
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]);
(fn_id, map.void())
}
ast::FunctionHeader::Func(params, name) => todo!(),
};
builder.begin_function(
ret_type,
Some(fn_id),
spirv::FunctionControl::NONE,
func_type,
)?;
Ok(())
}
pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result<Vec<u32>, dr::Error> {
let module = to_spirv_module(ast)?;
Ok(module.assemble())
@ -206,21 +233,19 @@ fn emit_memory_model(builder: &mut dr::Builder) {
fn to_ssa_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
opencl_id: spirv::Word,
f: ast::ParsedFunction<'a>,
) -> ExpandedFunction<'a> {
let ids_start = id_defs.current_id();
let fn_resolver = FnStringIdResolver::new(id_defs);
let mut fn_resolver = FnStringIdResolver::new(id_defs, f.func_directive.name());
let f_header = match f.func_directive {
ast::FunctionHeader::Kernel(name) => todo!(),
ast::FunctionHeader::Func(ret_params, name) => todo!(),
ast::FunctionHeader::Kernel(name) => ast::FunctionHeader::Kernel(name),
ast::FunctionHeader::Func(ret_params, name) => {
let name_id = fn_resolver.add_global_def(name);
let ret_ids = expand_fn_params(&mut fn_resolver, ret_params);
ast::FunctionHeader::Func(ret_ids, name_id)
}
};
let f_args = todo!();
let f_body = Some(to_ssa(
fn_resolver,
&f.args,
f.body.unwrap_or_else(|| todo!()),
));
let f_args = expand_fn_params(&mut fn_resolver, f.args);
let f_body = Some(to_ssa(fn_resolver, f.body.unwrap_or_else(|| Vec::new())));
ExpandedFunction {
func_directive: f_header,
args: f_args,
@ -228,19 +253,24 @@ fn to_ssa_function<'a>(
}
}
fn apply_id_offset(func_body: Vec<ExpandedStatement>, id_offset: u32) -> Vec<ExpandedStatement> {
func_body
.into_iter()
.map(|s| s.visit_variable(&mut |id| id + id_offset))
fn expand_fn_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
args: Vec<ast::Argument<ast::ParsedArgParams<'a>>>,
) -> Vec<ast::Argument<ExpandedArgParams>> {
args.into_iter()
.map(|a| ast::Argument {
name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))),
a_type: a.a_type,
length: a.length,
})
.collect()
}
fn to_ssa<'a, 'b>(
mut id_defs: FnStringIdResolver<'a, 'b>,
f_args: &'b [ast::Argument<ast::ParsedArgParams<'a>>],
f_body: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
) -> Vec<ExpandedStatement> {
let normalized_ids = normalize_identifiers(&mut id_defs, &f_args, f_body);
let normalized_ids = normalize_identifiers(&mut id_defs, f_body);
let mut numeric_id_defs = id_defs.finish();
let normalized_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut numeric_id_defs);
@ -593,7 +623,7 @@ fn insert_implicit_conversions(
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
args: &[ast::Argument<ast::ParsedArgParams>],
args: &[ast::Argument<ExpandedArgParams>],
) -> spirv::Word {
map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::from(arg.a_type)))
}
@ -603,17 +633,15 @@ fn emit_function_args(
map: &mut TypeWordMap,
args: &[ast::Argument<ExpandedArgParams>],
) {
let mut id = todo!();
for arg in args {
let result_type = map.get_or_add_scalar(builder, arg.a_type);
let inst = dr::Instruction::new(
spirv::Op::FunctionParameter,
Some(result_type),
Some(id),
Some(arg.name),
Vec::new(),
);
builder.function.as_mut().unwrap().parameters.push(inst);
id += 1;
}
}
@ -1095,12 +1123,8 @@ fn emit_implicit_conversion(
// TODO: support scopes
fn normalize_identifiers<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
args: &[ast::Argument<ast::ParsedArgParams<'a>>],
func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
) -> Vec<ast::Statement<NormalizedArgParams>> {
for arg in args {
id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type)));
}
for s in func.iter() {
match s {
ast::Statement::Label(id) => {
@ -1180,8 +1204,8 @@ impl<'a> GlobalStringIdResolver<'a> {
numeric_id
}
fn reserve_id(&mut self) {
self.current_id += 1;
fn get_id(&self, id: &str) -> spirv::Word {
self.variables[id]
}
fn current_id(&self) -> spirv::Word {
@ -1196,7 +1220,8 @@ struct FnStringIdResolver<'a, 'b> {
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
fn new(global: &'b mut GlobalStringIdResolver<'a>) -> Self {
fn new(global: &'b mut GlobalStringIdResolver<'a>, f_name: &'a str) -> Self {
global.add_def(f_name);
Self {
global: global,
variables: vec![HashMap::new(); 1],
@ -1229,6 +1254,10 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
self.global.variables[id]
}
fn add_global_def(&mut self, id: &'a str) -> spirv::Word {
self.global.add_def(id)
}
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
let numeric_id = self.global.current_id;
self.variables
@ -1294,25 +1323,6 @@ enum Statement<I> {
Constant(ConstantDefinition),
}
impl Statement<ast::Instruction<ExpandedArgParams>> {
fn visit_variable<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
match self {
Statement::Variable(id, t, ss, align) => Statement::Variable(f(id), t, ss, align),
Statement::LoadVar(a, t) => {
Statement::LoadVar(a.map(&mut reduced_visitor(f), Some(t)), t)
}
Statement::StoreVar(a, t) => {
Statement::StoreVar(a.map(&mut reduced_visitor(f), Some(t)), t)
}
Statement::Label(id) => Statement::Label(f(id)),
Statement::Instruction(inst) => Statement::Instruction(inst.visit_variable(f)),
Statement::Conditional(bra) => Statement::Conditional(bra.map(f)),
Statement::Conversion(conv) => Statement::Conversion(conv.map(f)),
Statement::Constant(cons) => Statement::Constant(cons.map(f)),
}
}
}
enum NormalizedArgParams {}
type NormalizedStatement = Statement<ast::Instruction<NormalizedArgParams>>;
@ -1513,18 +1523,7 @@ where
}
}
fn reduced_visitor<'a>(
f: &'a mut impl FnMut(spirv::Word) -> spirv::Word,
) -> impl FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word + 'a {
move |desc| f(desc.op)
}
impl ast::Instruction<ExpandedArgParams> {
fn visit_variable<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
let mut visitor = reduced_visitor(f);
self.map(&mut visitor)
}
fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
self,
f: &mut F,
@ -1562,32 +1561,12 @@ struct ConstantDefinition {
pub value: i128,
}
impl ConstantDefinition {
fn map<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
Self {
dst: f(self.dst),
typ: self.typ,
value: self.value,
}
}
}
struct BrachCondition {
predicate: spirv::Word,
if_true: spirv::Word,
if_false: spirv::Word,
}
impl BrachCondition {
fn map<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
Self {
predicate: f(self.predicate),
if_true: f(self.if_true),
if_false: f(self.if_false),
}
}
}
struct ImplicitConversion {
src: spirv::Word,
dst: spirv::Word,
@ -1604,18 +1583,6 @@ enum ConversionKind {
Ptr(ast::LdStateSpace),
}
impl ImplicitConversion {
fn map<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
Self {
src: f(self.src),
dst: f(self.dst),
from: self.from,
to: self.to,
kind: self.kind,
}
}
}
impl<T> ast::PredAt<T> {
fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::PredAt<U> {
ast::PredAt {
@ -2354,6 +2321,15 @@ fn insert_implicit_bitcasts(
}
}
impl<'a> ast::FunctionHeader<'a, ast::ParsedArgParams<'a>> {
fn name(&self) -> &'a str {
match self {
ast::FunctionHeader::Kernel(name) => name,
ast::FunctionHeader::Func(_, name) => name,
}
}
}
// CFGs below taken from "Modern Compiler Implementation in Java"
#[cfg(test)]
mod tests {