ZLUDA/zluda/tests/common.rs
Andrzej Janik 1b9ba2b233 Nobody expects the Red Team
Too many changes to list, but broadly:
* Remove Intel GPU support from the compiler
* Add AMD GPU support to the compiler
* Remove Intel GPU host code
* Add AMD GPU host code
* More device instructions. From 40 to 68
* More host functions. From 48 to 184
* Add proof of concept implementation of OptiX framework
* Add minimal support of cuDNN, cuBLAS, cuSPARSE, cuFFT, NCCL, NVML
* Improve ZLUDA launcher for Windows
2024-02-11 20:45:51 +01:00

129 lines
4 KiB
Rust

#![allow(non_snake_case)]
use cuda_base::cuda_function_declarations;
use std::ffi::c_void;
macro_rules! unimplemented_cuda_fn {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path);*) => {
pub trait CudaDriverFns {
fn new() -> Self;
fn is_nvidia() -> bool;
$(
unsafe fn $fn_name (&self, $( $arg_id : $arg_type),* ) -> $ret_type;
)*
}
#[derive(Copy, Clone)]
pub struct Cuda {
lib: *mut c_void
}
unsafe impl Send for Cuda {}
unsafe impl Sync for Cuda {}
impl CudaDriverFns for Cuda {
fn new() -> Self {
let lib = unsafe { os::load_cuda() };
Self { lib }
}
fn is_nvidia() -> bool { true }
$(
unsafe fn $fn_name (&self, $( $arg_id : $arg_type),* ) -> $ret_type {
let fn_ptr = os::get_proc_address(self.lib, concat!(stringify!($fn_name), "\0").as_bytes());
let cu_fn = std::mem::transmute::<_, unsafe extern $abi fn( $( $arg_id : $arg_type),* ) -> $ret_type>(fn_ptr);
cu_fn ( $( $arg_id),* )
}
)*
}
#[derive(Copy, Clone)]
pub struct Zluda;
impl CudaDriverFns for Zluda {
fn new() -> Self { Self }
fn is_nvidia() -> bool { false }
$(
unsafe fn $fn_name (&self, $( $arg_id : $arg_type),* ) -> $ret_type {
zluda::cuda::$fn_name ( $( $arg_id),* )
}
)*
}
};
}
cuda_function_declarations!(cuda_types, unimplemented_cuda_fn, UNUSED, []);
#[macro_export]
macro_rules! cuda_driver_test {
($func:ident) => {
paste::paste! {
#[test]
#[allow(non_snake_case)]
fn [<$func _zluda>]() {
unsafe { $func::<crate::common::Zluda>(crate::common::Zluda::new()) }
}
#[test]
#[allow(non_snake_case)]
fn [<$func _cuda>]() {
unsafe { $func::<crate::common::Cuda>(crate::common::Cuda::new()) }
}
}
};
}
#[allow(dead_code)]
pub const CU_STREAM_LEGACY: cuda_types::CUstream = 1 as *mut _;
#[allow(dead_code)]
pub const CU_STREAM_PER_THREAD: cuda_types::CUstream = 2 as *mut _;
#[cfg(windows)]
mod os {
use std::ffi::c_void;
pub unsafe fn load_cuda() -> *mut c_void {
use winapi::um::libloaderapi::LoadLibraryA;
let result = LoadLibraryA(b"C:\\Windows\\System32\\nvcuda.dll\0".as_ptr() as _);
if result == std::ptr::null_mut() {
panic!("{:?}", std::io::Error::last_os_error());
}
result as _
}
pub unsafe fn get_proc_address(handle: *mut c_void, func: &[u8]) -> *mut c_void {
use winapi::um::libloaderapi::GetProcAddress;
GetProcAddress(handle as _, func.as_ptr() as *const _) as _
}
}
#[cfg(not(windows))]
mod os {
use std::ffi::c_void;
use libc;
use std::ffi::CStr;
#[cfg(test)]
pub unsafe fn load_cuda() -> *mut c_void {
// Ubuntu path
let mut result = libc::dlopen(
b"/usr/lib/x86_64-linux-gnu/libcuda.so.1\0".as_ptr() as _,
libc::RTLD_LOCAL | libc::RTLD_LAZY,
);
// RHEL path
if result == std::ptr::null_mut() {
result = libc::dlopen(
b"/usr/lib64/libcuda.so.1\0".as_ptr() as _,
libc::RTLD_LOCAL | libc::RTLD_LAZY,
);
}
if result == std::ptr::null_mut() {
panic!("{}", CStr::from_ptr(libc::dlerror()).to_string_lossy());
}
result
}
#[cfg(test)]
pub unsafe fn get_proc_address(handle: *mut c_void, func: &[u8]) -> *mut c_void {
libc::dlsym(handle, func.as_ptr() as *const _)
}
}