Remove zluda_dnn remains.

This commit is contained in:
Seunghoon Lee 2024-04-29 22:56:19 +09:00
parent 2804604c29
commit 7538ae61c6
No known key found for this signature in database
GPG key ID: 436E38F4E70BD152
5 changed files with 13 additions and 494 deletions

View file

@ -1,19 +1,4 @@
use std::env::VarError; fn main() {
use std::{env, path::PathBuf};
fn main() -> Result<(), VarError> {
println!("cargo:rustc-link-lib=dylib=MIOpen"); println!("cargo:rustc-link-lib=dylib=MIOpen");
if cfg!(windows) { println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
let env = env::var("CARGO_CFG_TARGET_ENV")?;
if env == "msvc" {
let mut path = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
path.push("lib");
println!("cargo:rustc-link-search=native={}", path.display());
} else {
println!("cargo:rustc-link-search=native=C:\\Windows\\System32");
};
} else {
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
}
Ok(())
} }

Binary file not shown.

View file

@ -151,33 +151,6 @@ pub struct cudnnTensorTransformStruct {
_unused: [u8; 0], _unused: [u8; 0],
} }
pub type cudnnTensorTransformDescriptor_t = *mut cudnnTensorTransformStruct; pub type cudnnTensorTransformDescriptor_t = *mut cudnnTensorTransformStruct;
#[repr(C)]
#[derive(Copy, Clone)]
pub struct cudnnEngineHeurStruct {
pub operation_graph: cudnnOperationGraphDescriptor_t,
}
pub type cudnnEngineHeurDescriptor_t = *mut cudnnEngineHeurStruct;
#[repr(C)]
#[derive(Copy, Clone)]
pub struct cudnnOperationConvolutionForwardStruct {
pub x_desc: cudnnTensorDescriptor_t,
pub y_desc: cudnnTensorDescriptor_t,
pub w_desc: cudnnFilterDescriptor_t,
pub conv_desc: cudnnConvolutionDescriptor_t,
}
pub type cudnnOperationConvolutionForwardDescriptor_t = *mut cudnnOperationConvolutionForwardStruct;
#[repr(C)]
#[derive(Copy, Clone)]
pub struct cudnnOperationGraphStruct {
pub handle: cudnnHandle_t,
pub ops: *const cudnnBackendDescriptorType_t,
}
pub type cudnnOperationGraphDescriptor_t = *mut cudnnOperationGraphStruct;
#[repr(C)]
#[derive(Copy, Clone)]
pub struct cudnnVariantPackStruct {
}
pub type cudnnVariantPackDescriptor_t = *mut cudnnVariantPackStruct;
impl cudnnDataType_t { impl cudnnDataType_t {
pub const CUDNN_DATA_FLOAT: cudnnDataType_t = cudnnDataType_t(0); pub const CUDNN_DATA_FLOAT: cudnnDataType_t = cudnnDataType_t(0);
} }

View file

