Compare commits
26 commits
master
...
future/mod
Author | SHA1 | Date | |
---|---|---|---|
3caec2532c | |||
40d46b35a9 | |||
75d332d5ff | |||
7dfd642e67 | |||
66b3d22b7f | |||
485e336cf8 | |||
796ca8de8d | |||
47c0e724ff | |||
b8893b71b2 | |||
6464bda679 | |||
e20b03dbc2 | |||
33aa37ae54 | |||
a086ad740a | |||
7565ad160e | |||
d8fbbbdd1f | |||
277c3a50b9 | |||
7137146ee0 | |||
15dd55af02 | |||
b567623119 | |||
3027beec5d | |||
4f79acf74d | |||
01433fc1f5 | |||
5b8e627066 | |||
c98ca550d4 | |||
a60a8cfa9e | |||
0fd8c94328 |
569
Cargo.lock
generated
569
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -1 +1 @@
|
|||
bindgen /opt/rocm/include/miopen/miopen.h -o src/miopen.rs --no-layout-tests --size_t-is-usize --default-enum-style=newtype --no-derive-debug --allowlist-function "miopen.*" --allowlist-var "MIOPEN_*" --must-use-type miopenStatus_t -- -D__HIP_PLATFORM_AMD__ -DMIOPEN_BACKEND_HIP=1 -I/opt/rocm/include -x c++
|
||||
bindgen $Env:HIP_PATH/include/miopen/miopen.h -o src/miopen.rs --no-layout-tests --default-enum-style=newtype --no-derive-debug --allowlist-function "miopen.*" --allowlist-var "MIOPEN_*" --must-use-type miopenStatus_t -- -D__HIP_PLATFORM_AMD__ -DMIOPEN_BACKEND_HIP=1 -DMIOPEN_BETA_API=1 -I"$Env:HIP_PATH/include" -x c++
|
|
@ -1,4 +1,19 @@
|
|||
fn main() {
|
||||
use std::env::VarError;
|
||||
use std::{env, path::PathBuf};
|
||||
|
||||
fn main() -> Result<(), VarError> {
|
||||
println!("cargo:rustc-link-lib=dylib=MIOpen");
|
||||
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
|
||||
if cfg!(windows) {
|
||||
let env = env::var("CARGO_CFG_TARGET_ENV")?;
|
||||
if env == "msvc" {
|
||||
let mut path = PathBuf::from(env::var("HIP_PATH")?);
|
||||
path.push("lib");
|
||||
println!("cargo:rustc-link-search=native={}", path.display());
|
||||
} else {
|
||||
println!("cargo:rustc-link-search=native=C:\\Windows\\System32");
|
||||
};
|
||||
} else {
|
||||
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -9,10 +9,12 @@ name = "cudnn"
|
|||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
cuda_types = { path = "../cuda_types" }
|
||||
hip_common = { path = "../hip_common" }
|
||||
miopen-sys = { path = "../miopen-sys" }
|
||||
hip_runtime-sys = { path = "../hip_runtime-sys" }
|
||||
zluda_dark_api = { path = "../zluda_dark_api" }
|
||||
|
||||
[package.metadata.zluda]
|
||||
linux_only = true
|
||||
linux_names = ["libcudnn.so.7", "libcudnn.so.8"]
|
||||
dump_names = ["libcudnn.so"]
|
||||
|
|
|
@ -151,6 +151,33 @@ pub struct cudnnTensorTransformStruct {
|
|||
_unused: [u8; 0],
|
||||
}
|
||||
pub type cudnnTensorTransformDescriptor_t = *mut cudnnTensorTransformStruct;
|
||||
#[repr(C)]
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct cudnnEngineHeurStruct {
|
||||
pub operation_graph: cudnnOperationGraphDescriptor_t,
|
||||
}
|
||||
pub type cudnnEngineHeurDescriptor_t = *mut cudnnEngineHeurStruct;
|
||||
#[repr(C)]
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct cudnnOperationConvolutionForwardStruct {
|
||||
pub x_desc: cudnnTensorDescriptor_t,
|
||||
pub y_desc: cudnnTensorDescriptor_t,
|
||||
pub w_desc: cudnnFilterDescriptor_t,
|
||||
pub conv_desc: cudnnConvolutionDescriptor_t,
|
||||
}
|
||||
pub type cudnnOperationConvolutionForwardDescriptor_t = *mut cudnnOperationConvolutionForwardStruct;
|
||||
#[repr(C)]
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct cudnnOperationGraphStruct {
|
||||
pub handle: cudnnHandle_t,
|
||||
pub ops: *const cudnnBackendDescriptorType_t,
|
||||
}
|
||||
pub type cudnnOperationGraphDescriptor_t = *mut cudnnOperationGraphStruct;
|
||||
#[repr(C)]
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct cudnnVariantPackStruct {
|
||||
}
|
||||
pub type cudnnVariantPackDescriptor_t = *mut cudnnVariantPackStruct;
|
||||
impl cudnnDataType_t {
|
||||
pub const CUDNN_DATA_FLOAT: cudnnDataType_t = cudnnDataType_t(0);
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -8,6 +8,8 @@ pub mod types {
|
|||
pub use super::cudnn_types_v8::*;
|
||||
}
|
||||
|
||||
use cuda_types::{CUuuid, CUresult};
|
||||
|
||||
#[allow(warnings)]
|
||||
mod cudnn_v7;
|
||||
pub use cudnn_v7::*;
|
||||
|
@ -22,6 +24,25 @@ use hip_runtime_sys::*;
|
|||
use miopen_sys::*;
|
||||
use std::{mem, ptr};
|
||||
|
||||
impl miopenBackendHeurMode_t {
|
||||
pub const MIOPEN_HEUR_MODE_INSTANT: miopenBackendHeurMode_t = miopenBackendHeurMode_t(0);
|
||||
}
|
||||
impl miopenBackendHeurMode_t {
|
||||
pub const MIOPEN_HEUR_MODE_B: miopenBackendHeurMode_t = miopenBackendHeurMode_t(1);
|
||||
}
|
||||
impl miopenBackendHeurMode_t {
|
||||
pub const MIOPEN_HEUR_MODE_FALLBACK: miopenBackendHeurMode_t = miopenBackendHeurMode_t(2);
|
||||
}
|
||||
impl miopenBackendHeurMode_t {
|
||||
pub const MIOPEN_HEUR_MODE_A: miopenBackendHeurMode_t = miopenBackendHeurMode_t(3);
|
||||
}
|
||||
impl miopenBackendHeurMode_t {
|
||||
pub const MIOPEN_HEUR_MODES_COUNT: miopenBackendHeurMode_t = miopenBackendHeurMode_t(4);
|
||||
}
|
||||
#[repr(transparent)]
|
||||
#[derive(Copy, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct miopenBackendHeurMode_t(pub ::std::os::raw::c_uint);
|
||||
|
||||
macro_rules! call {
|
||||
($expr:expr) => {{
|
||||
let result = $expr;
|
||||
|
@ -44,24 +65,58 @@ fn unsupported() -> cudnnStatus_t {
|
|||
fn to_cudnn(status: miopen_sys::miopenStatus_t) -> cudnnStatus_t {
|
||||
match status {
|
||||
miopen_sys::miopenStatus_t::miopenStatusSuccess => cudnnStatus_t::CUDNN_STATUS_SUCCESS,
|
||||
err => panic!("{}", err.0), //cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR,
|
||||
miopen_sys::miopenStatus_t::miopenStatusInvalidValue => cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE,
|
||||
miopen_sys::miopenStatus_t::miopenStatusBadParm => cudnnStatus_t::CUDNN_STATUS_BAD_PARAM,
|
||||
miopen_sys::miopenStatus_t::miopenStatusNotImplemented => cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED,
|
||||
miopen_sys::miopenStatus_t::miopenStatusUnknownError => cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR,
|
||||
miopen_sys::miopenStatus_t::miopenStatusUnsupportedOp => cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED,
|
||||
err => panic!("[ZLUDA] MIOpen failed: {}", err.0), //cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_miopen(status: cudnnStatus_t) -> miopen_sys::miopenStatus_t {
|
||||
match status {
|
||||
cudnnStatus_t::CUDNN_STATUS_SUCCESS => miopen_sys::miopenStatus_t::miopenStatusSuccess,
|
||||
cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE => miopen_sys::miopenStatus_t::miopenStatusInvalidValue,
|
||||
cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => miopen_sys::miopenStatus_t::miopenStatusBadParm,
|
||||
cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => miopen_sys::miopenStatus_t::miopenStatusNotImplemented,
|
||||
cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR => miopen_sys::miopenStatus_t::miopenStatusUnknownError,
|
||||
err => panic!("[ZLUDA] MIOpen failed: {}", err.0),
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn get_error_string(status: cudnnStatus_t) -> *const ::std::os::raw::c_char {
|
||||
let status = to_miopen(status);
|
||||
miopenGetErrorString(status)
|
||||
}
|
||||
|
||||
unsafe fn get_property(
|
||||
prop: libraryPropertyType,
|
||||
value: *mut i32,
|
||||
) -> cudnnStatus_t {
|
||||
*value = match prop {
|
||||
libraryPropertyType_t::MAJOR_VERSION => 8,
|
||||
libraryPropertyType_t::MINOR_VERSION => 7,
|
||||
libraryPropertyType_t::PATCH_LEVEL => 0,
|
||||
_ => panic!(),
|
||||
};
|
||||
cudnnStatus_t::CUDNN_STATUS_SUCCESS
|
||||
}
|
||||
|
||||
unsafe fn create(handle: *mut cudnnHandle_t) -> cudnnStatus_t {
|
||||
to_cudnn(miopen_sys::miopenCreate(handle as _))
|
||||
to_cudnn(miopenCreate(handle as _))
|
||||
}
|
||||
|
||||
unsafe fn cudnn_create_tensor_descriptor(
|
||||
tensor_desc: *mut cudnnTensorDescriptor_t,
|
||||
) -> cudnnStatus_t {
|
||||
to_cudnn(miopen_sys::miopenCreateTensorDescriptor(tensor_desc as _))
|
||||
to_cudnn(miopenCreateTensorDescriptor(tensor_desc as _))
|
||||
}
|
||||
|
||||
unsafe fn cudnn_create_activation_descriptor(
|
||||
activation_desc: *mut cudnnActivationDescriptor_t,
|
||||
) -> cudnnStatus_t {
|
||||
to_cudnn(miopen_sys::miopenCreateActivationDescriptor(
|
||||
to_cudnn(miopenCreateActivationDescriptor(
|
||||
activation_desc as _,
|
||||
))
|
||||
}
|
||||
|
@ -69,7 +124,7 @@ unsafe fn cudnn_create_activation_descriptor(
|
|||
unsafe fn cudnn_create_convolution_descriptor(
|
||||
conv_desc: *mut cudnnConvolutionDescriptor_t,
|
||||
) -> cudnnStatus_t {
|
||||
to_cudnn(miopen_sys::miopenCreateConvolutionDescriptor(
|
||||
to_cudnn(miopenCreateConvolutionDescriptor(
|
||||
conv_desc as _,
|
||||
))
|
||||
}
|
||||
|
@ -77,17 +132,19 @@ unsafe fn cudnn_create_convolution_descriptor(
|
|||
unsafe fn cudnn_create_filter_descriptor(
|
||||
filter_desc: *mut cudnnFilterDescriptor_t,
|
||||
) -> cudnnStatus_t {
|
||||
to_cudnn(miopen_sys::miopenCreateTensorDescriptor(filter_desc as _))
|
||||
to_cudnn(miopenCreateTensorDescriptor(filter_desc as _))
|
||||
}
|
||||
|
||||
unsafe fn cudnn_create_lrn_descriptor(norm_desc: *mut cudnnLRNDescriptor_t) -> cudnnStatus_t {
|
||||
to_cudnn(miopen_sys::miopenCreateLRNDescriptor(norm_desc as _))
|
||||
to_cudnn(miopenCreateLRNDescriptor(norm_desc as _))
|
||||
}
|
||||
|
||||
unsafe fn cudnn_create_pooling_descriptor(
|
||||
pooling_desc: *mut cudnnPoolingDescriptor_t,
|
||||
) -> cudnnStatus_t {
|
||||
to_cudnn(miopen_sys::miopenCreatePoolingDescriptor(pooling_desc as _))
|
||||
to_cudnn(miopenCreatePoolingDescriptor(
|
||||
pooling_desc as _,
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn set_tensor_nd_decriptor(
|
||||
|
@ -97,20 +154,22 @@ unsafe fn set_tensor_nd_decriptor(
|
|||
dim_a: *const i32,
|
||||
stride_a: *const i32,
|
||||
) -> cudnnStatus_t {
|
||||
to_cudnn(miopen_sys::miopenSetTensorDescriptor(
|
||||
let data_type = to_data_type(data_type);
|
||||
to_cudnn(miopenSetTensorDescriptor(
|
||||
tensor_desc as _,
|
||||
from_data_type(data_type),
|
||||
data_type,
|
||||
nb_dims,
|
||||
dim_a as _,
|
||||
stride_a as _,
|
||||
))
|
||||
}
|
||||
|
||||
fn from_data_type(type_: cudnnDataType_t) -> miopenDataType_t {
|
||||
fn to_data_type(type_: cudnnDataType_t) -> miopenDataType_t {
|
||||
match type_ {
|
||||
cudnnDataType_t::CUDNN_DATA_FLOAT => miopenDataType_t::miopenFloat,
|
||||
cudnnDataType_t::CUDNN_DATA_DOUBLE => miopenDataType_t::miopenDouble,
|
||||
cudnnDataType_t::CUDNN_DATA_HALF => miopenDataType_t::miopenHalf,
|
||||
cudnnDataType_t::CUDNN_DATA_BFLOAT16 => miopenDataType_t::miopenBFloat16,
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
@ -122,7 +181,7 @@ unsafe fn set_filter_nd_descriptor(
|
|||
nb_dims: i32,
|
||||
filter_dim_a: *const i32,
|
||||
) -> cudnnStatus_t {
|
||||
let data_type = from_data_type(data_type);
|
||||
let data_type = to_data_type(data_type);
|
||||
to_cudnn(miopenSetTensorDescriptor(
|
||||
filter_desc as _,
|
||||
data_type,
|
||||
|
@ -150,8 +209,8 @@ unsafe fn set_convolution_nd_descriptor(
|
|||
let v = *filter_stride_a.add(1);
|
||||
let d_h = *dilation_a.add(0);
|
||||
let d_w = *dilation_a.add(1);
|
||||
let mode = conv_mode_to_cudnn(mode);
|
||||
to_cudnn(miopen_sys::miopenInitConvolutionDescriptor(
|
||||
let mode = to_conv_mode(mode);
|
||||
to_cudnn(miopenInitConvolutionDescriptor(
|
||||
conv_desc as _,
|
||||
mode,
|
||||
pad_h,
|
||||
|
@ -163,7 +222,7 @@ unsafe fn set_convolution_nd_descriptor(
|
|||
))
|
||||
}
|
||||
|
||||
fn conv_mode_to_cudnn(mode: cudnnConvolutionMode_t) -> miopenConvolutionMode_t {
|
||||
fn to_conv_mode(mode: cudnnConvolutionMode_t) -> miopenConvolutionMode_t {
|
||||
match mode {
|
||||
cudnnConvolutionMode_t::CUDNN_CONVOLUTION => miopenConvolutionMode_t::miopenTranspose,
|
||||
cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION => {
|
||||
|
@ -173,6 +232,14 @@ fn conv_mode_to_cudnn(mode: cudnnConvolutionMode_t) -> miopenConvolutionMode_t {
|
|||
}
|
||||
}
|
||||
|
||||
fn to_heur_mode(mode: cudnnBackendHeurMode_t) -> miopenBackendHeurMode_t {
|
||||
match mode {
|
||||
cudnnBackendHeurMode_t::CUDNN_HEUR_MODE_INSTANT => miopenBackendHeurMode_t::MIOPEN_HEUR_MODE_INSTANT,
|
||||
cudnnBackendHeurMode_t::CUDNN_HEUR_MODE_FALLBACK => miopenBackendHeurMode_t::MIOPEN_HEUR_MODE_FALLBACK,
|
||||
_ => panic!("[ZLUDA] Unknown heuristic mode: {}", mode.0),
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn get_convolution_nd_forward_output_dim(
|
||||
conv_desc: cudnnConvolutionDescriptor_t,
|
||||
input_tensor_desc: cudnnTensorDescriptor_t,
|
||||
|
@ -180,7 +247,7 @@ unsafe fn get_convolution_nd_forward_output_dim(
|
|||
mut nb_dims: i32,
|
||||
tensor_ouput_dim_a: *mut i32,
|
||||
) -> cudnnStatus_t {
|
||||
to_cudnn(miopen_sys::miopenGetConvolutionNdForwardOutputDim(
|
||||
to_cudnn(miopenGetConvolutionNdForwardOutputDim(
|
||||
conv_desc as _,
|
||||
input_tensor_desc as _,
|
||||
filter_desc as _,
|
||||
|
@ -816,7 +883,7 @@ unsafe fn set_tensor_4d_descriptor_ex(
|
|||
h_stride: i32,
|
||||
w_stride: i32,
|
||||
) -> cudnnStatus_t {
|
||||
let data_type = from_data_type(data_type);
|
||||
let data_type = to_data_type(data_type);
|
||||
to_cudnn(miopenSet4dTensorDescriptorEx(
|
||||
tensor_desc as _,
|
||||
data_type,
|
||||
|
@ -851,11 +918,26 @@ unsafe fn transform_tensor(
|
|||
))
|
||||
}
|
||||
|
||||
unsafe fn set_stream(stream_id: *mut CUstream_st) -> cudnnStatus_t {
|
||||
if stream_id != ptr::null_mut() {
|
||||
todo!()
|
||||
}
|
||||
cudnnStatus_t::CUDNN_STATUS_SUCCESS
|
||||
unsafe fn set_stream(
|
||||
handle: cudnnHandle_t,
|
||||
stream_id: *mut CUstream_st,
|
||||
) -> cudnnStatus_t {
|
||||
let lib = hip_common::zluda_ext::get_cuda_library().unwrap();
|
||||
let cu_get_export_table = lib
|
||||
.get::<unsafe extern "C" fn(
|
||||
ppExportTable: *mut *const ::std::os::raw::c_void,
|
||||
pExportTableId: *const CUuuid,
|
||||
) -> CUresult>(b"cuGetExportTable\0")
|
||||
.unwrap();
|
||||
let mut export_table = ptr::null();
|
||||
let error = (cu_get_export_table)(&mut export_table, &zluda_dark_api::ZludaExt::GUID);
|
||||
assert_eq!(error, CUresult::CUDA_SUCCESS);
|
||||
let zluda_ext = zluda_dark_api::ZludaExt::new(export_table);
|
||||
let stream: Result<_, _> = zluda_ext.get_hip_stream(stream_id as _).into();
|
||||
to_cudnn(miopenSetStream(
|
||||
handle.cast(),
|
||||
stream.unwrap() as _,
|
||||
))
|
||||
}
|
||||
|
||||
fn set_convolution_math_type(
|
||||
|
@ -1099,3 +1181,201 @@ unsafe fn convolution_backward_data(
|
|||
work_space_size_in_bytes,
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn get_stream(
|
||||
handle: *mut cudnnContext,
|
||||
stream_id: *mut cudaStream_t,
|
||||
) -> cudnnStatus_t {
|
||||
to_cudnn(miopenGetStream(
|
||||
handle as _,
|
||||
stream_id as _,
|
||||
))
|
||||
}
|
||||
|
||||
fn to_backend_descriptor_type(descriptor_type: cudnnBackendDescriptorType_t) -> miopenBackendDescriptorType_t {
|
||||
match descriptor_type {
|
||||
cudnnBackendDescriptorType_t::CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR => miopenBackendDescriptorType_t::MIOPEN_BACKEND_CONVOLUTION_DESCRIPTOR,
|
||||
cudnnBackendDescriptorType_t::CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR => miopenBackendDescriptorType_t::MIOPEN_BACKEND_ENGINEHEUR_DESCRIPTOR,
|
||||
cudnnBackendDescriptorType_t::CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR => miopenBackendDescriptorType_t::MIOPEN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
|
||||
cudnnBackendDescriptorType_t::CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR => miopenBackendDescriptorType_t::MIOPEN_BACKEND_OPERATIONGRAPH_DESCRIPTOR,
|
||||
cudnnBackendDescriptorType_t::CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR => miopenBackendDescriptorType_t::MIOPEN_BACKEND_VARIANT_PACK_DESCRIPTOR,
|
||||
cudnnBackendDescriptorType_t::CUDNN_BACKEND_TENSOR_DESCRIPTOR => miopenBackendDescriptorType_t::MIOPEN_BACKEND_TENSOR_DESCRIPTOR,
|
||||
_ => panic!("[ZLUDA] Unknown descriptor type: {}", descriptor_type.0),
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn backend_create_descriptor(
|
||||
descriptor_type: cudnnBackendDescriptorType_t,
|
||||
descriptor: *mut cudnnBackendDescriptor_t,
|
||||
) -> cudnnStatus_t {
|
||||
let descriptor_type = to_backend_descriptor_type(descriptor_type);
|
||||
to_cudnn(miopenBackendCreateDescriptor(
|
||||
descriptor_type,
|
||||
descriptor.cast(),
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn backend_destroy_descriptor(
|
||||
descriptor: cudnnBackendDescriptor_t,
|
||||
) -> cudnnStatus_t {
|
||||
to_cudnn(miopenBackendDestroyDescriptor(
|
||||
descriptor.cast(),
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn backend_finalize(
|
||||
descriptor: cudnnBackendDescriptor_t,
|
||||
) -> cudnnStatus_t {
|
||||
to_cudnn(miopenBackendFinalize(
|
||||
descriptor.cast(),
|
||||
))
|
||||
}
|
||||
|
||||
fn to_backend_attribute_name(name: cudnnBackendAttributeName_t) -> miopenBackendAttributeName_t {
|
||||
match name {
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_COMP_TYPE => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_COMP_TYPE,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_CONV_MODE => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_CONV_MODE,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_DILATIONS => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_DILATIONS,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_FILTER_STRIDES,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_POST_PADDINGS => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_POST_PADDINGS,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_PRE_PADDINGS,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_SPATIAL_DIMS,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_ENGINEHEUR_MODE => miopenBackendAttributeName_t::MIOPEN_ATTR_ENGINEHEUR_MODE,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH => miopenBackendAttributeName_t::MIOPEN_ATTR_ENGINEHEUR_OPERATION_GRAPH,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_ENGINEHEUR_RESULTS => miopenBackendAttributeName_t::MIOPEN_ATTR_ENGINEHEUR_RESULTS,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATION_CONVOLUTION_FORWARD_W,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATION_CONVOLUTION_FORWARD_X,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATIONGRAPH_HANDLE => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATIONGRAPH_HANDLE,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATIONGRAPH_OPS => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATIONGRAPH_OPS,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_BYTE_ALIGNMENT,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_DATA_TYPE => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_DATA_TYPE,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_DIMENSIONS => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_DIMENSIONS,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_STRIDES => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_STRIDES,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_UNIQUE_ID => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_UNIQUE_ID,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_IS_VIRTUAL => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_IS_VIRTUAL,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_IS_BY_VALUE => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_IS_BY_VALUE,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS => miopenBackendAttributeName_t::MIOPEN_ATTR_VARIANT_PACK_UNIQUE_IDS,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS => miopenBackendAttributeName_t::MIOPEN_ATTR_VARIANT_PACK_DATA_POINTERS,
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_VARIANT_PACK_WORKSPACE => miopenBackendAttributeName_t::MIOPEN_ATTR_VARIANT_PACK_WORKSPACE,
|
||||
_ => panic!("[ZLUDA] Unknown attribute name: {}", name.0),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_unsupported_attribute_name(name: cudnnBackendAttributeName_t) -> bool {
|
||||
match name {
|
||||
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_backend_attribute_type(attribute_type: cudnnBackendAttributeType_t) -> miopenBackendAttributeType_t {
|
||||
match attribute_type {
|
||||
cudnnBackendAttributeType_t::CUDNN_TYPE_HANDLE => miopenBackendAttributeType_t::MIOPEN_TYPE_HANDLE,
|
||||
cudnnBackendAttributeType_t::CUDNN_TYPE_DATA_TYPE => miopenBackendAttributeType_t::MIOPEN_TYPE_DATA_TYPE,
|
||||
cudnnBackendAttributeType_t::CUDNN_TYPE_BOOLEAN => miopenBackendAttributeType_t::MIOPEN_TYPE_BOOLEAN,
|
||||
cudnnBackendAttributeType_t::CUDNN_TYPE_INT64 => miopenBackendAttributeType_t::MIOPEN_TYPE_INT64,
|
||||
cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT => miopenBackendAttributeType_t::MIOPEN_TYPE_FLOAT,
|
||||
cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE => miopenBackendAttributeType_t::MIOPEN_TYPE_DOUBLE,
|
||||
cudnnBackendAttributeType_t::CUDNN_TYPE_VOID_PTR => miopenBackendAttributeType_t::MIOPEN_TYPE_VOID_PTR,
|
||||
cudnnBackendAttributeType_t::CUDNN_TYPE_CONVOLUTION_MODE => miopenBackendAttributeType_t::MIOPEN_TYPE_CONVOLUTION_MODE,
|
||||
cudnnBackendAttributeType_t::CUDNN_TYPE_HEUR_MODE => miopenBackendAttributeType_t::MIOPEN_TYPE_HEUR_MODE,
|
||||
cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR => miopenBackendAttributeType_t::MIOPEN_TYPE_BACKEND_DESCRIPTOR,
|
||||
_ => panic!("[ZLUDA] Unknown attribute type: {}", attribute_type.0),
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn backend_cudnn_to_miopen(
|
||||
elements_type: miopenBackendAttributeType_t,
|
||||
element_count: i64,
|
||||
array_of_elements: *mut ::std::os::raw::c_void,
|
||||
) -> () {
|
||||
match elements_type {
|
||||
miopenBackendAttributeType_t::MIOPEN_TYPE_HANDLE => (),
|
||||
miopenBackendAttributeType_t::MIOPEN_TYPE_DATA_TYPE => {
|
||||
if element_count != 1 {
|
||||
panic!("[ZLUDA] Unexpected value: element_count={}", element_count)
|
||||
}
|
||||
let p_data_type: *mut miopenDataType_t = array_of_elements.cast();
|
||||
*p_data_type = to_data_type(*(p_data_type as *mut cudnnDataType_t));
|
||||
},
|
||||
miopenBackendAttributeType_t::MIOPEN_TYPE_INT64 => (),
|
||||
miopenBackendAttributeType_t::MIOPEN_TYPE_FLOAT => (),
|
||||
miopenBackendAttributeType_t::MIOPEN_TYPE_DOUBLE => (),
|
||||
miopenBackendAttributeType_t::MIOPEN_TYPE_VOID_PTR => (),
|
||||
miopenBackendAttributeType_t::MIOPEN_TYPE_CONVOLUTION_MODE => {
|
||||
if element_count != 1 {
|
||||
panic!("[ZLUDA] Unexpected value: element_count={}", element_count)
|
||||
}
|
||||
let p_conv_mode: *mut miopenConvolutionMode_t = array_of_elements.cast();
|
||||
*p_conv_mode = to_conv_mode(*(p_conv_mode as *mut cudnnConvolutionMode_t));
|
||||
},
|
||||
miopenBackendAttributeType_t::MIOPEN_TYPE_HEUR_MODE => {
|
||||
if element_count != 1 {
|
||||
panic!("[ZLUDA] Unexpected value: element_count={}", element_count)
|
||||
}
|
||||
let p_heur_mode: *mut miopenBackendHeurMode_t = array_of_elements.cast();
|
||||
*p_heur_mode = to_heur_mode(*(p_heur_mode as *mut cudnnBackendHeurMode_t));
|
||||
},
|
||||
miopenBackendAttributeType_t::MIOPEN_TYPE_BACKEND_DESCRIPTOR => (),
|
||||
_ => println!("[ZLUDA] Warning: found unknown backend attribute type: {}", elements_type.0),
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn backend_set_attribute(
|
||||
descriptor: cudnnBackendDescriptor_t,
|
||||
attribute_name: cudnnBackendAttributeName_t,
|
||||
attribute_type: cudnnBackendAttributeType_t,
|
||||
element_count: i64,
|
||||
array_of_elements: *const ::std::os::raw::c_void,
|
||||
) -> cudnnStatus_t {
|
||||
if is_unsupported_attribute_name(attribute_name) { // temporary skip unimplemented attribute names
|
||||
return cudnnStatus_t::CUDNN_STATUS_SUCCESS;
|
||||
}
|
||||
let attribute_name = to_backend_attribute_name(attribute_name);
|
||||
let attribute_type = to_backend_attribute_type(attribute_type);
|
||||
let elements = array_of_elements.clone();
|
||||
backend_cudnn_to_miopen(attribute_type, element_count, elements.cast_mut());
|
||||
to_cudnn(miopenBackendSetAttribute(
|
||||
descriptor.cast(),
|
||||
attribute_name,
|
||||
attribute_type,
|
||||
element_count,
|
||||
elements.cast_mut(),
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn backend_get_attribute(
|
||||
descriptor: cudnnBackendDescriptor_t,
|
||||
attribute_name: cudnnBackendAttributeName_t,
|
||||
attribute_type: cudnnBackendAttributeType_t,
|
||||
requested_element_count: i64,
|
||||
element_count: *mut i64,
|
||||
array_of_elements: *mut ::std::os::raw::c_void,
|
||||
) -> cudnnStatus_t {
|
||||
let attribute_name = to_backend_attribute_name(attribute_name);
|
||||
let attribute_type = to_backend_attribute_type(attribute_type);
|
||||
to_cudnn(miopenBackendGetAttribute(
|
||||
descriptor.cast(),
|
||||
attribute_name,
|
||||
attribute_type,
|
||||
requested_element_count,
|
||||
element_count,
|
||||
array_of_elements,
|
||||
))
|
||||
}
|
||||
|
||||
unsafe fn backend_execute(
|
||||
handle: cudnnHandle_t,
|
||||
execution_plan: cudnnBackendDescriptor_t,
|
||||
variant_pack: cudnnBackendDescriptor_t,
|
||||
) -> cudnnStatus_t {
|
||||
to_cudnn(miopenBackendExecute(
|
||||
handle.cast(),
|
||||
execution_plan.cast(),
|
||||
variant_pack.cast(),
|
||||
))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue