Add support for declaring __local variables and their alignment
This commit is contained in:
parent
efd83981b8
commit
0f4a4c634b
21
ptx/src/test/spirv_run/local_align.ptx
Normal file
21
ptx/src/test/spirv_run/local_align.ptx
Normal file
|
@ -0,0 +1,21 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry local_align(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.local .align 8 .b8 __local_depot0[8];
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u64 temp;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.u64 temp, [in_addr];
|
||||
st.u64 [out_addr], temp;
|
||||
ret;
|
||||
}
|
38
ptx/src/test/spirv_run/local_align.spvtxt
Normal file
38
ptx/src/test/spirv_run/local_align.spvtxt
Normal file
|
@ -0,0 +1,38 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
%1 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %5 "local_align"
|
||||
OpDecorate %8 Alignment 8
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%4 = OpTypeFunction %void %ulong %ulong
|
||||
%uchar = OpTypeInt 8 0
|
||||
%_arr_uchar_8 = OpTypeArray %uchar %8
|
||||
%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||
%5 = OpFunction %void None %4
|
||||
%6 = OpFunctionParameter %ulong
|
||||
%7 = OpFunctionParameter %ulong
|
||||
%18 = OpLabel
|
||||
%8 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup
|
||||
%9 = OpVariable %_ptr_Function_ulong Function
|
||||
%10 = OpVariable %_ptr_Function_ulong Function
|
||||
%11 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %9 %6
|
||||
OpStore %10 %7
|
||||
%13 = OpLoad %ulong %9
|
||||
%16 = OpConvertUToPtr %_ptr_Generic_ulong %13
|
||||
%12 = OpLoad %ulong %16
|
||||
OpStore %11 %12
|
||||
%14 = OpLoad %ulong %10
|
||||
%15 = OpLoad %ulong %11
|
||||
%17 = OpConvertUToPtr %_ptr_Generic_ulong %14
|
||||
OpStore %17 %15
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -51,6 +51,7 @@ test_ptx!(shl, [11u64], [44u64]);
|
|||
test_ptx!(cvt_sat_s_u, [-1i32], [0i32]);
|
||||
test_ptx!(cvta, [3.0f32], [3.0f32]);
|
||||
test_ptx!(block, [1u64], [2u64]);
|
||||
test_ptx!(local_align, [1u64], [1u64]);
|
||||
|
||||
struct DisplayError<T: Debug> {
|
||||
err: T,
|
||||
|
|
|
@ -5,20 +5,21 @@ use std::{borrow::Cow, iter, mem};
|
|||
|
||||
use rspirv::binary::Assemble;
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||
enum SpirvType {
|
||||
Base(SpirvScalarKey),
|
||||
Pointer(SpirvScalarKey, spirv::StorageClass),
|
||||
Array(SpirvScalarKey, u32),
|
||||
Pointer(Box<SpirvType>, spirv::StorageClass),
|
||||
}
|
||||
|
||||
impl SpirvType {
|
||||
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
|
||||
let key = match t {
|
||||
ast::Type::Scalar(typ) => SpirvScalarKey::from(typ),
|
||||
ast::Type::ExtendedScalar(typ) => SpirvScalarKey::from(typ),
|
||||
ast::Type::Array(_, _) => todo!(),
|
||||
ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
|
||||
ast::Type::ExtendedScalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
|
||||
ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len),
|
||||
};
|
||||
SpirvType::Pointer(key, sc)
|
||||
SpirvType::Pointer(Box::new(key), sc)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -27,7 +28,7 @@ impl From<ast::Type> for SpirvType {
|
|||
match t {
|
||||
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
|
||||
ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()),
|
||||
ast::Type::Array(_, _) => todo!(),
|
||||
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -126,13 +127,20 @@ impl TypeWordMap {
|
|||
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
|
||||
match t {
|
||||
SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key),
|
||||
SpirvType::Pointer(typ, storage) => {
|
||||
let base = self.get_or_add_spirv_scalar(b, typ);
|
||||
SpirvType::Pointer(ref typ, storage) => {
|
||||
let base = self.get_or_add(b, *typ.clone());
|
||||
*self
|
||||
.complex
|
||||
.entry(t)
|
||||
.or_insert_with(|| b.type_pointer(None, storage, base))
|
||||
}
|
||||
SpirvType::Array(typ, len) => {
|
||||
let base = self.get_or_add_spirv_scalar(b, typ);
|
||||
*self
|
||||
.complex
|
||||
.entry(t)
|
||||
.or_insert_with(|| b.type_array(base, len))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -248,7 +256,7 @@ fn normalize_labels(
|
|||
labels_in_use.insert(cond.if_true);
|
||||
labels_in_use.insert(cond.if_false);
|
||||
}
|
||||
Statement::Variable(_, _, _)
|
||||
Statement::Variable(_, _, _, _)
|
||||
| Statement::LoadVar(_, _)
|
||||
| Statement::StoreVar(_, _)
|
||||
| Statement::Conversion(_)
|
||||
|
@ -298,9 +306,9 @@ fn normalize_predicates(
|
|||
result.push(Statement::Instruction(inst));
|
||||
}
|
||||
}
|
||||
ast::Statement::Variable(var) => {
|
||||
result.push(Statement::Variable(var.name, var.v_type, var.space))
|
||||
}
|
||||
ast::Statement::Variable(var) => result.push(Statement::Variable(
|
||||
var.name, var.v_type, var.space, var.align,
|
||||
)),
|
||||
// Blocks are flattened when resolving ids
|
||||
ast::Statement::Block(_) => unreachable!(),
|
||||
}
|
||||
|
@ -373,7 +381,7 @@ fn insert_mem_ssa_statements(
|
|||
bra.predicate = generated_id;
|
||||
result.push(Statement::Conditional(bra));
|
||||
}
|
||||
s @ Statement::Variable(_, _, _) | s @ Statement::Label(_) => result.push(s),
|
||||
s @ Statement::Variable(_, _, _, _) | s @ Statement::Label(_) => result.push(s),
|
||||
Statement::LoadVar(_, _)
|
||||
| Statement::StoreVar(_, _)
|
||||
| Statement::Conversion(_)
|
||||
|
@ -395,7 +403,9 @@ fn expand_arguments(
|
|||
let new_inst = inst.map(&mut visitor);
|
||||
result.push(Statement::Instruction(new_inst));
|
||||
}
|
||||
Statement::Variable(id, typ, ss) => result.push(Statement::Variable(id, typ, ss)),
|
||||
Statement::Variable(id, typ, ss, align) => {
|
||||
result.push(Statement::Variable(id, typ, ss, align))
|
||||
}
|
||||
Statement::Label(id) => result.push(Statement::Label(id)),
|
||||
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
|
||||
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
|
||||
|
@ -555,7 +565,7 @@ fn insert_implicit_conversions(
|
|||
s @ Statement::Conditional(_)
|
||||
| s @ Statement::Label(_)
|
||||
| s @ Statement::Constant(_)
|
||||
| s @ Statement::Variable(_, _, _)
|
||||
| s @ Statement::Variable(_, _, _, _)
|
||||
| s @ Statement::LoadVar(_, _)
|
||||
| s @ Statement::StoreVar(_, _) => result.push(s),
|
||||
Statement::Conversion(_) => unreachable!(),
|
||||
|
@ -614,15 +624,24 @@ fn emit_function_body_ops(
|
|||
}
|
||||
match s {
|
||||
Statement::Label(_) => (),
|
||||
Statement::Variable(id, typ, ss) => {
|
||||
Statement::Variable(id, typ, ss, align) => {
|
||||
let type_id = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::new_pointer(*typ, spirv::StorageClass::Function),
|
||||
);
|
||||
if *ss != ast::StateSpace::Reg {
|
||||
todo!()
|
||||
let st_class = match ss {
|
||||
ast::StateSpace::Reg => spirv::StorageClass::Function,
|
||||
ast::StateSpace::Local => spirv::StorageClass::Workgroup,
|
||||
_ => todo!(),
|
||||
};
|
||||
builder.variable(type_id, Some(*id), st_class, None);
|
||||
if let Some(align) = align {
|
||||
builder.decorate(
|
||||
*id,
|
||||
spirv::Decoration::Alignment,
|
||||
&[dr::Operand::LiteralInt32(*align)],
|
||||
);
|
||||
}
|
||||
builder.variable(type_id, Some(*id), spirv::StorageClass::Function, None);
|
||||
}
|
||||
Statement::Constant(cnst) => {
|
||||
let typ_id = map.get_or_add_scalar(builder, cnst.typ);
|
||||
|
@ -1006,7 +1025,10 @@ fn emit_implicit_conversion(
|
|||
ConversionKind::Ptr(space) => {
|
||||
let dst_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::Pointer(SpirvScalarKey::from(to_type), space.to_spirv()),
|
||||
SpirvType::Pointer(
|
||||
Box::new(SpirvType::Base(SpirvScalarKey::from(to_type))),
|
||||
space.to_spirv(),
|
||||
),
|
||||
);
|
||||
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
|
||||
}
|
||||
|
@ -1221,7 +1243,7 @@ impl NumericIdResolver {
|
|||
}
|
||||
|
||||
enum Statement<I> {
|
||||
Variable(spirv::Word, ast::Type, ast::StateSpace),
|
||||
Variable(spirv::Word, ast::Type, ast::StateSpace, Option<u32>),
|
||||
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
|
||||
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
|
||||
Label(u32),
|
||||
|
@ -1235,7 +1257,7 @@ enum Statement<I> {
|
|||
impl Statement<ast::Instruction<ExpandedArgParams>> {
|
||||
fn visit_variable<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
|
||||
match self {
|
||||
Statement::Variable(id, t, ss) => Statement::Variable(f(id), t, ss),
|
||||
Statement::Variable(id, t, ss, align) => Statement::Variable(f(id), t, ss, align),
|
||||
Statement::LoadVar(a, t) => {
|
||||
Statement::LoadVar(a.map(&mut reduced_visitor(f), Some(t)), t)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue