Compare commits

...

26 commits

Author SHA1 Message Date
Seunghoon Lee 3caec2532c
Merge branch 'master' into future/module/zluda_dnn 2024-07-13 18:01:33 +09:00
Seunghoon Lee 40d46b35a9
wip 2024-07-13 17:56:11 +09:00
Seunghoon Lee 75d332d5ff
Merge branch 'module/zluda_runtime' into future/module/zluda_dnn 2024-05-31 10:51:44 +09:00
Seunghoon Lee 7dfd642e67
WIP 2024-05-30 03:10:42 +09:00
Seunghoon Lee 66b3d22b7f
[Fix] Handle stream correctly. 2024-05-21 18:19:30 +09:00
Seunghoon Lee 485e336cf8
Merge branch 'module/zluda_runtime' into future/module/zluda_dnn 2024-05-21 18:18:35 +09:00
Seunghoon Lee 796ca8de8d
[Fix] Handle stream correctly. 2024-05-21 18:17:57 +09:00
Seunghoon Lee 47c0e724ff
Merge branch 'master' into future/module/zluda_dnn 2024-05-21 10:49:50 +09:00
Seunghoon Lee b8893b71b2
Merge branch 'master' into future/module/zluda_dnn 2024-05-17 13:47:21 +09:00
Seunghoon Lee 6464bda679
wip 2024-05-17 10:30:04 +09:00
Seunghoon Lee e20b03dbc2
WIP 2024-05-01 00:25:16 +09:00
Seunghoon Lee 33aa37ae54
Merge branch 'master' into future/module/zluda_dnn 2024-04-30 22:59:17 +09:00
Seunghoon Lee a086ad740a
WIP 2024-04-29 22:30:10 +09:00
Seunghoon Lee 7565ad160e
Implement cudnnGetProperty. 2024-04-28 15:20:19 +09:00
Seunghoon Lee d8fbbbdd1f
Merge branch 'master' into future/module/zluda_dnn 2024-04-28 15:02:25 +09:00
Seunghoon Lee 277c3a50b9
WIP 2024-04-28 01:33:01 +09:00
Seunghoon Lee 7137146ee0
WIP 2024-04-27 15:28:46 +09:00
Seunghoon Lee 15dd55af02
Merge branch 'master' into future/module/zluda_dnn 2024-04-21 21:00:15 +09:00
Seunghoon Lee b567623119
[WIP] Graph API. 2024-04-21 00:07:46 +09:00
Seunghoon Lee 3027beec5d
Update MIOpen. (graph api) 2024-04-18 01:45:37 +09:00
Seunghoon Lee 4f79acf74d
Merge branch 'master' into future/module/zluda_dnn 2024-04-18 01:11:12 +09:00
Seunghoon Lee 01433fc1f5
Remove unused functions. 2024-04-10 15:10:08 +09:00
Seunghoon Lee 5b8e627066
wip 2024-04-04 02:35:29 +09:00
Seunghoon Lee c98ca550d4
wip 2024-03-31 22:05:55 +09:00
Seunghoon Lee a60a8cfa9e
bindgen 2024-03-30 16:04:14 +09:00
Seunghoon Lee 0fd8c94328
Update MIOpen. 2024-03-30 04:00:41 +09:00
8 changed files with 3067 additions and 568 deletions

569
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1 +1 @@
bindgen /opt/rocm/include/miopen/miopen.h -o src/miopen.rs --no-layout-tests --size_t-is-usize --default-enum-style=newtype --no-derive-debug --allowlist-function "miopen.*" --allowlist-var "MIOPEN_*" --must-use-type miopenStatus_t -- -D__HIP_PLATFORM_AMD__ -DMIOPEN_BACKEND_HIP=1 -I/opt/rocm/include -x c++
bindgen $Env:HIP_PATH/include/miopen/miopen.h -o src/miopen.rs --no-layout-tests --default-enum-style=newtype --no-derive-debug --allowlist-function "miopen.*" --allowlist-var "MIOPEN_*" --must-use-type miopenStatus_t -- -D__HIP_PLATFORM_AMD__ -DMIOPEN_BACKEND_HIP=1 -DMIOPEN_BETA_API=1 -I"$Env:HIP_PATH/include" -x c++

View file

