diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index d858d06..f4502af 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -506,11 +506,12 @@ pub enum Instruction { Call(CallInst

), Abs(AbsDetails, Arg2

), Mad(MulDetails, Arg4

), - Or(OrType, Arg3

), + Or(OrAndType, Arg3

), Sub(ArithDetails, Arg3

), Min(MinMaxDetails, Arg3

), Max(MinMaxDetails, Arg3

), Rcp(RcpDetails, Arg2

), + And(OrAndType, Arg3

), } #[derive(Copy, Clone)] @@ -974,7 +975,7 @@ pub struct RetData { pub uniform: bool, } -sub_enum!(OrType { +sub_enum!(OrAndType { Pred, B16, B32, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index d445baa..7414443 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -125,6 +125,7 @@ match { // IF YOU ARE ADDING A NEW TOKEN HERE ALSO ADD IT BELOW TO ExtendedID "abs", "add", + "and", "bra", "call", "cvt", @@ -158,6 +159,7 @@ match { ExtendedID : &'input str = { "abs", "add", + "and", "bra", "call", "cvt", @@ -608,6 +610,7 @@ Instruction: ast::Instruction> = { InstAbs, InstMad, InstOr, + InstAnd, InstSub, InstMin, InstMax, @@ -1190,16 +1193,21 @@ SignedIntType: ast::ScalarType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-or InstOr: ast::Instruction> = { - "or" => ast::Instruction::Or(d, a), + "or" => ast::Instruction::Or(d, a), }; -OrType: ast::OrType = { - ".pred" => ast::OrType::Pred, - ".b16" => ast::OrType::B16, - ".b32" => ast::OrType::B32, - ".b64" => ast::OrType::B64, +OrAndType: ast::OrAndType = { + ".pred" => ast::OrAndType::Pred, + ".b16" => ast::OrAndType::B16, + ".b32" => ast::OrAndType::B32, + ".b64" => ast::OrAndType::B64, } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-and +InstAnd: ast::Instruction> = { + "and" => ast::Instruction::And(d, a), +}; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp InstRcp: ast::Instruction> = { "rcp" ".f32" => { diff --git a/ptx/src/test/spirv_run/and.ptx b/ptx/src/test/spirv_run/and.ptx new file mode 100644 index 0000000..88292a7 --- /dev/null +++ b/ptx/src/test/spirv_run/and.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry and( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 temp1; + .reg .u32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 temp1, [in_addr]; + ld.u32 temp2, [in_addr+4]; + and.b32 temp1, temp1, temp2; + st.u32 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/and.spvtxt b/ptx/src/test/spirv_run/and.spvtxt new file mode 100644 index 0000000..9b72477 --- /dev/null +++ b/ptx/src/test/spirv_run/and.spvtxt @@ -0,0 +1,66 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 41 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +OpCapability FunctionFloatControlINTEL +OpExtension "SPV_INTEL_float_controls2" +%33 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "and" +%34 = OpTypeVoid +%35 = OpTypeInt 64 0 +%36 = OpTypeFunction %34 %35 %35 +%37 = OpTypePointer Function %35 +%38 = OpTypeInt 32 0 +%39 = OpTypePointer Function %38 +%40 = OpTypePointer Generic %38 +%23 = OpConstant %35 4 +%1 = OpFunction %34 None %36 +%8 = OpFunctionParameter %35 +%9 = OpFunctionParameter %35 +%31 = OpLabel +%2 = OpVariable %37 Function +%3 = OpVariable %37 Function +%4 = OpVariable %37 Function +%5 = OpVariable %37 Function +%6 = OpVariable %39 Function +%7 = OpVariable %39 Function +OpStore %2 %8 +OpStore %3 %9 +%11 = OpLoad %35 %2 +%10 = OpCopyObject %35 %11 +OpStore %4 %10 +%13 = OpLoad %35 %3 +%12 = OpCopyObject %35 %13 +OpStore %5 %12 +%15 = OpLoad %35 %4 +%25 = OpConvertUToPtr %40 %15 +%14 = OpLoad %38 %25 +OpStore %6 %14 +%17 = OpLoad %35 %4 +%24 = OpIAdd %35 %17 %23 +%26 = OpConvertUToPtr %40 %24 +%16 = OpLoad %38 %26 +OpStore %7 %16 +%19 = OpLoad %38 %6 +%20 = OpLoad %38 %7 +%28 = OpCopyObject %38 %19 +%29 = OpCopyObject %38 %20 +%27 = OpBitwiseAnd %38 %28 %29 +%18 = OpCopyObject %38 %27 +OpStore %6 %18 +%21 = OpLoad %35 %5 +%22 = OpLoad %38 %6 +%30 = OpConvertUToPtr %40 %21 +OpStore %30 %22 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 40acd46..dfdec72 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -89,6 +89,7 @@ test_ptx!(rcp, [2f32], [0.5f32]); test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0b1_00000000_01000000000000000000000u32]); test_ptx!(constant_f32, [10f32], [5f32]); test_ptx!(constant_negative, [-101i32], [101i32]); +test_ptx!(and, [6u32, 3u32], [2u32]); struct DisplayError { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c0ff8f0..c699cc4 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1263,6 +1263,9 @@ fn convert_to_typed_statements( ast::Instruction::Rcp(d, a) => { result.push(Statement::Instruction(ast::Instruction::Rcp(d, a.cast()))) } + ast::Instruction::And(d, a) => { + result.push(Statement::Instruction(ast::Instruction::And(d, a.cast()))) + } }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), @@ -2325,7 +2328,7 @@ fn emit_function_body_ops( }, ast::Instruction::Or(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - if *t == ast::OrType::Pred { + if *t == ast::OrAndType::Pred { builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?; } else { builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?; @@ -2351,6 +2354,14 @@ fn emit_function_body_ops( ast::Instruction::Rcp(d, a) => { emit_rcp(builder, map, d, a)?; } + ast::Instruction::And(t, a) => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); + if *t == ast::OrAndType::Pred { + builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?; + } else { + builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?; + } + } }, Statement::LoadVar(arg, typ) => { let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); @@ -4041,6 +4052,10 @@ impl ast::Instruction { }); ast::Instruction::Rcp(d, a.map(visitor, &typ)?) } + ast::Instruction::And(t, a) => ast::Instruction::And( + t, + a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?, + ), }) } } @@ -4285,6 +4300,7 @@ impl ast::Instruction { | ast::Instruction::Min(_, _) | ast::Instruction::Max(_, _) | ast::Instruction::Rcp(_, _) + | ast::Instruction::And(_, _) | ast::Instruction::Mad(_, _) => None, } } @@ -4303,6 +4319,7 @@ impl ast::Instruction { ast::Instruction::Ret(_) => None, ast::Instruction::Call(_) => None, ast::Instruction::Or(_, _) => None, + ast::Instruction::And(_, _) => None, ast::Instruction::Cvta(_, _) => None, ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None, ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None,