Add support of cuBLAS, cuSPARSE for Windows.

This commit is contained in:
Seunghoon Lee 2024-02-15 06:54:21 +09:00
parent 1b9ba2b233
commit 1ef7ef3938
No known key found for this signature in database
GPG key ID: 436E38F4E70BD152
14 changed files with 244 additions and 17 deletions

View file

@ -13,10 +13,15 @@ command = "cargo"
args = [ args = [
"build", "build",
"-p", "offline_compiler", "-p", "offline_compiler",
"-p", "zluda_blas",
"-p", "zluda_ccl",
"-p", "zluda_dnn",
"-p", "zluda_dump", "-p", "zluda_dump",
"-p", "zluda_inject", "-p", "zluda_inject",
"-p", "zluda_fft",
"-p", "zluda_lib", "-p", "zluda_lib",
"-p", "zluda_ml", "-p", "zluda_ml",
"-p", "zluda_sparse",
"-p", "zluda_redirect", "-p", "zluda_redirect",
] ]

View file

@ -1,4 +1,10 @@
fn main() { use std::env::VarError;
use std::{env, path::PathBuf};
fn main() -> Result<(), VarError> {
println!("cargo:rustc-link-lib=dylib=hipfft"); println!("cargo:rustc-link-lib=dylib=hipfft");
println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); let mut path = PathBuf::from(env::var("HIP_PATH")?);
path.push("lib");
println!("cargo:rustc-link-search=native={}", path.display());
Ok(())
} }

View file

@ -1,4 +1,10 @@
fn main() { use std::env::VarError;
use std::{env, path::PathBuf};
fn main() -> Result<(), VarError> {
println!("cargo:rustc-link-lib=dylib=rocblas"); println!("cargo:rustc-link-lib=dylib=rocblas");
println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); let mut path = PathBuf::from(env::var("HIP_PATH")?);
path.push("lib");
println!("cargo:rustc-link-search=native={}", path.display());
Ok(())
} }

View file

@ -1,4 +1,10 @@
fn main() { use std::env::VarError;
use std::{env, path::PathBuf};
fn main() -> Result<(), VarError> {
println!("cargo:rustc-link-lib=dylib=rocsolver"); println!("cargo:rustc-link-lib=dylib=rocsolver");
println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); let mut path = PathBuf::from(env::var("HIP_PATH")?);
path.push("lib");
println!("cargo:rustc-link-search=native={}", path.display());
Ok(())
} }

View file

@ -1,4 +1,10 @@
fn main() { use std::env::VarError;
use std::{env, path::PathBuf};
fn main() -> Result<(), VarError> {
println!("cargo:rustc-link-lib=dylib=rocsparse"); println!("cargo:rustc-link-lib=dylib=rocsparse");
println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); let mut path = PathBuf::from(env::var("HIP_PATH")?);
path.push("lib");
println!("cargo:rustc-link-search=native={}", path.display());
Ok(())
} }

View file

@ -16,6 +16,5 @@ zluda_dark_api = { path = "../zluda_dark_api" }
cuda_types = { path = "../cuda_types" } cuda_types = { path = "../cuda_types" }
[package.metadata.zluda] [package.metadata.zluda]
linux_only = true
linux_names = ["libcublas.so.10", "libcublas.so.11"] linux_names = ["libcublas.so.10", "libcublas.so.11"]
dump_names = ["libcublas.so"] dump_names = ["libcublas.so"]

View file

