Implement cublasSdot.

This commit is contained in:
Seunghoon Lee 2024-03-25 11:30:23 +09:00
parent 7c3891e6b3
commit 1122cc0e83
No known key found for this signature in database
GPG key ID: 91024D13C6CA4722
2 changed files with 68 additions and 6 deletions

View file

@ -802,7 +802,15 @@ pub unsafe extern "system" fn cublasSdot_v2(
incy: ::std::os::raw::c_int, incy: ::std::os::raw::c_int,
result: *mut f32, result: *mut f32,
) -> cublasStatus_t { ) -> cublasStatus_t {
crate::unsupported() crate::sdot_v2(
handle,
n,
x,
incx,
y,
incy,
result,
)
} }
#[no_mangle] #[no_mangle]
@ -4920,7 +4928,13 @@ pub unsafe extern "system" fn cublasSdot(
y: *const f32, y: *const f32,
incy: ::std::os::raw::c_int, incy: ::std::os::raw::c_int,
) -> f32 { ) -> f32 {
unimplemented!() crate::sdot(
n,
x,
incx,
y,
incy,
)
} }
#[no_mangle] #[no_mangle]

View file

@ -209,13 +209,13 @@ unsafe fn sgemm(
c: *mut f32, c: *mut f32,
ldc: i32, ldc: i32,
) -> cublasStatus_t { ) -> cublasStatus_t {
let transa = op_from_cuda(cublasOperation_t(transa as _));
let transb = op_from_cuda(cublasOperation_t(transb as _));
let mut handle = mem::zeroed(); let mut handle = mem::zeroed();
let mut status = to_cuda(rocblas_create_handle(handle)); let mut status = to_cuda(rocblas_create_handle(handle));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS { if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status; return status;
} }
let transa = op_from_cuda(cublasOperation_t(transa as _));
let transb = op_from_cuda(cublasOperation_t(transb as _));
status = to_cuda(rocblas_sgemm( status = to_cuda(rocblas_sgemm(
handle.cast(), handle.cast(),
transa, transa,
@ -279,6 +279,34 @@ unsafe fn init() -> cublasStatus_t {
cublasStatus_t::CUBLAS_STATUS_SUCCESS cublasStatus_t::CUBLAS_STATUS_SUCCESS
} }
unsafe fn sdot(
n: i32,
x: *const f32,
incx: i32,
y: *const f32,
incy: i32,
) -> cublasStatus_t {
let mut handle = mem::zeroed();
let mut status = to_cuda(rocblas_create_handle(handle));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
let result = mem::zeroed();
status = to_cuda(rocblas_sdot(
handle.cast(),
n,
x,
incx,
y,
incy,
result,
));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
to_cuda(rocblas_destroy_handle(*handle))
}
unsafe fn dasum_v2( unsafe fn dasum_v2(
handle: *mut cublasContext, handle: *mut cublasContext,
n: i32, n: i32,
@ -333,6 +361,26 @@ unsafe fn dnrm_v2(
to_cuda(rocblas_dnrm2(handle.cast(), n, x, incx, result)) to_cuda(rocblas_dnrm2(handle.cast(), n, x, incx, result))
} }
unsafe fn sdot_v2(
handle: cublasHandle_t,
n: i32,
x: *const f32,
incx: i32,
y: *const f32,
incy: i32,
result: *mut f32,
) -> cublasStatus_t {
to_cuda(rocblas_sdot(
handle.cast(),
n,
x,
incx,
y,
incy,
result,
))
}
unsafe fn idamax_v2( unsafe fn idamax_v2(
handle: *mut cublasContext, handle: *mut cublasContext,
n: i32, n: i32,
@ -979,13 +1027,13 @@ unsafe fn dgemm(
c: *mut f64, c: *mut f64,
ldc: i32, ldc: i32,
) -> cublasStatus_t { ) -> cublasStatus_t {
let transa = op_from_cuda(cublasOperation_t(transa as _));
let transb = op_from_cuda(cublasOperation_t(transb as _));
let mut handle = mem::zeroed(); let mut handle = mem::zeroed();
let mut status = to_cuda(rocblas_create_handle(handle)); let mut status = to_cuda(rocblas_create_handle(handle));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS { if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status; return status;
} }
let transa = op_from_cuda(cublasOperation_t(transa as _));
let transb = op_from_cuda(cublasOperation_t(transb as _));
status = to_cuda(rocblas_dgemm( status = to_cuda(rocblas_dgemm(
handle.cast(), handle.cast(),
transa, transa,