diff --git a/miopen-sys/build.rs b/miopen-sys/build.rs index 1271246..f6c0300 100644 --- a/miopen-sys/build.rs +++ b/miopen-sys/build.rs @@ -1,19 +1,4 @@ -use std::env::VarError; -use std::{env, path::PathBuf}; - -fn main() -> Result<(), VarError> { +fn main() { println!("cargo:rustc-link-lib=dylib=MIOpen"); - if cfg!(windows) { - let env = env::var("CARGO_CFG_TARGET_ENV")?; - if env == "msvc" { - let mut path = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); - path.push("lib"); - println!("cargo:rustc-link-search=native={}", path.display()); - } else { - println!("cargo:rustc-link-search=native=C:\\Windows\\System32"); - }; - } else { - println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); - } - Ok(()) + println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); } diff --git a/miopen-sys/lib/MIOpen.lib b/miopen-sys/lib/MIOpen.lib deleted file mode 100644 index 6103042..0000000 Binary files a/miopen-sys/lib/MIOpen.lib and /dev/null differ diff --git a/zluda_dnn/src/cudnn_types_v8.rs b/zluda_dnn/src/cudnn_types_v8.rs index c99e039..b6b5ebd 100644 --- a/zluda_dnn/src/cudnn_types_v8.rs +++ b/zluda_dnn/src/cudnn_types_v8.rs @@ -151,33 +151,6 @@ pub struct cudnnTensorTransformStruct { _unused: [u8; 0], } pub type cudnnTensorTransformDescriptor_t = *mut cudnnTensorTransformStruct; -#[repr(C)] -#[derive(Copy, Clone)] -pub struct cudnnEngineHeurStruct { - pub operation_graph: cudnnOperationGraphDescriptor_t, -} -pub type cudnnEngineHeurDescriptor_t = *mut cudnnEngineHeurStruct; -#[repr(C)] -#[derive(Copy, Clone)] -pub struct cudnnOperationConvolutionForwardStruct { - pub x_desc: cudnnTensorDescriptor_t, - pub y_desc: cudnnTensorDescriptor_t, - pub w_desc: cudnnFilterDescriptor_t, - pub conv_desc: cudnnConvolutionDescriptor_t, -} -pub type cudnnOperationConvolutionForwardDescriptor_t = *mut cudnnOperationConvolutionForwardStruct; -#[repr(C)] -#[derive(Copy, Clone)] -pub struct cudnnOperationGraphStruct { - pub handle: cudnnHandle_t, - pub ops: *const cudnnBackendDescriptorType_t, -} -pub type cudnnOperationGraphDescriptor_t = *mut cudnnOperationGraphStruct; -#[repr(C)] -#[derive(Copy, Clone)] -pub struct cudnnVariantPackStruct { -} -pub type cudnnVariantPackDescriptor_t = *mut cudnnVariantPackStruct; impl cudnnDataType_t { pub const CUDNN_DATA_FLOAT: cudnnDataType_t = cudnnDataType_t(0); } diff --git a/zluda_dnn/src/cudnn_v8.rs b/zluda_dnn/src/cudnn_v8.rs index f89d82e..8acde31 100644 --- a/zluda_dnn/src/cudnn_v8.rs +++ b/zluda_dnn/src/cudnn_v8.rs @@ -4,7 +4,7 @@ use crate::types::*; #[no_mangle] pub unsafe extern "system" fn cudnnGetVersion() -> usize { - 8700 as usize + unimplemented!() } #[no_mangle] @@ -65,7 +65,7 @@ pub unsafe extern "system" fn cudnnGetStream( handle: cudnnHandle_t, streamId: *mut cudaStream_t, ) -> cudnnStatus_t { - crate::get_stream(handle, streamId) + crate::unsupported() } #[no_mangle] @@ -2793,18 +2793,7 @@ pub unsafe extern "system" fn cudnnSetConvolution2dDescriptor( mode: cudnnConvolutionMode_t, computeType: cudnnDataType_t, ) -> cudnnStatus_t { - let pad_a = [pad_h, pad_w]; - let filter_stride_a = [u, v]; - let dilation_a = [dilation_h, dilation_w]; - crate::set_convolution_nd_descriptor( - convDesc, - 2, - pad_a.as_ptr(), - filter_stride_a.as_ptr(), - dilation_a.as_ptr(), - mode, - computeType, - ) + crate::unsupported() } #[no_mangle] @@ -3422,27 +3411,14 @@ pub unsafe extern "system" fn cudnnBackendCreateDescriptor( descriptorType: cudnnBackendDescriptorType_t, descriptor: *mut cudnnBackendDescriptor_t, ) -> cudnnStatus_t { - match descriptorType { - cudnnBackendDescriptorType_t::CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR => crate::cudnn_create_convolution_descriptor(descriptor as _), - cudnnBackendDescriptorType_t::CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR => crate::cudnn_create_engineheur_descriptor(descriptor as _), - cudnnBackendDescriptorType_t::CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR => crate::cudnn_create_operation_convolution_forward_descriptor(descriptor as _), - cudnnBackendDescriptorType_t::CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR => crate::cudnn_create_operationgraph_descriptor(descriptor as _), - cudnnBackendDescriptorType_t::CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR => cudnnStatus_t::CUDNN_STATUS_SUCCESS, - cudnnBackendDescriptorType_t::CUDNN_BACKEND_TENSOR_DESCRIPTOR => crate::cudnn_create_tensor_descriptor(descriptor as _), - _ => { - println!("[ZLUDA] Unsupported descriptor type: {}", descriptorType.0); - crate::unsupported() - }, - } + crate::unsupported() } #[no_mangle] pub unsafe extern "system" fn cudnnBackendDestroyDescriptor( descriptor: cudnnBackendDescriptor_t, ) -> cudnnStatus_t { - // TODO - // Do not know how to destroy unknown descriptor. - cudnnStatus_t::CUDNN_STATUS_SUCCESS + crate::unsupported() } #[no_mangle] @@ -3456,7 +3432,7 @@ pub unsafe extern "system" fn cudnnBackendInitialize( pub unsafe extern "system" fn cudnnBackendFinalize( descriptor: cudnnBackendDescriptor_t, ) -> cudnnStatus_t { - cudnnStatus_t::CUDNN_STATUS_SUCCESS + crate::unsupported() } #[no_mangle] @@ -3467,18 +3443,7 @@ pub unsafe extern "system" fn cudnnBackendSetAttribute( elementCount: i64, arrayOfElements: *const ::std::os::raw::c_void, ) -> cudnnStatus_t { - match attributeName.0 { - 100..=199 => crate::set_convolution_nd_descriptor_by_attribute(descriptor as _, attributeName, elementCount, arrayOfElements), - 200..=299 => crate::set_engineheur_descriptor_by_attribute(descriptor as _, attributeName, elementCount, arrayOfElements), - 700..=799 => crate::set_operation_convolution_forward_descriptor_by_attribute(descriptor as _, attributeName, elementCount, arrayOfElements as _), - 800..=899 => crate::set_operationgraph_descriptor_by_attribute(descriptor as _, attributeName, elementCount, arrayOfElements as _), - 900..=999 => crate::set_tensor_nd_decriptor_by_attribute(descriptor as _, attributeName, elementCount, arrayOfElements), - 1000..=1099 => cudnnStatus_t::CUDNN_STATUS_SUCCESS, - _ => { - println!("[ZLUDA] Tried to set unsupported attribute: {}", attributeName.0); - crate::unsupported() - }, - } + crate::unsupported() } #[no_mangle] @@ -3490,13 +3455,7 @@ pub unsafe extern "system" fn cudnnBackendGetAttribute( elementCount: *mut i64, arrayOfElements: *mut ::std::os::raw::c_void, ) -> cudnnStatus_t { - match attributeName { - cudnnBackendAttributeName_t::CUDNN_ATTR_ENGINEHEUR_RESULTS => crate::get_engineheur_results(descriptor as _, requestedElementCount, elementCount, arrayOfElements), - _ => { - println!("[ZLUDA] Tried to get unsupported attribute: {}", attributeName.0); - crate::unsupported() - }, - } + crate::unsupported() } #[no_mangle] diff --git a/zluda_dnn/src/lib.rs b/zluda_dnn/src/lib.rs index 41331e6..3f1d19f 100644 --- a/zluda_dnn/src/lib.rs +++ b/zluda_dnn/src/lib.rs @@ -20,7 +20,7 @@ use types::*; use hip_runtime_sys::*; use miopen_sys::*; -use std::{mem, ptr, alloc::{self, Layout}}; +use std::{mem, ptr}; macro_rules! call { ($expr:expr) => {{ @@ -44,10 +44,7 @@ fn unsupported() -> cudnnStatus_t { fn to_cudnn(status: miopen_sys::miopenStatus_t) -> cudnnStatus_t { match status { miopen_sys::miopenStatus_t::miopenStatusSuccess => cudnnStatus_t::CUDNN_STATUS_SUCCESS, - miopen_sys::miopenStatus_t::miopenStatusInvalidValue => cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE, - miopen_sys::miopenStatus_t::miopenStatusBadParm => cudnnStatus_t::CUDNN_STATUS_BAD_PARAM, - miopen_sys::miopenStatus_t::miopenStatusUnknownError => cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR, - err => panic!("[ZLUDA] MIOpen failed: {}", err.0), //cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR, + err => panic!("{}", err.0), //cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR, } } @@ -90,132 +87,7 @@ unsafe fn cudnn_create_lrn_descriptor(norm_desc: *mut cudnnLRNDescriptor_t) -> c unsafe fn cudnn_create_pooling_descriptor( pooling_desc: *mut cudnnPoolingDescriptor_t, ) -> cudnnStatus_t { - to_cudnn(miopen_sys::miopenCreatePoolingDescriptor( - pooling_desc as _, - )) -} - -unsafe fn cudnn_create_engineheur_descriptor( - engineheur_desc: *mut cudnnEngineHeurDescriptor_t, -) -> cudnnStatus_t { - let layout = Layout::new::(); - *engineheur_desc = alloc::alloc(layout) as _; - cudnnStatus_t::CUDNN_STATUS_SUCCESS -} - -unsafe fn cudnn_create_operation_convolution_forward_descriptor( - operation_convolution_forward_desc: *mut cudnnOperationConvolutionForwardDescriptor_t, -) -> cudnnStatus_t { - let layout = Layout::new::(); - *operation_convolution_forward_desc = alloc::alloc(layout) as _; - cudnnStatus_t::CUDNN_STATUS_SUCCESS -} - -unsafe fn cudnn_create_operationgraph_descriptor( - operationgraph_desc: *mut cudnnOperationGraphDescriptor_t, -) -> cudnnStatus_t { - let layout = Layout::new::(); - *operationgraph_desc = alloc::alloc(layout) as _; - cudnnStatus_t::CUDNN_STATUS_SUCCESS -} - -unsafe fn get_tensor_size( - tensor_desc: *mut cudnnTensorStruct, - size: *mut i32, -) -> cudnnStatus_t { - to_cudnn(miopen_sys::miopenGetTensorDescriptorSize( - tensor_desc as _, - size, - )) -} - -unsafe fn set_tensor_nd_decriptor_by_attribute( - tensor_desc: *mut cudnnTensorStruct, - attribute_name: cudnnBackendAttributeName_t, - count: i64, - elements: *const ::std::os::raw::c_void, -) -> cudnnStatus_t { - let mut size = 0; - get_tensor_size( - tensor_desc, - &mut size, - ); - let mut data_type = cudnnDataType_t::CUDNN_DATA_FLOAT; - let mut dim_a = [0; 5]; - let mut stride_a = [0; 5]; - get_tensor_nd_decriptor( - tensor_desc, - &mut data_type, - dim_a.as_mut_ptr(), - stride_a.as_mut_ptr(), - ); - match attribute_name { - cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT => cudnnStatus_t::CUDNN_STATUS_SUCCESS, - cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_DATA_TYPE => { - let mut nb_index: usize = 0; - while nb_index < 5 && dim_a[nb_index] != 0 { - nb_index += 1; - } - let parameters = elements as *const cudnnDataType_t; - data_type = *parameters; - if dim_a[0] == 0 { // This tensor is not initialized yet. - dim_a[0] = 1; - } - if stride_a[0] == 0 { // This tensor is not initialized yet. - stride_a[0] = 1; - } - set_tensor_nd_decriptor( - tensor_desc, - data_type, - (nb_index + 1) as _, - dim_a[0..=nb_index].as_ptr(), - stride_a[0..=nb_index].as_ptr(), - ) - }, - cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_DIMENSIONS => { - let parameters = elements as *const i64; - let count_u = count as usize; - for i in 0..count_u { - dim_a[i] = *parameters.add(i) as i32; - if stride_a[i] == 0 { - stride_a[i] = 1; // fill invalid value - } - } - set_tensor_nd_decriptor( - tensor_desc, - data_type, - count as _, - dim_a[0..count_u].as_ptr(), - stride_a[0..count_u].as_ptr(), - ) - }, - cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_STRIDES => { - let parameters = elements as *const i64; - let count_u = count as usize; - for i in 0..count_u { - stride_a[i] = *parameters.add(i) as i32; - if dim_a[i] == 0 { - dim_a[i] = 1; // fill invalid value - } - } - let mut dim_last_index: usize = 4; - while stride_a[dim_last_index] == 0 { - dim_last_index -= 1; - } - set_tensor_nd_decriptor( - tensor_desc, - data_type, - count as _, - dim_a[0..count_u].as_ptr(), - stride_a[0..count_u].as_ptr(), - ) - }, - cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_UNIQUE_ID => cudnnStatus_t::CUDNN_STATUS_SUCCESS, - _ => { - println!("[ZLUDA] Unsupported tensor attribute: {}", attribute_name.0); - crate::unsupported() - }, - } + to_cudnn(miopen_sys::miopenCreatePoolingDescriptor(pooling_desc as _)) } unsafe fn set_tensor_nd_decriptor( @@ -234,39 +106,11 @@ unsafe fn set_tensor_nd_decriptor( )) } -unsafe fn get_tensor_nd_decriptor( - tensor_desc: *mut cudnnTensorStruct, - data_type: *mut cudnnDataType_t, - dim_a: *mut i32, - stride_a: *mut i32, -) -> cudnnStatus_t { - let mut miopen_data_type = from_data_type(*data_type); - let status = miopen_sys::miopenGetTensorDescriptor( - tensor_desc as _, - &mut miopen_data_type, - dim_a as _, - stride_a as _, - ); - *data_type = to_data_type(miopen_data_type); - to_cudnn(status) -} - -fn to_data_type(type_: miopenDataType_t) -> cudnnDataType_t { - match type_ { - miopenDataType_t::miopenFloat => cudnnDataType_t::CUDNN_DATA_FLOAT, - miopenDataType_t::miopenDouble => cudnnDataType_t::CUDNN_DATA_DOUBLE, - miopenDataType_t::miopenHalf => cudnnDataType_t::CUDNN_DATA_HALF, - miopenDataType_t::miopenBFloat16 => cudnnDataType_t::CUDNN_DATA_BFLOAT16, - _ => todo!(), - } -} - fn from_data_type(type_: cudnnDataType_t) -> miopenDataType_t { match type_ { cudnnDataType_t::CUDNN_DATA_FLOAT => miopenDataType_t::miopenFloat, cudnnDataType_t::CUDNN_DATA_DOUBLE => miopenDataType_t::miopenDouble, cudnnDataType_t::CUDNN_DATA_HALF => miopenDataType_t::miopenHalf, - cudnnDataType_t::CUDNN_DATA_BFLOAT16 => miopenDataType_t::miopenBFloat16, _ => todo!(), } } @@ -288,114 +132,6 @@ unsafe fn set_filter_nd_descriptor( )) } -unsafe fn set_convolution_nd_descriptor_by_attribute( - conv_desc: cudnnConvolutionDescriptor_t, - attribute_name: cudnnBackendAttributeName_t, - count: i64, - elements: *const ::std::os::raw::c_void, -) -> cudnnStatus_t { - let mut array_length = 2; - let mut pad_a = [0; 2]; - let mut filter_stride_a = [0; 2]; - let mut dilation_a = [0; 2]; - let mut mode = cudnnConvolutionMode_t::CUDNN_CONVOLUTION; - get_convolution_nd_descriptor( - conv_desc, - &mut array_length, // TODO - pad_a.as_mut_ptr(), - filter_stride_a.as_mut_ptr(), - dilation_a.as_mut_ptr(), - &mut mode, - cudnnDataType_t::CUDNN_DATA_FLOAT, // will be unused - ); - match attribute_name { - cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_COMP_TYPE => cudnnStatus_t::CUDNN_STATUS_SUCCESS, - cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_CONV_MODE => { - let parameters = elements as *const cudnnConvolutionMode_t; - set_convolution_nd_descriptor( - conv_desc, - array_length, // TODO - pad_a.as_ptr(), - filter_stride_a.as_ptr(), - dilation_a.as_ptr(), - *parameters, - cudnnDataType_t::CUDNN_DATA_FLOAT, // will be unused - ) - }, - cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_DILATIONS => { - if count != 2 { - todo!() - } - let parameters = elements as *const i64; - for i in 0..(array_length as usize) { - dilation_a[i] = *parameters.add(i) as i32; - } - set_convolution_nd_descriptor( - conv_desc, - count as i32, // TODO - pad_a.as_ptr(), - filter_stride_a.as_ptr(), - dilation_a.as_ptr(), - mode, - cudnnDataType_t::CUDNN_DATA_FLOAT, // will be unused - ) - }, - cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES => { - if count != 2 { - todo!() - } - let parameters = elements as *const i64; - for i in 0..(array_length as usize) { - filter_stride_a[i] = *parameters.add(i) as i32; - } - set_convolution_nd_descriptor( - conv_desc, - count as i32, // TODO - pad_a.as_ptr(), - filter_stride_a.as_ptr(), - dilation_a.as_ptr(), - mode, - cudnnDataType_t::CUDNN_DATA_FLOAT, // will be unused - ) - }, - cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_POST_PADDINGS => cudnnStatus_t::CUDNN_STATUS_SUCCESS, - cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS => { - if count != 2 { - todo!() - } - let parameters = elements as *const i64; - for i in 0..(array_length as usize) { - pad_a[i] = *parameters.add(i) as i32; - } - set_convolution_nd_descriptor( - conv_desc, - count as i32, // TODO - pad_a.as_ptr(), - filter_stride_a.as_ptr(), - dilation_a.as_ptr(), - mode, - cudnnDataType_t::CUDNN_DATA_FLOAT, // will be unused - ) - }, - cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS => { - let parameters = elements as *const i64; - set_convolution_nd_descriptor( - conv_desc, - (*parameters) as i32, - pad_a.as_ptr(), - filter_stride_a.as_ptr(), - dilation_a.as_ptr(), - mode, - cudnnDataType_t::CUDNN_DATA_FLOAT, // will be unused - ) - }, - _ => { - println!("[ZLUDA] Unsupported convolution attribute: {}", attribute_name.0); - crate::unsupported() - }, - } -} - unsafe fn set_convolution_nd_descriptor( conv_desc: cudnnConvolutionDescriptor_t, array_length: i32, @@ -427,31 +163,6 @@ unsafe fn set_convolution_nd_descriptor( )) } -unsafe fn get_convolution_nd_descriptor( - conv_desc: cudnnConvolutionDescriptor_t, - array_length: *mut i32, - pad_a: *mut i32, - filter_stride_a: *mut i32, - dilation_a: *mut i32, - mode: *mut cudnnConvolutionMode_t, - _compute_type: cudnnDataType_t, -) -> cudnnStatus_t { - *array_length = 2; // TODO - let mut miopen_conv_mode = conv_mode_to_cudnn(*mode); - let status = miopen_sys::miopenGetConvolutionDescriptor( - conv_desc as _, - &mut miopen_conv_mode, - pad_a.add(0), - pad_a.add(1), - filter_stride_a.add(0), - filter_stride_a.add(1), - dilation_a.add(0), - dilation_a.add(1), - ); - *mode = conv_mode_from_cudnn(miopen_conv_mode); - to_cudnn(status) -} - fn conv_mode_to_cudnn(mode: cudnnConvolutionMode_t) -> miopenConvolutionMode_t { match mode { cudnnConvolutionMode_t::CUDNN_CONVOLUTION => miopenConvolutionMode_t::miopenTranspose, @@ -462,16 +173,6 @@ fn conv_mode_to_cudnn(mode: cudnnConvolutionMode_t) -> miopenConvolutionMode_t { } } -fn conv_mode_from_cudnn(mode: miopenConvolutionMode_t) -> cudnnConvolutionMode_t { - match mode { - miopenConvolutionMode_t::miopenTranspose => cudnnConvolutionMode_t::CUDNN_CONVOLUTION, - miopenConvolutionMode_t::miopenConvolution => { - cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION - } - _ => panic!(), - } -} - unsafe fn get_convolution_nd_forward_output_dim( conv_desc: cudnnConvolutionDescriptor_t, input_tensor_desc: cudnnTensorDescriptor_t, @@ -1398,102 +1099,3 @@ unsafe fn convolution_backward_data( work_space_size_in_bytes, )) } - -unsafe fn set_engineheur_descriptor_by_attribute( - engineheur_desc: *mut cudnnEngineHeurStruct, - attribute_name: cudnnBackendAttributeName_t, - count: i64, - elements: *const ::std::os::raw::c_void, -) -> cudnnStatus_t { - match attribute_name { - cudnnBackendAttributeName_t::CUDNN_ATTR_ENGINEHEUR_MODE => cudnnStatus_t::CUDNN_STATUS_SUCCESS, - cudnnBackendAttributeName_t::CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH => { - (*engineheur_desc).operation_graph = *(elements as *const cudnnOperationGraphDescriptor_t); - cudnnStatus_t::CUDNN_STATUS_SUCCESS - }, - _ => panic!(), - } -} - -unsafe fn get_engineheur_results( - engineheur_desc: *mut cudnnEngineHeurStruct, - requested_algo_count: i64, - returned_algo_count: *mut i64, - perf_results: *mut std::ffi::c_void, -) -> cudnnStatus_t { - let operation_graph = *(*engineheur_desc).operation_graph; - let ops = *(operation_graph.ops as cudnnOperationConvolutionForwardDescriptor_t); - let mut req = requested_algo_count as i32; - if requested_algo_count == 0 { // total? - // TODO - req = 10; - } - find_convolution_forward_algorithm( - operation_graph.handle, - ops.x_desc, - ops.w_desc, - ops.conv_desc, - ops.y_desc, - req, - returned_algo_count as _, - perf_results as _, - ) -} - -unsafe fn set_operation_convolution_forward_descriptor_by_attribute( - operation_convolution_forward_desc: *mut cudnnOperationConvolutionForwardStruct, - attribute_name: cudnnBackendAttributeName_t, - count: i64, - elements: *const cudnnBackendDescriptor_t, // *const ::std::os::raw::c_void -) -> cudnnStatus_t { - match attribute_name { - cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA => cudnnStatus_t::CUDNN_STATUS_SUCCESS, - cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA => cudnnStatus_t::CUDNN_STATUS_SUCCESS, - cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC => { - (*operation_convolution_forward_desc).conv_desc = (*elements) as _; - cudnnStatus_t::CUDNN_STATUS_SUCCESS - }, - cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W => { - (*operation_convolution_forward_desc).w_desc = (*elements) as _; - cudnnStatus_t::CUDNN_STATUS_SUCCESS - }, - cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X => { - (*operation_convolution_forward_desc).x_desc = (*elements) as _; - cudnnStatus_t::CUDNN_STATUS_SUCCESS - }, - cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y => { - (*operation_convolution_forward_desc).y_desc = (*elements) as _; - cudnnStatus_t::CUDNN_STATUS_SUCCESS - }, - _ => panic!(), - } -} - -unsafe fn set_operationgraph_descriptor_by_attribute( - operationgraph_desc: *mut cudnnOperationGraphStruct, - attribute_name: cudnnBackendAttributeName_t, - count: i64, - elements: *const *mut ::std::os::raw::c_void, // *const ::std::os::raw::c_void -) -> cudnnStatus_t { - match attribute_name { - cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATIONGRAPH_HANDLE => { - (*operationgraph_desc).handle = (*elements) as _; - cudnnStatus_t::CUDNN_STATUS_SUCCESS - }, - cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATIONGRAPH_OPS => { - (*operationgraph_desc).ops = (*elements) as _; - cudnnStatus_t::CUDNN_STATUS_SUCCESS - }, - _ => crate::unsupported(), - } -} - -unsafe fn get_stream( - handle: *mut cudnnContext, - stream_id: *mut cudaStream_t, -) -> cudnnStatus_t { - to_cudnn(miopenGetStream( - handle as _, - stream_id as _, - )) -}