@ -3955,7 +3955,7 @@ pub unsafe extern "system" fn cublasGemmStridedBatchedEx(
computeType: cublasComputeType_t, computeType: cublasComputeType_t,
algo: cublasGemmAlgo_t, algo: cublasGemmAlgo_t,
) -> cublasStatus_t { ) -> cublasStatus_t {
crate::unsupported() crate::gemm_strided_batched_ex(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo)
} }
#[no_mangle] #[no_mangle]
@ -6345,7 +6345,7 @@ pub unsafe extern "system" fn cublasZhpr2(
unimplemented!() unimplemented!()
} }
#[no_mangle] /*#[no_mangle]
pub unsafe extern "system" fn cublasSgemm( pub unsafe extern "system" fn cublasSgemm(
transa: ::std::os::raw::c_char, transa: ::std::os::raw::c_char,
transb: ::std::os::raw::c_char, transb: ::std::os::raw::c_char,
@ -6362,6 +6362,27 @@ pub unsafe extern "system" fn cublasSgemm(
ldc: ::std::os::raw::c_int, ldc: ::std::os::raw::c_int,
) { ) {
unimplemented!() unimplemented!()
}*/
#[no_mangle]
pub unsafe extern "system" fn cublasSgemm(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: ::std::os::raw::c_int,
n: ::std::os::raw::c_int,
k: ::std::os::raw::c_int,
alpha: *const f32,
A: *const f32,
lda: ::std::os::raw::c_int,
B: *const f32,
ldb: ::std::os::raw::c_int,
beta: *const f32,
C: *mut f32,
ldc: ::std::os::raw::c_int,
) -> cublasStatus_t {
crate::sgemm_v2(
handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc,
)
} }
#[no_mangle] #[no_mangle]

View file

@ -386,6 +386,71 @@ fn to_compute_type(compute_type: cublasComputeType_t) -> rocblas_datatype {
} }
} }
unsafe fn gemm_strided_batched_ex(
handle: *mut cublasContext,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: i32,
n: i32,
k: i32,
alpha: *const std::ffi::c_void,
a: *const std::ffi::c_void,
atype: cudaDataType_t,
lda: i32,
stride_a: i64,
b: *const std::ffi::c_void,
btype: cudaDataType_t,
ldb: i32,
stride_b: i64,
beta: *const std::ffi::c_void,
c: *mut std::ffi::c_void,
ctype: cudaDataType_t,
ldc: i32,
stride_c: i64,
batch_count: i32,
compute_type: cublasComputeType_t,
algo: cublasGemmAlgo_t,
) -> cublasStatus_t {
let transa = op_from_cuda(transa);
let transb = op_from_cuda(transb);
let atype = type_from_cuda(atype);
let btype = type_from_cuda(btype);
let ctype = type_from_cuda(ctype);
let compute_type = to_compute_type(compute_type);
let algo = to_algo(algo);
to_cuda(rocblas_gemm_strided_batched_ex(
handle.cast(),
transa,
transb,
m,
n,
k,
alpha,
a,
atype,
lda,
stride_a,
b,
btype,
ldb,
stride_b,
beta,
c,
ctype,
ldc,
stride_c,
c,
ctype,
ldc,
stride_c,
batch_count,
compute_type,
algo,
0,
0,
))
}
unsafe fn zgemm_strided_batch( unsafe fn zgemm_strided_batch(
handle: *mut cublasContext, handle: *mut cublasContext,
transa: cublasOperation_t, transa: cublasOperation_t,

View file

@ -11,6 +11,5 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
[package.metadata.zluda] [package.metadata.zluda]
linux_only = true
linux_names = ["libnccl.so.2"] linux_names = ["libnccl.so.2"]
dump_names = ["libnccl.so"] dump_names = ["libnccl.so"]

View file

@ -13,6 +13,5 @@ miopen-sys = { path = "../miopen-sys" }
hip_runtime-sys = { path = "../hip_runtime-sys" } hip_runtime-sys = { path = "../hip_runtime-sys" }
[package.metadata.zluda] [package.metadata.zluda]
linux_only = true
linux_names = ["libcudnn.so.7", "libcudnn.so.8"] linux_names = ["libcudnn.so.7", "libcudnn.so.8"]
dump_names = ["libcudnn.so"] dump_names = ["libcudnn.so"]

View file

@ -17,6 +17,5 @@ slab = "0.4"
lazy_static = "1.4.0" lazy_static = "1.4.0"
[package.metadata.zluda] [package.metadata.zluda]
linux_only = true
linux_names = ["libcufft.so.10"] linux_names = ["libcufft.so.10"]
dump_names = ["libcufft.so"] dump_names = ["libcufft.so"]

View file

@ -23,6 +23,11 @@ use winapi::um::{
use winapi::um::winbase::{INFINITE, WAIT_FAILED}; use winapi::um::winbase::{INFINITE, WAIT_FAILED};
static REDIRECT_DLL: &'static str = "zluda_redirect.dll"; static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
static CUBLAS_DLL: &'static str = "cublas.dll";
static CUDNN_DLL: &'static str = "cudnn.dll";
static CUFFT_DLL: &'static str = "cufft.dll";
static CUSPARSE_DLL: &'static str = "cusparse.dll";
static NCCL_DLL: &'static str = "nccl.dll";
static NVCUDA_DLL: &'static str = "nvcuda.dll"; static NVCUDA_DLL: &'static str = "nvcuda.dll";
static NVML_DLL: &'static str = "nvml.dll"; static NVML_DLL: &'static str = "nvml.dll";
static NVAPI_DLL: &'static str = "nvapi64.dll"; static NVAPI_DLL: &'static str = "nvapi64.dll";
@ -33,6 +38,26 @@ include!("../../zluda_redirect/src/payload_guid.rs");
#[derive(FromArgs)] #[derive(FromArgs)]
/// Launch application with custom CUDA libraries /// Launch application with custom CUDA libraries
struct ProgramArguments { struct ProgramArguments {
/// DLL to be injected instead of system cublas.dll. If not provided {0}, will use cublas.dll from its own directory
#[argh(option)]
cublas: Option<PathBuf>,
/// DLL to be injected instead of system cudnn.dll. If not provided {0}, will use cudnn.dll from its own directory
#[argh(option)]
cudnn: Option<PathBuf>,
/// DLL to be injected instead of system cufft.dll. If not provided {0}, will use cufft.dll from its own directory
#[argh(option)]
cufft: Option<PathBuf>,
/// DLL to be injected instead of system cusparse.dll. If not provided {0}, will use cusparse.dll from its own directory
#[argh(option)]
cusparse: Option<PathBuf>,
/// DLL to be injected instead of system nccl.dll. If not provided {0}, will use nccl.dll from its own directory
#[argh(option)]
nccl: Option<PathBuf>,
/// DLL to be injected instead of system nvcuda.dll. If not provided {0}, will use nvcuda.dll from its own directory /// DLL to be injected instead of system nvcuda.dll. If not provided {0}, will use nvcuda.dll from its own directory
#[argh(option)] #[argh(option)]
nvcuda: Option<PathBuf>, nvcuda: Option<PathBuf>,
@ -65,6 +90,11 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() }; let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() }; let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
let mut dlls_to_inject = vec![ let mut dlls_to_inject = vec![
environment.cublas_path_zero_terminated.as_ptr() as _,
//environment.cudnn_path_zero_terminated.as_ptr() as _,
environment.cufft_path_zero_terminated.as_ptr() as _,
environment.cusparse_path_zero_terminated.as_ptr() as _,
environment.nccl_path_zero_terminated.as_ptr() as _,
environment.nvcuda_path_zero_terminated.as_ptr() as _, environment.nvcuda_path_zero_terminated.as_ptr() as _,
environment.nvml_path_zero_terminated.as_ptr() as *const i8, environment.nvml_path_zero_terminated.as_ptr() as *const i8,
environment.redirect_path_zero_terminated.as_ptr() as _, environment.redirect_path_zero_terminated.as_ptr() as _,
@ -146,6 +176,11 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
} }
struct NormalizedArguments { struct NormalizedArguments {
cublas_path: PathBuf,
cudnn_path: PathBuf,
cufft_path: PathBuf,
cusparse_path: PathBuf,
nccl_path: PathBuf,
nvcuda_path: PathBuf, nvcuda_path: PathBuf,
nvml_path: PathBuf, nvml_path: PathBuf,
nvapi_path: Option<PathBuf>, nvapi_path: Option<PathBuf>,
@ -157,6 +192,16 @@ struct NormalizedArguments {
impl NormalizedArguments { impl NormalizedArguments {
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> { fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
let current_exe = env::current_exe()?; let current_exe = env::current_exe()?;
let cublas_path =
Self::get_absolute_path_or_default(&current_exe, prog_args.cublas, CUBLAS_DLL)?;
let cudnn_path =
Self::get_absolute_path_or_default(&current_exe, prog_args.cudnn, CUDNN_DLL)?;
let cufft_path =
Self::get_absolute_path_or_default(&current_exe, prog_args.cufft, CUFFT_DLL)?;
let cusparse_path =
Self::get_absolute_path_or_default(&current_exe, prog_args.cusparse, CUSPARSE_DLL)?;
let nccl_path =
Self::get_absolute_path_or_default(&current_exe, prog_args.nccl, NCCL_DLL)?;
let nvcuda_path = let nvcuda_path =
Self::get_absolute_path_or_default(&current_exe, prog_args.nvcuda, NVCUDA_DLL)?; Self::get_absolute_path_or_default(&current_exe, prog_args.nvcuda, NVCUDA_DLL)?;
let nvml_path = Self::get_absolute_path_or_default(&current_exe, prog_args.nvml, NVML_DLL)?; let nvml_path = Self::get_absolute_path_or_default(&current_exe, prog_args.nvml, NVML_DLL)?;
@ -167,6 +212,11 @@ impl NormalizedArguments {
let mut redirect_path = current_exe.parent().unwrap().to_path_buf(); let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
redirect_path.push(REDIRECT_DLL); redirect_path.push(REDIRECT_DLL);
Ok(Self { Ok(Self {
cublas_path,
cudnn_path,
cufft_path,
cusparse_path,
nccl_path,
nvcuda_path, nvcuda_path,
nvml_path, nvml_path,
nvapi_path, nvapi_path,
@ -224,6 +274,11 @@ impl NormalizedArguments {
} }
struct Environment { struct Environment {
cublas_path_zero_terminated: String,
cudnn_path_zero_terminated: String,
cufft_path_zero_terminated: String,
cusparse_path_zero_terminated: String,
nccl_path_zero_terminated: String,
nvcuda_path_zero_terminated: String, nvcuda_path_zero_terminated: String,
nvml_path_zero_terminated: String, nvml_path_zero_terminated: String,
nvapi_path_zero_terminated: Option<String>, nvapi_path_zero_terminated: Option<String>,
@ -239,6 +294,31 @@ struct Environment {
impl Environment { impl Environment {
fn setup(args: NormalizedArguments) -> io::Result<Self> { fn setup(args: NormalizedArguments) -> io::Result<Self> {
let _temp_dir = TempDir::new()?; let _temp_dir = TempDir::new()?;
let cublas_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
args.cublas_path,
&_temp_dir,
CUBLAS_DLL,
)?);
let cudnn_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
args.cudnn_path,
&_temp_dir,
CUDNN_DLL,
)?);
let cufft_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
args.cufft_path,
&_temp_dir,
CUFFT_DLL,
)?);
let cusparse_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
args.cusparse_path,
&_temp_dir,
CUSPARSE_DLL,
)?);
let nccl_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
args.nccl_path,
&_temp_dir,
NCCL_DLL,
)?);
let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name( let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
args.nvcuda_path, args.nvcuda_path,
&_temp_dir, &_temp_dir,
@ -269,6 +349,11 @@ impl Environment {
.transpose()?; .transpose()?;
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path); let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
Ok(Self { Ok(Self {
cublas_path_zero_terminated,
cudnn_path_zero_terminated,
cufft_path_zero_terminated,
cusparse_path_zero_terminated,
nccl_path_zero_terminated,
nvcuda_path_zero_terminated, nvcuda_path_zero_terminated,
nvml_path_zero_terminated, nvml_path_zero_terminated,
nvapi_path_zero_terminated, nvapi_path_zero_terminated,

View file

@ -52,6 +52,10 @@ use winapi::{
include!("payload_guid.rs"); include!("payload_guid.rs");
const WIN_MAX_PATH: usize = 260; const WIN_MAX_PATH: usize = 260;
const CUBLAS_UTF8: &'static str = "CUBLAS.DLL";
const CUBLAS_UTF16: &[u16] = wch!("CUBLAS.DLL");
const CUDNN_UTF8: &'static str = "CUDNN.DLL";
const CUDNN_UTF16: &[u16] = wch!("CUDNN.DLL");
const NVCUDA1_UTF8: &'static str = "NVCUDA.DLL"; const NVCUDA1_UTF8: &'static str = "NVCUDA.DLL";
const NVCUDA1_UTF16: &[u16] = wch!("NVCUDA.DLL"); const NVCUDA1_UTF16: &[u16] = wch!("NVCUDA.DLL");
const NVCUDA2_UTF8: &'static str = "NVCUDA.DLL"; const NVCUDA2_UTF8: &'static str = "NVCUDA.DLL";
@ -64,6 +68,10 @@ const NVOPTIX_UTF8: &'static str = "OPTIX.6.6.0.DLL";
const NVOPTIX_UTF16: &[u16] = wch!("OPTIX.6.6.0.DLL"); const NVOPTIX_UTF16: &[u16] = wch!("OPTIX.6.6.0.DLL");
static mut ZLUDA_PATH_UTF8: Option<&'static [u8]> = None; static mut ZLUDA_PATH_UTF8: Option<&'static [u8]> = None;
static mut ZLUDA_PATH_UTF16: Vec<u16> = Vec::new(); static mut ZLUDA_PATH_UTF16: Vec<u16> = Vec::new();
static mut ZLUDA_BLAS_PATH_UTF8: Option<&'static [u8]> = None;
static mut ZLUDA_BLAS_PATH_UTF16: Vec<u16> = Vec::new();
static mut ZLUDA_DNN_PATH_UTF8: Option<&'static [u8]> = None;
static mut ZLUDA_DNN_PATH_UTF16: Vec<u16> = Vec::new();
static mut ZLUDA_ML_PATH_UTF8: Option<&'static [u8]> = None; static mut ZLUDA_ML_PATH_UTF8: Option<&'static [u8]> = None;
static mut ZLUDA_ML_PATH_UTF16: Vec<u16> = Vec::new(); static mut ZLUDA_ML_PATH_UTF16: Vec<u16> = Vec::new();
static mut ZLUDA_API_PATH_UTF8: Option<&'static [u8]> = None; static mut ZLUDA_API_PATH_UTF8: Option<&'static [u8]> = None;
@ -199,7 +207,11 @@ unsafe fn get_library_name_utf8(raw_library_name: *const u8) -> *const u8 {
} }
} }
} }
if is_nvcuda_dll_utf8(library_name) { if is_cublas_dll_utf8(library_name) {
return ZLUDA_BLAS_PATH_UTF8.unwrap().as_ptr();
} /*else if is_cudnn_dll_utf8(library_name) {
return ZLUDA_DNN_PATH_UTF8.unwrap().as_ptr();
}*/ else if is_nvcuda_dll_utf8(library_name) {
return ZLUDA_PATH_UTF8.unwrap().as_ptr(); return ZLUDA_PATH_UTF8.unwrap().as_ptr();
} else if is_nvml_dll_utf8(library_name) { } else if is_nvml_dll_utf8(library_name) {
return ZLUDA_ML_PATH_UTF8.unwrap().as_ptr(); return ZLUDA_ML_PATH_UTF8.unwrap().as_ptr();
@ -237,7 +249,11 @@ unsafe fn get_library_name_utf16(raw_library_name: *const u16) -> *const u16 {
} }
} }
} }
if is_nvcuda_dll_utf16(library_name) { if is_cublas_dll_utf16(library_name) {
return ZLUDA_BLAS_PATH_UTF16.as_ptr();
} /*else if is_cudnn_dll_utf16(library_name) {
return ZLUDA_DNN_PATH_UTF16.as_ptr();
}*/ else if is_nvcuda_dll_utf16(library_name) {
return ZLUDA_PATH_UTF16.as_ptr(); return ZLUDA_PATH_UTF16.as_ptr();
} else if is_nvml_dll_utf16(library_name) { } else if is_nvml_dll_utf16(library_name) {
return ZLUDA_ML_PATH_UTF16.as_ptr(); return ZLUDA_ML_PATH_UTF16.as_ptr();
@ -313,6 +329,22 @@ unsafe fn is_driverstore_utf16(lib: &[u16]) -> bool {
starts_with_ignore_case(lib, &DRIVERSTORE_UTF16, utf16_to_ascii_uppercase) starts_with_ignore_case(lib, &DRIVERSTORE_UTF16, utf16_to_ascii_uppercase)
} }
fn is_cublas_dll_utf8(lib: &[u8]) -> bool {
is_dll_utf8(lib, CUBLAS_UTF8.as_bytes())
}
fn is_cublas_dll_utf16(lib: &[u16]) -> bool {
is_dll_utf16(lib, CUBLAS_UTF16)
}
fn is_cudnn_dll_utf8(lib: &[u8]) -> bool {
is_dll_utf8(lib, CUDNN_UTF8.as_bytes())
}
fn is_cudnn_dll_utf16(lib: &[u16]) -> bool {
is_dll_utf16(lib, CUDNN_UTF16)
}
fn is_nvcuda_dll_utf8(lib: &[u8]) -> bool { fn is_nvcuda_dll_utf8(lib: &[u8]) -> bool {
is_dll_utf8(lib, NVCUDA1_UTF8.as_bytes()) || is_dll_utf8(lib, NVCUDA2_UTF8.as_bytes()) is_dll_utf8(lib, NVCUDA1_UTF8.as_bytes()) || is_dll_utf8(lib, NVCUDA2_UTF8.as_bytes())
} }

View file

@ -16,6 +16,5 @@ zluda_dark_api = { path = "../zluda_dark_api" }
cuda_types = { path = "../cuda_types" } cuda_types = { path = "../cuda_types" }
[package.metadata.zluda] [package.metadata.zluda]
linux_only = true
linux_names = ["libcusparse.so.11"] linux_names = ["libcusparse.so.11"]
dump_names = ["libcusparse.so"] dump_names = ["libcusparse.so"]