Implement missing pieces in vector support

This commit is contained in:
Andrzej Janik 2020-09-15 02:34:08 +02:00
parent bb5025c9b1
commit fcf3aaeb16
4 changed files with 226 additions and 146 deletions

View file

@ -317,7 +317,7 @@ pub struct PredAt<ID> {
pub enum Instruction<P: ArgParams> {
Ld(LdData, Arg2<P>),
Mov(MovType, Arg2<P>),
MovVector(MovVectorType, Arg2Vec<P>),
MovVector(MovVectorDetails, Arg2Vec<P>),
Mul(MulDetails, Arg3<P>),
Add(AddDetails, Arg3<P>),
Setp(SetpData, Arg4<P>),
@ -333,6 +333,11 @@ pub enum Instruction<P: ArgParams> {
Abs(AbsDetails, Arg2<P>),
}
#[derive(Copy, Clone)]
pub struct MovVectorDetails {
pub typ: MovVectorType,
pub length: u8,
}
pub struct AbsDetails {
pub flush_to_zero: bool,
pub typ: ScalarType,
@ -377,10 +382,12 @@ pub struct Arg2St<P: ArgParams> {
pub src2: P::Operand,
}
// We duplicate dst here because during further compilation
// composite dst and composite src will receive different ids
pub enum Arg2Vec<P: ArgParams> {
Dst(P::VecOperand, P::ID),
Dst((P::ID, u8), P::ID, P::ID),
Src(P::ID, P::VecOperand),
Both(P::VecOperand, P::VecOperand),
Both((P::ID, u8), P::ID, P::VecOperand),
}
pub struct Arg3<P: ArgParams> {

View file

@ -499,7 +499,7 @@ InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = {
ast::Instruction::Mov(t, a)
},
"mov" <t:MovVectorType> <a:Arg2Vec> => {
ast::Instruction::MovVector(t, a)
ast::Instruction::MovVector(ast::MovVectorDetails{typ: t, length: 0}, a)
}
};
@ -1030,9 +1030,9 @@ Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = {
};
Arg2Vec: ast::Arg2Vec<ast::ParsedArgParams<'input>> = {
<dst:VectorOperand> "," <src:ExtendedID> => ast::Arg2Vec::Dst(dst, src),
<dst:VectorOperand> "," <src:ExtendedID> => ast::Arg2Vec::Dst(dst, dst.0, src),
<dst:ExtendedID> "," <src:VectorOperand> => ast::Arg2Vec::Src(dst, src),
<dst:VectorOperand> "," <src:VectorOperand> => ast::Arg2Vec::Both(dst, src),
<dst:VectorOperand> "," <src:VectorOperand> => ast::Arg2Vec::Both(dst, dst.0, src),
};
VectorOperand: (&'input str, u8) = {

View file

@ -4,20 +4,20 @@
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%58 = OpExtInstImport "OpenCL.std"
%60 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %31 "vector"
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%v2uint = OpTypeVector %uint 2
%62 = OpTypeFunction %v2uint %v2uint
%64 = OpTypeFunction %v2uint %v2uint
%_ptr_Function_v2uint = OpTypePointer Function %v2uint
%_ptr_Function_uint = OpTypePointer Function %uint
%ulong = OpTypeInt 64 0
%66 = OpTypeFunction %void %ulong %ulong
%68 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint
%1 = OpFunction %v2uint None %62
%1 = OpFunction %v2uint None %64
%7 = OpFunctionParameter %v2uint
%30 = OpLabel
%3 = OpVariable %_ptr_Function_v2uint Function
@ -27,40 +27,40 @@
%6 = OpVariable %_ptr_Function_uint Function
OpStore %3 %7
%9 = OpLoad %v2uint %3
%24 = OpCompositeExtract %uint %9 0
%8 = OpCopyObject %uint %24
%27 = OpCompositeExtract %uint %9 0
%8 = OpCopyObject %uint %27
OpStore %5 %8
%11 = OpLoad %v2uint %3
%25 = OpCompositeExtract %uint %11 1
%10 = OpCopyObject %uint %25
%28 = OpCompositeExtract %uint %11 1
%10 = OpCopyObject %uint %28
OpStore %6 %10
%13 = OpLoad %uint %5
%14 = OpLoad %uint %6
%12 = OpIAdd %uint %13 %14
OpStore %6 %12
%16 = OpLoad %uint %6
%26 = OpCopyObject %uint %16
%15 = OpCompositeInsert %uint %26 %15 0
%16 = OpLoad %v2uint %4
%17 = OpLoad %uint %6
%15 = OpCompositeInsert %v2uint %17 %16 0
OpStore %4 %15
%18 = OpLoad %uint %6
%27 = OpCopyObject %uint %18
%17 = OpCompositeInsert %uint %27 %17 1
OpStore %4 %17
%20 = OpLoad %v2uint %4
%29 = OpCompositeExtract %uint %20 1
%28 = OpCopyObject %uint %29
%19 = OpCompositeInsert %uint %28 %19 0
OpStore %4 %19
%19 = OpLoad %v2uint %4
%20 = OpLoad %uint %6
%18 = OpCompositeInsert %v2uint %20 %19 1
OpStore %4 %18
%22 = OpLoad %v2uint %4
%21 = OpCopyObject %v2uint %22
OpStore %2 %21
%23 = OpLoad %v2uint %2
OpReturnValue %23
%23 = OpLoad %v2uint %4
%29 = OpCompositeExtract %uint %23 1
%21 = OpCompositeInsert %v2uint %29 %22 0
OpStore %4 %21
%25 = OpLoad %v2uint %4
%24 = OpCopyObject %v2uint %25
OpStore %2 %24
%26 = OpLoad %v2uint %2
OpReturnValue %26
OpFunctionEnd
%31 = OpFunction %void None %66
%31 = OpFunction %void None %68
%40 = OpFunctionParameter %ulong
%41 = OpFunctionParameter %ulong
%56 = OpLabel
%58 = OpLabel
%32 = OpVariable %_ptr_Function_ulong Function
%33 = OpVariable %_ptr_Function_ulong Function
%34 = OpVariable %_ptr_Function_ulong Function
@ -85,11 +85,13 @@
%48 = OpFunctionCall %v2uint %1 %49
OpStore %36 %48
%51 = OpLoad %v2uint %36
%50 = OpCopyObject %ulong %51
%55 = OpBitcast %ulong %51
%56 = OpCopyObject %ulong %55
%50 = OpCopyObject %ulong %56
OpStore %39 %50
%52 = OpLoad %ulong %35
%53 = OpLoad %v2uint %36
%55 = OpConvertUToPtr %_ptr_Generic_v2uint %52
OpStore %55 %53
%57 = OpConvertUToPtr %_ptr_Generic_v2uint %52
OpStore %57 %53
OpReturn
OpFunctionEnd

View file

@ -323,7 +323,8 @@ fn to_ssa<'input, 'b>(
let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body);
let mut numeric_id_defs = id_defs.finish();
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
let unadorned_statements = resolve_fn_calls(&fn_defs, unadorned_statements);
let unadorned_statements =
add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs);
let (f_args, ssa_statements) =
insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args);
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs);
@ -345,9 +346,10 @@ fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedSta
func
}
fn resolve_fn_calls(
fn_defs: &GlobalFnDeclResolver,
fn add_types_to_statements(
func: Vec<UnadornedStatement>,
fn_defs: &GlobalFnDeclResolver,
id_defs: &NumericIdResolver,
) -> Vec<UnadornedStatement> {
func.into_iter()
.map(|s| {
@ -365,6 +367,17 @@ fn resolve_fn_calls(
};
Statement::Call(resolved_call)
}
Statement::Instruction(ast::Instruction::MovVector(dets, args)) => {
// TODO fail on type mismatch
let new_dets = match id_defs.get_type(*args.dst()) {
Some(ast::Type::Vector(_, len)) => ast::MovVectorDetails {
length: len,
..dets
},
_ => dets,
};
Statement::Instruction(ast::Instruction::MovVector(new_dets, args))
}
s => s,
}
})
@ -685,7 +698,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
fn variable(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
typ: Option<ast::Type>,
_: Option<ast::Type>,
) -> spirv::Word {
desc.op
}
@ -757,34 +770,18 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
typ: ast::MovVectorType,
(scalar_type, vec_len): (ast::MovVectorType, u8),
) -> spirv::Word {
let (vector_id, index) = desc.op;
let new_id = self.id_def.new_id(Some(ast::Type::Scalar(typ.into())));
let composite = if desc.is_dst {
Statement::Composite(CompositeAccess {
typ: typ,
dst: new_id,
src: vector_id,
index: index as u32,
is_write: true
})
} else {
Statement::Composite(CompositeAccess {
typ: typ,
dst: new_id,
src: vector_id,
index: index as u32,
is_write: false
})
};
if desc.is_dst {
self.post_stmts.push(composite);
new_id
} else {
self.func.push(composite);
new_id
}
let new_id = self
.id_def
.new_id(Some(ast::Type::Vector(scalar_type.into(), vec_len)));
self.func.push(Statement::Composite(CompositeRead {
typ: scalar_type,
dst: new_id,
src_composite: desc.op.0,
src_index: desc.op.1 as u32,
}));
new_id
}
}
@ -864,6 +861,55 @@ fn insert_implicit_conversions(
|arg| ast::Instruction::St(st, arg),
)
}
ast::Instruction::Mov(d, mut arg) => {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov-2
// TODO: handle the case of mixed vector/scalar implicit conversions
let inst_typ_is_bit = match d {
ast::MovType::Scalar(t) => {
ast::ScalarType::from(t).kind() == ScalarKind::Bit
}
ast::MovType::Vector(_, _) => false,
};
let mut did_vector_implicit = false;
let mut post_conv = None;
if inst_typ_is_bit {
let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!());
if let ast::Type::Vector(_, _) = src_type {
arg.src = insert_conversion_src(
&mut result,
id_def,
arg.src,
src_type,
d.into(),
ConversionKind::Default,
);
did_vector_implicit = true;
}
let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!());
if let ast::Type::Vector(_, _) = src_type {
post_conv = Some(get_conversion_dst(
id_def,
&mut arg.dst,
d.into(),
dst_type,
ConversionKind::Default,
));
did_vector_implicit = true;
}
}
if did_vector_implicit {
result.push(Statement::Instruction(ast::Instruction::Mov(d, arg)));
} else {
insert_implicit_bitcasts(
&mut result,
id_def,
ast::Instruction::Mov(d, arg),
);
}
if let Some(post_conv) = post_conv {
result.push(post_conv);
}
}
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst),
},
s @ Statement::Composite(_)
@ -1087,10 +1133,31 @@ fn emit_function_body_ops(
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::SetpBool(_, _) => todo!(),
ast::Instruction::MovVector(t, arg) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
builder.copy_object(result_type, Some(arg.dst()), arg.src())?;
}
ast::Instruction::MovVector(typ, arg) => match arg {
ast::Arg2Vec::Dst((dst, dst_index), composite_src, src)
| ast::Arg2Vec::Both((dst, dst_index), composite_src, src) => {
let result_type = map.get_or_add(
builder,
SpirvType::Vector(
SpirvScalarKey::from(ast::ScalarType::from(typ.typ)),
typ.length,
),
);
let result_id = Some(*dst);
builder.composite_insert(
result_type,
result_id,
*src,
*composite_src,
[*dst_index as u32],
)?;
}
ast::Arg2Vec::Src(dst, src) => {
let result_type =
map.get_or_add_scalar(builder, ast::ScalarType::from(typ.typ));
builder.copy_object(result_type, Some(*dst), *src)?;
}
},
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@ -1105,15 +1172,12 @@ fn emit_function_body_ops(
Statement::Composite(c) => {
let result_type = map.get_or_add_scalar(builder, c.typ.into());
let result_id = Some(c.dst);
let indexes = [c.index];
if c.is_write {
let object = c.src;
let composite = c.dst;
builder.composite_insert(result_type, result_id, object, composite, indexes)?;
} else {
let composite = c.src;
builder.composite_extract(result_type, result_id, composite, indexes)?;
}
builder.composite_extract(
result_type,
result_id,
c.src_composite,
[c.src_index],
)?;
}
}
}
@ -1369,15 +1433,15 @@ fn emit_implicit_conversion(
) -> Result<(), dr::Error> {
let from_parts = cv.from.to_parts();
let to_parts = cv.to.to_parts();
match cv.kind {
ConversionKind::Ptr(space) => {
match (from_parts.kind, to_parts.kind, cv.kind) {
(_, _, ConversionKind::Ptr(space)) => {
let dst_type = map.get_or_add(
builder,
SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()),
);
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
ConversionKind::Default => {
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
if from_parts.width == to_parts.width {
let dst_type = map.get_or_add(builder, SpirvType::from(cv.from));
if from_parts.scalar_kind != ScalarKind::Float
@ -1424,7 +1488,13 @@ fn emit_implicit_conversion(
}
}
}
ConversionKind::SignExtend => todo!(),
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(),
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
| (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default) => {
let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
}
_ => unreachable!(),
}
Ok(())
}
@ -1723,7 +1793,7 @@ enum Statement<I, P: ast::ArgParams> {
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
Call(ResolvedCall<P>),
Composite(CompositeAccess),
Composite(CompositeRead),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
Conversion(ImplicitConversion),
@ -1874,7 +1944,7 @@ trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<T::VecOperand>,
typ: ast::MovVectorType,
typ: (ast::MovVectorType, u8),
) -> U::VecOperand;
}
@ -1902,9 +1972,12 @@ where
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: ast::MovVectorType,
(scalar_type, vec_len): (ast::MovVectorType, u8),
) -> spirv::Word {
self(desc, Some(ast::Type::Scalar(t.into())))
self(
desc.new_op(desc.op),
Some(ast::Type::Vector(scalar_type.into(), vec_len)),
)
}
}
@ -1942,7 +2015,7 @@ where
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<(&str, u8)>,
_: ast::MovVectorType,
_: (ast::MovVectorType, u8),
) -> (spirv::Word, u8) {
(self(desc.op.0), desc.op.1)
}
@ -1970,7 +2043,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
visitor: &mut V,
) -> ast::Instruction<U> {
match self {
ast::Instruction::MovVector(t, a) => ast::Instruction::MovVector(t, a.map(visitor, t)),
ast::Instruction::MovVector(t, a) => {
ast::Instruction::MovVector(t, a.map(visitor, (t.typ, t.length)))
}
ast::Instruction::Abs(_, _) => todo!(),
// Call instruction is converted to a call statement early on
ast::Instruction::Call(_) => unreachable!(),
@ -2090,12 +2165,12 @@ where
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
t: ast::MovVectorType,
(scalar_type, vector_len): (ast::MovVectorType, u8),
) -> (spirv::Word, u8) {
(
self(
desc.new_op(desc.op.0),
Some(ast::Type::Vector(t.into(), desc.op.1)),
Some(ast::Type::Vector(scalar_type.into(), vector_len)),
),
desc.op.1,
)
@ -2195,27 +2270,11 @@ impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
type Arg2 = ast::Arg2<ExpandedArgParams>;
type Arg2St = ast::Arg2St<ExpandedArgParams>;
struct CompositeAccess {
pub typ: ast::MovVectorType,
pub dst: spirv::Word,
pub src: spirv::Word,
pub index: u32,
pub is_write: bool
}
struct CompositeWrite {
pub typ: ast::MovVectorType,
pub dst: spirv::Word,
pub src_composite: spirv::Word,
pub src_scalar: spirv::Word,
pub index: u32,
}
struct CompositeRead {
pub typ: ast::MovVectorType,
pub dst: spirv::Word,
pub src: spirv::Word,
pub index: u32,
pub src_composite: spirv::Word,
pub src_index: u32,
}
struct ConstantDefinition {
@ -2407,28 +2466,47 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
}
impl<T: ArgParamsEx> ast::Arg2Vec<T> {
fn dst(&self) -> &T::ID {
match self {
ast::Arg2Vec::Dst((d, _), _, _)
| ast::Arg2Vec::Src(d, _)
| ast::Arg2Vec::Both((d, _), _, _) => d,
}
}
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::MovVectorType,
(scalar_type, vec_len): (ast::MovVectorType, u8),
) -> ast::Arg2Vec<U> {
match self {
ast::Arg2Vec::Dst(dst, src) => ast::Arg2Vec::Dst(
visitor.src_vec_operand(
ArgumentDescriptor {
op: dst,
is_dst: true,
is_pointer: false,
},
t,
ast::Arg2Vec::Dst((dst, len), composite_src, scalar_src) => ast::Arg2Vec::Dst(
(
visitor.variable(
ArgumentDescriptor {
op: dst,
is_dst: true,
is_pointer: false,
},
Some(ast::Type::Scalar(scalar_type.into())),
),
len,
),
visitor.variable(
ArgumentDescriptor {
op: src,
op: composite_src,
is_dst: false,
is_pointer: false,
},
Some(ast::Type::Scalar(t.into())),
Some(ast::Type::Scalar(scalar_type.into())),
),
visitor.variable(
ArgumentDescriptor {
op: scalar_src,
is_dst: false,
is_pointer: false,
},
Some(ast::Type::Scalar(scalar_type.into())),
),
),
ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src(
@ -2438,7 +2516,7 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
is_dst: true,
is_pointer: false,
},
Some(ast::Type::Scalar(t.into())),
Some(ast::Type::Scalar(scalar_type.into())),
),
visitor.src_vec_operand(
ArgumentDescriptor {
@ -2446,17 +2524,28 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
is_dst: false,
is_pointer: false,
},
t,
(scalar_type, vec_len),
),
),
ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both(
visitor.src_vec_operand(
ast::Arg2Vec::Both((dst, len), composite_src, src) => ast::Arg2Vec::Both(
(
visitor.variable(
ArgumentDescriptor {
op: dst,
is_dst: true,
is_pointer: false,
},
Some(ast::Type::Scalar(scalar_type.into())),
),
len,
),
visitor.variable(
ArgumentDescriptor {
op: dst,
is_dst: true,
op: composite_src,
is_dst: false,
is_pointer: false,
},
t,
Some(ast::Type::Scalar(scalar_type.into())),
),
visitor.src_vec_operand(
ArgumentDescriptor {
@ -2464,31 +2553,13 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
is_dst: false,
is_pointer: false,
},
t,
(scalar_type, vec_len),
),
),
}
}
}
impl ast::Arg2Vec<ExpandedArgParams> {
fn dst(&self) -> spirv::Word {
match self {
ast::Arg2Vec::Dst(dst, _) | ast::Arg2Vec::Src(dst, _) | ast::Arg2Vec::Both(dst, _) => {
*dst
}
}
}
fn src(&self) -> spirv::Word {
match self {
ast::Arg2Vec::Dst(_, src) | ast::Arg2Vec::Src(_, src) | ast::Arg2Vec::Both(_, src) => {
*src
}
}
}
}
impl<T: ArgParamsEx> ast::Arg3<T> {
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,