Implement cublasSdot.
This commit is contained in:
parent
7c3891e6b3
commit
1122cc0e83
|
@ -802,7 +802,15 @@ pub unsafe extern "system" fn cublasSdot_v2(
|
|||
incy: ::std::os::raw::c_int,
|
||||
result: *mut f32,
|
||||
) -> cublasStatus_t {
|
||||
crate::unsupported()
|
||||
crate::sdot_v2(
|
||||
handle,
|
||||
n,
|
||||
x,
|
||||
incx,
|
||||
y,
|
||||
incy,
|
||||
result,
|
||||
)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
@ -4920,7 +4928,13 @@ pub unsafe extern "system" fn cublasSdot(
|
|||
y: *const f32,
|
||||
incy: ::std::os::raw::c_int,
|
||||
) -> f32 {
|
||||
unimplemented!()
|
||||
crate::sdot(
|
||||
n,
|
||||
x,
|
||||
incx,
|
||||
y,
|
||||
incy,
|
||||
)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
|
|
@ -209,13 +209,13 @@ unsafe fn sgemm(
|
|||
c: *mut f32,
|
||||
ldc: i32,
|
||||
) -> cublasStatus_t {
|
||||
let transa = op_from_cuda(cublasOperation_t(transa as _));
|
||||
let transb = op_from_cuda(cublasOperation_t(transb as _));
|
||||
let mut handle = mem::zeroed();
|
||||
let mut status = to_cuda(rocblas_create_handle(handle));
|
||||
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
|
||||
return status;
|
||||
}
|
||||
let transa = op_from_cuda(cublasOperation_t(transa as _));
|
||||
let transb = op_from_cuda(cublasOperation_t(transb as _));
|
||||
status = to_cuda(rocblas_sgemm(
|
||||
handle.cast(),
|
||||
transa,
|
||||
|
@ -279,6 +279,34 @@ unsafe fn init() -> cublasStatus_t {
|
|||
cublasStatus_t::CUBLAS_STATUS_SUCCESS
|
||||
}
|
||||
|
||||
unsafe fn sdot(
|
||||
n: i32,
|
||||
x: *const f32,
|
||||
incx: i32,
|
||||
y: *const f32,
|
||||
incy: i32,
|
||||
) -> cublasStatus_t {
|
||||
let mut handle = mem::zeroed();
|
||||
let mut status = to_cuda(rocblas_create_handle(handle));
|
||||
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
|
||||
return status;
|
||||
}
|
||||
let result = mem::zeroed();
|
||||
status = to_cuda(rocblas_sdot(
|
||||
handle.cast(),
|
||||
n,
|
||||
x,
|
||||
incx,
|
||||
y,
|
||||
incy,
|
||||
result,
|
||||
));
|
||||
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
|
||||
return status;
|
||||
}
|
||||
to_cuda(rocblas_destroy_handle(*handle))
|
||||
}
|
||||
|
||||
unsafe fn dasum_v2(
|
||||
handle: *mut cublasContext,
|
||||
n: i32,
|
||||
|
@ -333,6 +361,26 @@ unsafe fn dnrm_v2(
|
|||
to_cuda(rocblas_dnrm2(handle.cast(), n, x, incx, result))
|
||||
}
|
||||
|
||||
unsafe fn sdot_v2(
|
||||
handle: cublasHandle_t,
|
||||
n: i32,
|
||||
x: *const f32,
|
||||
incx: i32,
|
||||
y: *const f32,
|
||||
incy: i32,
|
||||
result: *mut f32,
|
||||
) -> cublasStatus_t {
|
||||
to_cuda(rocblas_sdot(
|
||||
handle.cast(),
|
||||
n,
|
||||
x,
|
||||
incx,
|
||||
y,
|
||||
incy,
|
||||
result,
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn idamax_v2(
|
||||
handle: *mut cublasContext,
|
||||
n: i32,
|
||||
|
@ -979,13 +1027,13 @@ unsafe fn dgemm(
|
|||
c: *mut f64,
|
||||
ldc: i32,
|
||||
) -> cublasStatus_t {
|
||||
let transa = op_from_cuda(cublasOperation_t(transa as _));
|
||||
let transb = op_from_cuda(cublasOperation_t(transb as _));
|
||||
let mut handle = mem::zeroed();
|
||||
let mut status = to_cuda(rocblas_create_handle(handle));
|
||||
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
|
||||
return status;
|
||||
}
|
||||
let transa = op_from_cuda(cublasOperation_t(transa as _));
|
||||
let transb = op_from_cuda(cublasOperation_t(transb as _));
|
||||
status = to_cuda(rocblas_dgemm(
|
||||
handle.cast(),
|
||||
transa,
|
||||
|
|
Loading…
Reference in a new issue