[Fix] Clean up Runtime API.

This commit is contained in:
Seunghoon Lee 2024-05-21 10:49:19 +09:00
parent 11cc584451
commit 2ad9ad6851
No known key found for this signature in database
GPG key ID: 436E38F4E70BD152
3 changed files with 62 additions and 47 deletions

View file

@ -7154,9 +7154,17 @@ extern "C" {
extern "C" {
#[must_use]
pub fn __hipRegisterFatBinary(
data: *const ::std::os::raw::c_void,
data: *mut ::std::os::raw::c_void,
) -> *mut *mut ::std::os::raw::c_void;
}
/*
extern "C" {
#[must_use]
pub fn __hipRegisterFatBinaryEnd(
fatCubinHandle: *mut *mut ::std::os::raw::c_void,
) -> ::std::os::raw::c_void;
}
*/
extern "C" {
#[must_use]
pub fn __hipRegisterFunction(
@ -7172,6 +7180,17 @@ extern "C" {
wSize: *mut ::std::os::raw::c_int,
) -> ::std::os::raw::c_void;
}
/*
extern "C" {
#[must_use]
pub fn __hipRegisterHostVar(
fatCubinHandle: *mut *mut ::std::os::raw::c_void,
deviceName: *const ::std::os::raw::c_char,
hostVar: *mut ::std::os::raw::c_char,
size: usize,
) -> ::std::os::raw::c_void;
}
*/
extern "C" {
#[must_use]
pub fn __hipRegisterManagedVar(

File diff suppressed because one or more lines are too long

View file

@ -2,7 +2,6 @@ mod cudart;
pub use cudart::*;
use hip_runtime_sys::*;
use std::ptr;
#[cfg(debug_assertions)]
fn unsupported() -> cudaError_t {
@ -18,6 +17,7 @@ fn to_cuda(status: hipError_t) -> cudaError_t {
match status {
hipError_t::hipSuccess => cudaError_t::cudaSuccess,
hipError_t::hipErrorInvalidValue => cudaError_t::cudaErrorInvalidValue,
hipError_t::hipErrorOutOfMemory => cudaError_t::cudaErrorMemoryAllocation,
hipError_t::hipErrorInvalidResourceHandle => cudaError_t::cudaErrorInvalidResourceHandle,
hipError_t::hipErrorNotSupported => cudaError_t::cudaErrorNotSupported,
err => panic!("[ZLUDA] HIP Runtime failed: {}", err.0),
@ -28,6 +28,7 @@ fn to_hip(status: cudaError_t) -> hipError_t {
match status {
cudaError_t::cudaSuccess => hipError_t::hipSuccess,
cudaError_t::cudaErrorInvalidValue => hipError_t::hipErrorInvalidValue,
cudaError_t::cudaErrorMemoryAllocation => hipError_t::hipErrorOutOfMemory,
cudaError_t::cudaErrorInvalidResourceHandle => hipError_t::hipErrorInvalidResourceHandle,
cudaError_t::cudaErrorNotSupported => hipError_t::hipErrorNotSupported,
err => panic!("[ZLUDA] HIP Runtime failed: {}", err.0),
@ -41,7 +42,7 @@ fn to_hip_memcpy_kind(memcpy_kind: cudaMemcpyKind) -> hipMemcpyKind {
cudaMemcpyKind::cudaMemcpyDeviceToHost => hipMemcpyKind::hipMemcpyDeviceToHost,
cudaMemcpyKind::cudaMemcpyDeviceToDevice => hipMemcpyKind::hipMemcpyDeviceToDevice,
cudaMemcpyKind::cudaMemcpyDefault => hipMemcpyKind::hipMemcpyDefault,
_ => panic!()
_ => panic!(),
}
}
@ -55,7 +56,7 @@ fn to_hip_mem_pool_attr(mem_pool_attr: cudaMemPoolAttr) -> hipMemPoolAttr {
cudaMemPoolAttr::cudaMemPoolAttrReservedMemHigh => hipMemPoolAttr::hipMemPoolAttrReservedMemHigh,
cudaMemPoolAttr::cudaMemPoolAttrUsedMemCurrent => hipMemPoolAttr::hipMemPoolAttrUsedMemCurrent,
cudaMemPoolAttr::cudaMemPoolAttrUsedMemHigh => hipMemPoolAttr::hipMemPoolAttrUsedMemHigh,
_ => panic!("[ZLUDA] Unsupported memory pool attribute: {}", mem_pool_attr.0)
_ => panic!(),
}
}
@ -64,15 +65,15 @@ fn to_cuda_stream_capture_status(stream_capture_status: hipStreamCaptureStatus)
hipStreamCaptureStatus::hipStreamCaptureStatusNone => cudaStreamCaptureStatus::cudaStreamCaptureStatusNone,
hipStreamCaptureStatus::hipStreamCaptureStatusActive => cudaStreamCaptureStatus::cudaStreamCaptureStatusActive,
hipStreamCaptureStatus::hipStreamCaptureStatusInvalidated => cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated,
_ => panic!()
_ => panic!(),
}
}
fn to_hip_dim3(dim3: cudart::dim3) -> hip_runtime_api::dim3 {
fn to_hip_dim3(dim: cudart::dim3) -> hip_runtime_api::dim3 {
hip_runtime_api::dim3 {
x: dim3.x,
y: dim3.y,
z: dim3.z,
x: dim.x,
y: dim.y,
z: dim.z,
}
}
@ -80,7 +81,7 @@ unsafe fn pop_call_configuration(
grid_dim: *mut cudart::dim3,
block_dim: *mut cudart::dim3,
shared_mem: *mut usize,
stream: *mut ::std::os::raw::c_void,
stream: *mut cudaStream_t,
) -> cudaError_t {
to_cuda(__hipPopCallConfiguration(
grid_dim.cast(),
@ -94,8 +95,8 @@ unsafe fn push_call_configuration(
grid_dim: cudart::dim3,
block_dim: cudart::dim3,
shared_mem: usize,
stream: *mut ::std::os::raw::c_void,
) -> u32 {
stream: cudaStream_t,
) -> cudaError_t {
let grid_dim = to_hip_dim3(grid_dim);
let block_dim = to_hip_dim3(block_dim);
to_cuda(__hipPushCallConfiguration(
@ -103,7 +104,7 @@ unsafe fn push_call_configuration(
block_dim,
shared_mem,
stream.cast(),
)).0 as _
))
}
unsafe fn register_fat_binary(
@ -112,6 +113,12 @@ unsafe fn register_fat_binary(
__hipRegisterFatBinary(fat_cubin)
}
unsafe fn register_fat_binary_end(
_fat_cubin_handle: *mut *mut ::std::os::raw::c_void,
) -> () {
//__hipRegisterFatBinaryEnd(fat_cubin_handle)
}
unsafe fn register_function(
fat_cubin_handle: *mut *mut ::std::os::raw::c_void,
host_fun: *const ::std::os::raw::c_char,
@ -146,7 +153,7 @@ unsafe fn register_host_var(
) -> ::std::os::raw::c_void {
__hipRegisterVar(
fat_cubin_handle,
ptr::null_mut(),
host_var.cast(),
host_var,
device_name.cast_mut(),
0,
@ -161,18 +168,20 @@ unsafe fn register_managed_var(
host_var_ptr_address: *mut *mut ::std::os::raw::c_void,
device_address: *mut ::std::os::raw::c_char,
device_name: *const ::std::os::raw::c_char,
_ext: i32,
ext: i32,
size: usize,
constant: i32,
_global: i32,
global: i32,
) -> ::std::os::raw::c_void {
__hipRegisterManagedVar(
*fat_cubin_handle,
host_var_ptr_address,
device_address.cast(),
device_name,
__hipRegisterVar(
fat_cubin_handle,
*host_var_ptr_address,
device_address,
device_name.cast_mut(),
ext,
size,
constant as _,
constant,
global,
)
}
@ -186,9 +195,9 @@ unsafe fn register_surface(
) -> ::std::os::raw::c_void {
__hipRegisterSurface(
fat_cubin_handle,
host_var.cast_mut(),
(*device_address).cast(),
host_var as _,
device_name as _,
device_name.cast_mut(),
dim,
ext,
)
@ -205,9 +214,9 @@ unsafe fn register_texture(
) -> ::std::os::raw::c_void {
__hipRegisterTexture(
fat_cubin_handle,
host_var.cast_mut(),
(*device_address).cast(),
host_var as _,
device_name as _,
device_name.cast_mut(),
dim,
norm,
ext,
@ -226,8 +235,8 @@ unsafe fn register_var(
) -> ::std::os::raw::c_void {
__hipRegisterVar(
fat_cubin_handle,
device_address.cast(),
host_var,
host_var.cast(),
device_address,
device_name.cast_mut(),
ext,
size,
@ -336,16 +345,6 @@ unsafe fn get_device_count(
to_cuda(hipGetDeviceCount(count))
}
unsafe fn get_device_properties(
prop: *mut cudaDeviceProp,
device: i32,
) -> cudaError_t {
to_cuda(hipGetDeviceProperties(
prop.cast(),
device,
))
}
unsafe fn device_get_default_mem_pool(
mem_pool: *mut cudaMemPool_t,
device: i32,