Add support of cuBLAS, cuSPARSE for Windows.
This commit is contained in:
parent
1b9ba2b233
commit
1ef7ef3938
14 changed files with 244 additions and 17 deletions
|
@ -13,10 +13,15 @@ command = "cargo"
|
|||
args = [
|
||||
"build",
|
||||
"-p", "offline_compiler",
|
||||
"-p", "zluda_blas",
|
||||
"-p", "zluda_ccl",
|
||||
"-p", "zluda_dnn",
|
||||
"-p", "zluda_dump",
|
||||
"-p", "zluda_inject",
|
||||
"-p", "zluda_fft",
|
||||
"-p", "zluda_lib",
|
||||
"-p", "zluda_ml",
|
||||
"-p", "zluda_sparse",
|
||||
"-p", "zluda_redirect",
|
||||
]
|
||||
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
fn main() {
|
||||
use std::env::VarError;
|
||||
use std::{env, path::PathBuf};
|
||||
|
||||
fn main() -> Result<(), VarError> {
|
||||
println!("cargo:rustc-link-lib=dylib=hipfft");
|
||||
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
|
||||
let mut path = PathBuf::from(env::var("HIP_PATH")?);
|
||||
path.push("lib");
|
||||
println!("cargo:rustc-link-search=native={}", path.display());
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
fn main() {
|
||||
use std::env::VarError;
|
||||
use std::{env, path::PathBuf};
|
||||
|
||||
fn main() -> Result<(), VarError> {
|
||||
println!("cargo:rustc-link-lib=dylib=rocblas");
|
||||
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
|
||||
let mut path = PathBuf::from(env::var("HIP_PATH")?);
|
||||
path.push("lib");
|
||||
println!("cargo:rustc-link-search=native={}", path.display());
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
fn main() {
|
||||
use std::env::VarError;
|
||||
use std::{env, path::PathBuf};
|
||||
|
||||
fn main() -> Result<(), VarError> {
|
||||
println!("cargo:rustc-link-lib=dylib=rocsolver");
|
||||
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
|
||||
let mut path = PathBuf::from(env::var("HIP_PATH")?);
|
||||
path.push("lib");
|
||||
println!("cargo:rustc-link-search=native={}", path.display());
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
fn main() {
|
||||
use std::env::VarError;
|
||||
use std::{env, path::PathBuf};
|
||||
|
||||
fn main() -> Result<(), VarError> {
|
||||
println!("cargo:rustc-link-lib=dylib=rocsparse");
|
||||
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
|
||||
let mut path = PathBuf::from(env::var("HIP_PATH")?);
|
||||
path.push("lib");
|
||||
println!("cargo:rustc-link-search=native={}", path.display());
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -16,6 +16,5 @@ zluda_dark_api = { path = "../zluda_dark_api" }
|
|||
cuda_types = { path = "../cuda_types" }
|
||||
|
||||
[package.metadata.zluda]
|
||||
linux_only = true
|
||||
linux_names = ["libcublas.so.10", "libcublas.so.11"]
|
||||
dump_names = ["libcublas.so"]
|
||||
|
|
|
@ -3955,7 +3955,7 @@ pub unsafe extern "system" fn cublasGemmStridedBatchedEx(
|
|||
computeType: cublasComputeType_t,
|
||||
algo: cublasGemmAlgo_t,
|
||||
) -> cublasStatus_t {
|
||||
crate::unsupported()
|
||||
crate::gemm_strided_batched_ex(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
@ -6345,7 +6345,7 @@ pub unsafe extern "system" fn cublasZhpr2(
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
/*#[no_mangle]
|
||||
pub unsafe extern "system" fn cublasSgemm(
|
||||
transa: ::std::os::raw::c_char,
|
||||
transb: ::std::os::raw::c_char,
|
||||
|
@ -6362,6 +6362,27 @@ pub unsafe extern "system" fn cublasSgemm(
|
|||
ldc: ::std::os::raw::c_int,
|
||||
) {
|
||||
unimplemented!()
|
||||
}*/
|
||||
#[no_mangle]
|
||||
pub unsafe extern "system" fn cublasSgemm(
|
||||
handle: cublasHandle_t,
|
||||
transa: cublasOperation_t,
|
||||
transb: cublasOperation_t,
|
||||
m: ::std::os::raw::c_int,
|
||||
n: ::std::os::raw::c_int,
|
||||
k: ::std::os::raw::c_int,
|
||||
alpha: *const f32,
|
||||
A: *const f32,
|
||||
lda: ::std::os::raw::c_int,
|
||||
B: *const f32,
|
||||
ldb: ::std::os::raw::c_int,
|
||||
beta: *const f32,
|
||||
C: *mut f32,
|
||||
ldc: ::std::os::raw::c_int,
|
||||
) -> cublasStatus_t {
|
||||
crate::sgemm_v2(
|
||||
handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc,
|
||||
)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
|
|
@ -386,6 +386,71 @@ fn to_compute_type(compute_type: cublasComputeType_t) -> rocblas_datatype {
|
|||
}
|
||||
}
|
||||
|
||||
unsafe fn gemm_strided_batched_ex(
|
||||
handle: *mut cublasContext,
|
||||
transa: cublasOperation_t,
|
||||
transb: cublasOperation_t,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const std::ffi::c_void,
|
||||
a: *const std::ffi::c_void,
|
||||
atype: cudaDataType_t,
|
||||
lda: i32,
|
||||
stride_a: i64,
|
||||
b: *const std::ffi::c_void,
|
||||
btype: cudaDataType_t,
|
||||
ldb: i32,
|
||||
stride_b: i64,
|
||||
beta: *const std::ffi::c_void,
|
||||
c: *mut std::ffi::c_void,
|
||||
ctype: cudaDataType_t,
|
||||
ldc: i32,
|
||||
stride_c: i64,
|
||||
batch_count: i32,
|
||||
compute_type: cublasComputeType_t,
|
||||
algo: cublasGemmAlgo_t,
|
||||
) -> cublasStatus_t {
|
||||
let transa = op_from_cuda(transa);
|
||||
let transb = op_from_cuda(transb);
|
||||
let atype = type_from_cuda(atype);
|
||||
let btype = type_from_cuda(btype);
|
||||
let ctype = type_from_cuda(ctype);
|
||||
let compute_type = to_compute_type(compute_type);
|
||||
let algo = to_algo(algo);
|
||||
to_cuda(rocblas_gemm_strided_batched_ex(
|
||||
handle.cast(),
|
||||
transa,
|
||||
transb,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
a,
|
||||
atype,
|
||||
lda,
|
||||
stride_a,
|
||||
b,
|
||||
btype,
|
||||
ldb,
|
||||
stride_b,
|
||||
beta,
|
||||
c,
|
||||
ctype,
|
||||
ldc,
|
||||
stride_c,
|
||||
c,
|
||||
ctype,
|
||||
ldc,
|
||||
stride_c,
|
||||
batch_count,
|
||||
compute_type,
|
||||
algo,
|
||||
0,
|
||||
0,
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn zgemm_strided_batch(
|
||||
handle: *mut cublasContext,
|
||||
transa: cublasOperation_t,
|
||||
|
|
|
@ -11,6 +11,5 @@ crate-type = ["cdylib"]
|
|||
[dependencies]
|
||||
|
||||
[package.metadata.zluda]
|
||||
linux_only = true
|
||||
linux_names = ["libnccl.so.2"]
|
||||
dump_names = ["libnccl.so"]
|
||||
|
|
|
@ -13,6 +13,5 @@ miopen-sys = { path = "../miopen-sys" }
|
|||
hip_runtime-sys = { path = "../hip_runtime-sys" }
|
||||
|
||||
[package.metadata.zluda]
|
||||
linux_only = true
|
||||
linux_names = ["libcudnn.so.7", "libcudnn.so.8"]
|
||||
dump_names = ["libcudnn.so"]
|
||||
|
|
|
@ -17,6 +17,5 @@ slab = "0.4"
|
|||
lazy_static = "1.4.0"
|
||||
|
||||
[package.metadata.zluda]
|
||||
linux_only = true
|
||||
linux_names = ["libcufft.so.10"]
|
||||
dump_names = ["libcufft.so"]
|
||||
|
|
|
@ -23,6 +23,11 @@ use winapi::um::{
|
|||
use winapi::um::winbase::{INFINITE, WAIT_FAILED};
|
||||
|
||||
static REDIRECT_DLL: &'static str = "zluda_redirect.dll";
|
||||
static CUBLAS_DLL: &'static str = "cublas.dll";
|
||||
static CUDNN_DLL: &'static str = "cudnn.dll";
|
||||
static CUFFT_DLL: &'static str = "cufft.dll";
|
||||
static CUSPARSE_DLL: &'static str = "cusparse.dll";
|
||||
static NCCL_DLL: &'static str = "nccl.dll";
|
||||
static NVCUDA_DLL: &'static str = "nvcuda.dll";
|
||||
static NVML_DLL: &'static str = "nvml.dll";
|
||||
static NVAPI_DLL: &'static str = "nvapi64.dll";
|
||||
|
@ -33,6 +38,26 @@ include!("../../zluda_redirect/src/payload_guid.rs");
|
|||
#[derive(FromArgs)]
|
||||
/// Launch application with custom CUDA libraries
|
||||
struct ProgramArguments {
|
||||
/// DLL to be injected instead of system cublas.dll. If not provided {0}, will use cublas.dll from its own directory
|
||||
#[argh(option)]
|
||||
cublas: Option<PathBuf>,
|
||||
|
||||
/// DLL to be injected instead of system cudnn.dll. If not provided {0}, will use cudnn.dll from its own directory
|
||||
#[argh(option)]
|
||||
cudnn: Option<PathBuf>,
|
||||
|
||||
/// DLL to be injected instead of system cufft.dll. If not provided {0}, will use cufft.dll from its own directory
|
||||
#[argh(option)]
|
||||
cufft: Option<PathBuf>,
|
||||
|
||||
/// DLL to be injected instead of system cusparse.dll. If not provided {0}, will use cusparse.dll from its own directory
|
||||
#[argh(option)]
|
||||
cusparse: Option<PathBuf>,
|
||||
|
||||
/// DLL to be injected instead of system nccl.dll. If not provided {0}, will use nccl.dll from its own directory
|
||||
#[argh(option)]
|
||||
nccl: Option<PathBuf>,
|
||||
|
||||
/// DLL to be injected instead of system nvcuda.dll. If not provided {0}, will use nvcuda.dll from its own directory
|
||||
#[argh(option)]
|
||||
nvcuda: Option<PathBuf>,
|
||||
|
@ -65,6 +90,11 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||
let mut startup_info = unsafe { mem::zeroed::<detours_sys::_STARTUPINFOW>() };
|
||||
let mut proc_info = unsafe { mem::zeroed::<detours_sys::_PROCESS_INFORMATION>() };
|
||||
let mut dlls_to_inject = vec![
|
||||
environment.cublas_path_zero_terminated.as_ptr() as _,
|
||||
//environment.cudnn_path_zero_terminated.as_ptr() as _,
|
||||
environment.cufft_path_zero_terminated.as_ptr() as _,
|
||||
environment.cusparse_path_zero_terminated.as_ptr() as _,
|
||||
environment.nccl_path_zero_terminated.as_ptr() as _,
|
||||
environment.nvcuda_path_zero_terminated.as_ptr() as _,
|
||||
environment.nvml_path_zero_terminated.as_ptr() as *const i8,
|
||||
environment.redirect_path_zero_terminated.as_ptr() as _,
|
||||
|
@ -146,6 +176,11 @@ pub fn main_impl() -> Result<(), Box<dyn Error>> {
|
|||
}
|
||||
|
||||
struct NormalizedArguments {
|
||||
cublas_path: PathBuf,
|
||||
cudnn_path: PathBuf,
|
||||
cufft_path: PathBuf,
|
||||
cusparse_path: PathBuf,
|
||||
nccl_path: PathBuf,
|
||||
nvcuda_path: PathBuf,
|
||||
nvml_path: PathBuf,
|
||||
nvapi_path: Option<PathBuf>,
|
||||
|
@ -157,6 +192,16 @@ struct NormalizedArguments {
|
|||
impl NormalizedArguments {
|
||||
fn new(prog_args: ProgramArguments) -> Result<Self, Box<dyn Error>> {
|
||||
let current_exe = env::current_exe()?;
|
||||
let cublas_path =
|
||||
Self::get_absolute_path_or_default(¤t_exe, prog_args.cublas, CUBLAS_DLL)?;
|
||||
let cudnn_path =
|
||||
Self::get_absolute_path_or_default(¤t_exe, prog_args.cudnn, CUDNN_DLL)?;
|
||||
let cufft_path =
|
||||
Self::get_absolute_path_or_default(¤t_exe, prog_args.cufft, CUFFT_DLL)?;
|
||||
let cusparse_path =
|
||||
Self::get_absolute_path_or_default(¤t_exe, prog_args.cusparse, CUSPARSE_DLL)?;
|
||||
let nccl_path =
|
||||
Self::get_absolute_path_or_default(¤t_exe, prog_args.nccl, NCCL_DLL)?;
|
||||
let nvcuda_path =
|
||||
Self::get_absolute_path_or_default(¤t_exe, prog_args.nvcuda, NVCUDA_DLL)?;
|
||||
let nvml_path = Self::get_absolute_path_or_default(¤t_exe, prog_args.nvml, NVML_DLL)?;
|
||||
|
@ -167,6 +212,11 @@ impl NormalizedArguments {
|
|||
let mut redirect_path = current_exe.parent().unwrap().to_path_buf();
|
||||
redirect_path.push(REDIRECT_DLL);
|
||||
Ok(Self {
|
||||
cublas_path,
|
||||
cudnn_path,
|
||||
cufft_path,
|
||||
cusparse_path,
|
||||
nccl_path,
|
||||
nvcuda_path,
|
||||
nvml_path,
|
||||
nvapi_path,
|
||||
|
@ -224,6 +274,11 @@ impl NormalizedArguments {
|
|||
}
|
||||
|
||||
struct Environment {
|
||||
cublas_path_zero_terminated: String,
|
||||
cudnn_path_zero_terminated: String,
|
||||
cufft_path_zero_terminated: String,
|
||||
cusparse_path_zero_terminated: String,
|
||||
nccl_path_zero_terminated: String,
|
||||
nvcuda_path_zero_terminated: String,
|
||||
nvml_path_zero_terminated: String,
|
||||
nvapi_path_zero_terminated: Option<String>,
|
||||
|
@ -239,6 +294,31 @@ struct Environment {
|
|||
impl Environment {
|
||||
fn setup(args: NormalizedArguments) -> io::Result<Self> {
|
||||
let _temp_dir = TempDir::new()?;
|
||||
let cublas_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||
args.cublas_path,
|
||||
&_temp_dir,
|
||||
CUBLAS_DLL,
|
||||
)?);
|
||||
let cudnn_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||
args.cudnn_path,
|
||||
&_temp_dir,
|
||||
CUDNN_DLL,
|
||||
)?);
|
||||
let cufft_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||
args.cufft_path,
|
||||
&_temp_dir,
|
||||
CUFFT_DLL,
|
||||
)?);
|
||||
let cusparse_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||
args.cusparse_path,
|
||||
&_temp_dir,
|
||||
CUSPARSE_DLL,
|
||||
)?);
|
||||
let nccl_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||
args.nccl_path,
|
||||
&_temp_dir,
|
||||
NCCL_DLL,
|
||||
)?);
|
||||
let nvcuda_path_zero_terminated = Self::zero_terminate(Self::copy_to_correct_name(
|
||||
args.nvcuda_path,
|
||||
&_temp_dir,
|
||||
|
@ -269,6 +349,11 @@ impl Environment {
|
|||
.transpose()?;
|
||||
let redirect_path_zero_terminated = Self::zero_terminate(args.redirect_path);
|
||||
Ok(Self {
|
||||
cublas_path_zero_terminated,
|
||||
cudnn_path_zero_terminated,
|
||||
cufft_path_zero_terminated,
|
||||
cusparse_path_zero_terminated,
|
||||
nccl_path_zero_terminated,
|
||||
nvcuda_path_zero_terminated,
|
||||
nvml_path_zero_terminated,
|
||||
nvapi_path_zero_terminated,
|
||||
|
|
|
@ -52,6 +52,10 @@ use winapi::{
|
|||
include!("payload_guid.rs");
|
||||
|
||||
const WIN_MAX_PATH: usize = 260;
|
||||
const CUBLAS_UTF8: &'static str = "CUBLAS.DLL";
|
||||
const CUBLAS_UTF16: &[u16] = wch!("CUBLAS.DLL");
|
||||
const CUDNN_UTF8: &'static str = "CUDNN.DLL";
|
||||
const CUDNN_UTF16: &[u16] = wch!("CUDNN.DLL");
|
||||
const NVCUDA1_UTF8: &'static str = "NVCUDA.DLL";
|
||||
const NVCUDA1_UTF16: &[u16] = wch!("NVCUDA.DLL");
|
||||
const NVCUDA2_UTF8: &'static str = "NVCUDA.DLL";
|
||||
|
@ -64,6 +68,10 @@ const NVOPTIX_UTF8: &'static str = "OPTIX.6.6.0.DLL";
|
|||
const NVOPTIX_UTF16: &[u16] = wch!("OPTIX.6.6.0.DLL");
|
||||
static mut ZLUDA_PATH_UTF8: Option<&'static [u8]> = None;
|
||||
static mut ZLUDA_PATH_UTF16: Vec<u16> = Vec::new();
|
||||
static mut ZLUDA_BLAS_PATH_UTF8: Option<&'static [u8]> = None;
|
||||
static mut ZLUDA_BLAS_PATH_UTF16: Vec<u16> = Vec::new();
|
||||
static mut ZLUDA_DNN_PATH_UTF8: Option<&'static [u8]> = None;
|
||||
static mut ZLUDA_DNN_PATH_UTF16: Vec<u16> = Vec::new();
|
||||
static mut ZLUDA_ML_PATH_UTF8: Option<&'static [u8]> = None;
|
||||
static mut ZLUDA_ML_PATH_UTF16: Vec<u16> = Vec::new();
|
||||
static mut ZLUDA_API_PATH_UTF8: Option<&'static [u8]> = None;
|
||||
|
@ -199,7 +207,11 @@ unsafe fn get_library_name_utf8(raw_library_name: *const u8) -> *const u8 {
|
|||
}
|
||||
}
|
||||
}
|
||||
if is_nvcuda_dll_utf8(library_name) {
|
||||
if is_cublas_dll_utf8(library_name) {
|
||||
return ZLUDA_BLAS_PATH_UTF8.unwrap().as_ptr();
|
||||
} /*else if is_cudnn_dll_utf8(library_name) {
|
||||
return ZLUDA_DNN_PATH_UTF8.unwrap().as_ptr();
|
||||
}*/ else if is_nvcuda_dll_utf8(library_name) {
|
||||
return ZLUDA_PATH_UTF8.unwrap().as_ptr();
|
||||
} else if is_nvml_dll_utf8(library_name) {
|
||||
return ZLUDA_ML_PATH_UTF8.unwrap().as_ptr();
|
||||
|
@ -237,7 +249,11 @@ unsafe fn get_library_name_utf16(raw_library_name: *const u16) -> *const u16 {
|
|||
}
|
||||
}
|
||||
}
|
||||
if is_nvcuda_dll_utf16(library_name) {
|
||||
if is_cublas_dll_utf16(library_name) {
|
||||
return ZLUDA_BLAS_PATH_UTF16.as_ptr();
|
||||
} /*else if is_cudnn_dll_utf16(library_name) {
|
||||
return ZLUDA_DNN_PATH_UTF16.as_ptr();
|
||||
}*/ else if is_nvcuda_dll_utf16(library_name) {
|
||||
return ZLUDA_PATH_UTF16.as_ptr();
|
||||
} else if is_nvml_dll_utf16(library_name) {
|
||||
return ZLUDA_ML_PATH_UTF16.as_ptr();
|
||||
|
@ -313,6 +329,22 @@ unsafe fn is_driverstore_utf16(lib: &[u16]) -> bool {
|
|||
starts_with_ignore_case(lib, &DRIVERSTORE_UTF16, utf16_to_ascii_uppercase)
|
||||
}
|
||||
|
||||
fn is_cublas_dll_utf8(lib: &[u8]) -> bool {
|
||||
is_dll_utf8(lib, CUBLAS_UTF8.as_bytes())
|
||||
}
|
||||
|
||||
fn is_cublas_dll_utf16(lib: &[u16]) -> bool {
|
||||
is_dll_utf16(lib, CUBLAS_UTF16)
|
||||
}
|
||||
|
||||
fn is_cudnn_dll_utf8(lib: &[u8]) -> bool {
|
||||
is_dll_utf8(lib, CUDNN_UTF8.as_bytes())
|
||||
}
|
||||
|
||||
fn is_cudnn_dll_utf16(lib: &[u16]) -> bool {
|
||||
is_dll_utf16(lib, CUDNN_UTF16)
|
||||
}
|
||||
|
||||
fn is_nvcuda_dll_utf8(lib: &[u8]) -> bool {
|
||||
is_dll_utf8(lib, NVCUDA1_UTF8.as_bytes()) || is_dll_utf8(lib, NVCUDA2_UTF8.as_bytes())
|
||||
}
|
||||
|
|
|
@ -16,6 +16,5 @@ zluda_dark_api = { path = "../zluda_dark_api" }
|
|||
cuda_types = { path = "../cuda_types" }
|
||||
|
||||
[package.metadata.zluda]
|
||||
linux_only = true
|
||||
linux_names = ["libcusparse.so.11"]
|
||||
dump_names = ["libcusparse.so"]
|
||||
|
|
Loading…
Reference in a new issue