diff --git a/zluda_blas/src/cublas.rs b/zluda_blas/src/cublas.rs index 38fb165..198f9da 100644 --- a/zluda_blas/src/cublas.rs +++ b/zluda_blas/src/cublas.rs @@ -772,7 +772,19 @@ pub unsafe extern "system" fn cublasDotEx( resultType: cudaDataType, executionType: cudaDataType, ) -> cublasStatus_t { - crate::unsupported() + crate::dot_ex( + handle, + n, + x, + xType, + incx, + y, + yType, + incy, + result, + resultType, + executionType, + ) } #[no_mangle] diff --git a/zluda_blas/src/lib.rs b/zluda_blas/src/lib.rs index 0b4de70..e18a94c 100644 --- a/zluda_blas/src/lib.rs +++ b/zluda_blas/src/lib.rs @@ -356,6 +356,38 @@ unsafe fn dnrm_v2( to_cuda(rocblas_dnrm2(handle.cast(), n, x, incx, result)) } +unsafe fn dot_ex( + handle: *mut cublasContext, + n: i32, + x: *const ::std::os::raw::c_void, + x_type: cudaDataType, + incx: i32, + y: *const ::std::os::raw::c_void, + y_type: cudaDataType, + incy: i32, + result: *mut ::std::os::raw::c_void, + result_type: cudaDataType, + execution_type: cudaDataType, +) -> cublasStatus_t { + let x_type = type_from_cuda(x_type); + let y_type = type_from_cuda(y_type); + let result_type = type_from_cuda(result_type); + let execution_type = type_from_cuda(execution_type); + to_cuda(rocblas_dot_ex( + handle.cast(), + n, + x, + x_type, + incx, + y, + y_type, + incy, + result, + result_type, + execution_type, + )) +} + unsafe fn sdot_v2( handle: cublasHandle_t, n: i32,