Add support for declaring __local variables and their alignment

This commit is contained in:
Andrzej Janik 2020-09-02 22:47:27 +02:00
parent efd83981b8
commit 0f4a4c634b
4 changed files with 105 additions and 23 deletions

View file

@ -0,0 +1,21 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry local_align(
.param .u64 input,
.param .u64 output
)
{
.local .align 8 .b8 __local_depot0[8];
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u64 temp, [in_addr];
st.u64 [out_addr], temp;
ret;
}

View file

@ -0,0 +1,38 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %5 "local_align"
OpDecorate %8 Alignment 8
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%4 = OpTypeFunction %void %ulong %ulong
%uchar = OpTypeInt 8 0
%_arr_uchar_8 = OpTypeArray %uchar %8
%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%5 = OpFunction %void None %4
%6 = OpFunctionParameter %ulong
%7 = OpFunctionParameter %ulong
%18 = OpLabel
%8 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup
%9 = OpVariable %_ptr_Function_ulong Function
%10 = OpVariable %_ptr_Function_ulong Function
%11 = OpVariable %_ptr_Function_ulong Function
OpStore %9 %6
OpStore %10 %7
%13 = OpLoad %ulong %9
%16 = OpConvertUToPtr %_ptr_Generic_ulong %13
%12 = OpLoad %ulong %16
OpStore %11 %12
%14 = OpLoad %ulong %10
%15 = OpLoad %ulong %11
%17 = OpConvertUToPtr %_ptr_Generic_ulong %14
OpStore %17 %15
OpReturn
OpFunctionEnd

View file

@ -51,6 +51,7 @@ test_ptx!(shl, [11u64], [44u64]);
test_ptx!(cvt_sat_s_u, [-1i32], [0i32]);
test_ptx!(cvta, [3.0f32], [3.0f32]);
test_ptx!(block, [1u64], [2u64]);
test_ptx!(local_align, [1u64], [1u64]);
struct DisplayError<T: Debug> {
err: T,

View file

@ -5,20 +5,21 @@ use std::{borrow::Cow, iter, mem};
use rspirv::binary::Assemble;
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
#[derive(PartialEq, Eq, Hash, Clone)]
enum SpirvType {
Base(SpirvScalarKey),
Pointer(SpirvScalarKey, spirv::StorageClass),
Array(SpirvScalarKey, u32),
Pointer(Box<SpirvType>, spirv::StorageClass),
}
impl SpirvType {
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
let key = match t {
ast::Type::Scalar(typ) => SpirvScalarKey::from(typ),
ast::Type::ExtendedScalar(typ) => SpirvScalarKey::from(typ),
ast::Type::Array(_, _) => todo!(),
ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
ast::Type::ExtendedScalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len),
};
SpirvType::Pointer(key, sc)
SpirvType::Pointer(Box::new(key), sc)
}
}
@ -27,7 +28,7 @@ impl From<ast::Type> for SpirvType {
match t {
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()),
ast::Type::Array(_, _) => todo!(),
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
}
}
}
@ -126,13 +127,20 @@ impl TypeWordMap {
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
match t {
SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key),
SpirvType::Pointer(typ, storage) => {
let base = self.get_or_add_spirv_scalar(b, typ);
SpirvType::Pointer(ref typ, storage) => {
let base = self.get_or_add(b, *typ.clone());
*self
.complex
.entry(t)
.or_insert_with(|| b.type_pointer(None, storage, base))
}
SpirvType::Array(typ, len) => {
let base = self.get_or_add_spirv_scalar(b, typ);
*self
.complex
.entry(t)
.or_insert_with(|| b.type_array(base, len))
}
}
}
@ -248,7 +256,7 @@ fn normalize_labels(
labels_in_use.insert(cond.if_true);
labels_in_use.insert(cond.if_false);
}
Statement::Variable(_, _, _)
Statement::Variable(_, _, _, _)
| Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
| Statement::Conversion(_)
@ -298,9 +306,9 @@ fn normalize_predicates(
result.push(Statement::Instruction(inst));
}
}
ast::Statement::Variable(var) => {
result.push(Statement::Variable(var.name, var.v_type, var.space))
}
ast::Statement::Variable(var) => result.push(Statement::Variable(
var.name, var.v_type, var.space, var.align,
)),
// Blocks are flattened when resolving ids
ast::Statement::Block(_) => unreachable!(),
}
@ -373,7 +381,7 @@ fn insert_mem_ssa_statements(
bra.predicate = generated_id;
result.push(Statement::Conditional(bra));
}
s @ Statement::Variable(_, _, _) | s @ Statement::Label(_) => result.push(s),
s @ Statement::Variable(_, _, _, _) | s @ Statement::Label(_) => result.push(s),
Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
| Statement::Conversion(_)
@ -395,7 +403,9 @@ fn expand_arguments(
let new_inst = inst.map(&mut visitor);
result.push(Statement::Instruction(new_inst));
}
Statement::Variable(id, typ, ss) => result.push(Statement::Variable(id, typ, ss)),
Statement::Variable(id, typ, ss, align) => {
result.push(Statement::Variable(id, typ, ss, align))
}
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
@ -555,7 +565,7 @@ fn insert_implicit_conversions(
s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_, _, _)
| s @ Statement::Variable(_, _, _, _)
| s @ Statement::LoadVar(_, _)
| s @ Statement::StoreVar(_, _) => result.push(s),
Statement::Conversion(_) => unreachable!(),
@ -614,15 +624,24 @@ fn emit_function_body_ops(
}
match s {
Statement::Label(_) => (),
Statement::Variable(id, typ, ss) => {
Statement::Variable(id, typ, ss, align) => {
let type_id = map.get_or_add(
builder,
SpirvType::new_pointer(*typ, spirv::StorageClass::Function),
);
if *ss != ast::StateSpace::Reg {
todo!()
let st_class = match ss {
ast::StateSpace::Reg => spirv::StorageClass::Function,
ast::StateSpace::Local => spirv::StorageClass::Workgroup,
_ => todo!(),
};
builder.variable(type_id, Some(*id), st_class, None);
if let Some(align) = align {
builder.decorate(
*id,
spirv::Decoration::Alignment,
&[dr::Operand::LiteralInt32(*align)],
);
}
builder.variable(type_id, Some(*id), spirv::StorageClass::Function, None);
}
Statement::Constant(cnst) => {
let typ_id = map.get_or_add_scalar(builder, cnst.typ);
@ -1006,7 +1025,10 @@ fn emit_implicit_conversion(
ConversionKind::Ptr(space) => {
let dst_type = map.get_or_add(
builder,
SpirvType::Pointer(SpirvScalarKey::from(to_type), space.to_spirv()),
SpirvType::Pointer(
Box::new(SpirvType::Base(SpirvScalarKey::from(to_type))),
space.to_spirv(),
),
);
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
@ -1221,7 +1243,7 @@ impl NumericIdResolver {
}
enum Statement<I> {
Variable(spirv::Word, ast::Type, ast::StateSpace),
Variable(spirv::Word, ast::Type, ast::StateSpace, Option<u32>),
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
Label(u32),
@ -1235,7 +1257,7 @@ enum Statement<I> {
impl Statement<ast::Instruction<ExpandedArgParams>> {
fn visit_variable<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
match self {
Statement::Variable(id, t, ss) => Statement::Variable(f(id), t, ss),
Statement::Variable(id, t, ss, align) => Statement::Variable(f(id), t, ss, align),
Statement::LoadVar(a, t) => {
Statement::LoadVar(a.map(&mut reduced_visitor(f), Some(t)), t)
}