Fix cusparseDnMatGet.
This commit is contained in:
parent
ff1bc6d9b6
commit
7c3891e6b3
|
@ -58,7 +58,7 @@ unsafe fn create_csr(
|
|||
let csr_row_offsets_type = index_type(csr_row_offsets_type);
|
||||
let csr_col_ind_type = index_type(csr_col_ind_type);
|
||||
let idx_base = index_base(idx_base);
|
||||
let value_type = data_type(value_type);
|
||||
let value_type = to_roc_data_type(value_type);
|
||||
to_cuda(rocsparse_create_csr_descr(
|
||||
descr as _,
|
||||
rows,
|
||||
|
@ -74,7 +74,7 @@ unsafe fn create_csr(
|
|||
))
|
||||
}
|
||||
|
||||
fn data_type(data_type: cudaDataType_t) -> rocsparse_datatype {
|
||||
fn to_roc_data_type(data_type: cudaDataType_t) -> rocsparse_datatype {
|
||||
match data_type {
|
||||
cudaDataType_t::CUDA_R_32F => rocsparse_datatype::rocsparse_datatype_f32_r,
|
||||
cudaDataType_t::CUDA_R_64F => rocsparse_datatype::rocsparse_datatype_f64_r,
|
||||
|
@ -88,6 +88,20 @@ fn data_type(data_type: cudaDataType_t) -> rocsparse_datatype {
|
|||
}
|
||||
}
|
||||
|
||||
fn to_cuda_data_type(data_type: rocsparse_datatype) -> cudaDataType_t {
|
||||
match data_type {
|
||||
rocsparse_datatype::rocsparse_datatype_f32_r => cudaDataType_t::CUDA_R_32F,
|
||||
rocsparse_datatype::rocsparse_datatype_f64_r => cudaDataType_t::CUDA_R_64F,
|
||||
rocsparse_datatype::rocsparse_datatype_f32_c => cudaDataType_t::CUDA_C_32F,
|
||||
rocsparse_datatype::rocsparse_datatype_f64_c => cudaDataType_t::CUDA_C_64F,
|
||||
rocsparse_datatype::rocsparse_datatype_i8_r => cudaDataType_t::CUDA_R_8I,
|
||||
rocsparse_datatype::rocsparse_datatype_u8_r => cudaDataType_t::CUDA_R_8U,
|
||||
rocsparse_datatype::rocsparse_datatype_i32_r => cudaDataType_t::CUDA_R_32I,
|
||||
rocsparse_datatype::rocsparse_datatype_u32_r => cudaDataType_t::CUDA_R_32U,
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn index_type(index_type: cusparseIndexType_t) -> rocsparse_indextype {
|
||||
match index_type {
|
||||
cusparseIndexType_t::CUSPARSE_INDEX_16U => rocsparse_indextype::rocsparse_indextype_u16,
|
||||
|
@ -109,7 +123,7 @@ fn index_base(index_base: cusparseIndexBase_t) -> rocsparse_index_base {
|
|||
}
|
||||
}
|
||||
|
||||
fn order(order: cusparseOrder_t) -> rocsparse_order {
|
||||
fn to_roc_order(order: cusparseOrder_t) -> rocsparse_order {
|
||||
match order {
|
||||
cusparseOrder_t::CUSPARSE_ORDER_COL => rocsparse_order::rocsparse_order_column,
|
||||
cusparseOrder_t::CUSPARSE_ORDER_ROW => rocsparse_order::rocsparse_order_row,
|
||||
|
@ -117,6 +131,14 @@ fn order(order: cusparseOrder_t) -> rocsparse_order {
|
|||
}
|
||||
}
|
||||
|
||||
fn to_cuda_order(order: rocsparse_order) -> cusparseOrder_t {
|
||||
match order {
|
||||
rocsparse_order::rocsparse_order_column => cusparseOrder_t::CUSPARSE_ORDER_COL,
|
||||
rocsparse_order::rocsparse_order_row => cusparseOrder_t::CUSPARSE_ORDER_ROW,
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn create_csrsv2_info(info: *mut *mut csrsv2Info) -> cusparseStatus_t {
|
||||
to_cuda(rocsparse_create_mat_info(info.cast()))
|
||||
}
|
||||
|
@ -127,7 +149,7 @@ unsafe fn create_dn_vec(
|
|||
values: *mut std::ffi::c_void,
|
||||
value_type: cudaDataType_t,
|
||||
) -> cusparseStatus_t {
|
||||
let value_type = data_type(value_type);
|
||||
let value_type = to_roc_data_type(value_type);
|
||||
to_cuda(rocsparse_create_dnvec_descr(
|
||||
dn_vec_descr.cast(),
|
||||
size,
|
||||
|
@ -466,7 +488,7 @@ unsafe fn spmv(
|
|||
external_buffer: *mut std::ffi::c_void,
|
||||
) -> cusparseStatus_t {
|
||||
let op_a = operation(op_a);
|
||||
let compute_type = data_type(compute_type);
|
||||
let compute_type = to_roc_data_type(compute_type);
|
||||
let alg = to_spmv_alg(alg);
|
||||
// divide by 2 in case there's any arithmetic done on it
|
||||
let mut size = usize::MAX / 2;
|
||||
|
@ -498,7 +520,7 @@ unsafe fn spmv_buffersize(
|
|||
buffer_size: *mut usize,
|
||||
) -> cusparseStatus_t {
|
||||
let op_a = operation(op_a);
|
||||
let compute_type = data_type(compute_type);
|
||||
let compute_type = to_roc_data_type(compute_type);
|
||||
let alg = to_spmv_alg(alg);
|
||||
to_cuda(rocsparse_spmv(
|
||||
handle.cast(),
|
||||
|
@ -1130,7 +1152,7 @@ unsafe fn spsv_buffersize(
|
|||
buffer_size: *mut usize,
|
||||
) -> cusparseStatus_t {
|
||||
let op_a = operation(op_a);
|
||||
let compute_type = data_type(compute_type);
|
||||
let compute_type = to_roc_data_type(compute_type);
|
||||
let alg = to_spsv_alg(alg);
|
||||
to_cuda(rocsparse_spsv(
|
||||
handle.cast(),
|
||||
|
@ -1169,7 +1191,7 @@ unsafe fn spsv_analysis(
|
|||
external_buffer: *mut c_void,
|
||||
) -> cusparseStatus_t {
|
||||
let op_a = operation(op_a);
|
||||
let compute_type = data_type(compute_type);
|
||||
let compute_type = to_roc_data_type(compute_type);
|
||||
let alg = to_spsv_alg(alg);
|
||||
let spsv_descr = spsv_descr.cast::<SpSvDescr>().as_mut().unwrap();
|
||||
spsv_descr.external_buffer = external_buffer;
|
||||
|
@ -1200,7 +1222,7 @@ unsafe fn spsv_solve(
|
|||
spsv_descr: *mut cusparseSpSVDescr,
|
||||
) -> cusparseStatus_t {
|
||||
let op_a = operation(op_a);
|
||||
let compute_type = data_type(compute_type);
|
||||
let compute_type = to_roc_data_type(compute_type);
|
||||
let alg = to_spsv_alg(alg);
|
||||
let spsv_descr = spsv_descr.cast::<SpSvDescr>().as_ref().unwrap();
|
||||
to_cuda(rocsparse_spsv(
|
||||
|
@ -1322,10 +1344,10 @@ unsafe fn create_dn_mat(
|
|||
ld: i64,
|
||||
values: *mut ::std::os::raw::c_void,
|
||||
value_type: cudaDataType,
|
||||
o: cusparseOrder_t,
|
||||
order: cusparseOrder_t,
|
||||
) -> cusparseStatus_t {
|
||||
let value_type = data_type(value_type);
|
||||
let o = order(o);
|
||||
let value_type = to_roc_data_type(value_type);
|
||||
let order = to_roc_order(order);
|
||||
to_cuda(rocsparse_create_dnmat_descr(
|
||||
dn_mat_descr.cast(),
|
||||
rows,
|
||||
|
@ -1333,7 +1355,7 @@ unsafe fn create_dn_mat(
|
|||
ld,
|
||||
values,
|
||||
value_type,
|
||||
o,
|
||||
order,
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -1350,19 +1372,22 @@ unsafe fn dn_mat_get(
|
|||
ld: *mut i64,
|
||||
values: *mut *mut ::std::os::raw::c_void,
|
||||
type_: *mut cudaDataType,
|
||||
o: *mut cusparseOrder_t,
|
||||
order: *mut cusparseOrder_t,
|
||||
) -> cusparseStatus_t {
|
||||
let mut type_ = data_type(*type_);
|
||||
let mut o = order(*o);
|
||||
to_cuda(rocsparse_dnmat_get(
|
||||
let mut out_type = to_roc_data_type(*type_);
|
||||
let mut out_order = to_roc_order(*order);
|
||||
let status = to_cuda(rocsparse_dnmat_get(
|
||||
dn_mat_descr.cast(),
|
||||
rows,
|
||||
cols,
|
||||
ld,
|
||||
values,
|
||||
&mut type_,
|
||||
&mut o,
|
||||
))
|
||||
&mut out_type,
|
||||
&mut out_order,
|
||||
));
|
||||
*type_ = to_cuda_data_type(out_type);
|
||||
*order = to_cuda_order(out_order);
|
||||
status
|
||||
}
|
||||
|
||||
unsafe fn dn_mat_get_values(
|
||||
|
|
Loading…
Reference in a new issue