Implement cusparseCreateDnMat, cusparseDestroyDnMat, cusparseDnMat*.

This commit is contained in:
Seunghoon Lee 2024-03-20 14:12:00 +09:00
parent 2812b1db44
commit 6b2488395d
No known key found for this signature in database
GPG key ID: 91024D13C6CA4722
2 changed files with 138 additions and 7 deletions

View file

@ -7980,14 +7980,22 @@ pub unsafe extern "system" fn cusparseCreateDnMat(
valueType: cudaDataType,
order: cusparseOrder_t,
) -> cusparseStatus_t {
crate::unsupported()
crate::create_dn_mat(
dnMatDescr,
rows,
cols,
ld,
values,
valueType,
order,
)
}
#[no_mangle]
pub unsafe extern "system" fn cusparseDestroyDnMat(
dnMatDescr: cusparseDnMatDescr_t,
) -> cusparseStatus_t {
crate::unsupported()
crate::destroy_dn_mat(dnMatDescr)
}
#[no_mangle]
@ -8000,7 +8008,15 @@ pub unsafe extern "system" fn cusparseDnMatGet(
type_: *mut cudaDataType,
order: *mut cusparseOrder_t,
) -> cusparseStatus_t {
crate::unsupported()
crate::dn_mat_get(
dnMatDescr,
rows,
cols,
ld,
values,
type_,
order,
)
}
#[no_mangle]
@ -8008,7 +8024,10 @@ pub unsafe extern "system" fn cusparseDnMatGetValues(
dnMatDescr: cusparseDnMatDescr_t,
values: *mut *mut ::std::os::raw::c_void,
) -> cusparseStatus_t {
crate::unsupported()
crate::dn_mat_get_values(
dnMatDescr,
values,
)
}
#[no_mangle]
@ -8016,7 +8035,10 @@ pub unsafe extern "system" fn cusparseDnMatSetValues(
dnMatDescr: cusparseDnMatDescr_t,
values: *mut ::std::os::raw::c_void,
) -> cusparseStatus_t {
crate::unsupported()
crate::dn_mat_set_values(
dnMatDescr,
values,
)
}
#[no_mangle]
@ -8025,7 +8047,11 @@ pub unsafe extern "system" fn cusparseDnMatSetStridedBatch(
batchCount: ::std::os::raw::c_int,
batchStride: i64,
) -> cusparseStatus_t {
crate::unsupported()
crate::dn_mat_set_strided_batch(
dnMatDescr,
batchCount,
batchStride,
)
}
#[no_mangle]
@ -8034,7 +8060,11 @@ pub unsafe extern "system" fn cusparseDnMatGetStridedBatch(
batchCount: *mut ::std::os::raw::c_int,
batchStride: *mut i64,
) -> cusparseStatus_t {
crate::unsupported()
crate::dn_mat_get_strided_batch(
dnMatDescr,
batchCount,
batchStride,
)
}
#[no_mangle]

View file

@ -109,6 +109,13 @@ fn index_base(index_base: cusparseIndexBase_t) -> rocsparse_index_base {
}
}
fn order(order: cusparseOrder_t) -> rocsprase_order {
match order {
cusparseOrder_t::CUSPARSE_ORDER_COL => rocsparse_order::rocsparse_order_column,
cusparseOrder_t::CUSPARSE_ORDER_ROW => rocsparse_order::rocsparse_order_row,
}
}
unsafe fn create_csrsv2_info(info: *mut *mut csrsv2Info) -> cusparseStatus_t {
to_cuda(rocsparse_create_mat_info(info.cast()))
}
@ -1306,3 +1313,97 @@ unsafe fn dcsrilu02(
p_buffer,
))
}
unsafe fn create_dn_mat(
dn_mat_descr: *mut cusparseDnMatDescr_t,
rows: i64,
cols: i64,
ld: i64,
values: *mut ::std::os::raw::c_void,
value_type: cudaDataType,
o: cusparseOrder_t,
) -> cusparseStatus_t {
let value_type = data_type(value_type);
let o = order(o);
to_cuda(rocsparse_create_dnmat_descr(
dn_mat_descr.cast(),
rows,
cols,
ld,
values,
value_type,
o,
))
}
unsafe fn destroy_dn_mat(
dn_mat_descr: cusparseDnMatDescr_t,
) -> cusparseStatus_t {
to_cuda(rocsparse_destroy_dnmat_descr(dn_mat_descr.cast()))
}
unsafe fn dn_mat_get(
dn_mat_descr: cusparseDnMatDescr_t,
rows: *mut i64,
cols: *mut i64,
ld: *mut i64,
values: *mut *mut ::std::os::raw::c_void,
type_: *mut cudaDataType,
o: *mut cusparseOrder_t,
) -> cusparseStatus_t {
let type_ = data_type(type_);
let o = order(o);
to_cuda(rocsparse_dnmat_get(
dn_mat_descr.cast(),
rows,
cols,
ld,
values,
type_,
o,
))
}
unsafe fn dn_mat_get_values(
dn_mat_descr: cusparseDnMatDescr_t,
values: *mut *mut ::std::os::raw::c_void,
) -> cusparseStatus_t {
to_cuda(rocsparse_dnmat_get_values(
dn_mat_descr.cast(),
values,
))
}
unsafe fn dn_mat_set_values(
dn_mat_descr: cusparseDnMatDescr_t,
values: *mut ::std::os::raw::c_void,
) -> cusparseStatus_t {
to_cuda(rocsparse_dnmat_set_values(
dn_mat_descr.cast(),
values,
))
}
unsafe fn dn_mat_set_strided_batch(
dn_mat_descr: cusparseDnMatDescr_t,
batch_count: i32,
batch_stride: i64,
) -> cusparseStatus_t {
to_cuda(rocsparse_dnmat_set_strided_batch(
dn_mat_descr.cast(),
batch_count,
batch_stride,
))
}
unsafe fn dn_mat_get_strided_batch(
dn_mat_descr: cusparseDnMatDescr_t,
batch_count: *mut i32,
batch_stride: *mut i64,
) -> cusparseStatus_t {
to_cuda(rocsparse_dnmat_get_strided_batch(
dn_mat_descr.cast(),
batch_count,
batch_stride,
))
}