@ -4,7 +4,7 @@ use crate::types::*;
#[no_mangle] #[no_mangle]
pub unsafe extern "system" fn cudnnGetVersion() -> usize { pub unsafe extern "system" fn cudnnGetVersion() -> usize {
8700 as usize unimplemented!()
} }
#[no_mangle] #[no_mangle]
@ -65,7 +65,7 @@ pub unsafe extern "system" fn cudnnGetStream(
handle: cudnnHandle_t, handle: cudnnHandle_t,
streamId: *mut cudaStream_t, streamId: *mut cudaStream_t,
) -> cudnnStatus_t { ) -> cudnnStatus_t {
crate::get_stream(handle, streamId) crate::unsupported()
} }
#[no_mangle] #[no_mangle]
@ -2793,18 +2793,7 @@ pub unsafe extern "system" fn cudnnSetConvolution2dDescriptor(
mode: cudnnConvolutionMode_t, mode: cudnnConvolutionMode_t,
computeType: cudnnDataType_t, computeType: cudnnDataType_t,
) -> cudnnStatus_t { ) -> cudnnStatus_t {
let pad_a = [pad_h, pad_w]; crate::unsupported()
let filter_stride_a = [u, v];
let dilation_a = [dilation_h, dilation_w];
crate::set_convolution_nd_descriptor(
convDesc,
2,
pad_a.as_ptr(),
filter_stride_a.as_ptr(),
dilation_a.as_ptr(),
mode,
computeType,
)
} }
#[no_mangle] #[no_mangle]
@ -3422,27 +3411,14 @@ pub unsafe extern "system" fn cudnnBackendCreateDescriptor(
descriptorType: cudnnBackendDescriptorType_t, descriptorType: cudnnBackendDescriptorType_t,
descriptor: *mut cudnnBackendDescriptor_t, descriptor: *mut cudnnBackendDescriptor_t,
) -> cudnnStatus_t { ) -> cudnnStatus_t {
match descriptorType { crate::unsupported()
cudnnBackendDescriptorType_t::CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR => crate::cudnn_create_convolution_descriptor(descriptor as _),
cudnnBackendDescriptorType_t::CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR => crate::cudnn_create_engineheur_descriptor(descriptor as _),
cudnnBackendDescriptorType_t::CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR => crate::cudnn_create_operation_convolution_forward_descriptor(descriptor as _),
cudnnBackendDescriptorType_t::CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR => crate::cudnn_create_operationgraph_descriptor(descriptor as _),
cudnnBackendDescriptorType_t::CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR => cudnnStatus_t::CUDNN_STATUS_SUCCESS,
cudnnBackendDescriptorType_t::CUDNN_BACKEND_TENSOR_DESCRIPTOR => crate::cudnn_create_tensor_descriptor(descriptor as _),
_ => {
println!("[ZLUDA] Unsupported descriptor type: {}", descriptorType.0);
crate::unsupported()
},
}
} }
#[no_mangle] #[no_mangle]
pub unsafe extern "system" fn cudnnBackendDestroyDescriptor( pub unsafe extern "system" fn cudnnBackendDestroyDescriptor(
descriptor: cudnnBackendDescriptor_t, descriptor: cudnnBackendDescriptor_t,
) -> cudnnStatus_t { ) -> cudnnStatus_t {
// TODO crate::unsupported()
// Do not know how to destroy unknown descriptor.
cudnnStatus_t::CUDNN_STATUS_SUCCESS
} }
#[no_mangle] #[no_mangle]
@ -3456,7 +3432,7 @@ pub unsafe extern "system" fn cudnnBackendInitialize(
pub unsafe extern "system" fn cudnnBackendFinalize( pub unsafe extern "system" fn cudnnBackendFinalize(
descriptor: cudnnBackendDescriptor_t, descriptor: cudnnBackendDescriptor_t,
) -> cudnnStatus_t { ) -> cudnnStatus_t {
cudnnStatus_t::CUDNN_STATUS_SUCCESS crate::unsupported()
} }
#[no_mangle] #[no_mangle]
@ -3467,18 +3443,7 @@ pub unsafe extern "system" fn cudnnBackendSetAttribute(
elementCount: i64, elementCount: i64,
arrayOfElements: *const ::std::os::raw::c_void, arrayOfElements: *const ::std::os::raw::c_void,
) -> cudnnStatus_t { ) -> cudnnStatus_t {
match attributeName.0 { crate::unsupported()
100..=199 => crate::set_convolution_nd_descriptor_by_attribute(descriptor as _, attributeName, elementCount, arrayOfElements),
200..=299 => crate::set_engineheur_descriptor_by_attribute(descriptor as _, attributeName, elementCount, arrayOfElements),
700..=799 => crate::set_operation_convolution_forward_descriptor_by_attribute(descriptor as _, attributeName, elementCount, arrayOfElements as _),
800..=899 => crate::set_operationgraph_descriptor_by_attribute(descriptor as _, attributeName, elementCount, arrayOfElements as _),
900..=999 => crate::set_tensor_nd_decriptor_by_attribute(descriptor as _, attributeName, elementCount, arrayOfElements),
1000..=1099 => cudnnStatus_t::CUDNN_STATUS_SUCCESS,
_ => {
println!("[ZLUDA] Tried to set unsupported attribute: {}", attributeName.0);
crate::unsupported()
},
}
} }
#[no_mangle] #[no_mangle]
@ -3490,13 +3455,7 @@ pub unsafe extern "system" fn cudnnBackendGetAttribute(
elementCount: *mut i64, elementCount: *mut i64,
arrayOfElements: *mut ::std::os::raw::c_void, arrayOfElements: *mut ::std::os::raw::c_void,
) -> cudnnStatus_t { ) -> cudnnStatus_t {
match attributeName { crate::unsupported()
cudnnBackendAttributeName_t::CUDNN_ATTR_ENGINEHEUR_RESULTS => crate::get_engineheur_results(descriptor as _, requestedElementCount, elementCount, arrayOfElements),
_ => {
println!("[ZLUDA] Tried to get unsupported attribute: {}", attributeName.0);
crate::unsupported()
},
}
} }
#[no_mangle] #[no_mangle]

