Refactor host code to use one big lock

This commit is contained in:
Andrzej Janik 2020-11-11 22:35:34 +01:00
parent 7c93997cc9
commit a2e77fe961
15 changed files with 914 additions and 540 deletions

View file

@ -173,6 +173,16 @@ impl Context {
check!(sys::zeContextCreate(drv.0, &ctx_desc, &mut result));
Ok(Context(result))
}
pub unsafe fn mem_free(&mut self, ptr: *mut c_void) -> Result<()> {
check! {
sys::zeMemFree(
self.0,
ptr,
)
};
Ok(())
}
}
impl Drop for Context {
@ -239,7 +249,7 @@ pub struct Module(sys::ze_module_handle_t);
impl Module {
// HACK ALERT
// We use OpenCL for now to do SPIR-V linking, because Level0
// We use OpenCL for now to do SPIR-V linking, because Level0
// does not allow linking. Don't let presence of zeModuleDynamicLink fool
// you, it's not currently possible to create non-compiled modules.
// zeModuleCreate always compiles (builds and links).

27
notcuda/build.rs Normal file
View file

@ -0,0 +1,27 @@
// HACK ALERT
// This buidl script has been copy-pasted from cl-sys to avoid CUDA libraries
// overriding path to OpenCL
fn main() {
if cfg!(windows) {
let known_sdk = [
// E.g. "c:\Program Files (x86)\Intel\OpenCL SDK\lib\x86\"
("INTELOCLSDKROOT", "x64", "x86"),
// E.g. "c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\lib\Win32\"
("CUDA_PATH", "x64", "Win32"),
// E.g. "C:\Program Files (x86)\AMD APP SDK\3.0\lib\x86\"
("AMDAPPSDKROOT", "x86_64", "x86"),
];
for info in known_sdk.iter() {
if let Ok(sdk) = std::env::var(info.0) {
let mut path = std::path::PathBuf::from(sdk);
path.push("lib");
path.push(if cfg!(target_arch="x86_64") { info.1 } else { info.2 });
println!("cargo:rustc-link-search=native={}", path.display());
}
}
println!("cargo:rustc-link-search=native=C:\\Program Files (x86)\\OCL_SDK_Light\\lib\\x86_64");
}
}

View file

@ -2210,12 +2210,12 @@ pub extern "C" fn cuDriverGetVersion(driverVersion: *mut ::std::os::raw::c_int)
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuDeviceGet(device: *mut CUdevice, ordinal: ::std::os::raw::c_int) -> CUresult {
r#impl::device::get(device.decuda(), ordinal)
r#impl::device::get(device.decuda(), ordinal).encuda()
}
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuDeviceGetCount(count: *mut ::std::os::raw::c_int) -> CUresult {
r#impl::device::get_count(count)
r#impl::device::get_count(count).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@ -2314,7 +2314,6 @@ pub extern "C" fn cuDevicePrimaryCtxReset(dev: CUdevice) -> CUresult {
cuDevicePrimaryCtxReset_v2(dev)
}
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuDevicePrimaryCtxReset_v2(dev: CUdevice) -> CUresult {
r#impl::unimplemented()
@ -2331,7 +2330,7 @@ pub extern "C" fn cuCtxCreate_v2(
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuCtxDestroy_v2(ctx: CUcontext) -> CUresult {
r#impl::context::destroy_v2(ctx.decuda())
r#impl::context::destroy_v2(ctx.decuda()).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@ -2356,7 +2355,7 @@ pub extern "C" fn cuCtxGetCurrent(pctx: *mut CUcontext) -> CUresult {
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuCtxGetDevice(device: *mut CUdevice) -> CUresult {
r#impl::context::get_device(device.decuda())
r#impl::context::get_device(device.decuda()).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@ -2404,7 +2403,7 @@ pub extern "C" fn cuCtxGetApiVersion(
ctx: CUcontext,
version: *mut ::std::os::raw::c_uint,
) -> CUresult {
r#impl::context::get_api_version(ctx.decuda(), version)
r#impl::context::get_api_version(ctx.decuda(), version).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@ -2422,12 +2421,12 @@ pub extern "C" fn cuCtxResetPersistingL2Cache() -> CUresult {
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuCtxAttach(pctx: *mut CUcontext, flags: ::std::os::raw::c_uint) -> CUresult {
r#impl::unimplemented()
r#impl::context::attach(pctx.decuda(), flags).encuda()
}
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuCtxDetach(ctx: CUcontext) -> CUresult {
r#impl::unimplemented()
r#impl::context::detach(ctx.decuda()).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@ -2443,7 +2442,7 @@ pub extern "C" fn cuModuleLoadData(
module: *mut CUmodule,
image: *const ::std::os::raw::c_void,
) -> CUresult {
r#impl::unimplemented()
r#impl::module::load_data(module.decuda(), image).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@ -2564,7 +2563,7 @@ pub extern "C" fn cuMemGetInfo_v2(free: *mut usize, total: *mut usize) -> CUresu
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuMemAlloc_v2(dptr: *mut CUdeviceptr, bytesize: usize) -> CUresult {
r#impl::memory::alloc_v2(dptr.decuda(), bytesize)
r#impl::memory::alloc_v2(dptr.decuda(), bytesize).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@ -3281,7 +3280,7 @@ pub extern "C" fn cuStreamCreate(
phStream: *mut CUstream,
Flags: ::std::os::raw::c_uint,
) -> CUresult {
r#impl::unimplemented()
r#impl::stream::create(phStream.decuda(), Flags).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@ -3311,7 +3310,7 @@ pub extern "C" fn cuStreamGetFlags(
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuStreamGetCtx(hStream: CUstream, pctx: *mut CUcontext) -> CUresult {
r#impl::unimplemented()
r#impl::stream::get_ctx(hStream.decuda(), pctx.decuda()).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@ -3390,7 +3389,7 @@ pub extern "C" fn cuStreamSynchronize(hStream: CUstream) -> CUresult {
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuStreamDestroy_v2(hStream: CUstream) -> CUresult {
r#impl::unimplemented()
r#impl::stream::destroy_v2(hStream.decuda()).encuda()
}
#[cfg_attr(not(test), no_mangle)]

View file

@ -1,18 +1,15 @@
use super::CUresult;
use super::{device, HasLivenessCookie, LiveCheck};
use super::{device, stream::Stream, stream::StreamData, HasLivenessCookie, LiveCheck};
use super::{CUresult, GlobalState};
use crate::{cuda::CUcontext, cuda_impl};
use l0::sys::ze_result_t;
use std::mem::{self, ManuallyDrop};
use std::{cell::RefCell, num::NonZeroU32, os::raw::c_uint, ptr, sync::atomic::AtomicU32};
use std::{
cell::RefCell,
num::NonZeroU32,
os::raw::c_uint,
ptr,
sync::{atomic::AtomicU32, Mutex},
collections::HashSet,
mem::{self},
};
thread_local! {
pub static CONTEXT_STACK: RefCell<Vec<*const Context>> = RefCell::new(Vec::new());
pub static CONTEXT_STACK: RefCell<Vec<*mut Context>> = RefCell::new(Vec::new());
}
pub type Context = LiveCheck<ContextData>;
@ -23,6 +20,17 @@ impl HasLivenessCookie for ContextData {
#[cfg(target_pointer_width = "32")]
const COOKIE: usize = 0x0b643ffb;
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT;
fn try_drop(&mut self) -> Result<(), CUresult> {
for stream in self.streams.iter() {
let stream = unsafe { &mut **stream };
stream.context = ptr::null_mut();
Stream::destroy_impl(unsafe { Stream::ptr_from_inner(stream) })?;
}
Ok(())
}
}
enum ContextRefCount {
@ -67,26 +75,16 @@ impl ContextRefCount {
}
}
}
fn is_primary(&self) -> bool {
match self {
ContextRefCount::Primary => true,
ContextRefCount::NonPrimary(_) => false,
}
}
}
pub struct ContextData {
pub flags: AtomicU32,
pub device_index: device::Index,
// This pointer is null only for a moment when constructing primary context
pub device: *const Mutex<device::Device>,
// The split between mutable / non-mutable is mainly to avoid recursive locking in cuDevicePrimaryCtxGetState
pub mutable: Mutex<ContextDataMutable>,
}
pub struct ContextDataMutable {
pub device: *mut device::Device,
ref_count: ContextRefCount,
pub default_stream: StreamData,
pub streams: HashSet<*mut StreamData>,
// All the fields below are here to support internal CUDA driver API
pub cuda_manager: *mut cuda_impl::rt::ContextStateManager,
pub cuda_state: *mut cuda_impl::rt::ContextState,
pub cuda_dtor_cb: Option<
@ -100,63 +98,75 @@ pub struct ContextDataMutable {
impl ContextData {
pub fn new(
l0_ctx: &mut l0::Context,
l0_dev: &l0::Device,
flags: c_uint,
is_primary: bool,
dev_index: device::Index,
dev: *const Mutex<device::Device>,
) -> Self {
ContextData {
dev: *mut device::Device,
) -> Result<Self, CUresult> {
let default_stream = StreamData::new_unitialized(l0_ctx, l0_dev)?;
Ok(ContextData {
flags: AtomicU32::new(flags),
device_index: dev_index,
device: dev,
mutable: Mutex::new(ContextDataMutable {
ref_count: ContextRefCount::new(is_primary),
cuda_manager: ptr::null_mut(),
cuda_state: ptr::null_mut(),
cuda_dtor_cb: None,
}),
}
ref_count: ContextRefCount::new(is_primary),
default_stream,
streams: HashSet::new(),
cuda_manager: ptr::null_mut(),
cuda_state: ptr::null_mut(),
cuda_dtor_cb: None,
})
}
}
pub fn create_v2(pctx: *mut *mut Context, flags: u32, dev_idx: device::Index) -> CUresult {
impl Context {
pub fn late_init(&mut self) {
let ctx_data = self.as_option_mut().unwrap();
ctx_data.default_stream.context = ctx_data as *mut _;
}
}
pub fn create_v2(
pctx: *mut *mut Context,
flags: u32,
dev_idx: device::Index,
) -> Result<(), CUresult> {
if pctx == ptr::null_mut() {
return CUresult::CUDA_ERROR_INVALID_VALUE;
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let dev = device::get_device_ref(dev_idx);
let dev = match dev {
Ok(d) => d,
Err(e) => return e,
};
let mut ctx = Box::new(LiveCheck::new(ContextData::new(flags, false, dev_idx, dev)));
let ctx_ref = ctx.as_mut() as *mut Context;
let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| {
let dev_ptr = dev as *mut _;
let mut ctx_box = Box::new(LiveCheck::new(ContextData::new(
&mut dev.l0_context,
&dev.base,
flags,
false,
dev_ptr as *mut _,
)?));
ctx_box.late_init();
Ok::<_, CUresult>(ctx_box)
})??;
let ctx_ref = ctx_box.as_mut() as *mut Context;
unsafe { *pctx = ctx_ref };
mem::forget(ctx);
mem::forget(ctx_box);
CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx_ref));
CUresult::CUDA_SUCCESS
Ok(())
}
pub fn destroy_v2(ctx: *mut Context) -> CUresult {
pub fn destroy_v2(ctx: *mut Context) -> Result<(), CUresult> {
if ctx == ptr::null_mut() {
return CUresult::CUDA_ERROR_INVALID_VALUE;
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
CONTEXT_STACK.with(|stack| {
let mut stack = stack.borrow_mut();
let should_pop = match stack.last() {
Some(active_ctx) => *active_ctx == (ctx as *const _),
Some(active_ctx) => *active_ctx == (ctx as *mut _),
None => false,
};
if should_pop {
stack.pop();
}
});
let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(ctx) });
if !ctx_box.try_drop() {
CUresult::CUDA_ERROR_INVALID_CONTEXT
} else {
unsafe { ManuallyDrop::drop(&mut ctx_box) };
CUresult::CUDA_SUCCESS
}
GlobalState::lock(|_| Context::destroy_impl(ctx))?
}
pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult {
@ -172,17 +182,6 @@ pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult {
CUresult::CUDA_SUCCESS
}
pub fn with_current<F: FnOnce(&ContextData) -> R, R>(f: F) -> Result<R, CUresult> {
CONTEXT_STACK.with(|stack| {
stack
.borrow()
.last()
.and_then(|c| unsafe { &**c }.as_ref())
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
.map(f)
})
}
pub fn get_current(pctx: *mut *mut Context) -> l0::Result<()> {
if pctx == ptr::null_mut() {
return Err(ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT);
@ -205,37 +204,53 @@ pub fn set_current(ctx: *mut Context) -> CUresult {
}
}
pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> CUresult {
let _ctx = match unsafe { ctx.as_mut() } {
None => return CUresult::CUDA_ERROR_INVALID_VALUE,
Some(ctx) => match ctx.as_mut() {
None => return CUresult::CUDA_ERROR_INVALID_CONTEXT,
Some(ctx) => ctx,
},
};
pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> {
if ctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
GlobalState::lock(|_| {
unsafe { &*ctx }.as_result()?;
Ok::<_, CUresult>(())
})??;
//TODO: query device for properties roughly matching CUDA API version
unsafe { *version = 1100 };
CUresult::CUDA_SUCCESS
Ok(())
}
pub fn get_device(dev: *mut device::Index) -> CUresult {
let dev_idx = with_current(|ctx| ctx.device_index);
match dev_idx {
Ok(idx) => {
unsafe { *dev = idx }
CUresult::CUDA_SUCCESS
}
Err(err) => err,
pub fn get_device(dev: *mut device::Index) -> Result<(), CUresult> {
let dev_idx = GlobalState::lock_current_context(|ctx| unsafe { &*ctx.device }.index)?;
unsafe { *dev = dev_idx };
Ok(())
}
pub fn attach(pctx: *mut *mut Context, _flags: c_uint) -> Result<(), CUresult> {
if pctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let ctx = GlobalState::lock_current_context_unchecked(|unchecked_ctx| {
let ctx = unchecked_ctx.as_result_mut()?;
ctx.ref_count.incr()?;
Ok::<_, CUresult>(unchecked_ctx as *mut _)
})??;
unsafe { *pctx = ctx };
Ok(())
}
pub fn detach(pctx: *mut Context) -> Result<(), CUresult> {
if pctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
GlobalState::lock_current_context_unchecked(|unchecked_ctx| {
let ctx = unchecked_ctx.as_result_mut()?;
if ctx.ref_count.decr() {
Context::destroy_impl(unchecked_ctx)?;
}
Ok::<_, CUresult>(())
})?
}
#[cfg(test)]
pub fn is_context_stack_empty() -> bool {
CONTEXT_STACK.with(|stack| stack.borrow().is_empty())
}
#[cfg(test)]
mod tests {
mod test {
use super::super::test::CudaDriverFns;
use super::super::CUresult;
use std::{ffi::c_void, ptr};

View file

@ -1,24 +1,21 @@
use super::{context, transmute_lifetime, CUresult, Error};
use super::{context, CUresult, GlobalState};
use crate::cuda;
use cuda::{CUdevice_attribute, CUuuid_st};
use std::{
cmp, mem,
os::raw::{c_char, c_int},
ptr,
sync::{
atomic::{AtomicU32, Ordering},
Mutex, MutexGuard,
},
sync::atomic::{AtomicU32, Ordering},
};
const PROJECT_URL_SUFFIX: &'static str = " [github.com/vosen/notCUDA]";
static mut DEVICES: Option<Vec<Mutex<Device>>> = None;
#[repr(transparent)]
#[derive(Clone, Copy)]
#[derive(Clone, Copy, Eq, PartialEq, Hash)]
pub struct Index(pub c_int);
pub struct Device {
pub index: Index,
pub base: l0::Device,
pub default_queue: l0::CommandQueue,
pub l0_context: l0::Context,
@ -33,17 +30,19 @@ unsafe impl Send for Device {}
impl Device {
// Unsafe because it does not fully initalize primary_context
unsafe fn new(drv: &l0::Driver, d: l0::Device, idx: usize) -> l0::Result<Self> {
unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> {
let mut ctx = l0::Context::new(drv)?;
let queue = l0::CommandQueue::new(&mut ctx, &d)?;
let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?;
let primary_context = context::Context::new(context::ContextData::new(
&mut ctx,
&l0_dev,
0,
true,
Index(idx as c_int),
ptr::null(),
));
ptr::null_mut(),
)?);
Ok(Self {
base: d,
index: Index(idx as c_int),
base: l0_dev,
default_queue: queue,
l0_context: ctx,
primary_context: primary_context,
@ -93,83 +92,53 @@ impl Device {
Err(e) => Err(e),
}
}
pub fn late_init(&mut self) {
self.primary_context.as_option_mut().unwrap().device = self as *mut _;
}
}
pub fn init(driver: &l0::Driver) -> l0::Result<()> {
pub fn init(driver: &l0::Driver) -> Result<Vec<Device>, CUresult> {
let ze_devices = driver.devices()?;
let mut devices = ze_devices
.into_iter()
.enumerate()
.map(|(idx, d)| unsafe { Device::new(driver, d, idx) }.map(Mutex::new))
.map(|(idx, d)| unsafe { Device::new(driver, d, idx) })
.collect::<Result<Vec<_>, _>>()?;
for d in devices.iter_mut() {
d.get_mut()
.unwrap()
.primary_context
.as_mut()
.unwrap()
.device = d;
for dev in devices.iter_mut() {
dev.late_init();
dev.primary_context.late_init();
}
unsafe { DEVICES = Some(devices) };
Ok(devices)
}
pub fn get_count(count: *mut c_int) -> Result<(), CUresult> {
let len = GlobalState::lock(|state| state.devices.len())?;
unsafe { *count = len as c_int };
Ok(())
}
fn devices() -> Result<&'static Vec<Mutex<Device>>, CUresult> {
match unsafe { &DEVICES } {
Some(devs) => Ok(devs),
None => Err(CUresult::CUDA_ERROR_NOT_INITIALIZED),
}
}
pub fn get_device_ref(Index(dev_idx): Index) -> Result<&'static Mutex<Device>, CUresult> {
let devs = devices()?;
if dev_idx < 0 || dev_idx >= devs.len() as c_int {
return Err(CUresult::CUDA_ERROR_INVALID_DEVICE);
}
Ok(&devs[dev_idx as usize])
}
pub fn get_device(dev_idx: Index) -> Result<MutexGuard<'static, Device>, CUresult> {
let dev = get_device_ref(dev_idx)?;
dev.lock().map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
}
pub fn get_count(count: *mut c_int) -> CUresult {
let len = devices().map(|d| d.len());
match len {
Ok(len) => {
unsafe { *count = len as c_int };
CUresult::CUDA_SUCCESS
}
Err(e) => e,
}
}
pub fn get(device: *mut Index, ordinal: c_int) -> CUresult {
pub fn get(device: *mut Index, ordinal: c_int) -> Result<(), CUresult> {
if device == ptr::null_mut() || ordinal < 0 {
return CUresult::CUDA_ERROR_INVALID_VALUE;
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let len = devices().map(|d| d.len());
match len {
Ok(len) if ordinal < (len as i32) => {
unsafe { *device = Index(ordinal) };
CUresult::CUDA_SUCCESS
}
Ok(_) => CUresult::CUDA_ERROR_INVALID_VALUE,
Err(e) => e,
let len = GlobalState::lock(|state| state.devices.len())?;
if ordinal < (len as i32) {
unsafe { *device = Index(ordinal) };
Ok(())
} else {
Err(CUresult::CUDA_ERROR_INVALID_VALUE)
}
}
pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult> {
pub fn get_name(name: *mut c_char, len: i32, dev_idx: Index) -> Result<(), CUresult> {
if name == ptr::null_mut() || len < 0 {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
// This is safe because devices are 'static
let name_ptr = {
let mut dev = get_device(dev)?;
let props = dev.get_properties().map_err(Into::<CUresult>::into)?;
props.name.as_ptr()
};
let name_ptr = GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_properties()?;
Ok::<_, l0::sys::ze_result_t>(props.name.as_ptr())
})??;
let name_len = (0..256)
.position(|i| unsafe { *name_ptr.add(i) } == 0)
.unwrap_or(256);
@ -189,20 +158,14 @@ pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult>
Ok(())
}
pub fn total_mem_v2(bytes: *mut usize, dev: Index) -> Result<(), CUresult> {
pub fn total_mem_v2(bytes: *mut usize, dev_idx: Index) -> Result<(), CUresult> {
if bytes == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
// This is safe because devices are 'static
let mem_props = {
let mut dev = get_device(dev)?;
unsafe {
transmute_lifetime(
dev.get_memory_properties()
.map_err(Into::<CUresult>::into)?,
)
}
};
let mem_props = GlobalState::lock_device(dev_idx, |dev| {
let mem_props = dev.get_memory_properties()?;
Ok::<_, l0::sys::ze_result_t>(mem_props)
})??;
let max_mem = mem_props
.iter()
.map(|p| p.totalSize)
@ -228,56 +191,101 @@ impl CUdevice_attribute {
}
}
pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Result<(), Error> {
pub fn get_attribute(
pi: *mut i32,
attrib: CUdevice_attribute,
dev_idx: Index,
) -> Result<(), CUresult> {
if pi == ptr::null_mut() {
return Err(Error::Cuda(CUresult::CUDA_ERROR_INVALID_VALUE));
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
if let Some(value) = attrib.get_static_value() {
unsafe { *pi = value };
return Ok(());
}
let mut dev = get_device(dev).map_err(Error::Cuda)?;
let value = match attrib {
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT => {
dev.get_properties().map_err(Error::L0)?.maxHardwareContexts as i32
GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_properties()?;
Ok::<_, l0::sys::ze_result_t>(props.maxHardwareContexts as i32)
})??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT => {
let props = dev.get_properties().map_err(Error::L0)?;
(props.numSlices * props.numSubslicesPerSlice * props.numEUsPerSubslice) as i32
GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_properties()?;
Ok::<_, l0::sys::ze_result_t>(
(props.numSlices * props.numSubslicesPerSlice * props.numEUsPerSubslice) as i32,
)
})??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH => {
GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_image_properties()?;
Ok::<_, l0::sys::ze_result_t>(cmp::min(
props.maxImageDims1D,
c_int::max_value() as u32,
) as c_int)
})??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH => cmp::min(
dev.get_image_properties()
.map_err(Error::L0)?
.maxImageDims1D,
c_int::max_value() as u32,
) as c_int,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X => {
let props = dev.get_compute_properties().map_err(Error::L0)?;
cmp::max(i32::max_value() as u32, props.maxGroupCountX) as i32
GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(cmp::max(
i32::max_value() as u32,
props.maxGroupCountX,
) as i32)
})??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y => {
let props = dev.get_compute_properties().map_err(Error::L0)?;
cmp::max(i32::max_value() as u32, props.maxGroupCountY) as i32
GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(cmp::max(
i32::max_value() as u32,
props.maxGroupCountY,
) as i32)
})??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z => {
let props = dev.get_compute_properties().map_err(Error::L0)?;
cmp::max(i32::max_value() as u32, props.maxGroupCountZ) as i32
GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(cmp::max(
i32::max_value() as u32,
props.maxGroupCountZ,
) as i32)
})??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X => {
let props = dev.get_compute_properties().map_err(Error::L0)?;
cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32
GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(
cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32,
)
})??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y => {
let props = dev.get_compute_properties().map_err(Error::L0)?;
cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32
GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(
cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32,
)
})??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z => {
let props = dev.get_compute_properties().map_err(Error::L0)?;
cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32
GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(
cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32,
)
})??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
let props = dev.get_compute_properties().map_err(Error::L0)?;
cmp::max(i32::max_value() as u32, props.maxTotalGroupSize) as i32
GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_compute_properties()?;
Ok::<_, l0::sys::ze_result_t>(cmp::max(
i32::max_value() as u32,
props.maxTotalGroupSize,
) as i32)
})??
}
_ => {
// TODO: support more attributes for CUDA runtime
@ -293,14 +301,11 @@ pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Re
Ok(())
}
pub fn get_uuid(uuid: *mut CUuuid_st, dev: Index) -> Result<(), Error> {
let ze_uuid = {
get_device(dev)
.map_err(Error::Cuda)?
.get_properties()
.map_err(Error::L0)?
.uuid
};
pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> {
let ze_uuid = GlobalState::lock_device(dev_idx, |dev| {
let props = dev.get_properties()?;
Ok::<_, l0::sys::ze_result_t>(props.uuid)
})??;
unsafe {
*uuid = CUuuid_st {
bytes: mem::transmute(ze_uuid.id),
@ -309,53 +314,39 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev: Index) -> Result<(), Error> {
Ok(())
}
pub fn with_current_exclusive<F: FnOnce(&mut Device) -> R, R>(f: F) -> Result<R, CUresult> {
let dev = super::context::with_current(|ctx| ctx.device);
dev.and_then(|dev| {
unsafe { &*dev }
.try_lock()
.map(|mut dev| f(&mut dev))
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
})
}
pub fn with_exclusive<F: FnOnce(&mut Device) -> R, R>(dev: Index, f: F) -> Result<R, CUresult> {
let dev = get_device_ref(dev)?;
dev.try_lock()
.map(|mut dev| f(&mut dev))
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
}
pub fn primary_ctx_get_state(
idx: Index,
dev_idx: Index,
flags: *mut u32,
active: *mut i32,
) -> Result<(), CUresult> {
let (ctx_ptr, flags_ptr) = with_exclusive(idx, |dev| {
let (is_active, flags_value) = GlobalState::lock_device(dev_idx, |dev| {
// This is safe because primary context can't be dropped
let ctx_ptr = &dev.primary_context as *const _;
let ctx_ptr = &mut dev.primary_context as *mut _;
let flags_ptr =
(&unsafe { dev.primary_context.as_ref_unchecked() }.flags) as *const AtomicU32;
(ctx_ptr, flags_ptr)
})?;
let is_active = context::CONTEXT_STACK
.with(|stack| stack.borrow().last().map(|x| *x))
.map(|current| current == ctx_ptr)
.unwrap_or(false);
let flags_value = unsafe { &*flags_ptr }.load(Ordering::Relaxed);
unsafe { *flags = flags_value };
let is_active = context::CONTEXT_STACK
.with(|stack| stack.borrow().last().map(|x| *x))
.map(|current| current == ctx_ptr)
.unwrap_or(false);
let flags_value = unsafe { &*flags_ptr }.load(Ordering::Relaxed);
Ok::<_, l0::sys::ze_result_t>((is_active, flags_value))
})??;
unsafe { *active = if is_active { 1 } else { 0 } };
unsafe { *flags = flags_value };
Ok(())
}
pub fn primary_ctx_retain(pctx: *mut *mut context::Context, dev: Index) -> Result<(), CUresult> {
let ctx_ptr = with_exclusive(dev, |dev| &mut dev.primary_context as *mut _)?;
pub fn primary_ctx_retain(
pctx: *mut *mut context::Context,
dev_idx: Index,
) -> Result<(), CUresult> {
let ctx_ptr = GlobalState::lock_device(dev_idx, |dev| &mut dev.primary_context as *mut _)?;
unsafe { *pctx = ctx_ptr };
Ok(())
}
#[cfg(test)]
mod tests {
mod test {
use super::super::test::CudaDriverFns;
use super::super::CUresult;

View file

@ -4,7 +4,7 @@ use crate::{
cuda_impl,
};
use super::{context, device, module, Decuda, Encuda};
use super::{context, context::ContextData, module, Decuda, Encuda, GlobalState};
use std::mem;
use std::os::raw::{c_uint, c_ulong, c_ushort};
use std::{
@ -110,17 +110,8 @@ static CUDART_INTERFACE_VTABLE: [VTableEntry; CUDART_INTERFACE_LENGTH] = [
VTableEntry { ptr: ptr::null() },
];
unsafe extern "C" fn cudart_interface_fn1(pctx: *mut CUcontext, dev: CUdevice) -> CUresult {
cudart_interface_fn1_impl(pctx.decuda(), dev.decuda()).encuda()
}
fn cudart_interface_fn1_impl(
pctx: *mut *mut context::Context,
dev: device::Index,
) -> Result<(), CUresult> {
let ctx_ptr = device::with_exclusive(dev, |d| &mut d.primary_context as *mut context::Context)?;
unsafe { *pctx = ctx_ptr };
Ok(())
unsafe extern "C" fn cudart_interface_fn1(_pctx: *mut CUcontext, _dev: CUdevice) -> CUresult {
super::unimplemented()
}
/*
@ -200,7 +191,7 @@ unsafe extern "C" fn get_module_from_cubin(
ptr1: *mut c_void,
ptr2: *mut c_void,
) -> CUresult {
// Not sure what those twoparameters are actually used for,
// Not sure what those two parameters are actually used for,
// they are somehow involved in __cudaRegisterHostVar
if ptr1 != ptr::null_mut() || ptr2 != ptr::null_mut() {
return CUresult::CUDA_ERROR_NOT_SUPPORTED;
@ -234,10 +225,13 @@ unsafe extern "C" fn get_module_from_cubin(
},
Err(_) => continue,
};
let module = module::ModuleData::compile_spirv(kernel_text_string);
let module = module::SpirvModule::new(kernel_text_string);
match module {
Ok(module) => {
*result = Box::into_raw(Box::new(module));
match module::load_data_impl(result, module) {
Ok(()) => {}
Err(err) => return err,
}
return CUresult::CUDA_SUCCESS;
}
Err(_) => continue,
@ -309,7 +303,7 @@ unsafe extern "C" fn context_local_storage_ctor(
}
fn context_local_storage_ctor_impl(
mut cu_ctx: *mut context::Context,
cu_ctx: *mut context::Context,
mgr: *mut cuda_impl::rt::ContextStateManager,
ctx_state: *mut cuda_impl::rt::ContextState,
dtor_cb: Option<
@ -320,26 +314,11 @@ fn context_local_storage_ctor_impl(
),
>,
) -> Result<(), CUresult> {
if cu_ctx == ptr::null_mut() {
context::get_current(&mut cu_ctx)?;
}
if cu_ctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
unsafe { &*cu_ctx }
.as_ref()
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
.and_then(|ctx| {
ctx.mutable
.try_lock()
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
.map(|mut mutable| {
mutable.cuda_manager = mgr;
mutable.cuda_state = ctx_state;
mutable.cuda_dtor_cb = dtor_cb;
})
})?;
Ok(())
lock_context(cu_ctx, |ctx: &mut ContextData| {
ctx.cuda_manager = mgr;
ctx.cuda_state = ctx_state;
ctx.cuda_dtor_cb = dtor_cb;
})
}
// some kind of dtor
@ -357,24 +336,10 @@ unsafe extern "C" fn context_local_storage_get_state(
fn context_local_storage_get_state_impl(
ctx_state: *mut *mut cuda_impl::rt::ContextState,
mut cu_ctx: *mut context::Context,
cu_ctx: *mut context::Context,
_: *mut cuda_impl::rt::ContextStateManager,
) -> Result<(), CUresult> {
if cu_ctx == ptr::null_mut() {
context::get_current(&mut cu_ctx)?;
}
if cu_ctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let cuda_state = unsafe { &*cu_ctx }
.as_ref()
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
.and_then(|ctx| {
ctx.mutable
.try_lock()
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
.map(|mutable| mutable.cuda_state)
})?;
let cuda_state = lock_context(cu_ctx, |ctx: &mut ContextData| ctx.cuda_state)?;
if cuda_state == ptr::null_mut() {
Err(CUresult::CUDA_ERROR_INVALID_VALUE)
} else {
@ -382,3 +347,17 @@ fn context_local_storage_get_state_impl(
Ok(())
}
}
fn lock_context<T>(
cu_ctx: *mut context::Context,
fn_impl: impl FnOnce(&mut ContextData) -> T,
) -> Result<T, CUresult> {
if cu_ctx == ptr::null_mut() {
GlobalState::lock_current_context(fn_impl)
} else {
GlobalState::lock(|_| {
let ctx = unsafe { &mut *cu_ctx }.as_result_mut()?;
Ok(fn_impl(ctx))
})?
}
}

View file

@ -1,11 +1,28 @@
use ::std::os::raw::{c_uint, c_void};
use std::ptr;
use super::{device, stream::Stream, CUresult};
use super::{CUresult, GlobalState, HasLivenessCookie, LiveCheck, stream::Stream};
pub struct Function {
pub type Function = LiveCheck<FunctionData>;
impl HasLivenessCookie for FunctionData {
#[cfg(target_pointer_width = "64")]
const COOKIE: usize = 0x5e2ab14d5840678e;
#[cfg(target_pointer_width = "32")]
const COOKIE: usize = 0x33e6a1e6;
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
fn try_drop(&mut self) -> Result<(), CUresult> {
Ok(())
}
}
pub struct FunctionData {
pub base: l0::Kernel<'static>,
pub arg_size: Vec<usize>,
pub use_shared_mem: bool,
}
pub fn launch_kernel(
@ -17,36 +34,43 @@ pub fn launch_kernel(
block_dim_y: c_uint,
block_dim_z: c_uint,
shared_mem_bytes: c_uint,
strean: *mut Stream,
hstream: *mut Stream,
kernel_params: *mut *mut c_void,
extra: *mut *mut c_void,
) -> Result<(), CUresult> {
if f == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
if shared_mem_bytes != 0 || strean != ptr::null_mut() || extra != ptr::null_mut() {
if extra != ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_NOT_SUPPORTED);
}
let func = unsafe { &*f };
for (i, arg_size) in func.arg_size.iter().copied().enumerate() {
unsafe {
func.base
.set_arg_raw(i as u32, arg_size, *kernel_params.add(i))?
};
}
unsafe { &*f }
.base
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
device::with_current_exclusive(|dev| {
let mut cmd_list = l0::CommandList::new(&mut dev.l0_context, &dev.base)?;
GlobalState::lock_stream(hstream, |stream| {
let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?;
for (i, arg_size) in func.arg_size.iter().enumerate() {
unsafe {
func.base
.set_arg_raw(i as u32, *arg_size, *kernel_params.add(i))?
};
}
if func.use_shared_mem {
unsafe {
func.base.set_arg_raw(
func.arg_size.len() as u32,
shared_mem_bytes as usize,
ptr::null(),
)?
};
}
func.base
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
let mut cmd_list = stream.command_list()?;
cmd_list.append_launch_kernel(
&unsafe { &*f }.base,
&mut func.base,
&[grid_dim_x, grid_dim_y, grid_dim_z],
None,
&mut [],
)?;
dev.default_queue.execute(cmd_list)?;
l0::Result::Ok(())
})??;
Ok(())
stream.queue.execute(cmd_list)?;
Ok(())
})?
}

View file

@ -1,57 +1,34 @@
use super::CUresult;
use super::{stream, CUresult, GlobalState};
use std::ffi::c_void;
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
let alloc_result = super::device::with_current_exclusive(|dev| unsafe {
dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0)
});
match alloc_result {
Ok(Ok(alloc)) => {
unsafe { *dptr = alloc };
CUresult::CUDA_SUCCESS
}
Ok(Err(e)) => e.into(),
Err(e) => e,
}
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
let ptr = GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device };
Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?)
})??;
unsafe { *dptr = ptr };
Ok(())
}
pub fn copy_v2(
dst: *mut c_void,
src: *const c_void,
bytesize: usize,
) -> Result<Result<(), l0::sys::ze_result_t>, CUresult> {
super::device::with_current_exclusive(|dev| unsafe {
memcpy_impl(
&mut dev.l0_context,
dst,
src,
bytesize,
&dev.base,
&mut dev.default_queue,
)
pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
let mut cmd_list = stream.command_list()?;
unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?;
stream.queue.execute(cmd_list)?;
Ok::<_, CUresult>(())
})?
}
pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device };
Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?)
})
}
unsafe fn memcpy_impl(
ctx: &mut l0::Context,
dst: *mut c_void,
src: *const c_void,
bytes_count: usize,
dev: &l0::Device,
queue: &mut l0::CommandQueue,
) -> l0::Result<()> {
let mut cmd_list = l0::CommandList::new(ctx, &dev)?;
cmd_list.append_memory_copy_unsafe(dst, src, bytes_count, None, &mut [])?;
queue.execute(cmd_list)?;
Ok(())
}
pub(crate) fn free_v2(_: *mut c_void)-> l0::Result<()> {
Ok(())
.map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
}
#[cfg(test)]
mod tests {
mod test {
use super::super::test::CudaDriverFns;
use super::super::CUresult;
use std::ptr;
@ -82,4 +59,20 @@ mod tests {
assert_ne!(mem, ptr::null_mut());
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
}
cuda_driver_test!(free_without_ctx);
fn free_without_ctx<T: CudaDriverFns>() {
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
let mut ctx = ptr::null_mut();
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
let mut mem = ptr::null_mut();
assert_eq!(
T::cuMemAlloc_v2(&mut mem, std::mem::size_of::<usize>()),
CUresult::CUDA_SUCCESS
);
assert_ne!(mem, ptr::null_mut());
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
assert_eq!(T::cuMemFree_v2(mem), CUresult::CUDA_ERROR_INVALID_VALUE);
}
}

View file

@ -1,5 +1,15 @@
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st};
use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex};
use crate::{
cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st},
r#impl::device::Device,
};
use std::{
ffi::c_void,
mem::{self, ManuallyDrop},
os::raw::c_int,
ptr,
sync::Mutex,
sync::TryLockError,
};
#[cfg(test)]
#[macro_use]
@ -7,9 +17,9 @@ pub mod test;
pub mod context;
pub mod device;
pub mod export_table;
pub mod function;
pub mod memory;
pub mod module;
pub mod function;
pub mod stream;
#[cfg(debug_assertions)]
@ -22,8 +32,11 @@ pub fn unimplemented() -> CUresult {
CUresult::CUDA_ERROR_NOT_SUPPORTED
}
pub trait HasLivenessCookie {
pub trait HasLivenessCookie: Sized {
const COOKIE: usize;
const LIVENESS_FAIL: CUresult;
fn try_drop(&mut self) -> Result<(), CUresult>;
}
// This struct is a best-effort check if wrapped value has been dropped,
@ -42,19 +55,23 @@ impl<T: HasLivenessCookie> LiveCheck<T> {
}
}
fn destroy_impl(this: *mut Self) -> Result<(), CUresult> {
let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(this) });
ctx_box.try_drop()?;
unsafe { ManuallyDrop::drop(&mut ctx_box) };
Ok(())
}
unsafe fn ptr_from_inner(this: *mut T) -> *mut Self {
let outer_ptr = (this as *mut u8).sub(mem::size_of::<usize>());
outer_ptr as *mut Self
}
pub unsafe fn as_ref_unchecked(&self) -> &T {
&self.data
}
pub fn as_ref(&self) -> Option<&T> {
if self.cookie == T::COOKIE {
Some(&self.data)
} else {
None
}
}
pub fn as_mut(&mut self) -> Option<&mut T> {
pub fn as_option_mut(&mut self) -> Option<&mut T> {
if self.cookie == T::COOKIE {
Some(&mut self.data)
} else {
@ -62,14 +79,31 @@ impl<T: HasLivenessCookie> LiveCheck<T> {
}
}
pub fn as_result(&self) -> Result<&T, CUresult> {
if self.cookie == T::COOKIE {
Ok(&self.data)
} else {
Err(T::LIVENESS_FAIL)
}
}
pub fn as_result_mut(&mut self) -> Result<&mut T, CUresult> {
if self.cookie == T::COOKIE {
Ok(&mut self.data)
} else {
Err(T::LIVENESS_FAIL)
}
}
#[must_use]
pub fn try_drop(&mut self) -> bool {
pub fn try_drop(&mut self) -> Result<(), CUresult> {
if self.cookie == T::COOKIE {
self.cookie = 0;
self.data.try_drop()?;
unsafe { ManuallyDrop::drop(&mut self.data) };
return true;
return Ok(());
}
false
Err(T::LIVENESS_FAIL)
}
}
@ -121,6 +155,12 @@ impl From<l0::sys::ze_result_t> for CUresult {
}
}
impl<T> From<TryLockError<T>> for CUresult {
fn from(_: TryLockError<T>) -> Self {
CUresult::CUDA_ERROR_ILLEGAL_STATE
}
}
pub trait Encuda {
type To: Sized;
fn encuda(self: Self) -> Self::To;
@ -157,58 +197,103 @@ impl<T1: Encuda<To = CUresult>, T2: Encuda<To = CUresult>> Encuda for Result<T1,
}
}
pub enum Error {
L0(l0::sys::ze_result_t),
Cuda(CUresult),
}
impl Encuda for Error {
type To = CUresult;
fn encuda(self: Self) -> Self::To {
match self {
Error::L0(e) => e.into(),
Error::Cuda(e) => e,
}
}
}
lazy_static! {
static ref GLOBAL_STATE: Mutex<Option<GlobalState>> = Mutex::new(None);
}
struct GlobalState {
driver: l0::Driver,
devices: Vec<Device>,
}
unsafe impl Send for GlobalState {}
impl GlobalState {
fn lock<T>(f: impl FnOnce(&mut GlobalState) -> T) -> Result<T, CUresult> {
let mut mutex = GLOBAL_STATE
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let global_state = mutex.as_mut().ok_or(CUresult::CUDA_ERROR_ILLEGAL_STATE)?;
Ok(f(global_state))
}
fn lock_device<T>(
device::Index(dev_idx): device::Index,
f: impl FnOnce(&'static mut device::Device) -> T,
) -> Result<T, CUresult> {
if dev_idx < 0 {
return Err(CUresult::CUDA_ERROR_INVALID_DEVICE);
}
Self::lock(|global_state| {
if dev_idx >= global_state.devices.len() as c_int {
Err(CUresult::CUDA_ERROR_INVALID_DEVICE)
} else {
Ok(f(unsafe {
transmute_lifetime_mut(&mut global_state.devices[dev_idx as usize])
}))
}
})?
}
fn lock_current_context<F: FnOnce(&mut context::ContextData) -> R, R>(
f: F,
) -> Result<R, CUresult> {
Self::lock_current_context_unchecked(|ctx| Ok(f(ctx.as_result_mut()?)))?
}
fn lock_current_context_unchecked<F: FnOnce(&mut context::Context) -> R, R>(
f: F,
) -> Result<R, CUresult> {
context::CONTEXT_STACK.with(|stack| {
stack
.borrow_mut()
.last_mut()
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
.map(|ctx| GlobalState::lock(|_| f(unsafe { &mut **ctx })))?
})
}
fn lock_stream<T>(
stream: *mut stream::Stream,
f: impl FnOnce(&mut stream::StreamData) -> T,
) -> Result<T, CUresult> {
if stream == ptr::null_mut()
|| stream == stream::CU_STREAM_LEGACY
|| stream == stream::CU_STREAM_PER_THREAD
{
Self::lock_current_context(|ctx| Ok(f(&mut ctx.default_stream)))?
} else {
Self::lock(|_| {
let stream = unsafe { &mut *stream }.as_result_mut()?;
Ok(f(stream))
})?
}
}
}
// TODO: implement
fn is_intel_gpu_driver(_: &l0::Driver) -> bool {
true
}
pub fn init() -> l0::Result<()> {
pub fn init() -> Result<(), CUresult> {
let mut global_state = GLOBAL_STATE
.lock()
.map_err(|_| l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?;
.map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
if global_state.is_some() {
return Ok(());
}
l0::init()?;
let drivers = l0::Driver::get()?;
let driver = match drivers.into_iter().find(is_intel_gpu_driver) {
None => return Err(l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN),
Some(driver) => {
device::init(&driver)?;
driver
}
let devices = match drivers.into_iter().find(is_intel_gpu_driver) {
None => return Err(CUresult::CUDA_ERROR_UNKNOWN),
Some(driver) => device::init(&driver)?,
};
*global_state = Some(GlobalState { driver });
*global_state = Some(GlobalState { devices });
drop(global_state);
Ok(())
}
unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T {
unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T {
mem::transmute(t)
}

View file

@ -1,79 +1,90 @@
use std::{
collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex,
collections::hash_map, collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem,
os::raw::c_char, ptr, slice,
};
use super::{function::Function, transmute_lifetime, CUresult};
use super::{
device, function::Function, function::FunctionData, CUresult, GlobalState, HasLivenessCookie,
LiveCheck,
};
use ptx;
pub type Module = Mutex<ModuleData>;
pub type Module = LiveCheck<ModuleData>;
impl HasLivenessCookie for ModuleData {
#[cfg(target_pointer_width = "64")]
const COOKIE: usize = 0xf1313bd46505f98a;
#[cfg(target_pointer_width = "32")]
const COOKIE: usize = 0xbdbe3f15;
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
fn try_drop(&mut self) -> Result<(), CUresult> {
Ok(())
}
}
pub struct ModuleData {
base: l0::Module,
arg_lens: HashMap<CString, Vec<usize>>,
pub spirv: SpirvModule,
// This should be a Vec<>, but I'm feeling lazy
pub device_binaries: HashMap<device::Index, CompiledModule>,
}
pub enum ModuleCompileError<'a> {
Parse(
Vec<ptx::ast::PtxError>,
Option<ptx::ParseError<usize, ptx::Token<'a>, ptx::ast::PtxError>>,
),
Compile(ptx::TranslateError),
L0(l0::sys::ze_result_t),
CUDA(CUresult),
pub struct SpirvModule {
pub binaries: Vec<u32>,
pub kernel_info: HashMap<String, ptx::KernelInfo>,
pub should_link_ptx_impl: Option<&'static [u8]>,
pub build_options: CString,
}
impl<'a> ModuleCompileError<'a> {
pub fn get_build_log(&self) {}
pub struct CompiledModule {
pub base: l0::Module,
pub kernels: HashMap<CString, Box<Function>>,
}
impl<'a> From<ptx::TranslateError> for ModuleCompileError<'a> {
fn from(err: ptx::TranslateError) -> Self {
ModuleCompileError::Compile(err)
impl<L, T, E> From<ptx::ParseError<L, T, E>> for CUresult {
fn from(_: ptx::ParseError<L, T, E>) -> Self {
CUresult::CUDA_ERROR_INVALID_PTX
}
}
impl<'a> From<l0::sys::ze_result_t> for ModuleCompileError<'a> {
fn from(err: l0::sys::ze_result_t) -> Self {
ModuleCompileError::L0(err)
impl From<ptx::TranslateError> for CUresult {
fn from(_: ptx::TranslateError) -> Self {
CUresult::CUDA_ERROR_INVALID_PTX
}
}
impl<'a> From<CUresult> for ModuleCompileError<'a> {
fn from(err: CUresult) -> Self {
ModuleCompileError::CUDA(err)
impl SpirvModule {
pub fn new_raw<'a>(text: *const c_char) -> Result<Self, CUresult> {
let u8_text = unsafe { CStr::from_ptr(text) };
let ptx_text = u8_text
.to_str()
.map_err(|_| CUresult::CUDA_ERROR_INVALID_PTX)?;
Self::new(ptx_text)
}
}
impl ModuleData {
pub fn compile_spirv<'a>(ptx_text: &'a str) -> Result<Module, ModuleCompileError<'a>> {
pub fn new<'a>(ptx_text: &str) -> Result<Self, CUresult> {
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text);
let ast = match ast {
Err(e) => return Err(ModuleCompileError::Parse(errors, Some(e))),
Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)),
Ok(ast) => ast,
};
let (_, spirv, all_arg_lens) = ptx::to_spirv(ast)?;
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
let spirv_module = ptx::to_spirv_module(ast)?;
Ok(SpirvModule {
binaries: spirv_module.assemble(),
kernel_info: spirv_module.kernel_info,
should_link_ptx_impl: spirv_module.should_link_ptx_impl,
build_options: spirv_module.build_options,
})
}
pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> {
let byte_il = unsafe {
slice::from_raw_parts::<u8>(
spirv.as_ptr() as *const _,
spirv.len() * mem::size_of::<u32>(),
slice::from_raw_parts(
self.binaries.as_ptr() as *const u8,
self.binaries.len() * mem::size_of::<u32>(),
)
};
let module = super::device::with_current_exclusive(|dev| {
l0::Module::build_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
});
match module {
Ok((Ok(module), _)) => Ok(Mutex::new(Self {
base: module,
arg_lens: all_arg_lens
.into_iter()
.map(|(k, v)| (CString::new(k).unwrap(), v))
.collect(),
})),
Ok((Err(err), _)) => Err(ModuleCompileError::from(err)),
Err(err) => Err(ModuleCompileError::from(err)),
}
let l0_module = l0::Module::build_spirv(ctx, dev, byte_il, None).0?;
Ok(l0_module)
}
}
@ -85,34 +96,75 @@ pub fn get_function(
if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let name = unsafe { CStr::from_ptr(name) };
let (mut kernel, args_len) = unsafe { &*hmod }
.try_lock()
.map(|module| {
Result::<_, CUresult>::Ok((
l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name)?,
module
.arg_lens
.get(name)
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?
.clone(),
))
})
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??;
kernel.set_indirect_access(
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED,
)?;
unsafe {
*hfunc = Box::into_raw(Box::new(Function {
base: kernel,
arg_size: args_len,
}))
};
let name = unsafe { CStr::from_ptr(name) }.to_owned();
let function: *mut Function = GlobalState::lock_current_context(|ctx| {
let module = unsafe { &mut *hmod }.as_result_mut()?;
let device = unsafe { &mut *ctx.device };
let compiled_module = match module.device_binaries.entry(device.index) {
hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => {
let new_module = CompiledModule {
base: module.spirv.compile(&mut device.l0_context, &device.base)?,
kernels: HashMap::new(),
};
entry.insert(new_module)
}
};
//let compiled_module = unsafe { transmute_lifetime_mut(compiled_module) };
let kernel = match compiled_module.kernels.entry(name) {
hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(),
hash_map::Entry::Vacant(entry) => {
let kernel_info = module
.spirv
.kernel_info
.get(unsafe {
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
})
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
let kernel =
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
entry.insert(Box::new(Function::new(FunctionData {
base: kernel,
arg_size: kernel_info.arguments_sizes.clone(),
use_shared_mem: kernel_info.uses_shared_mem,
})))
}
};
Ok::<_, CUresult>(kernel as *mut _)
})??;
unsafe { *hfunc = function };
Ok(())
}
pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> {
pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<(), CUresult> {
let spirv_data = SpirvModule::new_raw(image as *const _)?;
load_data_impl(pmod, spirv_data)
}
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
let module = GlobalState::lock_current_context(|ctx| {
let device = unsafe { &mut *ctx.device };
let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?;
let mut device_binaries = HashMap::new();
let compiled_module = CompiledModule {
base: l0_module,
kernels: HashMap::new(),
};
device_binaries.insert(device.index, compiled_module);
let module_data = ModuleData {
spirv: spirv_data,
device_binaries,
};
Ok::<_, CUresult>(module_data)
})??;
let module_ptr = Box::into_raw(Box::new(Module::new(module)));
unsafe { *pmod = module_ptr };
Ok(())
}
pub(crate) fn unload(module: *mut Module) -> Result<(), CUresult> {
if module == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
GlobalState::lock(|_| Module::destroy_impl(module))?
}

View file

@ -1,36 +1,114 @@
use std::cell::RefCell;
use super::{
context::{Context, ContextData},
CUresult, GlobalState,
};
use std::{mem, ptr};
use device::Device;
use super::{HasLivenessCookie, LiveCheck};
use super::device;
pub type Stream = LiveCheck<StreamData>;
pub struct Stream {
dev: *mut Device,
}
pub const CU_STREAM_LEGACY: *mut Stream = 1 as *mut _;
pub const CU_STREAM_PER_THREAD: *mut Stream = 2 as *mut _;
pub struct DefaultStream {
streams: Vec<Option<Stream>>,
}
impl HasLivenessCookie for StreamData {
#[cfg(target_pointer_width = "64")]
const COOKIE: usize = 0x512097354de18d35;
impl DefaultStream {
fn new() -> Self {
DefaultStream {
streams: Vec::new(),
#[cfg(target_pointer_width = "32")]
const COOKIE: usize = 0x77d5cc0b;
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
fn try_drop(&mut self) -> Result<(), CUresult> {
if self.context != ptr::null_mut() {
let context = unsafe { &mut *self.context };
if !context.streams.remove(&(self as *mut _)) {
return Err(CUresult::CUDA_ERROR_UNKNOWN);
}
}
Ok(())
}
}
thread_local! {
pub static DEFAULT_STREAM: RefCell<DefaultStream> = RefCell::new(DefaultStream::new());
pub struct StreamData {
pub context: *mut ContextData,
pub queue: l0::CommandQueue,
}
impl StreamData {
pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result<Self, CUresult> {
Ok(StreamData {
context: ptr::null_mut(),
queue: l0::CommandQueue::new(ctx, dev)?,
})
}
pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> {
let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context;
let l0_dev = &unsafe { &*ctx.device }.base;
Ok(StreamData {
context: ctx as *mut _,
queue: l0::CommandQueue::new(l0_ctx, l0_dev)?,
})
}
pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> {
let ctx = unsafe { &mut *self.context };
let dev = unsafe { &mut *ctx.device };
l0::CommandList::new(&mut dev.l0_context, &dev.base)
}
}
impl Drop for StreamData {
fn drop(&mut self) {
if self.context == ptr::null_mut() {
return;
}
unsafe { (&mut *self.context).streams.remove(&(&mut *self as *mut _)) };
}
}
pub(crate) fn get_ctx(hstream: *mut Stream, pctx: *mut *mut Context) -> Result<(), CUresult> {
if pctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let ctx_ptr = GlobalState::lock_stream(hstream, |stream| stream.context)?;
if ctx_ptr == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED);
}
unsafe { *pctx = Context::ptr_from_inner(ctx_ptr) };
Ok(())
}
pub(crate) fn create(phstream: *mut *mut Stream, _flags: u32) -> Result<(), CUresult> {
let stream_ptr = GlobalState::lock_current_context(|ctx| {
let mut stream_box = Box::new(Stream::new(StreamData::new(ctx)?));
let stream_ptr = stream_box.as_mut().as_option_mut().unwrap() as *mut _;
if !ctx.streams.insert(stream_ptr) {
return Err(CUresult::CUDA_ERROR_UNKNOWN);
}
mem::forget(stream_box);
Ok::<_, CUresult>(stream_ptr)
})??;
unsafe { *phstream = Stream::ptr_from_inner(stream_ptr) };
Ok(())
}
pub(crate) fn destroy_v2(pstream: *mut Stream) -> Result<(), CUresult> {
if pstream == ptr::null_mut() || pstream == CU_STREAM_LEGACY || pstream == CU_STREAM_PER_THREAD
{
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
GlobalState::lock(|_| Stream::destroy_impl(pstream))?
}
#[cfg(test)]
mod tests {
mod test {
use crate::cuda::CUstream;
use super::super::test::CudaDriverFns;
use super::super::CUresult;
use std::ptr;
use std::{ptr, thread};
const CU_STREAM_LEGACY: CUstream = 1 as *mut _;
const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _;
@ -65,5 +143,100 @@ mod tests {
CUresult::CUDA_SUCCESS
);
assert_eq!(ctx2, stream_ctx2);
// Cleanup
assert_eq!(T::cuCtxDestroy_v2(ctx1), CUresult::CUDA_SUCCESS);
assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS);
}
cuda_driver_test!(stream_context_destroyed);
fn stream_context_destroyed<T: CudaDriverFns>() {
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
let mut ctx = ptr::null_mut();
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
let mut stream = ptr::null_mut();
assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS);
let mut stream_ctx1 = ptr::null_mut();
assert_eq!(
T::cuStreamGetCtx(stream, &mut stream_ctx1),
CUresult::CUDA_SUCCESS
);
assert_eq!(stream_ctx1, ctx);
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
let mut stream_ctx2 = ptr::null_mut();
// When a context gets destroyed, its streams are also destroyed
let cuda_result = T::cuStreamGetCtx(stream, &mut stream_ctx2);
assert!(
cuda_result == CUresult::CUDA_ERROR_INVALID_HANDLE
|| cuda_result == CUresult::CUDA_ERROR_INVALID_CONTEXT
|| cuda_result == CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED
);
assert_eq!(
T::cuStreamDestroy_v2(stream),
CUresult::CUDA_ERROR_INVALID_HANDLE
);
// Check if creating another context is possible
let mut ctx2 = ptr::null_mut();
assert_eq!(T::cuCtxCreate_v2(&mut ctx2, 0, 0), CUresult::CUDA_SUCCESS);
// Cleanup
assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS);
}
cuda_driver_test!(stream_moves_context_to_another_thread);
fn stream_moves_context_to_another_thread<T: CudaDriverFns>() {
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
let mut ctx = ptr::null_mut();
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
let mut stream = ptr::null_mut();
assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS);
let mut stream_ctx1 = ptr::null_mut();
assert_eq!(
T::cuStreamGetCtx(stream, &mut stream_ctx1),
CUresult::CUDA_SUCCESS
);
assert_eq!(stream_ctx1, ctx);
let stream_ptr = stream as usize;
let stream_ctx_on_thread = thread::spawn(move || {
let mut stream_ctx2 = ptr::null_mut();
assert_eq!(
T::cuStreamGetCtx(stream_ptr as *mut _, &mut stream_ctx2),
CUresult::CUDA_SUCCESS
);
stream_ctx2 as usize
})
.join()
.unwrap();
assert_eq!(stream_ctx1, stream_ctx_on_thread as *mut _);
// Cleanup
assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS);
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
}
cuda_driver_test!(can_destroy_stream);
fn can_destroy_stream<T: CudaDriverFns>() {
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
let mut ctx = ptr::null_mut();
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
let mut stream = ptr::null_mut();
assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS);
assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS);
// Cleanup
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
}
cuda_driver_test!(cant_destroy_default_stream);
fn cant_destroy_default_stream<T: CudaDriverFns>() {
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
let mut ctx = ptr::null_mut();
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
assert_ne!(
T::cuStreamDestroy_v2(super::CU_STREAM_LEGACY as *mut _),
CUresult::CUDA_SUCCESS
);
// Cleanup
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
}
}

View file

@ -1,8 +1,12 @@
#![allow(non_snake_case)]
use crate::{cuda::CUstream, r#impl as notcuda};
use crate::r#impl::CUresult;
use crate::{cuda::CUuuid, r#impl::Encuda};
use crate::cuda as notcuda;
use crate::cuda::CUstream;
use crate::cuda::CUuuid;
use crate::{
cuda::{CUdevice, CUdeviceptr},
r#impl::CUresult,
};
use ::std::{
ffi::c_void,
os::raw::{c_int, c_uint},
@ -37,48 +41,63 @@ pub trait CudaDriverFns {
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult;
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult;
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult;
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult;
fn cuMemFree_v2(mem: *mut c_void) -> CUresult;
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult;
}
pub struct NotCuda();
impl CudaDriverFns for NotCuda {
fn cuInit(_flags: c_uint) -> CUresult {
crate::cuda::cuInit(_flags as _)
notcuda::cuInit(_flags as _)
}
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
notcuda::context::create_v2(pctx as *mut _, flags, notcuda::device::Index(dev)).encuda()
notcuda::cuCtxCreate_v2(pctx as *mut _, flags, CUdevice(dev))
}
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
notcuda::context::destroy_v2(ctx as *mut _)
notcuda::cuCtxDestroy_v2(ctx as *mut _)
}
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
notcuda::context::pop_current_v2(pctx as *mut _)
notcuda::cuCtxPopCurrent_v2(pctx as *mut _)
}
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
notcuda::context::get_api_version(ctx as *mut _, version)
notcuda::cuCtxGetApiVersion(ctx as *mut _, version)
}
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
notcuda::context::get_current(pctx as *mut _).encuda()
notcuda::cuCtxGetCurrent(pctx as *mut _)
}
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
notcuda::memory::alloc_v2(dptr as *mut _, bytesize)
notcuda::cuMemAlloc_v2(dptr as *mut _, bytesize)
}
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
notcuda::device::get_uuid(uuid, notcuda::device::Index(dev)).encuda()
notcuda::cuDeviceGetUuid(uuid, CUdevice(dev))
}
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
notcuda::device::primary_ctx_get_state(notcuda::device::Index(dev), flags, active).encuda()
notcuda::cuDevicePrimaryCtxGetState(CUdevice(dev), flags, active)
}
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
crate::cuda::cuStreamGetCtx(hStream, pctx as _)
notcuda::cuStreamGetCtx(hStream, pctx as _)
}
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
notcuda::cuStreamCreate(stream, flags)
}
fn cuMemFree_v2(dptr: *mut c_void) -> CUresult {
notcuda::cuMemFree_v2(CUdeviceptr(dptr as _))
}
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
notcuda::cuStreamDestroy_v2(stream)
}
}
@ -123,4 +142,16 @@ impl CudaDriverFns for Cuda {
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) }
}
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
unsafe { CUresult(cuda::cuStreamCreate(stream as _, flags as _) as c_uint) }
}
fn cuMemFree_v2(mem: *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuMemFree_v2(mem as _) as c_uint) }
}
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
unsafe { CUresult(cuda::cuStreamDestroy_v2(stream as _) as c_uint) }
}
}

View file

@ -34,8 +34,9 @@ pub use crate::ptx::ModuleParser;
pub use lalrpop_util::lexer::Token;
pub use lalrpop_util::ParseError;
pub use rspirv::dr::Error as SpirvError;
pub use translate::TranslateError as TranslateError;
pub use translate::to_spirv;
pub use translate::to_spirv_module;
pub use translate::KernelInfo;
pub use translate::TranslateError;
pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
x.into_iter().filter_map(|x| x).collect()

View file

@ -12,7 +12,7 @@ fn parse_and_assert(s: &str) {
fn compile_and_assert(s: &str) -> Result<(), TranslateError> {
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
crate::to_spirv(ast)?;
crate::to_spirv_module(ast)?;
Ok(())
}

View file

@ -1,7 +1,7 @@
use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
use std::{borrow::Cow, convert::TryFrom, ffi::CString, hash::Hash, iter, mem};
use std::{borrow::Cow, ffi::CString, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryInto,
@ -450,6 +450,11 @@ pub struct Module {
pub should_link_ptx_impl: Option<&'static [u8]>,
pub build_options: CString,
}
impl Module {
pub fn assemble(&self) -> Vec<u32> {
self.spirv.assemble()
}
}
pub struct KernelInfo {
pub arguments_sizes: Vec<usize>,
@ -1046,8 +1051,12 @@ fn emit_function_header<'a>(
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
if let MethodName::Kernel(name) = func_decl.name {
let args_lens = func_decl
.input
let input_args = if !func_decl.uses_shared_mem {
func_decl.input.as_slice()
} else {
&func_decl.input[0..func_decl.input.len() - 1]
};
let args_lens = input_args
.iter()
.map(|param| param.v_type.size_of())
.collect();
@ -1135,21 +1144,6 @@ fn emit_function_header<'a>(
Ok(())
}
pub fn to_spirv<'a>(
ast: ast::Module<'a>,
) -> Result<(Option<&'static [u8]>, Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
let module = to_spirv_module(ast)?;
Ok((
module.should_link_ptx_impl,
module.spirv.assemble(),
module
.kernel_info
.into_iter()
.map(|(k, v)| (k, v.arguments_sizes))
.collect(),
))
}
fn emit_capabilities(builder: &mut dr::Builder) {
builder.capability(spirv::Capability::GenericPointer);
builder.capability(spirv::Capability::Linkage);