Implement cublasDotEx.

This commit is contained in:
Seunghoon Lee 2024-04-10 15:20:58 +09:00
parent 9e97c717c3
commit 4baac34d4e
No known key found for this signature in database
GPG key ID: 436E38F4E70BD152
2 changed files with 45 additions and 1 deletions

View file

@ -772,7 +772,19 @@ pub unsafe extern "system" fn cublasDotEx(
resultType: cudaDataType,
executionType: cudaDataType,
) -> cublasStatus_t {
crate::unsupported()
crate::dot_ex(
handle,
n,
x,
xType,
incx,
y,
yType,
incy,
result,
resultType,
executionType,
)
}
#[no_mangle]

View file

@ -356,6 +356,38 @@ unsafe fn dnrm_v2(
to_cuda(rocblas_dnrm2(handle.cast(), n, x, incx, result))
}
unsafe fn dot_ex(
handle: *mut cublasContext,
n: i32,
x: *const ::std::os::raw::c_void,
x_type: cudaDataType,
incx: i32,
y: *const ::std::os::raw::c_void,
y_type: cudaDataType,
incy: i32,
result: *mut ::std::os::raw::c_void,
result_type: cudaDataType,
execution_type: cudaDataType,
) -> cublasStatus_t {
let x_type = type_from_cuda(x_type);
let y_type = type_from_cuda(y_type);
let result_type = type_from_cuda(result_type);
let execution_type = type_from_cuda(execution_type);
to_cuda(rocblas_dot_ex(
handle.cast(),
n,
x,
x_type,
incx,
y,
y_type,
incy,
result,
result_type,
execution_type,
))
}
unsafe fn sdot_v2(
handle: cublasHandle_t,
n: i32,