Implement cublasDgetrsBatched.
This commit is contained in:
parent
605254b38e
commit
cd1e0a3d50
|
@ -4366,7 +4366,19 @@ pub unsafe extern "system" fn cublasDgetrsBatched(
|
|||
info: *mut ::std::os::raw::c_int,
|
||||
batchSize: ::std::os::raw::c_int,
|
||||
) -> cublasStatus_t {
|
||||
crate::unsupported()
|
||||
crate::dgetrs_batched(
|
||||
handle,
|
||||
trans,
|
||||
n,
|
||||
nrhs,
|
||||
Aarray,
|
||||
lda,
|
||||
devIpiv,
|
||||
Barray,
|
||||
ldb,
|
||||
info,
|
||||
batchSize,
|
||||
)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
|
|
@ -9,6 +9,7 @@ use rocsolver_sys::{
|
|||
rocsolver_cgetrf_batched,
|
||||
rocsolver_cgetri_outofplace_batched,
|
||||
rocsolver_sgetrs_batched,
|
||||
rocsolver_dgetrs_batched,
|
||||
rocsolver_zgetrf_batched,
|
||||
rocsolver_zgetri_outofplace_batched,
|
||||
};
|
||||
|
@ -742,6 +743,36 @@ unsafe fn sgetrs_batched(
|
|||
))
|
||||
}
|
||||
|
||||
unsafe fn dgetrs_batched(
|
||||
handle: *mut cublasContext,
|
||||
trans: cublasOperation_t,
|
||||
n: i32,
|
||||
nrhs: i32,
|
||||
a: *const *const f64,
|
||||
lda: i32,
|
||||
dev_ipiv: *const i32,
|
||||
b: *const *mut f64,
|
||||
ldb: i32,
|
||||
info: *mut i32,
|
||||
batch_size: i32,
|
||||
) -> cublasStatus_t {
|
||||
let trans = op_from_cuda_for_solver(trans);
|
||||
let stride = n * nrhs;
|
||||
to_cuda_solver(rocsolver_dgetrs_batched(
|
||||
handle.cast(),
|
||||
trans,
|
||||
n,
|
||||
nrhs,
|
||||
a.cast(),
|
||||
lda,
|
||||
dev_ipiv,
|
||||
stride as _,
|
||||
b,
|
||||
ldb,
|
||||
batch_size,
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn dtrmm_v2(
|
||||
handle: *mut cublasContext,
|
||||
side: cublasSideMode_t,
|
||||
|
|
Loading…
Reference in a new issue