This commit is contained in:
Seunghoon Lee 2024-05-30 03:10:42 +09:00
parent 796ca8de8d
commit 7dfd642e67
No known key found for this signature in database
GPG key ID: 436E38F4E70BD152
2 changed files with 62 additions and 513 deletions

File diff suppressed because one or more lines are too long

View file

@ -2,6 +2,7 @@ mod cudart;
pub use cudart::*;
use hip_runtime_sys::*;
use std::mem;
#[cfg(debug_assertions)]
fn unsupported() -> cudaError_t {
@ -548,26 +549,26 @@ unsafe fn stream_wait_event_ptsz(
event: cudaEvent_t,
flags: u32,
) -> cudaError_t {
let stream = to_stream(stream);
to_cuda(hipStreamWaitEvent_spt(
stream,
let lib = hip_common::zluda_ext::get_cuda_library().unwrap();
let cu_stream_wait_event = lib
.get::<unsafe extern "C" fn(
hStream: cuda_types::CUstream,
hEvent: cuda_types::CUevent,
Flags: ::std::os::raw::c_uint,
) -> cuda_types::CUresult>(b"cuStreamWaitEvent_ptsz\0")
.unwrap();
cudaError_t((cu_stream_wait_event)(
stream.cast(),
event.cast(),
flags,
))
).0)
}
unsafe fn stream_synchronize(
stream: cudaStream_t,
) -> cudaError_t {
let lib = hip_common::zluda_ext::get_cuda_library().unwrap();
let cu_stream_synchronize = lib
.get::<unsafe extern "C" fn(
hStream: cuda_types::CUstream,
) -> cuda_types::CUresult>(b"cuStreamSynchronize\0")
.unwrap();
cudaError_t((cu_stream_synchronize)(
stream.cast(),
).0)
let stream = to_stream(stream);
to_cuda(hipStreamSynchronize(stream))
}
unsafe fn stream_synchronize_ptsz(
@ -591,13 +592,6 @@ unsafe fn stream_query(
to_cuda(hipStreamQuery(stream))
}
unsafe fn stream_query_ptsz(
stream: cudaStream_t,
) -> cudaError_t {
let stream = to_stream(stream);
to_cuda(hipStreamQuery_spt(stream))
}
unsafe fn stream_attach_mem_async(
stream: cudaStream_t,
dev_ptr: *mut ::std::os::raw::c_void,
@ -624,49 +618,18 @@ unsafe fn stream_end_capture(
))
}
unsafe fn stream_end_capture_ptsz(
stream: cudaStream_t,
p_graph: *mut cudaGraph_t,
) -> cudaError_t {
let stream = to_stream(stream);
to_cuda(hipStreamEndCapture_spt(
stream,
p_graph.cast(),
))
}
unsafe fn stream_is_capturing(
stream: cudaStream_t,
p_capture_status: *mut cudaStreamCaptureStatus,
) -> cudaError_t {
let lib = hip_common::zluda_ext::get_cuda_library().unwrap();
let cu_stream_is_capturing = lib
.get::<unsafe extern "C" fn(
hStream: cuda_types::CUstream,
captureStatus: *mut cuda_types::CUstreamCaptureStatus,
) -> cuda_types::CUresult>(b"cuStreamIsCapturing\0")
.unwrap();
cudaError_t((cu_stream_is_capturing)(
stream.cast(),
p_capture_status.cast(),
).0)
}
unsafe fn stream_is_capturing_ptsz(
stream: cudaStream_t,
p_capture_status: *mut cudaStreamCaptureStatus,
) -> cudaError_t {
let lib = hip_common::zluda_ext::get_cuda_library().unwrap();
let cu_stream_is_capturing = lib
.get::<unsafe extern "C" fn(
hStream: cuda_types::CUstream,
captureStatus: *mut cuda_types::CUstreamCaptureStatus,
) -> cuda_types::CUresult>(b"cuStreamIsCapturing_ptsz\0")
.unwrap();
cudaError_t((cu_stream_is_capturing)(
stream.cast(),
p_capture_status.cast(),
).0)
let stream = to_stream(stream);
let mut capture_status = mem::zeroed();
let status = to_cuda(hipStreamIsCapturing(
stream,
&mut capture_status,
));
*p_capture_status = to_cuda_stream_capture_status(capture_status);
status
}
unsafe fn stream_get_capture_info(
@ -674,39 +637,15 @@ unsafe fn stream_get_capture_info(
p_capture_status: *mut cudaStreamCaptureStatus,
p_id: *mut u64,
) -> cudaError_t {
let lib = hip_common::zluda_ext::get_cuda_library().unwrap();
let cu_stream_get_capture_info = lib
.get::<unsafe extern "C" fn(
hStream: cuda_types::CUstream,
captureStatus_out: *mut cuda_types::CUstreamCaptureStatus,
id_out: *mut cuda_types::cuuint64_t,
) -> cuda_types::CUresult>(b"cuStreamGetCaptureInfo\0")
.unwrap();
cudaError_t((cu_stream_get_capture_info)(
stream.cast(),
p_capture_status.cast(),
let stream = to_stream(stream);
let mut capture_status = mem::zeroed();
let status = to_cuda(hipStreamGetCaptureInfo(
stream,
&mut capture_status,
p_id,
).0)
}
unsafe fn stream_get_capture_info_ptsz(
stream: cudaStream_t,
p_capture_status: *mut cudaStreamCaptureStatus,
p_id: *mut u64,
) -> cudaError_t {
let lib = hip_common::zluda_ext::get_cuda_library().unwrap();
let cu_stream_get_capture_info = lib
.get::<unsafe extern "C" fn(
hStream: cuda_types::CUstream,
captureStatus_out: *mut cuda_types::CUstreamCaptureStatus,
id_out: *mut cuda_types::cuuint64_t,
) -> cuda_types::CUresult>(b"cuStreamGetCaptureInfo_ptsz\0")
.unwrap();
cudaError_t((cu_stream_get_capture_info)(
stream.cast(),
p_capture_status.cast(),
p_id,
).0)
));
*p_capture_status = to_cuda_stream_capture_status(capture_status);
status
}
unsafe fn event_create(
@ -738,17 +677,6 @@ unsafe fn event_record(
))
}
unsafe fn event_record_ptsz(
event: cudaEvent_t,
stream: cudaStream_t,
) -> cudaError_t {
let stream = to_stream(stream);
to_cuda(hipEventRecord_spt(
event.cast(),
stream,
))
}
unsafe fn event_query(
event: cudaEvent_t,
) -> cudaError_t {
@ -806,27 +734,6 @@ unsafe fn launch_kernel(
))
}
unsafe fn launch_kernel_ptsz(
func: *const ::std::os::raw::c_void,
grid_dim: cudart::dim3,
block_dim: cudart::dim3,
args: *mut *mut ::std::os::raw::c_void,
shared_mem: usize,
stream: cudaStream_t,
) -> cudaError_t {
let grid_dim = to_hip_dim3(grid_dim);
let block_dim = to_hip_dim3(block_dim);
let stream = to_stream(stream);
to_cuda(hipLaunchKernel_spt(
func,
grid_dim,
block_dim,
args,
shared_mem,
stream,
))
}
unsafe fn launch_cooperative_kernel(
func: *const ::std::os::raw::c_void,
grid_dim: cudart::dim3,
@ -848,27 +755,6 @@ unsafe fn launch_cooperative_kernel(
))
}
unsafe fn launch_cooperative_kernel_ptsz(
func: *const ::std::os::raw::c_void,
grid_dim: cudart::dim3,
block_dim: cudart::dim3,
args: *mut *mut ::std::os::raw::c_void,
shared_mem: usize,
stream: cudaStream_t,
) -> cudaError_t {
let grid_dim = to_hip_dim3(grid_dim);
let block_dim = to_hip_dim3(block_dim);
let stream = to_stream(stream);
to_cuda(hipLaunchCooperativeKernel_spt(
func,
grid_dim,
block_dim,
args,
shared_mem as _,
stream,
))
}
unsafe fn launch_host_func(
stream: cudaStream_t,
fn_: cudaHostFn_t,
@ -882,33 +768,6 @@ unsafe fn launch_host_func(
))
}
unsafe fn launch_host_func_ptsz(
stream: cudaStream_t,
fn_: cudaHostFn_t,
user_data: *mut ::std::os::raw::c_void,
) -> cudaError_t {
let stream = to_stream(stream);
to_cuda(hipLaunchHostFunc_spt(
stream,
fn_,
user_data,
))
}
unsafe fn occupancy_max_active_blocks_per_multiprocessor(
num_blocks: *mut i32,
func: *const ::std::os::raw::c_void,
block_size: i32,
dynamic_s_mem_size: usize,
) -> cudaError_t {
to_cuda(hipOccupancyMaxActiveBlocksPerMultiprocessor(
num_blocks,
func,
block_size,
dynamic_s_mem_size,
))
}
unsafe fn occupancy_max_active_blocks_per_multiprocessor_with_flags(
num_blocks: *mut i32,
func: *const ::std::os::raw::c_void,
@ -916,13 +775,23 @@ unsafe fn occupancy_max_active_blocks_per_multiprocessor_with_flags(
dynamic_s_mem_size: usize,
flags: u32,
) -> cudaError_t {
to_cuda(hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
let lib = hip_common::zluda_ext::get_cuda_library().unwrap();
let cu_stream_synchronize = lib
.get::<unsafe extern "C" fn(
numBlocks: *mut ::std::os::raw::c_int,
func: *const cuda_types::CUfunc_st,
blockSize: ::std::os::raw::c_int,
dynamicSMemSize: usize,
flags: ::std::os::raw::c_uint,
) -> cuda_types::CUresult>(b"cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags\0")
.unwrap();
cudaError_t((cu_stream_synchronize)(
num_blocks,
func,
func.cast(),
block_size,
dynamic_s_mem_size,
flags,
))
).0)
}
unsafe fn malloc_managed(
@ -1339,24 +1208,6 @@ unsafe fn memcpy_async(
))
}
unsafe fn memcpy_async_ptsz(
dst: *mut ::std::os::raw::c_void,
src: *const ::std::os::raw::c_void,
count: usize,
kind: cudaMemcpyKind,
stream: cudaStream_t,
) -> cudaError_t {
let kind = to_hip_memcpy_kind(kind);
let stream = to_stream(stream);
to_cuda(hipMemcpyAsync_spt(
dst,
src,
count,
kind,
stream,
))
}
unsafe fn memcpy_peer_async(
dst: *mut ::std::os::raw::c_void,
dst_device: i32,
@ -1400,30 +1251,6 @@ unsafe fn memcpy_2d_async(
))
}
unsafe fn memcpy_2d_async_ptsz(
dst: *mut ::std::os::raw::c_void,
dpitch: usize,
src: *const ::std::os::raw::c_void,
spitch: usize,
width: usize,
height: usize,
kind: cudaMemcpyKind,
stream: cudaStream_t,
) -> cudaError_t {
let kind = to_hip_memcpy_kind(kind);
let stream = to_stream(stream);
to_cuda(hipMemcpy2DAsync_spt(
dst,
dpitch,
src,
spitch,
width,
height,
kind,
stream,
))
}
unsafe fn memcpy_2d_to_array_async(
dst: cudaArray_t,
w_offset: usize,
@ -1450,32 +1277,6 @@ unsafe fn memcpy_2d_to_array_async(
))
}
unsafe fn memcpy_2d_to_array_async_ptsz(
dst: cudaArray_t,
w_offset: usize,
h_offset: usize,
src: *const ::std::os::raw::c_void,
spitch: usize,
width: usize,
height: usize,
kind: cudaMemcpyKind,
stream: cudaStream_t,
) -> cudaError_t {
let kind = to_hip_memcpy_kind(kind);
let stream = to_stream(stream);
to_cuda(hipMemcpy2DToArrayAsync_spt(
dst.cast(),
w_offset,
h_offset,
src,
spitch,
width,
height,
kind,
stream,
))
}
unsafe fn memcpy_2d_from_array_async(
dst: *mut ::std::os::raw::c_void,
dpitch: usize,
@ -1502,32 +1303,6 @@ unsafe fn memcpy_2d_from_array_async(
))
}
unsafe fn memcpy_2d_from_array_async_ptsz(
dst: *mut ::std::os::raw::c_void,
dpitch: usize,
src: cudaArray_const_t,
w_offset: usize,
h_offset: usize,
width: usize,
height: usize,
kind: cudaMemcpyKind,
stream: cudaStream_t,
) -> cudaError_t {
let kind = to_hip_memcpy_kind(kind);
let stream = to_stream(stream);
to_cuda(hipMemcpy2DFromArrayAsync_spt(
dst,
dpitch,
src.cast(),
w_offset,
h_offset,
width,
height,
kind,
stream,
))
}
unsafe fn memcpy_to_symbol_async(
symbol: *const ::std::os::raw::c_void,
src: *const ::std::os::raw::c_void,
@ -1548,26 +1323,6 @@ unsafe fn memcpy_to_symbol_async(
))
}
unsafe fn memcpy_to_symbol_async_ptsz(
symbol: *const ::std::os::raw::c_void,
src: *const ::std::os::raw::c_void,
count: usize,
offset: usize,
kind: cudaMemcpyKind,
stream: cudaStream_t,
) -> cudaError_t {
let kind = to_hip_memcpy_kind(kind);
let stream = to_stream(stream);
to_cuda(hipMemcpyToSymbolAsync_spt(
symbol,
src,
count,
offset,
kind,
stream,
))
}
unsafe fn memcpy_from_symbol_async(
dst: *mut ::std::os::raw::c_void,
symbol: *const ::std::os::raw::c_void,
@ -1588,26 +1343,6 @@ unsafe fn memcpy_from_symbol_async(
))
}
unsafe fn memcpy_from_symbol_async_ptsz(
dst: *mut ::std::os::raw::c_void,
symbol: *const ::std::os::raw::c_void,
count: usize,
offset: usize,
kind: cudaMemcpyKind,
stream: cudaStream_t,
) -> cudaError_t {
let kind = to_hip_memcpy_kind(kind);
let stream = to_stream(stream);
to_cuda(hipMemcpyFromSymbolAsync_spt(
dst,
symbol,
count,
offset,
kind,
stream,
))
}
unsafe fn memset(
dev_ptr: *mut ::std::os::raw::c_void,
value: i32,
@ -1679,21 +1414,6 @@ unsafe fn memset_async(
))
}
unsafe fn memset_async_ptsz(
dev_ptr: *mut ::std::os::raw::c_void,
value: i32,
count: usize,
stream: cudaStream_t,
) -> cudaError_t {
let stream = to_stream(stream);
to_cuda(hipMemsetAsync_spt(
dev_ptr,
value,
count,
stream,
))
}
unsafe fn memset_2d_async(
dev_ptr: *mut ::std::os::raw::c_void,
pitch: usize,
@ -1713,25 +1433,6 @@ unsafe fn memset_2d_async(
))
}
unsafe fn memset_2d_async_ptsz(
dev_ptr: *mut ::std::os::raw::c_void,
pitch: usize,
value: i32,
width: usize,
height: usize,
stream: cudaStream_t,
) -> cudaError_t {
let stream = to_stream(stream);
to_cuda(hipMemset2DAsync_spt(
dev_ptr,
pitch,
value,
width,
height,
stream,
))
}
unsafe fn get_symbol_address(
dev_ptr: *mut *mut ::std::os::raw::c_void,
symbol: *const ::std::os::raw::c_void,
@ -1848,30 +1549,6 @@ unsafe fn memcpy_to_array_async(
))
}
unsafe fn memcpy_to_array_async_ptsz(
dst: cudaArray_t,
w_offset: usize,
h_offset: usize,
src: *const ::std::os::raw::c_void,
count: usize,
kind: cudaMemcpyKind,
stream: cudaStream_t,
) -> cudaError_t {
let kind = to_hip_memcpy_kind(kind);
let stream = to_stream(stream);
to_cuda(hipMemcpy2DToArrayAsync_spt(
dst.cast(),
w_offset,
h_offset,
src,
count,
w_offset,
h_offset,
kind,
stream,
))
}
unsafe fn memcpy_from_array_async(
dst: *mut ::std::os::raw::c_void,
src: cudaArray_const_t,
@ -1896,30 +1573,6 @@ unsafe fn memcpy_from_array_async(
))
}
unsafe fn memcpy_from_array_async_ptsz(
dst: *mut ::std::os::raw::c_void,
src: cudaArray_const_t,
w_offset: usize,
h_offset: usize,
count: usize,
kind: cudaMemcpyKind,
stream: cudaStream_t,
) -> cudaError_t {
let kind = to_hip_memcpy_kind(kind);
let stream = to_stream(stream);
to_cuda(hipMemcpy2DFromArrayAsync_spt(
dst,
count,
src.cast(),
w_offset,
h_offset,
w_offset,
h_offset,
kind,
stream,
))
}
unsafe fn malloc_async(
dev_ptr: *mut *mut ::std::os::raw::c_void,
size: usize,