Fix cusparseDnMatGet.

This commit is contained in:
Seunghoon Lee 2024-03-21 22:07:36 +09:00
parent ff1bc6d9b6
commit 7c3891e6b3
No known key found for this signature in database
GPG key ID: 436E38F4E70BD152

View file

@ -58,7 +58,7 @@ unsafe fn create_csr(
let csr_row_offsets_type = index_type(csr_row_offsets_type);
let csr_col_ind_type = index_type(csr_col_ind_type);
let idx_base = index_base(idx_base);
let value_type = data_type(value_type);
let value_type = to_roc_data_type(value_type);
to_cuda(rocsparse_create_csr_descr(
descr as _,
rows,
@ -74,7 +74,7 @@ unsafe fn create_csr(
))
}
fn data_type(data_type: cudaDataType_t) -> rocsparse_datatype {
fn to_roc_data_type(data_type: cudaDataType_t) -> rocsparse_datatype {
match data_type {
cudaDataType_t::CUDA_R_32F => rocsparse_datatype::rocsparse_datatype_f32_r,
cudaDataType_t::CUDA_R_64F => rocsparse_datatype::rocsparse_datatype_f64_r,
@ -88,6 +88,20 @@ fn data_type(data_type: cudaDataType_t) -> rocsparse_datatype {
}
}
fn to_cuda_data_type(data_type: rocsparse_datatype) -> cudaDataType_t {
match data_type {
rocsparse_datatype::rocsparse_datatype_f32_r => cudaDataType_t::CUDA_R_32F,
rocsparse_datatype::rocsparse_datatype_f64_r => cudaDataType_t::CUDA_R_64F,
rocsparse_datatype::rocsparse_datatype_f32_c => cudaDataType_t::CUDA_C_32F,
rocsparse_datatype::rocsparse_datatype_f64_c => cudaDataType_t::CUDA_C_64F,
rocsparse_datatype::rocsparse_datatype_i8_r => cudaDataType_t::CUDA_R_8I,
rocsparse_datatype::rocsparse_datatype_u8_r => cudaDataType_t::CUDA_R_8U,
rocsparse_datatype::rocsparse_datatype_i32_r => cudaDataType_t::CUDA_R_32I,
rocsparse_datatype::rocsparse_datatype_u32_r => cudaDataType_t::CUDA_R_32U,
_ => panic!(),
}
}
fn index_type(index_type: cusparseIndexType_t) -> rocsparse_indextype {
match index_type {
cusparseIndexType_t::CUSPARSE_INDEX_16U => rocsparse_indextype::rocsparse_indextype_u16,
@ -109,7 +123,7 @@ fn index_base(index_base: cusparseIndexBase_t) -> rocsparse_index_base {
}
}
fn order(order: cusparseOrder_t) -> rocsparse_order {
fn to_roc_order(order: cusparseOrder_t) -> rocsparse_order {
match order {
cusparseOrder_t::CUSPARSE_ORDER_COL => rocsparse_order::rocsparse_order_column,
cusparseOrder_t::CUSPARSE_ORDER_ROW => rocsparse_order::rocsparse_order_row,
@ -117,6 +131,14 @@ fn order(order: cusparseOrder_t) -> rocsparse_order {
}
}
fn to_cuda_order(order: rocsparse_order) -> cusparseOrder_t {
match order {
rocsparse_order::rocsparse_order_column => cusparseOrder_t::CUSPARSE_ORDER_COL,
rocsparse_order::rocsparse_order_row => cusparseOrder_t::CUSPARSE_ORDER_ROW,
_ => panic!(),
}
}
unsafe fn create_csrsv2_info(info: *mut *mut csrsv2Info) -> cusparseStatus_t {
to_cuda(rocsparse_create_mat_info(info.cast()))
}
@ -127,7 +149,7 @@ unsafe fn create_dn_vec(
values: *mut std::ffi::c_void,
value_type: cudaDataType_t,
) -> cusparseStatus_t {
let value_type = data_type(value_type);
let value_type = to_roc_data_type(value_type);
to_cuda(rocsparse_create_dnvec_descr(
dn_vec_descr.cast(),
size,
@ -466,7 +488,7 @@ unsafe fn spmv(
external_buffer: *mut std::ffi::c_void,
) -> cusparseStatus_t {
let op_a = operation(op_a);
let compute_type = data_type(compute_type);
let compute_type = to_roc_data_type(compute_type);
let alg = to_spmv_alg(alg);
// divide by 2 in case there's any arithmetic done on it
let mut size = usize::MAX / 2;
@ -498,7 +520,7 @@ unsafe fn spmv_buffersize(
buffer_size: *mut usize,
) -> cusparseStatus_t {
let op_a = operation(op_a);
let compute_type = data_type(compute_type);
let compute_type = to_roc_data_type(compute_type);
let alg = to_spmv_alg(alg);
to_cuda(rocsparse_spmv(
handle.cast(),
@ -1130,7 +1152,7 @@ unsafe fn spsv_buffersize(
buffer_size: *mut usize,
) -> cusparseStatus_t {
let op_a = operation(op_a);
let compute_type = data_type(compute_type);
let compute_type = to_roc_data_type(compute_type);
let alg = to_spsv_alg(alg);
to_cuda(rocsparse_spsv(
handle.cast(),
@ -1169,7 +1191,7 @@ unsafe fn spsv_analysis(
external_buffer: *mut c_void,
) -> cusparseStatus_t {
let op_a = operation(op_a);
let compute_type = data_type(compute_type);
let compute_type = to_roc_data_type(compute_type);
let alg = to_spsv_alg(alg);
let spsv_descr = spsv_descr.cast::<SpSvDescr>().as_mut().unwrap();
spsv_descr.external_buffer = external_buffer;
@ -1200,7 +1222,7 @@ unsafe fn spsv_solve(
spsv_descr: *mut cusparseSpSVDescr,
) -> cusparseStatus_t {
let op_a = operation(op_a);
let compute_type = data_type(compute_type);
let compute_type = to_roc_data_type(compute_type);
let alg = to_spsv_alg(alg);
let spsv_descr = spsv_descr.cast::<SpSvDescr>().as_ref().unwrap();
to_cuda(rocsparse_spsv(
@ -1322,10 +1344,10 @@ unsafe fn create_dn_mat(
ld: i64,
values: *mut ::std::os::raw::c_void,
value_type: cudaDataType,
o: cusparseOrder_t,
order: cusparseOrder_t,
) -> cusparseStatus_t {
let value_type = data_type(value_type);
let o = order(o);
let value_type = to_roc_data_type(value_type);
let order = to_roc_order(order);
to_cuda(rocsparse_create_dnmat_descr(
dn_mat_descr.cast(),
rows,
@ -1333,7 +1355,7 @@ unsafe fn create_dn_mat(
ld,
values,
value_type,
o,
order,
))
}
@ -1350,19 +1372,22 @@ unsafe fn dn_mat_get(
ld: *mut i64,
values: *mut *mut ::std::os::raw::c_void,
type_: *mut cudaDataType,
o: *mut cusparseOrder_t,
order: *mut cusparseOrder_t,
) -> cusparseStatus_t {
let mut type_ = data_type(*type_);
let mut o = order(*o);
to_cuda(rocsparse_dnmat_get(
let mut out_type = to_roc_data_type(*type_);
let mut out_order = to_roc_order(*order);
let status = to_cuda(rocsparse_dnmat_get(
dn_mat_descr.cast(),
rows,
cols,
ld,
values,
&mut type_,
&mut o,
))
&mut out_type,
&mut out_order,
));
*type_ = to_cuda_data_type(out_type);
*order = to_cuda_order(out_order);
status
}
unsafe fn dn_mat_get_values(