diff --git a/zluda_rtc/src/lib.rs b/zluda_rtc/src/lib.rs index 1725d1a..427a2e0 100644 --- a/zluda_rtc/src/lib.rs +++ b/zluda_rtc/src/lib.rs @@ -19,19 +19,22 @@ fn to_nvrtc(status: hiprtc_sys::hiprtcResult) -> nvrtcResult { match status { hiprtc_sys::hiprtcResult::HIPRTC_SUCCESS => nvrtcResult::NVRTC_SUCCESS, hiprtc_sys::hiprtcResult::HIPRTC_ERROR_INVALID_PROGRAM => nvrtcResult::NVRTC_ERROR_INVALID_PROGRAM, + hiprtc_sys::hiprtcResult::HIPRTC_ERROR_COMPILATION => nvrtcResult::NVRTC_ERROR_COMPILATION, err => panic!("[ZLUDA] HIPRTC failed: {}", err.0), } } -fn get_error_string(result: nvrtcResult) -> *const ::std::os::raw::c_char { - let error_string = - match result { - nvrtcResult::NVRTC_ERROR_INTERNAL_ERROR => String::from("NVRTC_ERROR_INTERNAL_ERROR"), - _ => result.0.to_string(), - }; - println!("[ZLUDA] HIPRTC failed: {}", error_string); - let cstr = std::ffi::CString::new("").unwrap(); - cstr.as_ptr() +fn to_hiprtc(status: nvrtcResult) -> hiprtc_sys::hiprtcResult { + match status { + nvrtcResult::NVRTC_SUCCESS => hiprtc_sys::hiprtcResult::HIPRTC_SUCCESS, + nvrtcResult::NVRTC_ERROR_INVALID_PROGRAM => hiprtc_sys::hiprtcResult::HIPRTC_ERROR_INVALID_PROGRAM, + nvrtcResult::NVRTC_ERROR_COMPILATION => hiprtc_sys::hiprtcResult::HIPRTC_ERROR_COMPILATION, + err => panic!("[ZLUDA] HIPRTC failed: {}", err.0), + } +} + +unsafe fn get_error_string(result: nvrtcResult) -> *const ::std::os::raw::c_char { + hiprtcGetErrorString(to_hiprtc(result)) } unsafe fn create_program( diff --git a/zluda_rtc/src/nvrtc.rs b/zluda_rtc/src/nvrtc.rs index 05f3acd..2473091 100644 --- a/zluda_rtc/src/nvrtc.rs +++ b/zluda_rtc/src/nvrtc.rs @@ -43,7 +43,7 @@ pub struct nvrtcResult(pub ::std::os::raw::c_int); #[doc = " \\ingroup error\n \\brief nvrtcGetErrorString is a helper function that returns a string\n describing the given nvrtcResult code, e.g., NVRTC_SUCCESS to\n \\c \"NVRTC_SUCCESS\".\n For unrecognized enumeration values, it returns\n \\c \"NVRTC_ERROR unknown\".\n\n \\param [in] result CUDA Runtime Compilation API result code.\n \\return Message string for the given #nvrtcResult code."] #[no_mangle] -pub extern "system" fn nvrtcGetErrorString(result: nvrtcResult) -> *const ::std::os::raw::c_char { +pub unsafe extern "system" fn nvrtcGetErrorString(result: nvrtcResult) -> *const ::std::os::raw::c_char { crate::get_error_string(result) }