Implement cublasDgetrsBatched.

This commit is contained in:
Seunghoon Lee 2024-03-18 22:38:57 +09:00
parent 605254b38e
commit cd1e0a3d50
No known key found for this signature in database
GPG key ID: 436E38F4E70BD152
2 changed files with 44 additions and 1 deletions

View file

@ -4366,7 +4366,19 @@ pub unsafe extern "system" fn cublasDgetrsBatched(
info: *mut ::std::os::raw::c_int,
batchSize: ::std::os::raw::c_int,
) -> cublasStatus_t {
crate::unsupported()
crate::dgetrs_batched(
handle,
trans,
n,
nrhs,
Aarray,
lda,
devIpiv,
Barray,
ldb,
info,
batchSize,
)
}
#[no_mangle]

View file

@ -9,6 +9,7 @@ use rocsolver_sys::{
rocsolver_cgetrf_batched,
rocsolver_cgetri_outofplace_batched,
rocsolver_sgetrs_batched,
rocsolver_dgetrs_batched,
rocsolver_zgetrf_batched,
rocsolver_zgetri_outofplace_batched,
};
@ -742,6 +743,36 @@ unsafe fn sgetrs_batched(
))
}
unsafe fn dgetrs_batched(
handle: *mut cublasContext,
trans: cublasOperation_t,
n: i32,
nrhs: i32,
a: *const *const f64,
lda: i32,
dev_ipiv: *const i32,
b: *const *mut f64,
ldb: i32,
info: *mut i32,
batch_size: i32,
) -> cublasStatus_t {
let trans = op_from_cuda_for_solver(trans);
let stride = n * nrhs;
to_cuda_solver(rocsolver_dgetrs_batched(
handle.cast(),
trans,
n,
nrhs,
a.cast(),
lda,
dev_ipiv,
stride as _,
b,
ldb,
batch_size,
))
}
unsafe fn dtrmm_v2(
handle: *mut cublasContext,
side: cublasSideMode_t,