diff --git a/ptx/src/test/spirv_run/b64tof64.ptx b/ptx/src/test/spirv_run/b64tof64.ptx new file mode 100644 index 0000000..de028c7 --- /dev/null +++ b/ptx/src/test/spirv_run/b64tof64.ptx @@ -0,0 +1,25 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry b64tof64( + .param .u64 input, + .param .u64 output +) +{ + .reg .f64 in_addr_f; + .reg .b64 in_addr; + .reg .u64 out_addr; + + .reg.u64 temp; + + ld.param.f64 in_addr_f, [input]; + ld.param.u64 out_addr, [output]; + + mov.b64 in_addr, in_addr_f; + + ld.u64 temp, [in_addr]; + st.u64 [out_addr], temp; + + ret; +} diff --git a/ptx/src/test/spirv_run/b64tof64.spvtxt b/ptx/src/test/spirv_run/b64tof64.spvtxt new file mode 100644 index 0000000..9146c90 --- /dev/null +++ b/ptx/src/test/spirv_run/b64tof64.spvtxt @@ -0,0 +1,50 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + OpCapability Float64 + %26 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "b64tof64" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %29 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %double = OpTypeFloat 64 +%_ptr_Function_double = OpTypePointer Function %double +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %1 = OpFunction %void None %29 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %24 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_double Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %20 = OpBitcast %double %11 + %10 = OpCopyObject %double %20 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %6 %12 + %15 = OpLoad %double %4 + %21 = OpBitcast %ulong %15 + %14 = OpCopyObject %ulong %21 + OpStore %5 %14 + %17 = OpLoad %ulong %5 + %22 = OpConvertUToPtr %_ptr_Generic_ulong %17 + %16 = OpLoad %ulong %22 + OpStore %7 %16 + %18 = OpLoad %ulong %6 + %19 = OpLoad %ulong %7 + %23 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %23 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/implicit_param.ptx b/ptx/src/test/spirv_run/implicit_param.ptx new file mode 100644 index 0000000..1d46bc1 --- /dev/null +++ b/ptx/src/test/spirv_run/implicit_param.ptx @@ -0,0 +1,24 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry implicit_param( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 temp; + .param .b32 temp_param; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.global.f32 temp, [in_addr]; + st.param.f32 [temp_param], temp; + ld.param.f32 temp, [temp_param]; + st.global.f32 [out_addr], temp; + + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/implicit_param.spvtxt b/ptx/src/test/spirv_run/implicit_param.spvtxt new file mode 100644 index 0000000..c30788c --- /dev/null +++ b/ptx/src/test/spirv_run/implicit_param.spvtxt @@ -0,0 +1,55 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + OpCapability Float64 + %28 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "implicit_param" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %31 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float + %1 = OpFunction %void None %31 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %26 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_float Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15 + %14 = OpLoad %float %22 + OpStore %6 %14 + %17 = OpLoad %float %6 + %23 = OpCopyObject %float %17 + %16 = OpBitcast %uint %23 + OpStore %7 %16 + %19 = OpLoad %uint %7 + %24 = OpBitcast %float %19 + %18 = OpCopyObject %float %24 + OpStore %6 %18 + %20 = OpLoad %ulong %5 + %21 = OpLoad %float %6 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %20 + OpStore %25 %21 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt index e7dba5a..249af90 100644 --- a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt +++ b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt @@ -4,6 +4,7 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 + OpCapability Float64 %23 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "ld_st_implicit" @@ -34,14 +35,14 @@ %14 = OpLoad %ulong %4 %17 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14 %18 = OpLoad %float %17 - %30 = OpBitcast %ulong %18 - %32 = OpUConvert %uint %30 - %13 = OpBitcast %uint %32 + %31 = OpBitcast %uint %18 + %13 = OpUConvert %ulong %31 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %ulong %6 - %33 = OpBitcast %uint %16 - %19 = OpUConvert %ulong %33 + %32 = OpBitcast %ulong %16 + %33 = OpUConvert %uint %32 + %19 = OpBitcast %float %33 %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15 OpStore %20 %19 OpReturn diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 4c9d779..06843f0 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -60,8 +60,10 @@ test_ptx!(call, [1u64], [2u64]); test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]); test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]); test_ptx!(ntid, [3u32], [4u32]); -test_ptx!(reg_local, [12u64], [12u64]); +test_ptx!(reg_local, [12u64], [13u64]); test_ptx!(mov_address, [0xDEADu64], [0u64]); +test_ptx!(b64tof64, [111u64], [111u64]); +test_ptx!(implicit_param, [34u32], [34u32]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/reg_local.ptx b/ptx/src/test/spirv_run/reg_local.ptx index fb234d8..f09b95a 100644 --- a/ptx/src/test/spirv_run/reg_local.ptx +++ b/ptx/src/test/spirv_run/reg_local.ptx @@ -17,7 +17,7 @@ ld.param.u64 out_addr, [output]; ld.global.u64 temp, [in_addr]; - st.u64 [local_x], temp; + st.u64 [local_x], temp + 1; ld.u64 temp, [local_x]; st.global.u64 [out_addr], temp; ret; diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt index 6810fec..2d6bd08 100644 --- a/ptx/src/test/spirv_run/reg_local.spvtxt +++ b/ptx/src/test/spirv_run/reg_local.spvtxt @@ -4,43 +4,60 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 - %25 = OpExtInstImport "OpenCL.std" + %35 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "add" + OpEntryPoint Kernel %1 "reg_local" + OpDecorate %4 Alignment 8 %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %28 = OpTypeFunction %void %ulong %ulong + %38 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %uchar = OpTypeInt 8 0 + %uint = OpTypeInt 32 0 + %uint_8 = OpConstant %uint 8 +%_arr_uchar_uint_8 = OpTypeArray %uchar %uint_8 +%_ptr_Function__arr_uchar_uint_8 = OpTypePointer Function %_arr_uchar_uint_8 +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %ulong_1 = OpConstant %ulong 1 - %1 = OpFunction %void None %28 - %8 = OpFunctionParameter %ulong + %1 = OpFunction %void None %38 %9 = OpFunctionParameter %ulong - %23 = OpLabel + %10 = OpFunctionParameter %ulong + %33 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function__arr_uchar_uint_8 Function %5 = OpVariable %_ptr_Function_ulong Function %6 = OpVariable %_ptr_Function_ulong Function %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %11 = OpLoad %ulong %2 - %10 = OpCopyObject %ulong %11 - OpStore %4 %10 - %13 = OpLoad %ulong %3 - %12 = OpCopyObject %ulong %13 - OpStore %5 %12 - %15 = OpLoad %ulong %4 - %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 - %14 = OpLoad %ulong %21 - OpStore %6 %14 - %17 = OpLoad %ulong %6 - %16 = OpIAdd %ulong %17 %ulong_1 - OpStore %7 %16 - %18 = OpLoad %ulong %5 - %19 = OpLoad %ulong %7 - %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 - OpStore %22 %19 + %8 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %9 + OpStore %3 %10 + %12 = OpLoad %ulong %2 + %11 = OpCopyObject %ulong %12 + OpStore %5 %11 + %14 = OpLoad %ulong %3 + %13 = OpCopyObject %ulong %14 + OpStore %6 %13 + %16 = OpLoad %ulong %5 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 + %26 = OpLoad %ulong %25 + %15 = OpCopyObject %ulong %26 + OpStore %7 %15 + %18 = OpLoad %ulong %7 + %27 = OpCopyObject %ulong %18 + %24 = OpIAdd %ulong %27 %ulong_1 + %28 = OpCopyObject %ulong %24 + %17 = OpBitcast %ulong %28 + OpStore %4 %17 + %20 = OpLoad %_arr_uchar_uint_8 %4 + %29 = OpBitcast %ulong %20 + %30 = OpCopyObject %ulong %29 + %19 = OpCopyObject %ulong %30 + OpStore %7 %19 + %21 = OpLoad %ulong %6 + %22 = OpLoad %ulong %7 + %31 = OpCopyObject %ulong %22 + %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21 + OpStore %32 %31 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7726040..22e16ff 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -299,6 +299,7 @@ fn emit_function_header<'a>( pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result, TranslateError> { let module = to_spirv_module(ast)?; + eprintln!("{}", module.disassemble()); Ok(module.assemble()) } @@ -309,6 +310,7 @@ fn emit_capabilities(builder: &mut dr::Builder) { builder.capability(spirv::Capability::Kernel); builder.capability(spirv::Capability::Int64); builder.capability(spirv::Capability::Int8); + builder.capability(spirv::Capability::Float64); } fn emit_extensions(_: &mut dr::Builder) {} @@ -990,8 +992,13 @@ fn insert_implicit_conversions( Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call)?, Statement::Instruction(inst) => match inst { ast::Instruction::Ld(ld, arg) => { - let pre_conv = - get_implicit_conversions_ld_src(id_def, ld.typ, ld.state_space, arg.src)?; + let pre_conv = get_implicit_conversions_ld_src( + id_def, + ld.typ, + ld.state_space, + arg.src, + false, + )?; let post_conv = get_implicit_conversions_ld_dst( id_def, ld.typ, @@ -1024,8 +1031,11 @@ fn insert_implicit_conversions( st.typ, st.state_space.to_ld_ss(), arg.src1, + true, )?; - let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param { + let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param + || st.state_space == ast::StStateSpace::Local + { (Vec::new(), post_conv) } else { (post_conv, Vec::new()) @@ -1667,7 +1677,7 @@ fn emit_implicit_conversion( } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { if from_parts.width == to_parts.width { - let dst_type = map.get_or_add(builder, SpirvType::from(cv.from)); + let dst_type = map.get_or_add(builder, SpirvType::from(cv.to)); if from_parts.scalar_kind != ScalarKind::Float && to_parts.scalar_kind != ScalarKind::Float { @@ -1714,7 +1724,7 @@ fn emit_implicit_conversion( } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(), (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) - | (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default) + | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default) | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { let into_type = map.get_or_add(builder, SpirvType::from(cv.to)); builder.bitcast(into_type, Some(cv.dst), cv.src)?; @@ -2409,7 +2419,8 @@ impl ast::Instruction { ast::Instruction::Call(_) => unreachable!(), ast::Instruction::Ld(d, a) => { let inst_type = d.typ; - let is_param = d.state_space == ast::LdStateSpace::Param; + let is_param = d.state_space == ast::LdStateSpace::Param + || d.state_space == ast::LdStateSpace::Local; ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, is_param)?) } ast::Instruction::Mov(d, a) => { @@ -2432,7 +2443,9 @@ impl ast::Instruction { let inst_type = d.typ; ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type))?) } - ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, false, t.to_type())?), + ast::Instruction::Not(t, a) => { + ast::Instruction::Not(t, a.map(visitor, false, t.to_type())?) + } ast::Instruction::Cvt(d, a) => { let (dst_t, src_t) = match &d { ast::CvtDetails::FloatFromFloat(desc) => ( @@ -2459,7 +2472,8 @@ impl ast::Instruction { } ast::Instruction::St(d, a) => { let inst_type = d.typ; - let is_param = d.state_space == ast::StStateSpace::Param; + let is_param = d.state_space == ast::StStateSpace::Param + || d.state_space == ast::StStateSpace::Local; ast::Instruction::St(d, a.map(visitor, inst_type, is_param)?) } ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?), @@ -3419,8 +3433,8 @@ fn get_implicit_conversions_ld_dst< Ok(Some(ImplicitConversion { src: u32::max_value(), dst: u32::max_value(), - from: if !in_reverse { dst_type } else { instr_type }, - to: if !in_reverse { instr_type } else { dst_type }, + from: if !in_reverse { instr_type } else { dst_type }, + to: if !in_reverse { dst_type } else { instr_type }, kind: conv, })) } else { @@ -3433,6 +3447,7 @@ fn get_implicit_conversions_ld_src( instr_type: ast::Type, state_space: ast::LdStateSpace, src: spirv::Word, + in_reverse_param_local: bool, ) -> Result, TranslateError> { let src_type = id_def.get_typed(src)?; match state_space { @@ -3442,8 +3457,16 @@ fn get_implicit_conversions_ld_src( ImplicitConversion { src: u32::max_value(), dst: u32::max_value(), - from: src_type, - to: instr_type, + from: if !in_reverse_param_local { + src_type + } else { + instr_type + }, + to: if !in_reverse_param_local { + instr_type + } else { + src_type + }, kind: ConversionKind::Default, }; 1 @@ -3512,32 +3535,6 @@ fn insert_conversion_src( temp_src } -/* -fn insert_with_implicit_conversion_dst< - T, - ShouldConvert: FnOnce(ast::StateSpace, ast::Type, ast::Type) -> Option, - Setter: Fn(&mut T) -> &mut spirv::Word, - ToInstruction: FnOnce(T) -> ast::Instruction, ->( - func: &mut Vec, - instr_type: ast::Type, - id_def: &mut NumericIdResolver, - should_convert: ShouldConvert, - mut t: T, - setter: Setter, - to_inst: ToInstruction, -) { - let dst = setter(&mut t); - let dst_type = id_def.get_type(*dst); - let dst_coercion = should_convert(dst_type.unwrap(), instr_type) - .map(|conv| get_conversion_dst(id_def, dst, instr_type, dst_type.unwrap(), conv)); - func.push(Statement::Instruction(to_inst(t))); - if let Some(conv) = dst_coercion { - func.push(conv); - } -} -*/ - #[must_use] fn get_conversion_dst( id_def: &mut MutableNumericIdResolver,