diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index 253ba4b..c56321a 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -569,6 +569,29 @@ impl<'a> CommandList<'a> { Ok(()) } + pub unsafe fn append_memory_fill_unsafe( + &mut self, + dst: *mut c_void, + pattern: &T, + byte_size: usize, + signal: Option<&mut Event<'a>>, + wait: &mut [Event<'a>], + ) -> Result<()> { + let signal_event = signal.map(|e| e.0).unwrap_or(ptr::null_mut()); + let (wait_len, wait_ptr) = Event::raw_slice(wait); + check!(sys::zeCommandListAppendMemoryFill( + self.0, + dst, + pattern as *const T as *const _, + mem::size_of::(), + byte_size, + signal_event, + wait_len, + wait_ptr + )); + Ok(()) + } + pub fn append_launch_kernel( &mut self, kernel: &'a Kernel, diff --git a/notcuda/src/cuda.rs b/notcuda/src/cuda.rs index ea7fc4b..a957ba0 100644 --- a/notcuda/src/cuda.rs +++ b/notcuda/src/cuda.rs @@ -2932,7 +2932,7 @@ pub extern "C" fn cuMemsetD32_v2( ui: ::std::os::raw::c_uint, N: usize, ) -> CUresult { - r#impl::unimplemented() + r#impl::memory::set_d32_v2(dstDevice.decuda(), ui, N).encuda() } #[cfg_attr(not(test), no_mangle)] diff --git a/notcuda/src/impl/memory.rs b/notcuda/src/impl/memory.rs index 62dc1cc..1e7dcb7 100644 --- a/notcuda/src/impl/memory.rs +++ b/notcuda/src/impl/memory.rs @@ -1,5 +1,5 @@ use super::{stream, CUresult, GlobalState}; -use std::ffi::c_void; +use std::{ffi::c_void, mem}; pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> { let ptr = GlobalState::lock_current_context(|ctx| { @@ -27,6 +27,17 @@ pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> { .map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)? } +pub(crate) fn set_d32_v2(dst: *mut c_void, ui: u32, n: usize) -> Result<(), CUresult> { + GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| { + let mut cmd_list = stream.command_list()?; + unsafe { + cmd_list.append_memory_fill_unsafe(dst, &ui, mem::size_of::() * n, None, &mut []) + }?; + stream.queue.execute(cmd_list)?; + Ok::<_, CUresult>(()) + })? +} + #[cfg(test)] mod test { use super::super::test::CudaDriverFns; diff --git a/ptx/lib/notcuda_ptx_impl.cl b/ptx/lib/notcuda_ptx_impl.cl index 4249f2b..c633ddc 100644 --- a/ptx/lib/notcuda_ptx_impl.cl +++ b/ptx/lib/notcuda_ptx_impl.cl @@ -120,22 +120,27 @@ atomic_dec(atom_acquire_sys_shared_dec, memory_order_acquire, memory_order_acqui atomic_dec(atom_release_sys_shared_dec, memory_order_release, memory_order_acquire, memory_scope_device, __local); atomic_dec(atom_acq_rel_sys_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local); -uint FUNC(bfe_u32)(uint base, uint pos, uint len) -{ +uint FUNC(bfe_u32)(uint base, uint pos, uint len) { return intel_ubfe(base, pos, len); } -ulong FUNC(bfe_u64)(ulong base, uint pos, uint len) -{ +ulong FUNC(bfe_u64)(ulong base, uint pos, uint len) { return intel_ubfe(base, pos, len); } -int FUNC(bfe_s32)(int base, uint pos, uint len) -{ +int FUNC(bfe_s32)(int base, uint pos, uint len) { return intel_sbfe(base, pos, len); } -long FUNC(bfe_s64)(long base, uint pos, uint len) -{ +long FUNC(bfe_s64)(long base, uint pos, uint len) { return intel_sbfe(base, pos, len); -} \ No newline at end of file +} + +void FUNC(__assertfail)( + __private ulong* message, + __private ulong* file, + __private uint* line, + __private ulong* function, + __private ulong* charSize +) { +} diff --git a/ptx/lib/notcuda_ptx_impl.spv b/ptx/lib/notcuda_ptx_impl.spv index 1ef470f..aa30fb8 100644 Binary files a/ptx/lib/notcuda_ptx_impl.spv and b/ptx/lib/notcuda_ptx_impl.spv differ diff --git a/ptx/src/test/spirv_run/assertfail.ptx b/ptx/src/test/spirv_run/assertfail.ptx new file mode 100644 index 0000000..47ecb53 --- /dev/null +++ b/ptx/src/test/spirv_run/assertfail.ptx @@ -0,0 +1,79 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.extern .func __assertfail +( + .param .b64 __assertfail_param_0, + .param .b64 __assertfail_param_1, + .param .b32 __assertfail_param_2, + .param .b64 __assertfail_param_3, + .param .b64 __assertfail_param_4 +); + +.extern .func __assertfail +( + .param .b64 __assertfail_param_0, + .param .b64 __assertfail_param_1, + .param .b32 __assertfail_param_2, + .param .b64 __assertfail_param_3, + .param .b64 __assertfail_param_4 +); + +.visible .entry assertfail( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + + + { + .reg .b32 temp_param_reg; + mov.u32 temp_param_reg, 0; + // } + .param .b64 param0; + st.param.b64 [param0+0], in_addr; + .param .b64 param1; + st.param.b64 [param1+0], in_addr; + .param .b32 param2; + st.param.b32 [param2+0], temp_param_reg; + .param .b64 param3; + st.param.b64 [param3+0], in_addr; + .param .b64 param4; + st.param.b64 [param4+0], in_addr; + call.uni + __assertfail, + ( + param0, + param1, + param2, + param3, + param4 + ); + + //{ + } + + ld.u64 temp, [in_addr]; + add.u64 temp2, temp, 1; + st.u64 [out_addr], temp2; + ret; +} + + +.extern .func __assertfail +( + .param .b64 __assertfail_param_0, + .param .b64 __assertfail_param_1, + .param .b32 __assertfail_param_2, + .param .b64 __assertfail_param_3, + .param .b64 __assertfail_param_4 +); diff --git a/ptx/src/test/spirv_run/assertfail.spvtxt b/ptx/src/test/spirv_run/assertfail.spvtxt new file mode 100644 index 0000000..09f9abf --- /dev/null +++ b/ptx/src/test/spirv_run/assertfail.spvtxt @@ -0,0 +1,105 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %67 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %12 "assertfail" + OpDecorate %1 LinkageAttributes "__notcuda_ptx_impl____assertfail" Import + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint + %73 = OpTypeFunction %void %_ptr_Function_ulong %_ptr_Function_ulong %_ptr_Function_uint %_ptr_Function_ulong %_ptr_Function_ulong + %74 = OpTypeFunction %void %ulong %ulong + %uint_0 = OpConstant %uint 0 + %ulong_0 = OpConstant %ulong 0 + %uchar = OpTypeInt 8 0 +%_ptr_Function_uchar = OpTypePointer Function %uchar + %ulong_0_0 = OpConstant %ulong 0 + %ulong_0_1 = OpConstant %ulong 0 + %ulong_0_2 = OpConstant %ulong 0 + %ulong_0_3 = OpConstant %ulong 0 +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %1 = OpFunction %void None %73 + %61 = OpFunctionParameter %_ptr_Function_ulong + %62 = OpFunctionParameter %_ptr_Function_ulong + %63 = OpFunctionParameter %_ptr_Function_uint + %64 = OpFunctionParameter %_ptr_Function_ulong + %65 = OpFunctionParameter %_ptr_Function_ulong + OpFunctionEnd + %12 = OpFunction %void None %74 + %25 = OpFunctionParameter %ulong + %26 = OpFunctionParameter %ulong + %60 = OpLabel + %13 = OpVariable %_ptr_Function_ulong Function + %14 = OpVariable %_ptr_Function_ulong Function + %15 = OpVariable %_ptr_Function_ulong Function + %16 = OpVariable %_ptr_Function_ulong Function + %17 = OpVariable %_ptr_Function_ulong Function + %18 = OpVariable %_ptr_Function_ulong Function + %19 = OpVariable %_ptr_Function_uint Function + %20 = OpVariable %_ptr_Function_ulong Function + %21 = OpVariable %_ptr_Function_ulong Function + %22 = OpVariable %_ptr_Function_uint Function + %23 = OpVariable %_ptr_Function_ulong Function + %24 = OpVariable %_ptr_Function_ulong Function + OpStore %13 %25 + OpStore %14 %26 + %27 = OpLoad %ulong %13 + OpStore %15 %27 + %28 = OpLoad %ulong %14 + OpStore %16 %28 + %53 = OpCopyObject %uint %uint_0 + %29 = OpCopyObject %uint %53 + OpStore %19 %29 + %30 = OpLoad %ulong %15 + %77 = OpBitcast %_ptr_Function_uchar %20 + %78 = OpInBoundsPtrAccessChain %_ptr_Function_uchar %77 %ulong_0 + %43 = OpBitcast %_ptr_Function_ulong %78 + %54 = OpCopyObject %ulong %30 + OpStore %43 %54 + %31 = OpLoad %ulong %15 + %79 = OpBitcast %_ptr_Function_uchar %21 + %80 = OpInBoundsPtrAccessChain %_ptr_Function_uchar %79 %ulong_0_0 + %45 = OpBitcast %_ptr_Function_ulong %80 + %55 = OpCopyObject %ulong %31 + OpStore %45 %55 + %32 = OpLoad %uint %19 + %81 = OpBitcast %_ptr_Function_uchar %22 + %82 = OpInBoundsPtrAccessChain %_ptr_Function_uchar %81 %ulong_0_1 + %47 = OpBitcast %_ptr_Function_uint %82 + OpStore %47 %32 + %33 = OpLoad %ulong %15 + %83 = OpBitcast %_ptr_Function_uchar %23 + %84 = OpInBoundsPtrAccessChain %_ptr_Function_uchar %83 %ulong_0_2 + %49 = OpBitcast %_ptr_Function_ulong %84 + %56 = OpCopyObject %ulong %33 + OpStore %49 %56 + %34 = OpLoad %ulong %15 + %85 = OpBitcast %_ptr_Function_uchar %24 + %86 = OpInBoundsPtrAccessChain %_ptr_Function_uchar %85 %ulong_0_3 + %51 = OpBitcast %_ptr_Function_ulong %86 + %57 = OpCopyObject %ulong %34 + OpStore %51 %57 + %87 = OpFunctionCall %void %1 %20 %21 %22 %23 %24 + %36 = OpLoad %ulong %15 + %58 = OpConvertUToPtr %_ptr_Generic_ulong %36 + %35 = OpLoad %ulong %58 + OpStore %17 %35 + %38 = OpLoad %ulong %17 + %37 = OpIAdd %ulong %38 %ulong_1 + OpStore %18 %37 + %39 = OpLoad %ulong %16 + %40 = OpLoad %ulong %18 + %59 = OpConvertUToPtr %_ptr_Generic_ulong %39 + OpStore %59 %40 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index c70ab5c..d6a8038 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -139,6 +139,8 @@ test_ptx!(stateful_ld_st_ntid, [123u64], [123u64]); test_ptx!(stateful_ld_st_ntid_chain, [12651u64], [12651u64]); test_ptx!(stateful_ld_st_ntid_sub, [96311u64], [96311u64]); test_ptx!(shared_ptr_take_address, [97815231u64], [97815231u64]); +// For now, we just that it builds and links +test_ptx!(assertfail, [716523871u64], [716523872u64]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/shl.spvtxt b/ptx/src/test/spirv_run/shl.spvtxt index ce19fa5..5841146 100644 --- a/ptx/src/test/spirv_run/shl.spvtxt +++ b/ptx/src/test/spirv_run/shl.spvtxt @@ -39,7 +39,8 @@ OpStore %6 %12 %15 = OpLoad %ulong %6 %21 = OpCopyObject %ulong %15 - %20 = OpShiftLeftLogical %ulong %21 %uint_2 + %32 = OpUConvert %ulong %uint_2 + %20 = OpShiftLeftLogical %ulong %21 %32 %14 = OpCopyObject %ulong %20 OpStore %7 %14 %16 = OpLoad %ulong %5 diff --git a/ptx/src/test/spirv_run/shl_link_hack.ptx b/ptx/src/test/spirv_run/shl_link_hack.ptx new file mode 100644 index 0000000..a32555c --- /dev/null +++ b/ptx/src/test/spirv_run/shl_link_hack.ptx @@ -0,0 +1,30 @@ +// HACK ALERT +// This test is for testing workaround for a bug in IGC where linking fails +// if there is shl/shr with different width of value and shift + +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry shl_link_hack( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + // Here only to trigger linking + .reg .u32 unused; + atom.inc.u32 unused, [out_addr], 2000000; + + ld.u64 temp, [in_addr]; + shl.b64 temp2, temp, 2; + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/shl_link_hack.spvtxt b/ptx/src/test/spirv_run/shl_link_hack.spvtxt new file mode 100644 index 0000000..0114a55 --- /dev/null +++ b/ptx/src/test/spirv_run/shl_link_hack.spvtxt @@ -0,0 +1,65 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %34 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "shl_link_hack" + OpDecorate %29 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_generic_inc" Import + %void = OpTypeVoid + %uint = OpTypeInt 32 0 +%_ptr_Generic_uint = OpTypePointer Generic %uint + %38 = OpTypeFunction %uint %_ptr_Generic_uint %uint + %ulong = OpTypeInt 64 0 + %40 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Function_uint = OpTypePointer Function %uint +%uint_2000000 = OpConstant %uint 2000000 +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %uint_2 = OpConstant %uint 2 + %29 = OpFunction %uint None %38 + %31 = OpFunctionParameter %_ptr_Generic_uint + %32 = OpFunctionParameter %uint + OpFunctionEnd + %1 = OpFunction %void None %40 + %9 = OpFunctionParameter %ulong + %10 = OpFunctionParameter %ulong + %28 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + %8 = OpVariable %_ptr_Function_uint Function + OpStore %2 %9 + OpStore %3 %10 + %11 = OpLoad %ulong %2 + OpStore %4 %11 + %12 = OpLoad %ulong %3 + OpStore %5 %12 + %14 = OpLoad %ulong %5 + %23 = OpConvertUToPtr %_ptr_Generic_uint %14 + %13 = OpFunctionCall %uint %29 %23 %uint_2000000 + OpStore %8 %13 + %16 = OpLoad %ulong %4 + %24 = OpConvertUToPtr %_ptr_Generic_ulong %16 + %15 = OpLoad %ulong %24 + OpStore %6 %15 + %18 = OpLoad %ulong %6 + %26 = OpCopyObject %ulong %18 + %44 = OpUConvert %ulong %uint_2 + %25 = OpShiftLeftLogical %ulong %26 %44 + %17 = OpCopyObject %ulong %25 + OpStore %7 %17 + %19 = OpLoad %ulong %5 + %20 = OpLoad %ulong %7 + %27 = OpConvertUToPtr %_ptr_Generic_ulong %19 + OpStore %27 %20 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 2b14bd7..ed752e9 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -465,7 +465,9 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result, _>>()?; let must_link_ptx_impl = ptx_impl_imports.len() > 0; let directives = ptx_impl_imports @@ -1173,13 +1175,13 @@ fn emit_memory_model(builder: &mut dr::Builder) { fn translate_directive<'input>( id_defs: &mut GlobalStringIdResolver<'input>, - ptx_impl_imports: &mut HashMap, + ptx_impl_imports: &mut HashMap>, d: ast::Directive<'input, ast::ParsedArgParams<'input>>, -) -> Result, TranslateError> { +) -> Result>, TranslateError> { Ok(match d { - ast::Directive::Variable(v) => Directive::Variable(translate_variable(id_defs, v)?), + ast::Directive::Variable(v) => Some(Directive::Variable(translate_variable(id_defs, v)?)), ast::Directive::Method(f) => { - Directive::Method(translate_function(id_defs, ptx_impl_imports, f)?) + translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method) } }) } @@ -1219,11 +1221,27 @@ fn translate_variable<'a>( fn translate_function<'a>( id_defs: &mut GlobalStringIdResolver<'a>, - ptx_impl_imports: &mut HashMap, + ptx_impl_imports: &mut HashMap>, f: ast::ParsedFunction<'a>, -) -> Result, TranslateError> { +) -> Result>, TranslateError> { + let import_as = match &f.func_directive { + ast::MethodDecl::Func(_, "__assertfail", _) => { + Some("__notcuda_ptx_impl____assertfail".to_owned()) + } + _ => None, + }; let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?; - to_ssa(ptx_impl_imports, str_resolver, fn_resolver, fn_decl, f.body) + let mut func = to_ssa(ptx_impl_imports, str_resolver, fn_resolver, fn_decl, f.body)?; + func.import_as = import_as; + if func.import_as.is_some() { + ptx_impl_imports.insert( + func.import_as.as_ref().unwrap().clone(), + Directive::Method(func), + ); + Ok(None) + } else { + Ok(Some(func)) + } } fn expand_kernel_params<'a, 'b>( @@ -5302,7 +5320,6 @@ pub enum StateSpace { Local, Shared, Param, - ParamReg, } impl From for StateSpace {