From bbb3a6c5cbaff3430191ef4858aa16be8320ce77 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 3 Sep 2020 20:58:35 +0200 Subject: [PATCH] Finish up cleanup for PTX function support --- notcuda/src/impl/export_table.rs | 2 +- ptx/src/test/spirv_run/mod.rs | 17 ++-- ptx/src/translate.rs | 168 +++++++++++++------------------ 3 files changed, 83 insertions(+), 104 deletions(-) diff --git a/notcuda/src/impl/export_table.rs b/notcuda/src/impl/export_table.rs index afd9077..233c496 100644 --- a/notcuda/src/impl/export_table.rs +++ b/notcuda/src/impl/export_table.rs @@ -8,7 +8,7 @@ use super::{context, device, module, Decuda, Encuda}; use std::mem; use std::os::raw::{c_uint, c_ulong, c_ushort}; use std::{ - ffi::{c_void, CStr, CString}, + ffi::{c_void, CStr}, ptr, slice, }; diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 9ea0100..9f62292 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -190,14 +190,17 @@ fn test_spvtxt_assert<'a>( ptr::null_mut() ) }; - assert_eq!(result, spv_result_t::SPV_SUCCESS); - let raw_text = unsafe { - std::slice::from_raw_parts((*spv_text).str_ as *const u8, (*spv_text).length) - }; - let spv_from_ptx_text = unsafe { str::from_utf8_unchecked(raw_text) }; - // TODO: stop leaking kernel text unsafe { spirv_tools::spvContextDestroy(spv_context) }; - panic!(spv_from_ptx_text); + if result == spv_result_t::SPV_SUCCESS { + let raw_text = unsafe { + std::slice::from_raw_parts((*spv_text).str_ as *const u8, (*spv_text).length) + }; + let spv_from_ptx_text = unsafe { str::from_utf8_unchecked(raw_text) }; + // TODO: stop leaking kernel text + panic!(spv_from_ptx_text); + } else { + panic!(ptx_mod.disassemble()); + } } unsafe { spirv_tools::spvContextDestroy(spv_context) }; Ok(()) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 8cf3aca..34d8c12 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -155,7 +155,14 @@ impl TypeWordMap { } pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result { + let mut id_defs = GlobalStringIdResolver::new(1); + let ssa_functions = ast + .functions + .into_iter() + .map(|f| to_ssa_function(&mut id_defs, f)) + .collect::>(); let mut builder = dr::Builder::new(); + builder.reserve_ids(id_defs.current_id()); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module builder.set_version(1, 3); emit_capabilities(&mut builder); @@ -163,13 +170,8 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result>(); for f in ssa_functions { + emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive, &*f.args)?; emit_function_args(&mut builder, &mut map, &*f.args); emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.body)?; builder.end_function()?; @@ -177,6 +179,31 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result, + params: &[ast::Argument], +) -> Result<(), dr::Error> { + let func_type = get_function_type(builder, map, params); + let (fn_id, ret_type) = match func_directive { + ast::FunctionHeader::Kernel(name) => { + let fn_id = global.get_id(name); + builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]); + (fn_id, map.void()) + } + ast::FunctionHeader::Func(params, name) => todo!(), + }; + builder.begin_function( + ret_type, + Some(fn_id), + spirv::FunctionControl::NONE, + func_type, + )?; + Ok(()) +} + pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result, dr::Error> { let module = to_spirv_module(ast)?; Ok(module.assemble()) @@ -206,21 +233,19 @@ fn emit_memory_model(builder: &mut dr::Builder) { fn to_ssa_function<'a>( id_defs: &mut GlobalStringIdResolver<'a>, - opencl_id: spirv::Word, f: ast::ParsedFunction<'a>, ) -> ExpandedFunction<'a> { - let ids_start = id_defs.current_id(); - let fn_resolver = FnStringIdResolver::new(id_defs); + let mut fn_resolver = FnStringIdResolver::new(id_defs, f.func_directive.name()); let f_header = match f.func_directive { - ast::FunctionHeader::Kernel(name) => todo!(), - ast::FunctionHeader::Func(ret_params, name) => todo!(), + ast::FunctionHeader::Kernel(name) => ast::FunctionHeader::Kernel(name), + ast::FunctionHeader::Func(ret_params, name) => { + let name_id = fn_resolver.add_global_def(name); + let ret_ids = expand_fn_params(&mut fn_resolver, ret_params); + ast::FunctionHeader::Func(ret_ids, name_id) + } }; - let f_args = todo!(); - let f_body = Some(to_ssa( - fn_resolver, - &f.args, - f.body.unwrap_or_else(|| todo!()), - )); + let f_args = expand_fn_params(&mut fn_resolver, f.args); + let f_body = Some(to_ssa(fn_resolver, f.body.unwrap_or_else(|| Vec::new()))); ExpandedFunction { func_directive: f_header, args: f_args, @@ -228,19 +253,24 @@ fn to_ssa_function<'a>( } } -fn apply_id_offset(func_body: Vec, id_offset: u32) -> Vec { - func_body - .into_iter() - .map(|s| s.visit_variable(&mut |id| id + id_offset)) +fn expand_fn_params<'a, 'b>( + fn_resolver: &mut FnStringIdResolver<'a, 'b>, + args: Vec>>, +) -> Vec> { + args.into_iter() + .map(|a| ast::Argument { + name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))), + a_type: a.a_type, + length: a.length, + }) .collect() } fn to_ssa<'a, 'b>( mut id_defs: FnStringIdResolver<'a, 'b>, - f_args: &'b [ast::Argument>], f_body: Vec>>, ) -> Vec { - let normalized_ids = normalize_identifiers(&mut id_defs, &f_args, f_body); + let normalized_ids = normalize_identifiers(&mut id_defs, f_body); let mut numeric_id_defs = id_defs.finish(); let normalized_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs); let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut numeric_id_defs); @@ -593,7 +623,7 @@ fn insert_implicit_conversions( fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, - args: &[ast::Argument], + args: &[ast::Argument], ) -> spirv::Word { map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::from(arg.a_type))) } @@ -603,17 +633,15 @@ fn emit_function_args( map: &mut TypeWordMap, args: &[ast::Argument], ) { - let mut id = todo!(); for arg in args { let result_type = map.get_or_add_scalar(builder, arg.a_type); let inst = dr::Instruction::new( spirv::Op::FunctionParameter, Some(result_type), - Some(id), + Some(arg.name), Vec::new(), ); builder.function.as_mut().unwrap().parameters.push(inst); - id += 1; } } @@ -1095,12 +1123,8 @@ fn emit_implicit_conversion( // TODO: support scopes fn normalize_identifiers<'a, 'b>( id_defs: &mut FnStringIdResolver<'a, 'b>, - args: &[ast::Argument>], func: Vec>>, ) -> Vec> { - for arg in args { - id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type))); - } for s in func.iter() { match s { ast::Statement::Label(id) => { @@ -1180,8 +1204,8 @@ impl<'a> GlobalStringIdResolver<'a> { numeric_id } - fn reserve_id(&mut self) { - self.current_id += 1; + fn get_id(&self, id: &str) -> spirv::Word { + self.variables[id] } fn current_id(&self) -> spirv::Word { @@ -1196,7 +1220,8 @@ struct FnStringIdResolver<'a, 'b> { } impl<'a, 'b> FnStringIdResolver<'a, 'b> { - fn new(global: &'b mut GlobalStringIdResolver<'a>) -> Self { + fn new(global: &'b mut GlobalStringIdResolver<'a>, f_name: &'a str) -> Self { + global.add_def(f_name); Self { global: global, variables: vec![HashMap::new(); 1], @@ -1229,6 +1254,10 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { self.global.variables[id] } + fn add_global_def(&mut self, id: &'a str) -> spirv::Word { + self.global.add_def(id) + } + fn add_def(&mut self, id: &'a str, typ: Option) -> spirv::Word { let numeric_id = self.global.current_id; self.variables @@ -1294,25 +1323,6 @@ enum Statement { Constant(ConstantDefinition), } -impl Statement> { - fn visit_variable spirv::Word>(self, f: &mut F) -> Self { - match self { - Statement::Variable(id, t, ss, align) => Statement::Variable(f(id), t, ss, align), - Statement::LoadVar(a, t) => { - Statement::LoadVar(a.map(&mut reduced_visitor(f), Some(t)), t) - } - Statement::StoreVar(a, t) => { - Statement::StoreVar(a.map(&mut reduced_visitor(f), Some(t)), t) - } - Statement::Label(id) => Statement::Label(f(id)), - Statement::Instruction(inst) => Statement::Instruction(inst.visit_variable(f)), - Statement::Conditional(bra) => Statement::Conditional(bra.map(f)), - Statement::Conversion(conv) => Statement::Conversion(conv.map(f)), - Statement::Constant(cons) => Statement::Constant(cons.map(f)), - } - } -} - enum NormalizedArgParams {} type NormalizedStatement = Statement>; @@ -1513,18 +1523,7 @@ where } } -fn reduced_visitor<'a>( - f: &'a mut impl FnMut(spirv::Word) -> spirv::Word, -) -> impl FnMut(ArgumentDescriptor) -> spirv::Word + 'a { - move |desc| f(desc.op) -} - impl ast::Instruction { - fn visit_variable spirv::Word>(self, f: &mut F) -> Self { - let mut visitor = reduced_visitor(f); - self.map(&mut visitor) - } - fn visit_variable_extended) -> spirv::Word>( self, f: &mut F, @@ -1562,32 +1561,12 @@ struct ConstantDefinition { pub value: i128, } -impl ConstantDefinition { - fn map spirv::Word>(self, f: &mut F) -> Self { - Self { - dst: f(self.dst), - typ: self.typ, - value: self.value, - } - } -} - struct BrachCondition { predicate: spirv::Word, if_true: spirv::Word, if_false: spirv::Word, } -impl BrachCondition { - fn map spirv::Word>(self, f: &mut F) -> Self { - Self { - predicate: f(self.predicate), - if_true: f(self.if_true), - if_false: f(self.if_false), - } - } -} - struct ImplicitConversion { src: spirv::Word, dst: spirv::Word, @@ -1604,18 +1583,6 @@ enum ConversionKind { Ptr(ast::LdStateSpace), } -impl ImplicitConversion { - fn map spirv::Word>(self, f: &mut F) -> Self { - Self { - src: f(self.src), - dst: f(self.dst), - from: self.from, - to: self.to, - kind: self.kind, - } - } -} - impl ast::PredAt { fn map_variable U>(self, f: &mut F) -> ast::PredAt { ast::PredAt { @@ -2354,6 +2321,15 @@ fn insert_implicit_bitcasts( } } +impl<'a> ast::FunctionHeader<'a, ast::ParsedArgParams<'a>> { + fn name(&self) -> &'a str { + match self { + ast::FunctionHeader::Kernel(name) => name, + ast::FunctionHeader::Func(_, name) => name, + } + } +} + // CFGs below taken from "Modern Compiler Implementation in Java" #[cfg(test)] mod tests {