View file

@ -20,7 +20,7 @@ use types::*;
use hip_runtime_sys::*; use hip_runtime_sys::*;
use miopen_sys::*; use miopen_sys::*;
use std::{mem, ptr, alloc::{self, Layout}}; use std::{mem, ptr};
macro_rules! call { macro_rules! call {
($expr:expr) => {{ ($expr:expr) => {{
@ -44,10 +44,7 @@ fn unsupported() -> cudnnStatus_t {
fn to_cudnn(status: miopen_sys::miopenStatus_t) -> cudnnStatus_t { fn to_cudnn(status: miopen_sys::miopenStatus_t) -> cudnnStatus_t {
match status { match status {
miopen_sys::miopenStatus_t::miopenStatusSuccess => cudnnStatus_t::CUDNN_STATUS_SUCCESS, miopen_sys::miopenStatus_t::miopenStatusSuccess => cudnnStatus_t::CUDNN_STATUS_SUCCESS,
miopen_sys::miopenStatus_t::miopenStatusInvalidValue => cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE, err => panic!("{}", err.0), //cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR,
miopen_sys::miopenStatus_t::miopenStatusBadParm => cudnnStatus_t::CUDNN_STATUS_BAD_PARAM,
miopen_sys::miopenStatus_t::miopenStatusUnknownError => cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR,
err => panic!("[ZLUDA] MIOpen failed: {}", err.0), //cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR,
} }
} }
@ -90,132 +87,7 @@ unsafe fn cudnn_create_lrn_descriptor(norm_desc: *mut cudnnLRNDescriptor_t) -> c
unsafe fn cudnn_create_pooling_descriptor( unsafe fn cudnn_create_pooling_descriptor(
pooling_desc: *mut cudnnPoolingDescriptor_t, pooling_desc: *mut cudnnPoolingDescriptor_t,
) -> cudnnStatus_t { ) -> cudnnStatus_t {
to_cudnn(miopen_sys::miopenCreatePoolingDescriptor( to_cudnn(miopen_sys::miopenCreatePoolingDescriptor(pooling_desc as _))
pooling_desc as _,
))
}
unsafe fn cudnn_create_engineheur_descriptor(
engineheur_desc: *mut cudnnEngineHeurDescriptor_t,
) -> cudnnStatus_t {
let layout = Layout::new::<cudnnEngineHeurStruct>();
*engineheur_desc = alloc::alloc(layout) as _;
cudnnStatus_t::CUDNN_STATUS_SUCCESS
}
unsafe fn cudnn_create_operation_convolution_forward_descriptor(
operation_convolution_forward_desc: *mut cudnnOperationConvolutionForwardDescriptor_t,
) -> cudnnStatus_t {
let layout = Layout::new::<cudnnOperationConvolutionForwardStruct>();
*operation_convolution_forward_desc = alloc::alloc(layout) as _;
cudnnStatus_t::CUDNN_STATUS_SUCCESS
}
unsafe fn cudnn_create_operationgraph_descriptor(
operationgraph_desc: *mut cudnnOperationGraphDescriptor_t,
) -> cudnnStatus_t {
let layout = Layout::new::<cudnnOperationGraphStruct>();
*operationgraph_desc = alloc::alloc(layout) as _;
cudnnStatus_t::CUDNN_STATUS_SUCCESS
}
unsafe fn get_tensor_size(
tensor_desc: *mut cudnnTensorStruct,
size: *mut i32,
) -> cudnnStatus_t {
to_cudnn(miopen_sys::miopenGetTensorDescriptorSize(
tensor_desc as _,
size,
))
}
unsafe fn set_tensor_nd_decriptor_by_attribute(
tensor_desc: *mut cudnnTensorStruct,
attribute_name: cudnnBackendAttributeName_t,
count: i64,
elements: *const ::std::os::raw::c_void,
) -> cudnnStatus_t {
let mut size = 0;
get_tensor_size(
tensor_desc,
&mut size,
);
let mut data_type = cudnnDataType_t::CUDNN_DATA_FLOAT;
let mut dim_a = [0; 5];
let mut stride_a = [0; 5];
get_tensor_nd_decriptor(
tensor_desc,
&mut data_type,
dim_a.as_mut_ptr(),
stride_a.as_mut_ptr(),
);
match attribute_name {
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT => cudnnStatus_t::CUDNN_STATUS_SUCCESS,
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_DATA_TYPE => {
let mut nb_index: usize = 0;
while nb_index < 5 && dim_a[nb_index] != 0 {
nb_index += 1;
}
let parameters = elements as *const cudnnDataType_t;
data_type = *parameters;
if dim_a[0] == 0 { // This tensor is not initialized yet.
dim_a[0] = 1;
}
if stride_a[0] == 0 { // This tensor is not initialized yet.
stride_a[0] = 1;
}
set_tensor_nd_decriptor(
tensor_desc,
data_type,
(nb_index + 1) as _,
dim_a[0..=nb_index].as_ptr(),
stride_a[0..=nb_index].as_ptr(),
)
},
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_DIMENSIONS => {
let parameters = elements as *const i64;
let count_u = count as usize;
for i in 0..count_u {
dim_a[i] = *parameters.add(i) as i32;
if stride_a[i] == 0 {
stride_a[i] = 1; // fill invalid value
}
}
set_tensor_nd_decriptor(
tensor_desc,
data_type,
count as _,
dim_a[0..count_u].as_ptr(),
stride_a[0..count_u].as_ptr(),
)
},
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_STRIDES => {
let parameters = elements as *const i64;
let count_u = count as usize;
for i in 0..count_u {
stride_a[i] = *parameters.add(i) as i32;
if dim_a[i] == 0 {
dim_a[i] = 1; // fill invalid value
}
}
let mut dim_last_index: usize = 4;
while stride_a[dim_last_index] == 0 {
dim_last_index -= 1;
}
set_tensor_nd_decriptor(
tensor_desc,
data_type,
count as _,
dim_a[0..count_u].as_ptr(),
stride_a[0..count_u].as_ptr(),
)
},
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_UNIQUE_ID => cudnnStatus_t::CUDNN_STATUS_SUCCESS,
_ => {
println!("[ZLUDA] Unsupported tensor attribute: {}", attribute_name.0);
crate::unsupported()
},
}
} }
unsafe fn set_tensor_nd_decriptor( unsafe fn set_tensor_nd_decriptor(
@ -234,39 +106,11 @@ unsafe fn set_tensor_nd_decriptor(
)) ))
} }
unsafe fn get_tensor_nd_decriptor(
tensor_desc: *mut cudnnTensorStruct,
data_type: *mut cudnnDataType_t,
dim_a: *mut i32,
stride_a: *mut i32,
) -> cudnnStatus_t {
let mut miopen_data_type = from_data_type(*data_type);
let status = miopen_sys::miopenGetTensorDescriptor(
tensor_desc as _,
&mut miopen_data_type,
dim_a as _,
stride_a as _,
);
*data_type = to_data_type(miopen_data_type);
to_cudnn(status)
}
fn to_data_type(type_: miopenDataType_t) -> cudnnDataType_t {
match type_ {
miopenDataType_t::miopenFloat => cudnnDataType_t::CUDNN_DATA_FLOAT,
miopenDataType_t::miopenDouble => cudnnDataType_t::CUDNN_DATA_DOUBLE,
miopenDataType_t::miopenHalf => cudnnDataType_t::CUDNN_DATA_HALF,
miopenDataType_t::miopenBFloat16 => cudnnDataType_t::CUDNN_DATA_BFLOAT16,
_ => todo!(),
}
}
fn from_data_type(type_: cudnnDataType_t) -> miopenDataType_t { fn from_data_type(type_: cudnnDataType_t) -> miopenDataType_t {
match type_ { match type_ {
cudnnDataType_t::CUDNN_DATA_FLOAT => miopenDataType_t::miopenFloat, cudnnDataType_t::CUDNN_DATA_FLOAT => miopenDataType_t::miopenFloat,
cudnnDataType_t::CUDNN_DATA_DOUBLE => miopenDataType_t::miopenDouble, cudnnDataType_t::CUDNN_DATA_DOUBLE => miopenDataType_t::miopenDouble,
cudnnDataType_t::CUDNN_DATA_HALF => miopenDataType_t::miopenHalf, cudnnDataType_t::CUDNN_DATA_HALF => miopenDataType_t::miopenHalf,
cudnnDataType_t::CUDNN_DATA_BFLOAT16 => miopenDataType_t::miopenBFloat16,
_ => todo!(), _ => todo!(),
} }
} }
@ -288,114 +132,6 @@ unsafe fn set_filter_nd_descriptor(
)) ))
} }
unsafe fn set_convolution_nd_descriptor_by_attribute(
conv_desc: cudnnConvolutionDescriptor_t,
attribute_name: cudnnBackendAttributeName_t,
count: i64,
elements: *const ::std::os::raw::c_void,
) -> cudnnStatus_t {
let mut array_length = 2;
let mut pad_a = [0; 2];
let mut filter_stride_a = [0; 2];
let mut dilation_a = [0; 2];
let mut mode = cudnnConvolutionMode_t::CUDNN_CONVOLUTION;
get_convolution_nd_descriptor(
conv_desc,
&mut array_length, // TODO
pad_a.as_mut_ptr(),
filter_stride_a.as_mut_ptr(),
dilation_a.as_mut_ptr(),
&mut mode,
cudnnDataType_t::CUDNN_DATA_FLOAT, // will be unused
);
match attribute_name {
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_COMP_TYPE => cudnnStatus_t::CUDNN_STATUS_SUCCESS,
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_CONV_MODE => {
let parameters = elements as *const cudnnConvolutionMode_t;
set_convolution_nd_descriptor(
conv_desc,
array_length, // TODO
pad_a.as_ptr(),
filter_stride_a.as_ptr(),
dilation_a.as_ptr(),
*parameters,
cudnnDataType_t::CUDNN_DATA_FLOAT, // will be unused
)
},
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_DILATIONS => {
if count != 2 {
todo!()
}
let parameters = elements as *const i64;
for i in 0..(array_length as usize) {
dilation_a[i] = *parameters.add(i) as i32;
}
set_convolution_nd_descriptor(
conv_desc,
count as i32, // TODO
pad_a.as_ptr(),
filter_stride_a.as_ptr(),
dilation_a.as_ptr(),
mode,
cudnnDataType_t::CUDNN_DATA_FLOAT, // will be unused
)
},
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES => {
if count != 2 {
todo!()
}
let parameters = elements as *const i64;
for i in 0..(array_length as usize) {
filter_stride_a[i] = *parameters.add(i) as i32;
}
set_convolution_nd_descriptor(
conv_desc,
count as i32, // TODO
pad_a.as_ptr(),
filter_stride_a.as_ptr(),
dilation_a.as_ptr(),
mode,
cudnnDataType_t::CUDNN_DATA_FLOAT, // will be unused
)
},
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_POST_PADDINGS => cudnnStatus_t::CUDNN_STATUS_SUCCESS,
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS => {
if count != 2 {
todo!()
}
let parameters = elements as *const i64;
for i in 0..(array_length as usize) {
pad_a[i] = *parameters.add(i) as i32;
}
set_convolution_nd_descriptor(
conv_desc,
count as i32, // TODO
pad_a.as_ptr(),
filter_stride_a.as_ptr(),
dilation_a.as_ptr(),
mode,
cudnnDataType_t::CUDNN_DATA_FLOAT, // will be unused
)
},
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS => {
let parameters = elements as *const i64;
set_convolution_nd_descriptor(
conv_desc,
(*parameters) as i32,
pad_a.as_ptr(),
filter_stride_a.as_ptr(),
dilation_a.as_ptr(),
mode,
cudnnDataType_t::CUDNN_DATA_FLOAT, // will be unused
)
},
_ => {
println!("[ZLUDA] Unsupported convolution attribute: {}", attribute_name.0);
crate::unsupported()
},
}
}
unsafe fn set_convolution_nd_descriptor( unsafe fn set_convolution_nd_descriptor(
conv_desc: cudnnConvolutionDescriptor_t, conv_desc: cudnnConvolutionDescriptor_t,
array_length: i32, array_length: i32,
@ -427,31 +163,6 @@ unsafe fn set_convolution_nd_descriptor(
)) ))
} }
unsafe fn get_convolution_nd_descriptor(
conv_desc: cudnnConvolutionDescriptor_t,
array_length: *mut i32,
pad_a: *mut i32,
filter_stride_a: *mut i32,
dilation_a: *mut i32,
mode: *mut cudnnConvolutionMode_t,
_compute_type: cudnnDataType_t,
) -> cudnnStatus_t {
*array_length = 2; // TODO
let mut miopen_conv_mode = conv_mode_to_cudnn(*mode);
let status = miopen_sys::miopenGetConvolutionDescriptor(
conv_desc as _,
&mut miopen_conv_mode,
pad_a.add(0),
pad_a.add(1),
filter_stride_a.add(0),
filter_stride_a.add(1),
dilation_a.add(0),
dilation_a.add(1),
);
*mode = conv_mode_from_cudnn(miopen_conv_mode);
to_cudnn(status)
}
fn conv_mode_to_cudnn(mode: cudnnConvolutionMode_t) -> miopenConvolutionMode_t { fn conv_mode_to_cudnn(mode: cudnnConvolutionMode_t) -> miopenConvolutionMode_t {
match mode { match mode {
cudnnConvolutionMode_t::CUDNN_CONVOLUTION => miopenConvolutionMode_t::miopenTranspose, cudnnConvolutionMode_t::CUDNN_CONVOLUTION => miopenConvolutionMode_t::miopenTranspose,
@ -462,16 +173,6 @@ fn conv_mode_to_cudnn(mode: cudnnConvolutionMode_t) -> miopenConvolutionMode_t {
} }
} }
fn conv_mode_from_cudnn(mode: miopenConvolutionMode_t) -> cudnnConvolutionMode_t {
match mode {
miopenConvolutionMode_t::miopenTranspose => cudnnConvolutionMode_t::CUDNN_CONVOLUTION,
miopenConvolutionMode_t::miopenConvolution => {
cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION
}
_ => panic!(),
}
}
unsafe fn get_convolution_nd_forward_output_dim( unsafe fn get_convolution_nd_forward_output_dim(
conv_desc: cudnnConvolutionDescriptor_t, conv_desc: cudnnConvolutionDescriptor_t,
input_tensor_desc: cudnnTensorDescriptor_t, input_tensor_desc: cudnnTensorDescriptor_t,
@ -1398,102 +1099,3 @@ unsafe fn convolution_backward_data(
work_space_size_in_bytes, work_space_size_in_bytes,
)) ))
} }
unsafe fn set_engineheur_descriptor_by_attribute(
engineheur_desc: *mut cudnnEngineHeurStruct,
attribute_name: cudnnBackendAttributeName_t,
count: i64,
elements: *const ::std::os::raw::c_void,
) -> cudnnStatus_t {
match attribute_name {
cudnnBackendAttributeName_t::CUDNN_ATTR_ENGINEHEUR_MODE => cudnnStatus_t::CUDNN_STATUS_SUCCESS,
cudnnBackendAttributeName_t::CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH => {
(*engineheur_desc).operation_graph = *(elements as *const cudnnOperationGraphDescriptor_t);
cudnnStatus_t::CUDNN_STATUS_SUCCESS
},
_ => panic!(),
}
}
unsafe fn get_engineheur_results(
engineheur_desc: *mut cudnnEngineHeurStruct,
requested_algo_count: i64,
returned_algo_count: *mut i64,
perf_results: *mut std::ffi::c_void,
) -> cudnnStatus_t {
let operation_graph = *(*engineheur_desc).operation_graph;
let ops = *(operation_graph.ops as cudnnOperationConvolutionForwardDescriptor_t);
let mut req = requested_algo_count as i32;
if requested_algo_count == 0 { // total?
// TODO
req = 10;
}
find_convolution_forward_algorithm(
operation_graph.handle,
ops.x_desc,
ops.w_desc,
ops.conv_desc,
ops.y_desc,
req,
returned_algo_count as _,
perf_results as _,
)
}
unsafe fn set_operation_convolution_forward_descriptor_by_attribute(
operation_convolution_forward_desc: *mut cudnnOperationConvolutionForwardStruct,
attribute_name: cudnnBackendAttributeName_t,
count: i64,
elements: *const cudnnBackendDescriptor_t, // *const ::std::os::raw::c_void
) -> cudnnStatus_t {
match attribute_name {
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA => cudnnStatus_t::CUDNN_STATUS_SUCCESS,
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA => cudnnStatus_t::CUDNN_STATUS_SUCCESS,
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC => {
(*operation_convolution_forward_desc).conv_desc = (*elements) as _;
cudnnStatus_t::CUDNN_STATUS_SUCCESS
},
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W => {
(*operation_convolution_forward_desc).w_desc = (*elements) as _;
cudnnStatus_t::CUDNN_STATUS_SUCCESS
},
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X => {
(*operation_convolution_forward_desc).x_desc = (*elements) as _;
cudnnStatus_t::CUDNN_STATUS_SUCCESS
},
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y => {
(*operation_convolution_forward_desc).y_desc = (*elements) as _;
cudnnStatus_t::CUDNN_STATUS_SUCCESS
},
_ => panic!(),
}
}
unsafe fn set_operationgraph_descriptor_by_attribute(
operationgraph_desc: *mut cudnnOperationGraphStruct,
attribute_name: cudnnBackendAttributeName_t,
count: i64,
elements: *const *mut ::std::os::raw::c_void, // *const ::std::os::raw::c_void
) -> cudnnStatus_t {
match attribute_name {
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATIONGRAPH_HANDLE => {
(*operationgraph_desc).handle = (*elements) as _;
cudnnStatus_t::CUDNN_STATUS_SUCCESS
},
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATIONGRAPH_OPS => {
(*operationgraph_desc).ops = (*elements) as _;
cudnnStatus_t::CUDNN_STATUS_SUCCESS
},
_ => crate::unsupported(),
}
}
unsafe fn get_stream(
handle: *mut cudnnContext,
stream_id: *mut cudaStream_t,
) -> cudnnStatus_t {
to_cudnn(miopenGetStream(
handle as _,
stream_id as _,
))
}