From 1ef7ef39385201d56e3cf320be843a2f8d02a702 Mon Sep 17 00:00:00 2001 From: Seunghoon Lee Date: Thu, 15 Feb 2024 06:54:21 +0900 Subject: [PATCH] Add support of cuBLAS, cuSPARSE for Windows. --- Makefile.toml | 5 +++ hipfft-sys/build.rs | 10 ++++- rocblas-sys/build.rs | 10 ++++- rocsolver-sys/build.rs | 10 ++++- rocsparse-sys/build.rs | 10 ++++- zluda_blas/Cargo.toml | 1 - zluda_blas/src/cublas.rs | 25 +++++++++++- zluda_blas/src/lib.rs | 65 ++++++++++++++++++++++++++++++ zluda_ccl/Cargo.toml | 1 - zluda_dnn/Cargo.toml | 1 - zluda_fft/Cargo.toml | 1 - zluda_inject/src/bin.rs | 85 +++++++++++++++++++++++++++++++++++++++ zluda_redirect/src/lib.rs | 36 ++++++++++++++++- zluda_sparse/Cargo.toml | 1 - 14 files changed, 244 insertions(+), 17 deletions(-) diff --git a/Makefile.toml b/Makefile.toml index adab2b9..114750a 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -13,10 +13,15 @@ command = "cargo" args = [ "build", "-p", "offline_compiler", + "-p", "zluda_blas", + "-p", "zluda_ccl", + "-p", "zluda_dnn", "-p", "zluda_dump", "-p", "zluda_inject", + "-p", "zluda_fft", "-p", "zluda_lib", "-p", "zluda_ml", + "-p", "zluda_sparse", "-p", "zluda_redirect", ] diff --git a/hipfft-sys/build.rs b/hipfft-sys/build.rs index 61a5e9b..46958a2 100644 --- a/hipfft-sys/build.rs +++ b/hipfft-sys/build.rs @@ -1,4 +1,10 @@ -fn main() { +use std::env::VarError; +use std::{env, path::PathBuf}; + +fn main() -> Result<(), VarError> { println!("cargo:rustc-link-lib=dylib=hipfft"); - println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); + let mut path = PathBuf::from(env::var("HIP_PATH")?); + path.push("lib"); + println!("cargo:rustc-link-search=native={}", path.display()); + Ok(()) } diff --git a/rocblas-sys/build.rs b/rocblas-sys/build.rs index cd0dd1b..e54039c 100644 --- a/rocblas-sys/build.rs +++ b/rocblas-sys/build.rs @@ -1,4 +1,10 @@ -fn main() { +use std::env::VarError; +use std::{env, path::PathBuf}; + +fn main() -> Result<(), VarError> { println!("cargo:rustc-link-lib=dylib=rocblas"); - println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); + let mut path = PathBuf::from(env::var("HIP_PATH")?); + path.push("lib"); + println!("cargo:rustc-link-search=native={}", path.display()); + Ok(()) } diff --git a/rocsolver-sys/build.rs b/rocsolver-sys/build.rs index b44db89..7f9abdd 100644 --- a/rocsolver-sys/build.rs +++ b/rocsolver-sys/build.rs @@ -1,4 +1,10 @@ -fn main() { +use std::env::VarError; +use std::{env, path::PathBuf}; + +fn main() -> Result<(), VarError> { println!("cargo:rustc-link-lib=dylib=rocsolver"); - println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); + let mut path = PathBuf::from(env::var("HIP_PATH")?); + path.push("lib"); + println!("cargo:rustc-link-search=native={}", path.display()); + Ok(()) } diff --git a/rocsparse-sys/build.rs b/rocsparse-sys/build.rs index 8470be3..fc95b5b 100644 --- a/rocsparse-sys/build.rs +++ b/rocsparse-sys/build.rs @@ -1,4 +1,10 @@ -fn main() { +use std::env::VarError; +use std::{env, path::PathBuf}; + +fn main() -> Result<(), VarError> { println!("cargo:rustc-link-lib=dylib=rocsparse"); - println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); + let mut path = PathBuf::from(env::var("HIP_PATH")?); + path.push("lib"); + println!("cargo:rustc-link-search=native={}", path.display()); + Ok(()) } diff --git a/zluda_blas/Cargo.toml b/zluda_blas/Cargo.toml index 8cc41fd..05c3315 100644 --- a/zluda_blas/Cargo.toml +++ b/zluda_blas/Cargo.toml @@ -16,6 +16,5 @@ zluda_dark_api = { path = "../zluda_dark_api" } cuda_types = { path = "../cuda_types" } [package.metadata.zluda] -linux_only = true linux_names = ["libcublas.so.10", "libcublas.so.11"] dump_names = ["libcublas.so"] diff --git a/zluda_blas/src/cublas.rs b/zluda_blas/src/cublas.rs index b0bf587..c833071 100644 --- a/zluda_blas/src/cublas.rs +++ b/zluda_blas/src/cublas.rs @@ -3955,7 +3955,7 @@ pub unsafe extern "system" fn cublasGemmStridedBatchedEx( computeType: cublasComputeType_t, algo: cublasGemmAlgo_t, ) -> cublasStatus_t { - crate::unsupported() + crate::gemm_strided_batched_ex(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo) } #[no_mangle] @@ -6345,7 +6345,7 @@ pub unsafe extern "system" fn cublasZhpr2( unimplemented!() } -#[no_mangle] +/*#[no_mangle] pub unsafe extern "system" fn cublasSgemm( transa: ::std::os::raw::c_char, transb: ::std::os::raw::c_char, @@ -6362,6 +6362,27 @@ pub unsafe extern "system" fn cublasSgemm( ldc: ::std::os::raw::c_int, ) { unimplemented!() +}*/ +#[no_mangle] +pub unsafe extern "system" fn cublasSgemm( + handle: cublasHandle_t, + transa: cublasOperation_t, + transb: cublasOperation_t, + m: ::std::os::raw::c_int, + n: ::std::os::raw::c_int, + k: ::std::os::raw::c_int, + alpha: *const f32, + A: *const f32, + lda: ::std::os::raw::c_int, + B: *const f32, + ldb: ::std::os::raw::c_int, + beta: *const f32, + C: *mut f32, + ldc: ::std::os::raw::c_int, +) -> cublasStatus_t { + crate::sgemm_v2( + handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, + ) } #[no_mangle] diff --git a/zluda_blas/src/lib.rs b/zluda_blas/src/lib.rs index 0d50a99..1eb587a 100644 --- a/zluda_blas/src/lib.rs +++ b/zluda_blas/src/lib.rs @@ -386,6 +386,71 @@ fn to_compute_type(compute_type: cublasComputeType_t) -> rocblas_datatype { } } +unsafe fn gemm_strided_batched_ex( + handle: *mut cublasContext, + transa: cublasOperation_t, + transb: cublasOperation_t, + m: i32, + n: i32, + k: i32, + alpha: *const std::ffi::c_void, + a: *const std::ffi::c_void, + atype: cudaDataType_t, + lda: i32, + stride_a: i64, + b: *const std::ffi::c_void, + btype: cudaDataType_t, + ldb: i32, + stride_b: i64, + beta: *const std::ffi::c_void, + c: *mut std::ffi::c_void, + ctype: cudaDataType_t, + ldc: i32, + stride_c: i64, + batch_count: i32, + compute_type: cublasComputeType_t, + algo: cublasGemmAlgo_t, +) -> cublasStatus_t { + let transa = op_from_cuda(transa); + let transb = op_from_cuda(transb); + let atype = type_from_cuda(atype); + let btype = type_from_cuda(btype); + let ctype = type_from_cuda(ctype); + let compute_type = to_compute_type(compute_type); + let algo = to_algo(algo); + to_cuda(rocblas_gemm_strided_batched_ex( + handle.cast(), + transa, + transb, + m, + n, + k, + alpha, + a, + atype, + lda, + stride_a, + b, + btype, + ldb, + stride_b, + beta, + c, + ctype, + ldc, + stride_c, + c, + ctype, + ldc, + stride_c, + batch_count, + compute_type, + algo, + 0, + 0, + )) +} + unsafe fn zgemm_strided_batch( handle: *mut cublasContext, transa: cublasOperation_t, diff --git a/zluda_ccl/Cargo.toml b/zluda_ccl/Cargo.toml index c429a00..30aec4b 100644 --- a/zluda_ccl/Cargo.toml +++ b/zluda_ccl/Cargo.toml @@ -11,6 +11,5 @@ crate-type = ["cdylib"] [dependencies] [package.metadata.zluda] -linux_only = true linux_names = ["libnccl.so.2"] dump_names = ["libnccl.so"] diff --git a/zluda_dnn/Cargo.toml b/zluda_dnn/Cargo.toml index 1ba5dce..74c81a4 100644 --- a/zluda_dnn/Cargo.toml +++ b/zluda_dnn/Cargo.toml @@ -13,6 +13,5 @@ miopen-sys = { path = "../miopen-sys" } hip_runtime-sys = { path = "../hip_runtime-sys" } [package.metadata.zluda] -linux_only = true linux_names = ["libcudnn.so.7", "libcudnn.so.8"] dump_names = ["libcudnn.so"] diff --git a/zluda_fft/Cargo.toml b/zluda_fft/Cargo.toml index d8d8efb..b949bbc 100644 --- a/zluda_fft/Cargo.toml +++ b/zluda_fft/Cargo.toml @@ -17,6 +17,5 @@ slab = "0.4" lazy_static = "1.4.0" [package.metadata.zluda] -linux_only = true linux_names = ["libcufft.so.10"] dump_names = ["libcufft.so"] diff --git a/zluda_inject/src/bin.rs b/zluda_inject/src/bin.rs index df664cf..0219c26 100644 --- a/zluda_inject/src/bin.rs +++ b/zluda_inject/src/bin.rs @@ -23,6 +23,11 @@ use winapi::um::{ use winapi::um::winbase::{INFINITE, WAIT_FAILED}; static REDIRECT_DLL: &'static str = "zluda_redirect.dll"; +static CUBLAS_DLL: &'static str = "cublas.dll"; +static CUDNN_DLL: &'static str = "cudnn.dll"; +static CUFFT_DLL: &'static str = "cufft.dll"; +static CUSPARSE_DLL: &'static str = "cusparse.dll"; +static NCCL_DLL: &'static str = "nccl.dll"; static NVCUDA_DLL: &'static str = "nvcuda.dll"; static NVML_DLL: &'static str = "nvml.dll"; static NVAPI_DLL: &'static str = "nvapi64.dll"; @@ -33,6 +38,26 @@ include!("../../zluda_redirect/src/payload_guid.rs"); #[derive(FromArgs)] /// Launch application with custom CUDA libraries struct ProgramArguments { + /// DLL to be injected instead of system cublas.dll. If not provided {0}, will use cublas.dll from its own directory + #[argh(option)] + cublas: Option, + + /// DLL to be injected instead of system cudnn.dll. If not provided {0}, will use cudnn.dll from its own directory + #[argh(option)] + cudnn: Option, + + /// DLL to be injected instead of system cufft.dll. If not provided {0}, will use cufft.dll from its own directory + #[argh(option)] + cufft: Option, + + /// DLL to be injected instead of system cusparse.dll. If not provided {0}, will use cusparse.dll from its own directory + #[argh(option)] + cusparse: Option, + + /// DLL to be injected instead of system nccl.dll. If not provided {0}, will use nccl.dll from its own directory + #[argh(option)] + nccl: Option, + /// DLL to be injected instead of system nvcuda.dll. If not provided {0}, will use nvcuda.dll from its own directory #[argh(option)] nvcuda: Option, @@ -65,6 +90,11 @@ pub fn main_impl() -> Result<(), Box> { let mut startup_info = unsafe { mem::zeroed::() }; let mut proc_info = unsafe { mem::zeroed::() }; let mut dlls_to_inject = vec![ + environment.cublas_path_zero_terminated.as_ptr() as _, + //environment.cudnn_path_zero_terminated.as_ptr() as _, + environment.cufft_path_zero_terminated.as_ptr() as _, + environment.cusparse_path_zero_terminated.as_ptr() as _, + environment.nccl_path_zero_terminated.as_ptr() as _, environment.nvcuda_path_zero_terminated.as_ptr() as _, environment.nvml_path_zero_terminated.as_ptr() as *const i8, environment.redirect_path_zero_terminated.as_ptr() as _, @@ -146,6 +176,11 @@ pub fn main_impl() -> Result<(), Box> { } struct NormalizedArguments { + cublas_path: PathBuf, + cudnn_path: PathBuf, + cufft_path: PathBuf, + cusparse_path: PathBuf, + nccl_path: PathBuf, nvcuda_path: PathBuf, nvml_path: PathBuf, nvapi_path: Option, @@ -157,6 +192,16 @@ struct NormalizedArguments { impl NormalizedArguments { fn new(prog_args: ProgramArguments) -> Result> { let current_exe = env::current_exe()?; + let cublas_path = + Self::get_absolute_path_or_default(¤t_exe, prog_args.cublas, CUBLAS_DLL)?; + let cudnn_path = + Self::get_absolute_path_or_default(¤t_exe, prog_args.cudnn, CUDNN_DLL)?; + let cufft_path = + Self::get_absolute_path_or_default(¤t_exe, prog_args.cufft, CUFFT_DLL)?; + let cusparse_path = + Self::get_absolute_path_or_default(¤t_exe, prog_args.cusparse, CUSPARSE_DLL)?; + let nccl_path = + Self::get_absolute_path_or_default(¤t_exe, prog_args.nccl, NCCL_DLL)?; let nvcuda_path = Self::get_absolute_path_or_default(¤t_exe, prog_args.nvcuda, NVCUDA_DLL)?; let nvml_path = Self::get_absolute_path_or_default(¤t_exe, prog_args.nvml, NVML_DLL)?; @@ -167,6 +212,11 @@ impl NormalizedArguments { let mut redirect_path = current_exe.parent().unwrap().to_path_buf(); redirect_path.push(REDIRECT_DLL); Ok(Self { + cublas_path, + cudnn_path, + cufft_path, + cusparse_path, + nccl_path, nvcuda_path, nvml_path, nvapi_path, @@ -224,6 +274,11 @@ impl NormalizedArguments { } struct Environment { + cublas_path_zero_terminated: String, + cudnn_path_zero_terminated: String, + cufft_path_zero_terminated: String, + cusparse_path_zero_terminated: String, + nccl_path_zero_terminated: String, nvcuda_path_zero_terminated: String, nvml_path_zero_terminated: String, nvapi_path_zero_terminated: Option, @@ -239,6 +294,31 @@ struct Environment { impl Environment { fn setup(args: NormalizedArguments) -> io::Result { let _temp_dir = TempDir::new()?; + let cublas_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( + args.cublas_path, + &_temp_dir, + CUBLAS_DLL, + )?); + let cudnn_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( + args.cudnn_path, + &_temp_dir, + CUDNN_DLL, + )?); + let cufft_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( + args.cufft_path, + &_temp_dir, + CUFFT_DLL, + )?); + let cusparse_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( + args.cusparse_path, + &_temp_dir, + CUSPARSE_DLL, + )?); + let nccl_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( + args.nccl_path, + &_temp_dir, + NCCL_DLL, + )?); let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( args.nvcuda_path, &_temp_dir, @@ -269,6 +349,11 @@ impl Environment { .transpose()?; let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path); Ok(Self { + cublas_path_zero_terminated, + cudnn_path_zero_terminated, + cufft_path_zero_terminated, + cusparse_path_zero_terminated, + nccl_path_zero_terminated, nvcuda_path_zero_terminated, nvml_path_zero_terminated, nvapi_path_zero_terminated, diff --git a/zluda_redirect/src/lib.rs b/zluda_redirect/src/lib.rs index ccc905f..9eabea7 100644 --- a/zluda_redirect/src/lib.rs +++ b/zluda_redirect/src/lib.rs @@ -52,6 +52,10 @@ use winapi::{ include!("payload_guid.rs"); const WIN_MAX_PATH: usize = 260; +const CUBLAS_UTF8: &'static str = "CUBLAS.DLL"; +const CUBLAS_UTF16: &[u16] = wch!("CUBLAS.DLL"); +const CUDNN_UTF8: &'static str = "CUDNN.DLL"; +const CUDNN_UTF16: &[u16] = wch!("CUDNN.DLL"); const NVCUDA1_UTF8: &'static str = "NVCUDA.DLL"; const NVCUDA1_UTF16: &[u16] = wch!("NVCUDA.DLL"); const NVCUDA2_UTF8: &'static str = "NVCUDA.DLL"; @@ -64,6 +68,10 @@ const NVOPTIX_UTF8: &'static str = "OPTIX.6.6.0.DLL"; const NVOPTIX_UTF16: &[u16] = wch!("OPTIX.6.6.0.DLL"); static mut ZLUDA_PATH_UTF8: Option<&'static [u8]> = None; static mut ZLUDA_PATH_UTF16: Vec = Vec::new(); +static mut ZLUDA_BLAS_PATH_UTF8: Option<&'static [u8]> = None; +static mut ZLUDA_BLAS_PATH_UTF16: Vec = Vec::new(); +static mut ZLUDA_DNN_PATH_UTF8: Option<&'static [u8]> = None; +static mut ZLUDA_DNN_PATH_UTF16: Vec = Vec::new(); static mut ZLUDA_ML_PATH_UTF8: Option<&'static [u8]> = None; static mut ZLUDA_ML_PATH_UTF16: Vec = Vec::new(); static mut ZLUDA_API_PATH_UTF8: Option<&'static [u8]> = None; @@ -199,7 +207,11 @@ unsafe fn get_library_name_utf8(raw_library_name: *const u8) -> *const u8 { } } } - if is_nvcuda_dll_utf8(library_name) { + if is_cublas_dll_utf8(library_name) { + return ZLUDA_BLAS_PATH_UTF8.unwrap().as_ptr(); + } /*else if is_cudnn_dll_utf8(library_name) { + return ZLUDA_DNN_PATH_UTF8.unwrap().as_ptr(); + }*/ else if is_nvcuda_dll_utf8(library_name) { return ZLUDA_PATH_UTF8.unwrap().as_ptr(); } else if is_nvml_dll_utf8(library_name) { return ZLUDA_ML_PATH_UTF8.unwrap().as_ptr(); @@ -237,7 +249,11 @@ unsafe fn get_library_name_utf16(raw_library_name: *const u16) -> *const u16 { } } } - if is_nvcuda_dll_utf16(library_name) { + if is_cublas_dll_utf16(library_name) { + return ZLUDA_BLAS_PATH_UTF16.as_ptr(); + } /*else if is_cudnn_dll_utf16(library_name) { + return ZLUDA_DNN_PATH_UTF16.as_ptr(); + }*/ else if is_nvcuda_dll_utf16(library_name) { return ZLUDA_PATH_UTF16.as_ptr(); } else if is_nvml_dll_utf16(library_name) { return ZLUDA_ML_PATH_UTF16.as_ptr(); @@ -313,6 +329,22 @@ unsafe fn is_driverstore_utf16(lib: &[u16]) -> bool { starts_with_ignore_case(lib, &DRIVERSTORE_UTF16, utf16_to_ascii_uppercase) } +fn is_cublas_dll_utf8(lib: &[u8]) -> bool { + is_dll_utf8(lib, CUBLAS_UTF8.as_bytes()) +} + +fn is_cublas_dll_utf16(lib: &[u16]) -> bool { + is_dll_utf16(lib, CUBLAS_UTF16) +} + +fn is_cudnn_dll_utf8(lib: &[u8]) -> bool { + is_dll_utf8(lib, CUDNN_UTF8.as_bytes()) +} + +fn is_cudnn_dll_utf16(lib: &[u16]) -> bool { + is_dll_utf16(lib, CUDNN_UTF16) +} + fn is_nvcuda_dll_utf8(lib: &[u8]) -> bool { is_dll_utf8(lib, NVCUDA1_UTF8.as_bytes()) || is_dll_utf8(lib, NVCUDA2_UTF8.as_bytes()) } diff --git a/zluda_sparse/Cargo.toml b/zluda_sparse/Cargo.toml index 3badf52..55e56d3 100644 --- a/zluda_sparse/Cargo.toml +++ b/zluda_sparse/Cargo.toml @@ -16,6 +16,5 @@ zluda_dark_api = { path = "../zluda_dark_api" } cuda_types = { path = "../cuda_types" } [package.metadata.zluda] -linux_only = true linux_names = ["libcusparse.so.11"] dump_names = ["libcusparse.so"]