Implement div, sqrt, rsqrt and more of setp

This commit is contained in:
Andrzej Janik 2020-11-01 14:34:03 +01:00
parent a82eb20817
commit b7d61baf37
12 changed files with 645 additions and 44 deletions

View file

@ -539,6 +539,9 @@ pub enum Instruction<P: ArgParams> {
Bar(BarDetails, Arg1Bar<P>),
Atom(AtomDetails, Arg3<P>),
AtomCas(AtomCasDetails, Arg4<P>),
Div(DivDetails, Arg3<P>),
Sqrt(SqrtDetails, Arg2<P>),
Rsqrt(RsqrtDetails, Arg2<P>),
}
#[derive(Copy, Clone)]
@ -1132,7 +1135,28 @@ pub struct AtomCasDetails {
pub semantics: AtomSemantics,
pub scope: MemScope,
pub space: AtomSpace,
pub typ: BitType
pub typ: BitType,
}
#[derive(Copy, Clone)]
pub enum DivDetails {
Unsigned(UIntType),
Signed(SIntType),
Float(DivFloatDetails),
}
#[derive(Copy, Clone)]
pub struct DivFloatDetails {
pub typ: FloatType,
pub flush_to_zero: Option<bool>,
pub kind: DivFloatKind,
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum DivFloatKind {
Approx,
Full,
Rounding(RoundingMode),
}
pub enum NumsOrArrays<'a> {
@ -1140,6 +1164,25 @@ pub enum NumsOrArrays<'a> {
Arrays(Vec<NumsOrArrays<'a>>),
}
#[derive(Copy, Clone)]
pub struct SqrtDetails {
pub typ: FloatType,
pub flush_to_zero: Option<bool>,
pub kind: SqrtKind,
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum SqrtKind {
Approx,
Rounding(RoundingMode),
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct RsqrtDetails {
pub typ: FloatType,
pub flush_to_zero: bool,
}
impl<'a> NumsOrArrays<'a> {
pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
self.normalize_dimensions(dimensions)?;

View file

@ -66,6 +66,7 @@ match {
".f64",
".file",
".ftz",
".full",
".func",
".ge",
".geu",
@ -94,6 +95,7 @@ match {
".num",
".or",
".param",
".pragma",
".pred",
".reg",
".relaxed",
@ -145,6 +147,7 @@ match {
"cvt",
"cvta",
"debug",
"div",
"fma",
"ld",
"mad",
@ -157,11 +160,13 @@ match {
"or",
"rcp",
"ret",
"rsqrt",
"selp",
"setp",
"shl",
"shr",
r"sm_[0-9]+" => ShaderModel,
"sqrt",
"st",
"sub",
"texmode_independent",
@ -184,6 +189,7 @@ ExtendedID : &'input str = {
"cvt",
"cvta",
"debug",
"div",
"fma",
"ld",
"mad",
@ -196,11 +202,13 @@ ExtendedID : &'input str = {
"or",
"rcp",
"ret",
"rsqrt",
"selp",
"setp",
"shl",
"shr",
ShaderModel,
"sqrt",
"st",
"sub",
"texmode_independent",
@ -415,9 +423,14 @@ Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
DebugDirective => None,
<v:MultiVariable> ";" => Some(ast::Statement::Variable(v)),
<p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i)),
PragmaStatement => None,
"{" <s:Statement*> "}" => Some(ast::Statement::Block(without_none(s)))
};
PragmaStatement: () = {
".pragma" String ";"
}
DebugDirective: () = {
DebugLocation
};
@ -667,7 +680,10 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstSelp,
InstBar,
InstAtom,
InstAtomCas
InstAtomCas,
InstDiv,
InstSqrt,
InstRsqrt,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@ -1485,6 +1501,82 @@ AtomSIntType: ast::SIntType = {
".s64" => ast::SIntType::S64,
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-div
InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = {
"div" <t:UIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Unsigned(t), a),
"div" <t:SIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Signed(t), a),
"div" <kind:DivFloatKind> <ftz:".ftz"?> ".f32" <a:Arg3> => {
let inner = ast::DivFloatDetails {
typ: ast::FloatType::F32,
flush_to_zero: Some(ftz.is_some()),
kind
};
ast::Instruction::Div(ast::DivDetails::Float(inner), a)
},
"div" <rnd:RoundingModeFloat> ".f64" <a:Arg3> => {
let inner = ast::DivFloatDetails {
typ: ast::FloatType::F64,
flush_to_zero: None,
kind: ast::DivFloatKind::Rounding(rnd)
};
ast::Instruction::Div(ast::DivDetails::Float(inner), a)
},
}
DivFloatKind: ast::DivFloatKind = {
".approx" => ast::DivFloatKind::Approx,
".full" => ast::DivFloatKind::Full,
<rnd:RoundingModeFloat> => ast::DivFloatKind::Rounding(rnd),
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sqrt
InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"sqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
let details = ast::SqrtDetails {
typ: ast::FloatType::F32,
flush_to_zero: Some(ftz.is_some()),
kind: ast::SqrtKind::Approx,
};
ast::Instruction::Sqrt(details, a)
},
"sqrt" <rnd:RoundingModeFloat> <ftz:".ftz"?> ".f32" <a:Arg2> => {
let details = ast::SqrtDetails {
typ: ast::FloatType::F32,
flush_to_zero: Some(ftz.is_some()),
kind: ast::SqrtKind::Rounding(rnd),
};
ast::Instruction::Sqrt(details, a)
},
"sqrt" <rnd:RoundingModeFloat> ".f64" <a:Arg2> => {
let details = ast::SqrtDetails {
typ: ast::FloatType::F64,
flush_to_zero: None,
kind: ast::SqrtKind::Rounding(rnd),
};
ast::Instruction::Sqrt(details, a)
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64
InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"rsqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
let details = ast::RsqrtDetails {
typ: ast::FloatType::F32,
flush_to_zero: ftz.is_some(),
};
ast::Instruction::Rsqrt(details, a)
},
"rsqrt" ".approx" <ftz:".ftz"?> ".f64" <a:Arg2> => {
let details = ast::RsqrtDetails {
typ: ast::FloatType::F64,
flush_to_zero: ftz.is_some(),
};
ast::Instruction::Rsqrt(details, a)
},
}
ArithDetails: ast::ArithDetails = {
<t:UIntType> => ast::ArithDetails::Unsigned(t),
<t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt {

View file

@ -0,0 +1,23 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry div_approx(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 temp1;
.reg .f32 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.f32 temp1, [in_addr];
ld.f32 temp2, [in_addr+4];
div.approx.f32 temp1, temp1, temp2;
st.f32 [out_addr], temp1;
ret;
}

View file

@ -0,0 +1,65 @@
; SPIR-V
; Version: 1.3
; Generator: rspirv
; Bound: 38
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
; OpCapability FunctionFloatControlINTEL
; OpExtension "SPV_INTEL_float_controls2"
%30 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "div_approx"
OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
OpDecorate %18 FPFastMathMode AllowRecip
%31 = OpTypeVoid
%32 = OpTypeInt 64 0
%33 = OpTypeFunction %31 %32 %32
%34 = OpTypePointer Function %32
%35 = OpTypeFloat 32
%36 = OpTypePointer Function %35
%37 = OpTypePointer Generic %35
%23 = OpConstant %32 4
%1 = OpFunction %31 None %33
%8 = OpFunctionParameter %32
%9 = OpFunctionParameter %32
%28 = OpLabel
%2 = OpVariable %34 Function
%3 = OpVariable %34 Function
%4 = OpVariable %34 Function
%5 = OpVariable %34 Function
%6 = OpVariable %36 Function
%7 = OpVariable %36 Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %32 %2
%10 = OpCopyObject %32 %11
OpStore %4 %10
%13 = OpLoad %32 %3
%12 = OpCopyObject %32 %13
OpStore %5 %12
%15 = OpLoad %32 %4
%25 = OpConvertUToPtr %37 %15
%14 = OpLoad %35 %25
OpStore %6 %14
%17 = OpLoad %32 %4
%24 = OpIAdd %32 %17 %23
%26 = OpConvertUToPtr %37 %24
%16 = OpLoad %35 %26
OpStore %7 %16
%19 = OpLoad %35 %6
%20 = OpLoad %35 %7
%18 = OpFDiv %35 %19 %20
OpStore %6 %18
%21 = OpLoad %32 %5
%22 = OpLoad %35 %6
%27 = OpConvertUToPtr %37 %21
OpStore %27 %22
OpReturn
OpFunctionEnd

View file

@ -97,9 +97,13 @@ test_ptx!(and, [6u32, 3u32], [2u32]);
test_ptx!(selp, [100u16, 200u16], [200u16]);
test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
test_ptx!(shared_variable, [513u64], [513u64]);
test_ptx!(shared_ptr_32, [513u64], [513u64]);
test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]);
test_ptx!(atom_inc, [100u32], [100u32, 101u32, 0u32]);
test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]);
test_ptx!(div_approx, [1f32, 2f32], [0.5f32]);
test_ptx!(sqrt, [0.25f32], [0.5f32]);
test_ptx!(rsqrt, [0.25f64], [2f64]);
struct DisplayError<T: Debug> {
err: T,

View file

@ -0,0 +1,21 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry rsqrt(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f64 temp1;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.f64 temp1, [in_addr];
rsqrt.approx.f64 temp1, temp1;
st.f64 [out_addr], temp1;
ret;
}

View file

@ -0,0 +1,56 @@
; SPIR-V
; Version: 1.3
; Generator: rspirv
; Bound: 31
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
; OpCapability FunctionFloatControlINTEL
; OpExtension "SPV_INTEL_float_controls2"
%23 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "rsqrt"
OpDecorate %1 FunctionDenormModeINTEL 64 Preserve
%24 = OpTypeVoid
%25 = OpTypeInt 64 0
%26 = OpTypeFunction %24 %25 %25
%27 = OpTypePointer Function %25
%28 = OpTypeFloat 64
%29 = OpTypePointer Function %28
%30 = OpTypePointer Generic %28
%1 = OpFunction %24 None %26
%7 = OpFunctionParameter %25
%8 = OpFunctionParameter %25
%21 = OpLabel
%2 = OpVariable %27 Function
%3 = OpVariable %27 Function
%4 = OpVariable %27 Function
%5 = OpVariable %27 Function
%6 = OpVariable %29 Function
OpStore %2 %7
OpStore %3 %8
%10 = OpLoad %25 %2
%9 = OpCopyObject %25 %10
OpStore %4 %9
%12 = OpLoad %25 %3
%11 = OpCopyObject %25 %12
OpStore %5 %11
%14 = OpLoad %25 %4
%19 = OpConvertUToPtr %30 %14
%13 = OpLoad %28 %19
OpStore %6 %13
%16 = OpLoad %28 %6
%15 = OpExtInst %28 %23 native_rsqrt %16
OpStore %6 %15
%17 = OpLoad %25 %5
%18 = OpLoad %28 %6
%20 = OpConvertUToPtr %30 %17
OpStore %20 %18
OpReturn
OpFunctionEnd

View file

@ -0,0 +1,29 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry shared_ptr_32(
.param .u64 input,
.param .u64 output
)
{
.shared .align 4 .b8 shared_mem1[128];
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u32 shared_addr;
.reg .u64 temp1;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
mov.u32 shared_addr, shared_mem1;
ld.global.u64 temp1, [in_addr];
st.shared.u64 [shared_addr], temp1;
ld.shared.u64 temp2, [shared_addr+0];
st.global.u64 [out_addr], temp2;
ret;
}

View file

@ -0,0 +1,74 @@
; SPIR-V
; Version: 1.3
; Generator: rspirv
; Bound: 47
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
; OpCapability FunctionFloatControlINTEL
; OpExtension "SPV_INTEL_float_controls2"
%34 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "shared_ptr_32" %4
OpDecorate %4 Alignment 4
%35 = OpTypeVoid
%36 = OpTypeInt 32 0
%37 = OpTypeInt 8 0
%38 = OpConstant %36 128
%39 = OpTypeArray %37 %38
%40 = OpTypePointer Workgroup %39
%4 = OpVariable %40 Workgroup
%41 = OpTypeInt 64 0
%42 = OpTypeFunction %35 %41 %41
%43 = OpTypePointer Function %41
%44 = OpTypePointer Function %36
%45 = OpTypePointer CrossWorkgroup %41
%46 = OpTypePointer Workgroup %41
%25 = OpConstant %36 0
%1 = OpFunction %35 None %42
%10 = OpFunctionParameter %41
%11 = OpFunctionParameter %41
%32 = OpLabel
%2 = OpVariable %43 Function
%3 = OpVariable %43 Function
%5 = OpVariable %43 Function
%6 = OpVariable %43 Function
%7 = OpVariable %44 Function
%8 = OpVariable %43 Function
%9 = OpVariable %43 Function
OpStore %2 %10
OpStore %3 %11
%13 = OpLoad %41 %2
%12 = OpCopyObject %41 %13
OpStore %5 %12
%15 = OpLoad %41 %3
%14 = OpCopyObject %41 %15
OpStore %6 %14
%27 = OpConvertPtrToU %36 %4
%16 = OpCopyObject %36 %27
OpStore %7 %16
%18 = OpLoad %41 %5
%28 = OpConvertUToPtr %45 %18
%17 = OpLoad %41 %28
OpStore %8 %17
%19 = OpLoad %36 %7
%20 = OpLoad %41 %8
%29 = OpConvertUToPtr %46 %19
OpStore %29 %20
%22 = OpLoad %36 %7
%26 = OpIAdd %36 %22 %25
%30 = OpConvertUToPtr %46 %26
%21 = OpLoad %41 %30
OpStore %9 %21
%23 = OpLoad %41 %6
%24 = OpLoad %41 %9
%31 = OpConvertUToPtr %45 %23
OpStore %31 %24
OpReturn
OpFunctionEnd

View file

@ -0,0 +1,21 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry sqrt(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .f32 temp1;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.f32 temp1, [in_addr];
sqrt.approx.f32 temp1, temp1;
st.f32 [out_addr], temp1;
ret;
}

View file

@ -0,0 +1,56 @@
; SPIR-V
; Version: 1.3
; Generator: rspirv
; Bound: 31
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
; OpCapability FunctionFloatControlINTEL
; OpExtension "SPV_INTEL_float_controls2"
%23 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "sqrt"
OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
%24 = OpTypeVoid
%25 = OpTypeInt 64 0
%26 = OpTypeFunction %24 %25 %25
%27 = OpTypePointer Function %25
%28 = OpTypeFloat 32
%29 = OpTypePointer Function %28
%30 = OpTypePointer Generic %28
%1 = OpFunction %24 None %26
%7 = OpFunctionParameter %25
%8 = OpFunctionParameter %25
%21 = OpLabel
%2 = OpVariable %27 Function
%3 = OpVariable %27 Function
%4 = OpVariable %27 Function
%5 = OpVariable %27 Function
%6 = OpVariable %29 Function
OpStore %2 %7
OpStore %3 %8
%10 = OpLoad %25 %2
%9 = OpCopyObject %25 %10
OpStore %4 %9
%12 = OpLoad %25 %3
%11 = OpCopyObject %25 %12
OpStore %5 %11
%14 = OpLoad %25 %4
%19 = OpConvertUToPtr %30 %14
%13 = OpLoad %28 %19
OpStore %6 %13
%16 = OpLoad %28 %6
%15 = OpExtInst %28 %23 native_sqrt %16
OpStore %6 %15
%17 = OpLoad %25 %5
%18 = OpLoad %28 %6
%20 = OpConvertUToPtr %30 %17
OpStore %20 %18
OpReturn
OpFunctionEnd

View file

@ -1,8 +1,11 @@
use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
use std::collections::{hash_map, HashMap, HashSet};
use std::{borrow::Cow, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryInto,
};
use rspirv::binary::Assemble;
@ -1499,6 +1502,15 @@ fn convert_to_typed_statements(
ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction(
ast::Instruction::AtomCas(d, a.cast()),
)),
ast::Instruction::Div(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Div(d, a.cast())))
}
ast::Instruction::Sqrt(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Sqrt(d, a.cast())))
}
ast::Instruction::Rsqrt(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast())))
}
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@ -1982,7 +1994,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
| ArgumentSemantics::DefaultRelaxed
| ArgumentSemantics::PhysicalPointer => {
if desc.sema == ArgumentSemantics::PhysicalPointer {
typ = ast::Type::Scalar(ast::ScalarType::U64);
typ = self.id_def.get_typed(reg)?;
}
let (width, kind) = match typ {
ast::Type::Scalar(scalar_t) => {
@ -2013,7 +2025,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
value: ast::ImmediateValue::S64(-(offset as i64)),
value: ast::ImmediateValue::U64(-(offset as i64) as u64),
}));
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Sub(
@ -2765,6 +2777,34 @@ fn emit_function_body_ops(
arg.src2,
)?;
}
ast::Instruction::Div(details, arg) => match details {
ast::DivDetails::Unsigned(t) => {
let result_type = map.get_or_add_scalar(builder, (*t).into());
builder.u_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::DivDetails::Signed(t) => {
let result_type = map.get_or_add_scalar(builder, (*t).into());
builder.s_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::DivDetails::Float(t) => {
let result_type = map.get_or_add_scalar(builder, t.typ.into());
builder.f_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
emit_float_div_decoration(builder, arg.dst, t.kind);
}
},
ast::Instruction::Sqrt(details, a) => {
emit_sqrt(builder, map, opencl, details, a)?;
}
ast::Instruction::Rsqrt(details, a) => {
let result_type = map.get_or_add_scalar(builder, details.typ.into());
builder.ext_inst(
result_type,
Some(a.dst),
opencl,
spirv::CLOp::native_rsqrt as spirv::Word,
&[a.src],
)?;
}
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@ -2795,6 +2835,47 @@ fn emit_function_body_ops(
Ok(())
}
fn emit_sqrt(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
details: &ast::SqrtDetails,
a: &ast::Arg2<ExpandedArgParams>,
) -> Result<(), TranslateError> {
let result_type = map.get_or_add_scalar(builder, details.typ.into());
let (ocl_op, rounding) = match details.kind {
ast::SqrtKind::Approx => (spirv::CLOp::native_sqrt, None),
ast::SqrtKind::Rounding(rnd) => (spirv::CLOp::sqrt, Some(rnd)),
};
builder.ext_inst(
result_type,
Some(a.dst),
opencl,
ocl_op as spirv::Word,
&[a.src],
)?;
emit_rounding_decoration(builder, a.dst, rounding);
Ok(())
}
fn emit_float_div_decoration(builder: &mut dr::Builder, dst: spirv::Word, kind: ast::DivFloatKind) {
match kind {
ast::DivFloatKind::Approx => {
builder.decorate(
dst,
spirv::Decoration::FPFastMathMode,
&[dr::Operand::FPFastMathMode(
spirv::FPFastMathMode::ALLOW_RECIP,
)],
);
}
ast::DivFloatKind::Rounding(rnd) => {
emit_rounding_decoration(builder, dst, Some(rnd));
}
ast::DivFloatKind::Full => {}
}
}
fn emit_atom(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@ -3307,7 +3388,25 @@ fn emit_setp(
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
_ => todo!(),
(ast::SetpCompareOp::NanEq, _) => {
builder.f_unord_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanNotEq, _) => {
builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanLess, _) => {
builder.f_unord_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanLessOrEq, _) => {
builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanGreater, _) => {
builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanGreaterOrEq, _) => {
builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
_ => todo!()
}?;
Ok(())
}
@ -3486,8 +3585,8 @@ fn emit_implicit_conversion(
let from_parts = cv.from.to_parts();
let to_parts = cv.to.to_parts();
match (from_parts.kind, to_parts.kind, cv.kind) {
(_, _, ConversionKind::PtrToBit) => {
let dst_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
(_, _, ConversionKind::PtrToBit(typ)) => {
let dst_type = map.get_or_add_scalar(builder, typ.into());
builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
}
(_, _, ConversionKind::BitToPtr(_)) => {
@ -4570,6 +4669,15 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::AtomCas(d, a) => {
ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?)
}
ast::Instruction::Div(d, a) => {
ast::Instruction::Div(d, a.map_non_shift(visitor, &d.get_type(), false)?)
}
ast::Instruction::Sqrt(d, a) => {
ast::Instruction::Sqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
}
ast::Instruction::Rsqrt(d, a) => {
ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
}
})
}
}
@ -4794,32 +4902,7 @@ impl ast::Instruction<ExpandedArgParams> {
fn jump_target(&self) -> Option<spirv::Word> {
match self {
ast::Instruction::Bra(_, a) => Some(a.src),
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)
| ast::Instruction::Mul(_, _)
| ast::Instruction::Add(_, _)
| ast::Instruction::Setp(_, _)
| ast::Instruction::SetpBool(_, _)
| ast::Instruction::Not(_, _)
| ast::Instruction::Cvt(_, _)
| ast::Instruction::Cvta(_, _)
| ast::Instruction::Shl(_, _)
| ast::Instruction::Shr(_, _)
| ast::Instruction::St(_, _)
| ast::Instruction::Ret(_)
| ast::Instruction::Abs(_, _)
| ast::Instruction::Call(_)
| ast::Instruction::Or(_, _)
| ast::Instruction::Sub(_, _)
| ast::Instruction::Min(_, _)
| ast::Instruction::Max(_, _)
| ast::Instruction::Rcp(_, _)
| ast::Instruction::And(_, _)
| ast::Instruction::Selp(_, _)
| ast::Instruction::Bar(_, _)
| ast::Instruction::Atom(_, _)
| ast::Instruction::AtomCas(_, _)
| ast::Instruction::Mad(_, _) => None,
_ => None,
}
}
@ -4856,6 +4939,9 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None,
ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None,
ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None,
ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None,
ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None,
ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
@ -4884,14 +4970,20 @@ impl ast::Instruction<ExpandedArgParams> {
ast::CvtDetails::FloatFromFloat(ast::CvtDesc { flush_to_zero, .. }),
_,
)
| ast::Instruction::Cvt(
ast::CvtDetails::FloatFromInt(ast::CvtDesc { flush_to_zero, .. }),
_,
)
| ast::Instruction::Cvt(
ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }),
_,
) => flush_to_zero.map(|ftz| (ftz, 4)),
ast::Instruction::Div(ast::DivDetails::Float(details), _) => details
.flush_to_zero
.map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
ast::Instruction::Sqrt(details, _) => details
.flush_to_zero
.map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
ast::Instruction::Rsqrt(details, _) => Some((
details.flush_to_zero,
ast::ScalarType::from(details.typ).size_of(),
)),
}
}
}
@ -4978,13 +5070,13 @@ struct ImplicitConversion {
kind: ConversionKind,
}
#[derive(Debug, PartialEq, Copy, Clone)]
#[derive(PartialEq, Copy, Clone)]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
BitToPtr(ast::LdStateSpace),
PtrToBit,
PtrToBit(ast::UIntType),
PtrToPtr { spirv_ptr: bool },
}
@ -6027,6 +6119,16 @@ impl ast::MinMaxDetails {
}
}
impl ast::DivDetails {
fn get_type(&self) -> ast::Type {
ast::Type::Scalar(match self {
ast::DivDetails::Unsigned(t) => (*t).into(),
ast::DivDetails::Signed(t) => (*t).into(),
ast::DivDetails::Float(d) => d.typ.into(),
})
}
}
impl ast::AtomInnerDetails {
fn get_type(&self) -> ast::ScalarType {
match self {
@ -6193,6 +6295,15 @@ fn bitcast_physical_pointer(
Err(TranslateError::Unreachable)
}
}
ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32)
| ast::Type::Scalar(ast::ScalarType::S32) => {
if let Some(ast::LdStateSpace::Shared) = ss {
Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared)))
} else {
Err(TranslateError::MismatchedType)
}
}
ast::Type::Pointer(op_scalar_t, op_space) => {
if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
if op_space == instr_space {
@ -6220,10 +6331,16 @@ fn bitcast_physical_pointer(
fn force_bitcast_ptr_to_bit(
_: &ast::Type,
_: &ast::Type,
instr_type: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
Ok(Some(ConversionKind::PtrToBit))
// TODO: verify this on f32, u16 and the like
if let ast::Type::Scalar(scalar_t) = instr_type {
if let Ok(int_type) = (*scalar_t).try_into() {
return Ok(Some(ConversionKind::PtrToBit(int_type)));
}
}
Err(TranslateError::MismatchedType)
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
@ -6542,9 +6659,9 @@ mod tests {
&ast::Type::Scalar(*instr_type),
);
if instr_idx == op_idx {
assert_eq!(conversion, None);
assert!(conversion == None);
} else {
assert_eq!(conversion, conv_table[instr_idx][op_idx]);
assert!(conversion == conv_table[instr_idx][op_idx]);
}
}
}