@ -1,4 +1,19 @@
fn main() {
use std::env::VarError;
use std::{env, path::PathBuf};
fn main() -> Result<(), VarError> {
println!("cargo:rustc-link-lib=dylib=MIOpen");
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
if cfg!(windows) {
let env = env::var("CARGO_CFG_TARGET_ENV")?;
if env == "msvc" {
let mut path = PathBuf::from(env::var("HIP_PATH")?);
path.push("lib");
println!("cargo:rustc-link-search=native={}", path.display());
} else {
println!("cargo:rustc-link-search=native=C:\\Windows\\System32");
};
} else {
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
}
Ok(())
}

File diff suppressed because one or more lines are too long

View file

@ -9,10 +9,12 @@ name = "cudnn"
crate-type = ["cdylib"]
[dependencies]
cuda_types = { path = "../cuda_types" }
hip_common = { path = "../hip_common" }
miopen-sys = { path = "../miopen-sys" }
hip_runtime-sys = { path = "../hip_runtime-sys" }
zluda_dark_api = { path = "../zluda_dark_api" }
[package.metadata.zluda]
linux_only = true
linux_names = ["libcudnn.so.7", "libcudnn.so.8"]
dump_names = ["libcudnn.so"]

View file

@ -151,6 +151,33 @@ pub struct cudnnTensorTransformStruct {
_unused: [u8; 0],
}
pub type cudnnTensorTransformDescriptor_t = *mut cudnnTensorTransformStruct;
#[repr(C)]
#[derive(Copy, Clone)]
pub struct cudnnEngineHeurStruct {
pub operation_graph: cudnnOperationGraphDescriptor_t,
}
pub type cudnnEngineHeurDescriptor_t = *mut cudnnEngineHeurStruct;
#[repr(C)]
#[derive(Copy, Clone)]
pub struct cudnnOperationConvolutionForwardStruct {
pub x_desc: cudnnTensorDescriptor_t,
pub y_desc: cudnnTensorDescriptor_t,
pub w_desc: cudnnFilterDescriptor_t,
pub conv_desc: cudnnConvolutionDescriptor_t,
}
pub type cudnnOperationConvolutionForwardDescriptor_t = *mut cudnnOperationConvolutionForwardStruct;
#[repr(C)]
#[derive(Copy, Clone)]
pub struct cudnnOperationGraphStruct {
pub handle: cudnnHandle_t,
pub ops: *const cudnnBackendDescriptorType_t,
}
pub type cudnnOperationGraphDescriptor_t = *mut cudnnOperationGraphStruct;
#[repr(C)]
#[derive(Copy, Clone)]
pub struct cudnnVariantPackStruct {
}
pub type cudnnVariantPackDescriptor_t = *mut cudnnVariantPackStruct;
impl cudnnDataType_t {
pub const CUDNN_DATA_FLOAT: cudnnDataType_t = cudnnDataType_t(0);
}

File diff suppressed because it is too large Load diff

View file

@ -8,6 +8,8 @@ pub mod types {
pub use super::cudnn_types_v8::*;
}
use cuda_types::{CUuuid, CUresult};
#[allow(warnings)]
mod cudnn_v7;
pub use cudnn_v7::*;
@ -22,6 +24,25 @@ use hip_runtime_sys::*;
use miopen_sys::*;
use std::{mem, ptr};
impl miopenBackendHeurMode_t {
pub const MIOPEN_HEUR_MODE_INSTANT: miopenBackendHeurMode_t = miopenBackendHeurMode_t(0);
}
impl miopenBackendHeurMode_t {
pub const MIOPEN_HEUR_MODE_B: miopenBackendHeurMode_t = miopenBackendHeurMode_t(1);
}
impl miopenBackendHeurMode_t {
pub const MIOPEN_HEUR_MODE_FALLBACK: miopenBackendHeurMode_t = miopenBackendHeurMode_t(2);
}
impl miopenBackendHeurMode_t {
pub const MIOPEN_HEUR_MODE_A: miopenBackendHeurMode_t = miopenBackendHeurMode_t(3);
}
impl miopenBackendHeurMode_t {
pub const MIOPEN_HEUR_MODES_COUNT: miopenBackendHeurMode_t = miopenBackendHeurMode_t(4);
}
#[repr(transparent)]
#[derive(Copy, Clone, Hash, PartialEq, Eq)]
pub struct miopenBackendHeurMode_t(pub ::std::os::raw::c_uint);
macro_rules! call {
($expr:expr) => {{
let result = $expr;
@ -44,24 +65,58 @@ fn unsupported() -> cudnnStatus_t {
fn to_cudnn(status: miopen_sys::miopenStatus_t) -> cudnnStatus_t {
match status {
miopen_sys::miopenStatus_t::miopenStatusSuccess => cudnnStatus_t::CUDNN_STATUS_SUCCESS,
err => panic!("{}", err.0), //cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR,
miopen_sys::miopenStatus_t::miopenStatusInvalidValue => cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE,
miopen_sys::miopenStatus_t::miopenStatusBadParm => cudnnStatus_t::CUDNN_STATUS_BAD_PARAM,
miopen_sys::miopenStatus_t::miopenStatusNotImplemented => cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED,
miopen_sys::miopenStatus_t::miopenStatusUnknownError => cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR,
miopen_sys::miopenStatus_t::miopenStatusUnsupportedOp => cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED,
err => panic!("[ZLUDA] MIOpen failed: {}", err.0), //cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR,
}
}
fn to_miopen(status: cudnnStatus_t) -> miopen_sys::miopenStatus_t {
match status {
cudnnStatus_t::CUDNN_STATUS_SUCCESS => miopen_sys::miopenStatus_t::miopenStatusSuccess,
cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE => miopen_sys::miopenStatus_t::miopenStatusInvalidValue,
cudnnStatus_t::CUDNN_STATUS_BAD_PARAM => miopen_sys::miopenStatus_t::miopenStatusBadParm,
cudnnStatus_t::CUDNN_STATUS_NOT_SUPPORTED => miopen_sys::miopenStatus_t::miopenStatusNotImplemented,
cudnnStatus_t::CUDNN_STATUS_INTERNAL_ERROR => miopen_sys::miopenStatus_t::miopenStatusUnknownError,
err => panic!("[ZLUDA] MIOpen failed: {}", err.0),
}
}
unsafe fn get_error_string(status: cudnnStatus_t) -> *const ::std::os::raw::c_char {
let status = to_miopen(status);
miopenGetErrorString(status)
}
unsafe fn get_property(
prop: libraryPropertyType,
value: *mut i32,
) -> cudnnStatus_t {
*value = match prop {
libraryPropertyType_t::MAJOR_VERSION => 8,
libraryPropertyType_t::MINOR_VERSION => 7,
libraryPropertyType_t::PATCH_LEVEL => 0,
_ => panic!(),
};
cudnnStatus_t::CUDNN_STATUS_SUCCESS
}
unsafe fn create(handle: *mut cudnnHandle_t) -> cudnnStatus_t {
to_cudnn(miopen_sys::miopenCreate(handle as _))
to_cudnn(miopenCreate(handle as _))
}
unsafe fn cudnn_create_tensor_descriptor(
tensor_desc: *mut cudnnTensorDescriptor_t,
) -> cudnnStatus_t {
to_cudnn(miopen_sys::miopenCreateTensorDescriptor(tensor_desc as _))
to_cudnn(miopenCreateTensorDescriptor(tensor_desc as _))
}
unsafe fn cudnn_create_activation_descriptor(
activation_desc: *mut cudnnActivationDescriptor_t,
) -> cudnnStatus_t {
to_cudnn(miopen_sys::miopenCreateActivationDescriptor(
to_cudnn(miopenCreateActivationDescriptor(
activation_desc as _,
))
}
@ -69,7 +124,7 @@ unsafe fn cudnn_create_activation_descriptor(
unsafe fn cudnn_create_convolution_descriptor(
conv_desc: *mut cudnnConvolutionDescriptor_t,
) -> cudnnStatus_t {
to_cudnn(miopen_sys::miopenCreateConvolutionDescriptor(
to_cudnn(miopenCreateConvolutionDescriptor(
conv_desc as _,
))
}
@ -77,17 +132,19 @@ unsafe fn cudnn_create_convolution_descriptor(
unsafe fn cudnn_create_filter_descriptor(
filter_desc: *mut cudnnFilterDescriptor_t,
) -> cudnnStatus_t {
to_cudnn(miopen_sys::miopenCreateTensorDescriptor(filter_desc as _))
to_cudnn(miopenCreateTensorDescriptor(filter_desc as _))
}
unsafe fn cudnn_create_lrn_descriptor(norm_desc: *mut cudnnLRNDescriptor_t) -> cudnnStatus_t {
to_cudnn(miopen_sys::miopenCreateLRNDescriptor(norm_desc as _))
to_cudnn(miopenCreateLRNDescriptor(norm_desc as _))
}
unsafe fn cudnn_create_pooling_descriptor(
pooling_desc: *mut cudnnPoolingDescriptor_t,
) -> cudnnStatus_t {
to_cudnn(miopen_sys::miopenCreatePoolingDescriptor(pooling_desc as _))
to_cudnn(miopenCreatePoolingDescriptor(
pooling_desc as _,
))
}
unsafe fn set_tensor_nd_decriptor(
@ -97,20 +154,22 @@ unsafe fn set_tensor_nd_decriptor(
dim_a: *const i32,
stride_a: *const i32,
) -> cudnnStatus_t {
to_cudnn(miopen_sys::miopenSetTensorDescriptor(
let data_type = to_data_type(data_type);
to_cudnn(miopenSetTensorDescriptor(
tensor_desc as _,
from_data_type(data_type),
data_type,
nb_dims,
dim_a as _,
stride_a as _,
))
}
fn from_data_type(type_: cudnnDataType_t) -> miopenDataType_t {
fn to_data_type(type_: cudnnDataType_t) -> miopenDataType_t {
match type_ {
cudnnDataType_t::CUDNN_DATA_FLOAT => miopenDataType_t::miopenFloat,
cudnnDataType_t::CUDNN_DATA_DOUBLE => miopenDataType_t::miopenDouble,
cudnnDataType_t::CUDNN_DATA_HALF => miopenDataType_t::miopenHalf,
cudnnDataType_t::CUDNN_DATA_BFLOAT16 => miopenDataType_t::miopenBFloat16,
_ => todo!(),
}
}
@ -122,7 +181,7 @@ unsafe fn set_filter_nd_descriptor(
nb_dims: i32,
filter_dim_a: *const i32,
) -> cudnnStatus_t {
let data_type = from_data_type(data_type);
let data_type = to_data_type(data_type);
to_cudnn(miopenSetTensorDescriptor(
filter_desc as _,
data_type,
@ -150,8 +209,8 @@ unsafe fn set_convolution_nd_descriptor(
let v = *filter_stride_a.add(1);
let d_h = *dilation_a.add(0);
let d_w = *dilation_a.add(1);
let mode = conv_mode_to_cudnn(mode);
to_cudnn(miopen_sys::miopenInitConvolutionDescriptor(
let mode = to_conv_mode(mode);
to_cudnn(miopenInitConvolutionDescriptor(
conv_desc as _,
mode,
pad_h,
@ -163,7 +222,7 @@ unsafe fn set_convolution_nd_descriptor(
))
}
fn conv_mode_to_cudnn(mode: cudnnConvolutionMode_t) -> miopenConvolutionMode_t {
fn to_conv_mode(mode: cudnnConvolutionMode_t) -> miopenConvolutionMode_t {
match mode {
cudnnConvolutionMode_t::CUDNN_CONVOLUTION => miopenConvolutionMode_t::miopenTranspose,
cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION => {
@ -173,6 +232,14 @@ fn conv_mode_to_cudnn(mode: cudnnConvolutionMode_t) -> miopenConvolutionMode_t {
}
}
fn to_heur_mode(mode: cudnnBackendHeurMode_t) -> miopenBackendHeurMode_t {
match mode {
cudnnBackendHeurMode_t::CUDNN_HEUR_MODE_INSTANT => miopenBackendHeurMode_t::MIOPEN_HEUR_MODE_INSTANT,
cudnnBackendHeurMode_t::CUDNN_HEUR_MODE_FALLBACK => miopenBackendHeurMode_t::MIOPEN_HEUR_MODE_FALLBACK,
_ => panic!("[ZLUDA] Unknown heuristic mode: {}", mode.0),
}
}
unsafe fn get_convolution_nd_forward_output_dim(
conv_desc: cudnnConvolutionDescriptor_t,
input_tensor_desc: cudnnTensorDescriptor_t,
@ -180,7 +247,7 @@ unsafe fn get_convolution_nd_forward_output_dim(
mut nb_dims: i32,
tensor_ouput_dim_a: *mut i32,
) -> cudnnStatus_t {
to_cudnn(miopen_sys::miopenGetConvolutionNdForwardOutputDim(
to_cudnn(miopenGetConvolutionNdForwardOutputDim(
conv_desc as _,
input_tensor_desc as _,
filter_desc as _,
@ -816,7 +883,7 @@ unsafe fn set_tensor_4d_descriptor_ex(
h_stride: i32,
w_stride: i32,
) -> cudnnStatus_t {
let data_type = from_data_type(data_type);
let data_type = to_data_type(data_type);
to_cudnn(miopenSet4dTensorDescriptorEx(
tensor_desc as _,
data_type,
@ -851,11 +918,26 @@ unsafe fn transform_tensor(
))
}
unsafe fn set_stream(stream_id: *mut CUstream_st) -> cudnnStatus_t {
if stream_id != ptr::null_mut() {
todo!()
}
cudnnStatus_t::CUDNN_STATUS_SUCCESS
unsafe fn set_stream(
handle: cudnnHandle_t,
stream_id: *mut CUstream_st,
) -> cudnnStatus_t {
let lib = hip_common::zluda_ext::get_cuda_library().unwrap();
let cu_get_export_table = lib
.get::<unsafe extern "C" fn(
ppExportTable: *mut *const ::std::os::raw::c_void,
pExportTableId: *const CUuuid,
) -> CUresult>(b"cuGetExportTable\0")
.unwrap();
let mut export_table = ptr::null();
let error = (cu_get_export_table)(&mut export_table, &zluda_dark_api::ZludaExt::GUID);
assert_eq!(error, CUresult::CUDA_SUCCESS);
let zluda_ext = zluda_dark_api::ZludaExt::new(export_table);
let stream: Result<_, _> = zluda_ext.get_hip_stream(stream_id as _).into();
to_cudnn(miopenSetStream(
handle.cast(),
stream.unwrap() as _,
))
}
fn set_convolution_math_type(
@ -1099,3 +1181,201 @@ unsafe fn convolution_backward_data(
work_space_size_in_bytes,
))
}
unsafe fn get_stream(
handle: *mut cudnnContext,
stream_id: *mut cudaStream_t,
) -> cudnnStatus_t {
to_cudnn(miopenGetStream(
handle as _,
stream_id as _,
))
}
fn to_backend_descriptor_type(descriptor_type: cudnnBackendDescriptorType_t) -> miopenBackendDescriptorType_t {
match descriptor_type {
cudnnBackendDescriptorType_t::CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR => miopenBackendDescriptorType_t::MIOPEN_BACKEND_CONVOLUTION_DESCRIPTOR,
cudnnBackendDescriptorType_t::CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR => miopenBackendDescriptorType_t::MIOPEN_BACKEND_ENGINEHEUR_DESCRIPTOR,
cudnnBackendDescriptorType_t::CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR => miopenBackendDescriptorType_t::MIOPEN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
cudnnBackendDescriptorType_t::CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR => miopenBackendDescriptorType_t::MIOPEN_BACKEND_OPERATIONGRAPH_DESCRIPTOR,
cudnnBackendDescriptorType_t::CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR => miopenBackendDescriptorType_t::MIOPEN_BACKEND_VARIANT_PACK_DESCRIPTOR,
cudnnBackendDescriptorType_t::CUDNN_BACKEND_TENSOR_DESCRIPTOR => miopenBackendDescriptorType_t::MIOPEN_BACKEND_TENSOR_DESCRIPTOR,
_ => panic!("[ZLUDA] Unknown descriptor type: {}", descriptor_type.0),
}
}
unsafe fn backend_create_descriptor(
descriptor_type: cudnnBackendDescriptorType_t,
descriptor: *mut cudnnBackendDescriptor_t,
) -> cudnnStatus_t {
let descriptor_type = to_backend_descriptor_type(descriptor_type);
to_cudnn(miopenBackendCreateDescriptor(
descriptor_type,
descriptor.cast(),
))
}
unsafe fn backend_destroy_descriptor(
descriptor: cudnnBackendDescriptor_t,
) -> cudnnStatus_t {
to_cudnn(miopenBackendDestroyDescriptor(
descriptor.cast(),
))
}
unsafe fn backend_finalize(
descriptor: cudnnBackendDescriptor_t,
) -> cudnnStatus_t {
to_cudnn(miopenBackendFinalize(
descriptor.cast(),
))
}
fn to_backend_attribute_name(name: cudnnBackendAttributeName_t) -> miopenBackendAttributeName_t {
match name {
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_COMP_TYPE => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_COMP_TYPE,
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_CONV_MODE => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_CONV_MODE,
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_DILATIONS => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_DILATIONS,
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_FILTER_STRIDES,
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_POST_PADDINGS => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_POST_PADDINGS,
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_PRE_PADDINGS,
cudnnBackendAttributeName_t::CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS => miopenBackendAttributeName_t::MIOPEN_ATTR_CONVOLUTION_SPATIAL_DIMS,
cudnnBackendAttributeName_t::CUDNN_ATTR_ENGINEHEUR_MODE => miopenBackendAttributeName_t::MIOPEN_ATTR_ENGINEHEUR_MODE,
cudnnBackendAttributeName_t::CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH => miopenBackendAttributeName_t::MIOPEN_ATTR_ENGINEHEUR_OPERATION_GRAPH,
cudnnBackendAttributeName_t::CUDNN_ATTR_ENGINEHEUR_RESULTS => miopenBackendAttributeName_t::MIOPEN_ATTR_ENGINEHEUR_RESULTS,
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA,
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA,
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC,
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATION_CONVOLUTION_FORWARD_W,
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATION_CONVOLUTION_FORWARD_X,
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y,
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATIONGRAPH_HANDLE => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATIONGRAPH_HANDLE,
cudnnBackendAttributeName_t::CUDNN_ATTR_OPERATIONGRAPH_OPS => miopenBackendAttributeName_t::MIOPEN_ATTR_OPERATIONGRAPH_OPS,
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_BYTE_ALIGNMENT,
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_DATA_TYPE => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_DATA_TYPE,
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_DIMENSIONS => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_DIMENSIONS,
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_STRIDES => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_STRIDES,
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_UNIQUE_ID => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_UNIQUE_ID,
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_IS_VIRTUAL => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_IS_VIRTUAL,
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_IS_BY_VALUE => miopenBackendAttributeName_t::MIOPEN_ATTR_TENSOR_IS_BY_VALUE,
cudnnBackendAttributeName_t::CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS => miopenBackendAttributeName_t::MIOPEN_ATTR_VARIANT_PACK_UNIQUE_IDS,
cudnnBackendAttributeName_t::CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS => miopenBackendAttributeName_t::MIOPEN_ATTR_VARIANT_PACK_DATA_POINTERS,
cudnnBackendAttributeName_t::CUDNN_ATTR_VARIANT_PACK_WORKSPACE => miopenBackendAttributeName_t::MIOPEN_ATTR_VARIANT_PACK_WORKSPACE,
_ => panic!("[ZLUDA] Unknown attribute name: {}", name.0),
}
}
fn is_unsupported_attribute_name(name: cudnnBackendAttributeName_t) -> bool {
match name {
cudnnBackendAttributeName_t::CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT => true,
_ => false,
}
}
fn to_backend_attribute_type(attribute_type: cudnnBackendAttributeType_t) -> miopenBackendAttributeType_t {
match attribute_type {
cudnnBackendAttributeType_t::CUDNN_TYPE_HANDLE => miopenBackendAttributeType_t::MIOPEN_TYPE_HANDLE,
cudnnBackendAttributeType_t::CUDNN_TYPE_DATA_TYPE => miopenBackendAttributeType_t::MIOPEN_TYPE_DATA_TYPE,
cudnnBackendAttributeType_t::CUDNN_TYPE_BOOLEAN => miopenBackendAttributeType_t::MIOPEN_TYPE_BOOLEAN,
cudnnBackendAttributeType_t::CUDNN_TYPE_INT64 => miopenBackendAttributeType_t::MIOPEN_TYPE_INT64,
cudnnBackendAttributeType_t::CUDNN_TYPE_FLOAT => miopenBackendAttributeType_t::MIOPEN_TYPE_FLOAT,
cudnnBackendAttributeType_t::CUDNN_TYPE_DOUBLE => miopenBackendAttributeType_t::MIOPEN_TYPE_DOUBLE,
cudnnBackendAttributeType_t::CUDNN_TYPE_VOID_PTR => miopenBackendAttributeType_t::MIOPEN_TYPE_VOID_PTR,
cudnnBackendAttributeType_t::CUDNN_TYPE_CONVOLUTION_MODE => miopenBackendAttributeType_t::MIOPEN_TYPE_CONVOLUTION_MODE,
cudnnBackendAttributeType_t::CUDNN_TYPE_HEUR_MODE => miopenBackendAttributeType_t::MIOPEN_TYPE_HEUR_MODE,
cudnnBackendAttributeType_t::CUDNN_TYPE_BACKEND_DESCRIPTOR => miopenBackendAttributeType_t::MIOPEN_TYPE_BACKEND_DESCRIPTOR,
_ => panic!("[ZLUDA] Unknown attribute type: {}", attribute_type.0),
}
}
unsafe fn backend_cudnn_to_miopen(
elements_type: miopenBackendAttributeType_t,
element_count: i64,
array_of_elements: *mut ::std::os::raw::c_void,
) -> () {
match elements_type {
miopenBackendAttributeType_t::MIOPEN_TYPE_HANDLE => (),
miopenBackendAttributeType_t::MIOPEN_TYPE_DATA_TYPE => {
if element_count != 1 {
panic!("[ZLUDA] Unexpected value: element_count={}", element_count)
}
let p_data_type: *mut miopenDataType_t = array_of_elements.cast();
*p_data_type = to_data_type(*(p_data_type as *mut cudnnDataType_t));
},
miopenBackendAttributeType_t::MIOPEN_TYPE_INT64 => (),
miopenBackendAttributeType_t::MIOPEN_TYPE_FLOAT => (),
miopenBackendAttributeType_t::MIOPEN_TYPE_DOUBLE => (),
miopenBackendAttributeType_t::MIOPEN_TYPE_VOID_PTR => (),
miopenBackendAttributeType_t::MIOPEN_TYPE_CONVOLUTION_MODE => {
if element_count != 1 {
panic!("[ZLUDA] Unexpected value: element_count={}", element_count)
}
let p_conv_mode: *mut miopenConvolutionMode_t = array_of_elements.cast();
*p_conv_mode = to_conv_mode(*(p_conv_mode as *mut cudnnConvolutionMode_t));
},
miopenBackendAttributeType_t::MIOPEN_TYPE_HEUR_MODE => {
if element_count != 1 {
panic!("[ZLUDA] Unexpected value: element_count={}", element_count)
}
let p_heur_mode: *mut miopenBackendHeurMode_t = array_of_elements.cast();
*p_heur_mode = to_heur_mode(*(p_heur_mode as *mut cudnnBackendHeurMode_t));
},
miopenBackendAttributeType_t::MIOPEN_TYPE_BACKEND_DESCRIPTOR => (),
_ => println!("[ZLUDA] Warning: found unknown backend attribute type: {}", elements_type.0),
}
}
unsafe fn backend_set_attribute(
descriptor: cudnnBackendDescriptor_t,
attribute_name: cudnnBackendAttributeName_t,
attribute_type: cudnnBackendAttributeType_t,
element_count: i64,
array_of_elements: *const ::std::os::raw::c_void,
) -> cudnnStatus_t {
if is_unsupported_attribute_name(attribute_name) { // temporary skip unimplemented attribute names
return cudnnStatus_t::CUDNN_STATUS_SUCCESS;
}
let attribute_name = to_backend_attribute_name(attribute_name);
let attribute_type = to_backend_attribute_type(attribute_type);
let elements = array_of_elements.clone();
backend_cudnn_to_miopen(attribute_type, element_count, elements.cast_mut());
to_cudnn(miopenBackendSetAttribute(
descriptor.cast(),
attribute_name,
attribute_type,
element_count,
elements.cast_mut(),
))
}
unsafe fn backend_get_attribute(
descriptor: cudnnBackendDescriptor_t,
attribute_name: cudnnBackendAttributeName_t,
attribute_type: cudnnBackendAttributeType_t,
requested_element_count: i64,
element_count: *mut i64,
array_of_elements: *mut ::std::os::raw::c_void,
) -> cudnnStatus_t {
let attribute_name = to_backend_attribute_name(attribute_name);
let attribute_type = to_backend_attribute_type(attribute_type);
to_cudnn(miopenBackendGetAttribute(
descriptor.cast(),
attribute_name,
attribute_type,
requested_element_count,
element_count,
array_of_elements,
))
}
unsafe fn backend_execute(
handle: cudnnHandle_t,
execution_plan: cudnnBackendDescriptor_t,
variant_pack: cudnnBackendDescriptor_t,
) -> cudnnStatus_t {
to_cudnn(miopenBackendExecute(
handle.cast(),
execution_plan.cast(),
variant_pack.cast(),
))
}