diff --git a/ptx/src/test/spirv_run/local_align.ptx b/ptx/src/test/spirv_run/local_align.ptx new file mode 100644 index 0000000..6e10de3 --- /dev/null +++ b/ptx/src/test/spirv_run/local_align.ptx @@ -0,0 +1,21 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry local_align( + .param .u64 input, + .param .u64 output +) +{ + .local .align 8 .b8 __local_depot0[8]; + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + st.u64 [out_addr], temp; + ret; +} diff --git a/ptx/src/test/spirv_run/local_align.spvtxt b/ptx/src/test/spirv_run/local_align.spvtxt new file mode 100644 index 0000000..beefb76 --- /dev/null +++ b/ptx/src/test/spirv_run/local_align.spvtxt @@ -0,0 +1,38 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %5 "local_align" + OpDecorate %8 Alignment 8 + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %4 = OpTypeFunction %void %ulong %ulong + %uchar = OpTypeInt 8 0 +%_arr_uchar_8 = OpTypeArray %uchar %8 +%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8 +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %5 = OpFunction %void None %4 + %6 = OpFunctionParameter %ulong + %7 = OpFunctionParameter %ulong + %18 = OpLabel + %8 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup + %9 = OpVariable %_ptr_Function_ulong Function + %10 = OpVariable %_ptr_Function_ulong Function + %11 = OpVariable %_ptr_Function_ulong Function + OpStore %9 %6 + OpStore %10 %7 + %13 = OpLoad %ulong %9 + %16 = OpConvertUToPtr %_ptr_Generic_ulong %13 + %12 = OpLoad %ulong %16 + OpStore %11 %12 + %14 = OpLoad %ulong %10 + %15 = OpLoad %ulong %11 + %17 = OpConvertUToPtr %_ptr_Generic_ulong %14 + OpStore %17 %15 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 23852a1..8883669 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -51,6 +51,7 @@ test_ptx!(shl, [11u64], [44u64]); test_ptx!(cvt_sat_s_u, [-1i32], [0i32]); test_ptx!(cvta, [3.0f32], [3.0f32]); test_ptx!(block, [1u64], [2u64]); +test_ptx!(local_align, [1u64], [1u64]); struct DisplayError { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 3fe01cf..642e6ec 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -5,20 +5,21 @@ use std::{borrow::Cow, iter, mem}; use rspirv::binary::Assemble; -#[derive(PartialEq, Eq, Hash, Clone, Copy)] +#[derive(PartialEq, Eq, Hash, Clone)] enum SpirvType { Base(SpirvScalarKey), - Pointer(SpirvScalarKey, spirv::StorageClass), + Array(SpirvScalarKey, u32), + Pointer(Box, spirv::StorageClass), } impl SpirvType { fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self { let key = match t { - ast::Type::Scalar(typ) => SpirvScalarKey::from(typ), - ast::Type::ExtendedScalar(typ) => SpirvScalarKey::from(typ), - ast::Type::Array(_, _) => todo!(), + ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)), + ast::Type::ExtendedScalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)), + ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len), }; - SpirvType::Pointer(key, sc) + SpirvType::Pointer(Box::new(key), sc) } } @@ -27,7 +28,7 @@ impl From for SpirvType { match t { ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()), - ast::Type::Array(_, _) => todo!(), + ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), } } } @@ -126,13 +127,20 @@ impl TypeWordMap { fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { match t { SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key), - SpirvType::Pointer(typ, storage) => { - let base = self.get_or_add_spirv_scalar(b, typ); + SpirvType::Pointer(ref typ, storage) => { + let base = self.get_or_add(b, *typ.clone()); *self .complex .entry(t) .or_insert_with(|| b.type_pointer(None, storage, base)) } + SpirvType::Array(typ, len) => { + let base = self.get_or_add_spirv_scalar(b, typ); + *self + .complex + .entry(t) + .or_insert_with(|| b.type_array(base, len)) + } } } @@ -248,7 +256,7 @@ fn normalize_labels( labels_in_use.insert(cond.if_true); labels_in_use.insert(cond.if_false); } - Statement::Variable(_, _, _) + Statement::Variable(_, _, _, _) | Statement::LoadVar(_, _) | Statement::StoreVar(_, _) | Statement::Conversion(_) @@ -298,9 +306,9 @@ fn normalize_predicates( result.push(Statement::Instruction(inst)); } } - ast::Statement::Variable(var) => { - result.push(Statement::Variable(var.name, var.v_type, var.space)) - } + ast::Statement::Variable(var) => result.push(Statement::Variable( + var.name, var.v_type, var.space, var.align, + )), // Blocks are flattened when resolving ids ast::Statement::Block(_) => unreachable!(), } @@ -373,7 +381,7 @@ fn insert_mem_ssa_statements( bra.predicate = generated_id; result.push(Statement::Conditional(bra)); } - s @ Statement::Variable(_, _, _) | s @ Statement::Label(_) => result.push(s), + s @ Statement::Variable(_, _, _, _) | s @ Statement::Label(_) => result.push(s), Statement::LoadVar(_, _) | Statement::StoreVar(_, _) | Statement::Conversion(_) @@ -395,7 +403,9 @@ fn expand_arguments( let new_inst = inst.map(&mut visitor); result.push(Statement::Instruction(new_inst)); } - Statement::Variable(id, typ, ss) => result.push(Statement::Variable(id, typ, ss)), + Statement::Variable(id, typ, ss, align) => { + result.push(Statement::Variable(id, typ, ss, align)) + } Statement::Label(id) => result.push(Statement::Label(id)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), @@ -555,7 +565,7 @@ fn insert_implicit_conversions( s @ Statement::Conditional(_) | s @ Statement::Label(_) | s @ Statement::Constant(_) - | s @ Statement::Variable(_, _, _) + | s @ Statement::Variable(_, _, _, _) | s @ Statement::LoadVar(_, _) | s @ Statement::StoreVar(_, _) => result.push(s), Statement::Conversion(_) => unreachable!(), @@ -614,15 +624,24 @@ fn emit_function_body_ops( } match s { Statement::Label(_) => (), - Statement::Variable(id, typ, ss) => { + Statement::Variable(id, typ, ss, align) => { let type_id = map.get_or_add( builder, SpirvType::new_pointer(*typ, spirv::StorageClass::Function), ); - if *ss != ast::StateSpace::Reg { - todo!() + let st_class = match ss { + ast::StateSpace::Reg => spirv::StorageClass::Function, + ast::StateSpace::Local => spirv::StorageClass::Workgroup, + _ => todo!(), + }; + builder.variable(type_id, Some(*id), st_class, None); + if let Some(align) = align { + builder.decorate( + *id, + spirv::Decoration::Alignment, + &[dr::Operand::LiteralInt32(*align)], + ); } - builder.variable(type_id, Some(*id), spirv::StorageClass::Function, None); } Statement::Constant(cnst) => { let typ_id = map.get_or_add_scalar(builder, cnst.typ); @@ -1006,7 +1025,10 @@ fn emit_implicit_conversion( ConversionKind::Ptr(space) => { let dst_type = map.get_or_add( builder, - SpirvType::Pointer(SpirvScalarKey::from(to_type), space.to_spirv()), + SpirvType::Pointer( + Box::new(SpirvType::Base(SpirvScalarKey::from(to_type))), + space.to_spirv(), + ), ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } @@ -1221,7 +1243,7 @@ impl NumericIdResolver { } enum Statement { - Variable(spirv::Word, ast::Type, ast::StateSpace), + Variable(spirv::Word, ast::Type, ast::StateSpace, Option), LoadVar(ast::Arg2, ast::Type), StoreVar(ast::Arg2St, ast::Type), Label(u32), @@ -1235,7 +1257,7 @@ enum Statement { impl Statement> { fn visit_variable spirv::Word>(self, f: &mut F) -> Self { match self { - Statement::Variable(id, t, ss) => Statement::Variable(f(id), t, ss), + Statement::Variable(id, t, ss, align) => Statement::Variable(f(id), t, ss, align), Statement::LoadVar(a, t) => { Statement::LoadVar(a.map(&mut reduced_visitor(f), Some(t)), t) }