Implement cusparseCreateDnMat, cusparseDestroyDnMat, cusparseDnMat*.
This commit is contained in:
parent
2812b1db44
commit
6b2488395d
|
@ -7980,14 +7980,22 @@ pub unsafe extern "system" fn cusparseCreateDnMat(
|
|||
valueType: cudaDataType,
|
||||
order: cusparseOrder_t,
|
||||
) -> cusparseStatus_t {
|
||||
crate::unsupported()
|
||||
crate::create_dn_mat(
|
||||
dnMatDescr,
|
||||
rows,
|
||||
cols,
|
||||
ld,
|
||||
values,
|
||||
valueType,
|
||||
order,
|
||||
)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "system" fn cusparseDestroyDnMat(
|
||||
dnMatDescr: cusparseDnMatDescr_t,
|
||||
) -> cusparseStatus_t {
|
||||
crate::unsupported()
|
||||
crate::destroy_dn_mat(dnMatDescr)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
@ -8000,7 +8008,15 @@ pub unsafe extern "system" fn cusparseDnMatGet(
|
|||
type_: *mut cudaDataType,
|
||||
order: *mut cusparseOrder_t,
|
||||
) -> cusparseStatus_t {
|
||||
crate::unsupported()
|
||||
crate::dn_mat_get(
|
||||
dnMatDescr,
|
||||
rows,
|
||||
cols,
|
||||
ld,
|
||||
values,
|
||||
type_,
|
||||
order,
|
||||
)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
@ -8008,7 +8024,10 @@ pub unsafe extern "system" fn cusparseDnMatGetValues(
|
|||
dnMatDescr: cusparseDnMatDescr_t,
|
||||
values: *mut *mut ::std::os::raw::c_void,
|
||||
) -> cusparseStatus_t {
|
||||
crate::unsupported()
|
||||
crate::dn_mat_get_values(
|
||||
dnMatDescr,
|
||||
values,
|
||||
)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
@ -8016,7 +8035,10 @@ pub unsafe extern "system" fn cusparseDnMatSetValues(
|
|||
dnMatDescr: cusparseDnMatDescr_t,
|
||||
values: *mut ::std::os::raw::c_void,
|
||||
) -> cusparseStatus_t {
|
||||
crate::unsupported()
|
||||
crate::dn_mat_set_values(
|
||||
dnMatDescr,
|
||||
values,
|
||||
)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
@ -8025,7 +8047,11 @@ pub unsafe extern "system" fn cusparseDnMatSetStridedBatch(
|
|||
batchCount: ::std::os::raw::c_int,
|
||||
batchStride: i64,
|
||||
) -> cusparseStatus_t {
|
||||
crate::unsupported()
|
||||
crate::dn_mat_set_strided_batch(
|
||||
dnMatDescr,
|
||||
batchCount,
|
||||
batchStride,
|
||||
)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
@ -8034,7 +8060,11 @@ pub unsafe extern "system" fn cusparseDnMatGetStridedBatch(
|
|||
batchCount: *mut ::std::os::raw::c_int,
|
||||
batchStride: *mut i64,
|
||||
) -> cusparseStatus_t {
|
||||
crate::unsupported()
|
||||
crate::dn_mat_get_strided_batch(
|
||||
dnMatDescr,
|
||||
batchCount,
|
||||
batchStride,
|
||||
)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
|
|
@ -109,6 +109,13 @@ fn index_base(index_base: cusparseIndexBase_t) -> rocsparse_index_base {
|
|||
}
|
||||
}
|
||||
|
||||
fn order(order: cusparseOrder_t) -> rocsprase_order {
|
||||
match order {
|
||||
cusparseOrder_t::CUSPARSE_ORDER_COL => rocsparse_order::rocsparse_order_column,
|
||||
cusparseOrder_t::CUSPARSE_ORDER_ROW => rocsparse_order::rocsparse_order_row,
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn create_csrsv2_info(info: *mut *mut csrsv2Info) -> cusparseStatus_t {
|
||||
to_cuda(rocsparse_create_mat_info(info.cast()))
|
||||
}
|
||||
|
@ -1306,3 +1313,97 @@ unsafe fn dcsrilu02(
|
|||
p_buffer,
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn create_dn_mat(
|
||||
dn_mat_descr: *mut cusparseDnMatDescr_t,
|
||||
rows: i64,
|
||||
cols: i64,
|
||||
ld: i64,
|
||||
values: *mut ::std::os::raw::c_void,
|
||||
value_type: cudaDataType,
|
||||
o: cusparseOrder_t,
|
||||
) -> cusparseStatus_t {
|
||||
let value_type = data_type(value_type);
|
||||
let o = order(o);
|
||||
to_cuda(rocsparse_create_dnmat_descr(
|
||||
dn_mat_descr.cast(),
|
||||
rows,
|
||||
cols,
|
||||
ld,
|
||||
values,
|
||||
value_type,
|
||||
o,
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn destroy_dn_mat(
|
||||
dn_mat_descr: cusparseDnMatDescr_t,
|
||||
) -> cusparseStatus_t {
|
||||
to_cuda(rocsparse_destroy_dnmat_descr(dn_mat_descr.cast()))
|
||||
}
|
||||
|
||||
unsafe fn dn_mat_get(
|
||||
dn_mat_descr: cusparseDnMatDescr_t,
|
||||
rows: *mut i64,
|
||||
cols: *mut i64,
|
||||
ld: *mut i64,
|
||||
values: *mut *mut ::std::os::raw::c_void,
|
||||
type_: *mut cudaDataType,
|
||||
o: *mut cusparseOrder_t,
|
||||
) -> cusparseStatus_t {
|
||||
let type_ = data_type(type_);
|
||||
let o = order(o);
|
||||
to_cuda(rocsparse_dnmat_get(
|
||||
dn_mat_descr.cast(),
|
||||
rows,
|
||||
cols,
|
||||
ld,
|
||||
values,
|
||||
type_,
|
||||
o,
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn dn_mat_get_values(
|
||||
dn_mat_descr: cusparseDnMatDescr_t,
|
||||
values: *mut *mut ::std::os::raw::c_void,
|
||||
) -> cusparseStatus_t {
|
||||
to_cuda(rocsparse_dnmat_get_values(
|
||||
dn_mat_descr.cast(),
|
||||
values,
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn dn_mat_set_values(
|
||||
dn_mat_descr: cusparseDnMatDescr_t,
|
||||
values: *mut ::std::os::raw::c_void,
|
||||
) -> cusparseStatus_t {
|
||||
to_cuda(rocsparse_dnmat_set_values(
|
||||
dn_mat_descr.cast(),
|
||||
values,
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn dn_mat_set_strided_batch(
|
||||
dn_mat_descr: cusparseDnMatDescr_t,
|
||||
batch_count: i32,
|
||||
batch_stride: i64,
|
||||
) -> cusparseStatus_t {
|
||||
to_cuda(rocsparse_dnmat_set_strided_batch(
|
||||
dn_mat_descr.cast(),
|
||||
batch_count,
|
||||
batch_stride,
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn dn_mat_get_strided_batch(
|
||||
dn_mat_descr: cusparseDnMatDescr_t,
|
||||
batch_count: *mut i32,
|
||||
batch_stride: *mut i64,
|
||||
) -> cusparseStatus_t {
|
||||
to_cuda(rocsparse_dnmat_get_strided_batch(
|
||||
dn_mat_descr.cast(),
|
||||
batch_count,
|
||||
batch_stride,
|
||||
))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue