diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 078cb31..da37ee3 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -317,7 +317,7 @@ pub struct PredAt { pub enum Instruction { Ld(LdData, Arg2

), Mov(MovType, Arg2

), - MovVector(MovVectorType, Arg2Vec

), + MovVector(MovVectorDetails, Arg2Vec

), Mul(MulDetails, Arg3

), Add(AddDetails, Arg3

), Setp(SetpData, Arg4

), @@ -333,6 +333,11 @@ pub enum Instruction { Abs(AbsDetails, Arg2

), } +#[derive(Copy, Clone)] +pub struct MovVectorDetails { + pub typ: MovVectorType, + pub length: u8, +} pub struct AbsDetails { pub flush_to_zero: bool, pub typ: ScalarType, @@ -377,10 +382,12 @@ pub struct Arg2St { pub src2: P::Operand, } +// We duplicate dst here because during further compilation +// composite dst and composite src will receive different ids pub enum Arg2Vec { - Dst(P::VecOperand, P::ID), + Dst((P::ID, u8), P::ID, P::ID), Src(P::ID, P::VecOperand), - Both(P::VecOperand, P::VecOperand), + Both((P::ID, u8), P::ID, P::VecOperand), } pub struct Arg3 { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 6e5f5e3..1ffbca2 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -499,7 +499,7 @@ InstMov: ast::Instruction> = { ast::Instruction::Mov(t, a) }, "mov" => { - ast::Instruction::MovVector(t, a) + ast::Instruction::MovVector(ast::MovVectorDetails{typ: t, length: 0}, a) } }; @@ -1030,9 +1030,9 @@ Arg2: ast::Arg2> = { }; Arg2Vec: ast::Arg2Vec> = { - "," => ast::Arg2Vec::Dst(dst, src), + "," => ast::Arg2Vec::Dst(dst, dst.0, src), "," => ast::Arg2Vec::Src(dst, src), - "," => ast::Arg2Vec::Both(dst, src), + "," => ast::Arg2Vec::Both(dst, dst.0, src), }; VectorOperand: (&'input str, u8) = { diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt index 25dd80e..ff0ee97 100644 --- a/ptx/src/test/spirv_run/vector.spvtxt +++ b/ptx/src/test/spirv_run/vector.spvtxt @@ -4,20 +4,20 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 - %58 = OpExtInstImport "OpenCL.std" + %60 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %31 "vector" %void = OpTypeVoid %uint = OpTypeInt 32 0 %v2uint = OpTypeVector %uint 2 - %62 = OpTypeFunction %v2uint %v2uint + %64 = OpTypeFunction %v2uint %v2uint %_ptr_Function_v2uint = OpTypePointer Function %v2uint %_ptr_Function_uint = OpTypePointer Function %uint %ulong = OpTypeInt 64 0 - %66 = OpTypeFunction %void %ulong %ulong + %68 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_v2uint = OpTypePointer Generic %v2uint - %1 = OpFunction %v2uint None %62 + %1 = OpFunction %v2uint None %64 %7 = OpFunctionParameter %v2uint %30 = OpLabel %3 = OpVariable %_ptr_Function_v2uint Function @@ -27,40 +27,40 @@ %6 = OpVariable %_ptr_Function_uint Function OpStore %3 %7 %9 = OpLoad %v2uint %3 - %24 = OpCompositeExtract %uint %9 0 - %8 = OpCopyObject %uint %24 + %27 = OpCompositeExtract %uint %9 0 + %8 = OpCopyObject %uint %27 OpStore %5 %8 %11 = OpLoad %v2uint %3 - %25 = OpCompositeExtract %uint %11 1 - %10 = OpCopyObject %uint %25 + %28 = OpCompositeExtract %uint %11 1 + %10 = OpCopyObject %uint %28 OpStore %6 %10 %13 = OpLoad %uint %5 %14 = OpLoad %uint %6 %12 = OpIAdd %uint %13 %14 OpStore %6 %12 - %16 = OpLoad %uint %6 - %26 = OpCopyObject %uint %16 - %15 = OpCompositeInsert %uint %26 %15 0 + %16 = OpLoad %v2uint %4 + %17 = OpLoad %uint %6 + %15 = OpCompositeInsert %v2uint %17 %16 0 OpStore %4 %15 - %18 = OpLoad %uint %6 - %27 = OpCopyObject %uint %18 - %17 = OpCompositeInsert %uint %27 %17 1 - OpStore %4 %17 - %20 = OpLoad %v2uint %4 - %29 = OpCompositeExtract %uint %20 1 - %28 = OpCopyObject %uint %29 - %19 = OpCompositeInsert %uint %28 %19 0 - OpStore %4 %19 + %19 = OpLoad %v2uint %4 + %20 = OpLoad %uint %6 + %18 = OpCompositeInsert %v2uint %20 %19 1 + OpStore %4 %18 %22 = OpLoad %v2uint %4 - %21 = OpCopyObject %v2uint %22 - OpStore %2 %21 - %23 = OpLoad %v2uint %2 - OpReturnValue %23 + %23 = OpLoad %v2uint %4 + %29 = OpCompositeExtract %uint %23 1 + %21 = OpCompositeInsert %v2uint %29 %22 0 + OpStore %4 %21 + %25 = OpLoad %v2uint %4 + %24 = OpCopyObject %v2uint %25 + OpStore %2 %24 + %26 = OpLoad %v2uint %2 + OpReturnValue %26 OpFunctionEnd - %31 = OpFunction %void None %66 + %31 = OpFunction %void None %68 %40 = OpFunctionParameter %ulong %41 = OpFunctionParameter %ulong - %56 = OpLabel + %58 = OpLabel %32 = OpVariable %_ptr_Function_ulong Function %33 = OpVariable %_ptr_Function_ulong Function %34 = OpVariable %_ptr_Function_ulong Function @@ -85,11 +85,13 @@ %48 = OpFunctionCall %v2uint %1 %49 OpStore %36 %48 %51 = OpLoad %v2uint %36 - %50 = OpCopyObject %ulong %51 + %55 = OpBitcast %ulong %51 + %56 = OpCopyObject %ulong %55 + %50 = OpCopyObject %ulong %56 OpStore %39 %50 %52 = OpLoad %ulong %35 %53 = OpLoad %v2uint %36 - %55 = OpConvertUToPtr %_ptr_Generic_v2uint %52 - OpStore %55 %53 + %57 = OpConvertUToPtr %_ptr_Generic_v2uint %52 + OpStore %57 %53 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 57d3485..3ac5222 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -323,7 +323,8 @@ fn to_ssa<'input, 'b>( let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body); let mut numeric_id_defs = id_defs.finish(); let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs); - let unadorned_statements = resolve_fn_calls(&fn_defs, unadorned_statements); + let unadorned_statements = + add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs); let (f_args, ssa_statements) = insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs); @@ -345,9 +346,10 @@ fn normalize_variable_decls(mut func: Vec) -> Vec, + fn_defs: &GlobalFnDeclResolver, + id_defs: &NumericIdResolver, ) -> Vec { func.into_iter() .map(|s| { @@ -365,6 +367,17 @@ fn resolve_fn_calls( }; Statement::Call(resolved_call) } + Statement::Instruction(ast::Instruction::MovVector(dets, args)) => { + // TODO fail on type mismatch + let new_dets = match id_defs.get_type(*args.dst()) { + Some(ast::Type::Vector(_, len)) => ast::MovVectorDetails { + length: len, + ..dets + }, + _ => dets, + }; + Statement::Instruction(ast::Instruction::MovVector(new_dets, args)) + } s => s, } }) @@ -685,7 +698,7 @@ impl<'a, 'b> ArgumentMapVisitor fn variable( &mut self, desc: ArgumentDescriptor, - typ: Option, + _: Option, ) -> spirv::Word { desc.op } @@ -757,34 +770,18 @@ impl<'a, 'b> ArgumentMapVisitor fn src_vec_operand( &mut self, desc: ArgumentDescriptor<(spirv::Word, u8)>, - typ: ast::MovVectorType, + (scalar_type, vec_len): (ast::MovVectorType, u8), ) -> spirv::Word { - let (vector_id, index) = desc.op; - let new_id = self.id_def.new_id(Some(ast::Type::Scalar(typ.into()))); - let composite = if desc.is_dst { - Statement::Composite(CompositeAccess { - typ: typ, - dst: new_id, - src: vector_id, - index: index as u32, - is_write: true - }) - } else { - Statement::Composite(CompositeAccess { - typ: typ, - dst: new_id, - src: vector_id, - index: index as u32, - is_write: false - }) - }; - if desc.is_dst { - self.post_stmts.push(composite); - new_id - } else { - self.func.push(composite); - new_id - } + let new_id = self + .id_def + .new_id(Some(ast::Type::Vector(scalar_type.into(), vec_len))); + self.func.push(Statement::Composite(CompositeRead { + typ: scalar_type, + dst: new_id, + src_composite: desc.op.0, + src_index: desc.op.1 as u32, + })); + new_id } } @@ -864,6 +861,55 @@ fn insert_implicit_conversions( |arg| ast::Instruction::St(st, arg), ) } + ast::Instruction::Mov(d, mut arg) => { + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov-2 + // TODO: handle the case of mixed vector/scalar implicit conversions + let inst_typ_is_bit = match d { + ast::MovType::Scalar(t) => { + ast::ScalarType::from(t).kind() == ScalarKind::Bit + } + ast::MovType::Vector(_, _) => false, + }; + let mut did_vector_implicit = false; + let mut post_conv = None; + if inst_typ_is_bit { + let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!()); + if let ast::Type::Vector(_, _) = src_type { + arg.src = insert_conversion_src( + &mut result, + id_def, + arg.src, + src_type, + d.into(), + ConversionKind::Default, + ); + did_vector_implicit = true; + } + let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!()); + if let ast::Type::Vector(_, _) = src_type { + post_conv = Some(get_conversion_dst( + id_def, + &mut arg.dst, + d.into(), + dst_type, + ConversionKind::Default, + )); + did_vector_implicit = true; + } + } + if did_vector_implicit { + result.push(Statement::Instruction(ast::Instruction::Mov(d, arg))); + } else { + insert_implicit_bitcasts( + &mut result, + id_def, + ast::Instruction::Mov(d, arg), + ); + } + if let Some(post_conv) = post_conv { + result.push(post_conv); + } + } inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst), }, s @ Statement::Composite(_) @@ -1087,10 +1133,31 @@ fn emit_function_body_ops( builder.copy_object(result_type, Some(arg.dst), arg.src)?; } ast::Instruction::SetpBool(_, _) => todo!(), - ast::Instruction::MovVector(t, arg) => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - builder.copy_object(result_type, Some(arg.dst()), arg.src())?; - } + ast::Instruction::MovVector(typ, arg) => match arg { + ast::Arg2Vec::Dst((dst, dst_index), composite_src, src) + | ast::Arg2Vec::Both((dst, dst_index), composite_src, src) => { + let result_type = map.get_or_add( + builder, + SpirvType::Vector( + SpirvScalarKey::from(ast::ScalarType::from(typ.typ)), + typ.length, + ), + ); + let result_id = Some(*dst); + builder.composite_insert( + result_type, + result_id, + *src, + *composite_src, + [*dst_index as u32], + )?; + } + ast::Arg2Vec::Src(dst, src) => { + let result_type = + map.get_or_add_scalar(builder, ast::ScalarType::from(typ.typ)); + builder.copy_object(result_type, Some(*dst), *src)?; + } + }, }, Statement::LoadVar(arg, typ) => { let type_id = map.get_or_add(builder, SpirvType::from(*typ)); @@ -1105,15 +1172,12 @@ fn emit_function_body_ops( Statement::Composite(c) => { let result_type = map.get_or_add_scalar(builder, c.typ.into()); let result_id = Some(c.dst); - let indexes = [c.index]; - if c.is_write { - let object = c.src; - let composite = c.dst; - builder.composite_insert(result_type, result_id, object, composite, indexes)?; - } else { - let composite = c.src; - builder.composite_extract(result_type, result_id, composite, indexes)?; - } + builder.composite_extract( + result_type, + result_id, + c.src_composite, + [c.src_index], + )?; } } } @@ -1369,15 +1433,15 @@ fn emit_implicit_conversion( ) -> Result<(), dr::Error> { let from_parts = cv.from.to_parts(); let to_parts = cv.to.to_parts(); - match cv.kind { - ConversionKind::Ptr(space) => { + match (from_parts.kind, to_parts.kind, cv.kind) { + (_, _, ConversionKind::Ptr(space)) => { let dst_type = map.get_or_add( builder, SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()), ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } - ConversionKind::Default => { + (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { if from_parts.width == to_parts.width { let dst_type = map.get_or_add(builder, SpirvType::from(cv.from)); if from_parts.scalar_kind != ScalarKind::Float @@ -1424,7 +1488,13 @@ fn emit_implicit_conversion( } } } - ConversionKind::SignExtend => todo!(), + (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(), + (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) + | (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default) => { + let into_type = map.get_or_add(builder, SpirvType::from(cv.to)); + builder.bitcast(into_type, Some(cv.dst), cv.src)?; + } + _ => unreachable!(), } Ok(()) } @@ -1723,7 +1793,7 @@ enum Statement { LoadVar(ast::Arg2, ast::Type), StoreVar(ast::Arg2St, ast::Type), Call(ResolvedCall

), - Composite(CompositeAccess), + Composite(CompositeRead), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), Conversion(ImplicitConversion), @@ -1874,7 +1944,7 @@ trait ArgumentMapVisitor { fn src_vec_operand( &mut self, desc: ArgumentDescriptor, - typ: ast::MovVectorType, + typ: (ast::MovVectorType, u8), ) -> U::VecOperand; } @@ -1902,9 +1972,12 @@ where fn src_vec_operand( &mut self, desc: ArgumentDescriptor, - t: ast::MovVectorType, + (scalar_type, vec_len): (ast::MovVectorType, u8), ) -> spirv::Word { - self(desc, Some(ast::Type::Scalar(t.into()))) + self( + desc.new_op(desc.op), + Some(ast::Type::Vector(scalar_type.into(), vec_len)), + ) } } @@ -1942,7 +2015,7 @@ where fn src_vec_operand( &mut self, desc: ArgumentDescriptor<(&str, u8)>, - _: ast::MovVectorType, + _: (ast::MovVectorType, u8), ) -> (spirv::Word, u8) { (self(desc.op.0), desc.op.1) } @@ -1970,7 +2043,9 @@ impl ast::Instruction { visitor: &mut V, ) -> ast::Instruction { match self { - ast::Instruction::MovVector(t, a) => ast::Instruction::MovVector(t, a.map(visitor, t)), + ast::Instruction::MovVector(t, a) => { + ast::Instruction::MovVector(t, a.map(visitor, (t.typ, t.length))) + } ast::Instruction::Abs(_, _) => todo!(), // Call instruction is converted to a call statement early on ast::Instruction::Call(_) => unreachable!(), @@ -2090,12 +2165,12 @@ where fn src_vec_operand( &mut self, desc: ArgumentDescriptor<(spirv::Word, u8)>, - t: ast::MovVectorType, + (scalar_type, vector_len): (ast::MovVectorType, u8), ) -> (spirv::Word, u8) { ( self( desc.new_op(desc.op.0), - Some(ast::Type::Vector(t.into(), desc.op.1)), + Some(ast::Type::Vector(scalar_type.into(), vector_len)), ), desc.op.1, ) @@ -2195,27 +2270,11 @@ impl VisitVariableExpanded for ast::Instruction { type Arg2 = ast::Arg2; type Arg2St = ast::Arg2St; -struct CompositeAccess { - pub typ: ast::MovVectorType, - pub dst: spirv::Word, - pub src: spirv::Word, - pub index: u32, - pub is_write: bool -} - -struct CompositeWrite { - pub typ: ast::MovVectorType, - pub dst: spirv::Word, - pub src_composite: spirv::Word, - pub src_scalar: spirv::Word, - pub index: u32, -} - struct CompositeRead { pub typ: ast::MovVectorType, pub dst: spirv::Word, - pub src: spirv::Word, - pub index: u32, + pub src_composite: spirv::Word, + pub src_index: u32, } struct ConstantDefinition { @@ -2407,28 +2466,47 @@ impl ast::Arg2St { } impl ast::Arg2Vec { + fn dst(&self) -> &T::ID { + match self { + ast::Arg2Vec::Dst((d, _), _, _) + | ast::Arg2Vec::Src(d, _) + | ast::Arg2Vec::Both((d, _), _, _) => d, + } + } + fn map>( self, visitor: &mut V, - t: ast::MovVectorType, + (scalar_type, vec_len): (ast::MovVectorType, u8), ) -> ast::Arg2Vec { match self { - ast::Arg2Vec::Dst(dst, src) => ast::Arg2Vec::Dst( - visitor.src_vec_operand( - ArgumentDescriptor { - op: dst, - is_dst: true, - is_pointer: false, - }, - t, + ast::Arg2Vec::Dst((dst, len), composite_src, scalar_src) => ast::Arg2Vec::Dst( + ( + visitor.variable( + ArgumentDescriptor { + op: dst, + is_dst: true, + is_pointer: false, + }, + Some(ast::Type::Scalar(scalar_type.into())), + ), + len, ), visitor.variable( ArgumentDescriptor { - op: src, + op: composite_src, is_dst: false, is_pointer: false, }, - Some(ast::Type::Scalar(t.into())), + Some(ast::Type::Scalar(scalar_type.into())), + ), + visitor.variable( + ArgumentDescriptor { + op: scalar_src, + is_dst: false, + is_pointer: false, + }, + Some(ast::Type::Scalar(scalar_type.into())), ), ), ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src( @@ -2438,7 +2516,7 @@ impl ast::Arg2Vec { is_dst: true, is_pointer: false, }, - Some(ast::Type::Scalar(t.into())), + Some(ast::Type::Scalar(scalar_type.into())), ), visitor.src_vec_operand( ArgumentDescriptor { @@ -2446,17 +2524,28 @@ impl ast::Arg2Vec { is_dst: false, is_pointer: false, }, - t, + (scalar_type, vec_len), ), ), - ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both( - visitor.src_vec_operand( + ast::Arg2Vec::Both((dst, len), composite_src, src) => ast::Arg2Vec::Both( + ( + visitor.variable( + ArgumentDescriptor { + op: dst, + is_dst: true, + is_pointer: false, + }, + Some(ast::Type::Scalar(scalar_type.into())), + ), + len, + ), + visitor.variable( ArgumentDescriptor { - op: dst, - is_dst: true, + op: composite_src, + is_dst: false, is_pointer: false, }, - t, + Some(ast::Type::Scalar(scalar_type.into())), ), visitor.src_vec_operand( ArgumentDescriptor { @@ -2464,31 +2553,13 @@ impl ast::Arg2Vec { is_dst: false, is_pointer: false, }, - t, + (scalar_type, vec_len), ), ), } } } -impl ast::Arg2Vec { - fn dst(&self) -> spirv::Word { - match self { - ast::Arg2Vec::Dst(dst, _) | ast::Arg2Vec::Src(dst, _) | ast::Arg2Vec::Both(dst, _) => { - *dst - } - } - } - - fn src(&self) -> spirv::Word { - match self { - ast::Arg2Vec::Dst(_, src) | ast::Arg2Vec::Src(_, src) | ast::Arg2Vec::Both(_, src) => { - *src - } - } - } -} - impl ast::Arg3 { fn map_non_shift>( self,