Add support for parsing instruction cvt

This commit is contained in:
Andrzej Janik 2020-08-03 01:42:13 +02:00
parent ff449289eb
commit a10ee48e91
4 changed files with 302 additions and 32 deletions

View file

@ -9,6 +9,8 @@ quick_error! {
display("{}", err)
cause(err)
}
SyntaxError {}
NonF32Ftz {}
}
}
@ -101,9 +103,11 @@ pub enum ScalarType {
impl From<IntType> for ScalarType {
fn from(t: IntType) -> Self {
match t {
IntType::S8 => ScalarType::S8,
IntType::S16 => ScalarType::S16,
IntType::S32 => ScalarType::S32,
IntType::S64 => ScalarType::S64,
IntType::U8 => ScalarType::U8,
IntType::U16 => ScalarType::U16,
IntType::U32 => ScalarType::U32,
IntType::U64 => ScalarType::U64,
@ -113,14 +117,38 @@ impl From<IntType> for ScalarType {
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum IntType {
U8,
U16,
U32,
U64,
S8,
S16,
S32,
S64,
}
impl IntType {
pub fn is_signed(self) -> bool {
match self {
IntType::U8 | IntType::U16 | IntType::U32 | IntType::U64 => false,
IntType::S8 | IntType::S16 | IntType::S32 | IntType::S64 => true,
}
}
pub fn width(self) -> u8 {
match self {
IntType::U8 => 1,
IntType::U16 => 2,
IntType::U32 => 4,
IntType::U64 => 8,
IntType::S8 => 1,
IntType::S16 => 2,
IntType::S32 => 4,
IntType::S64 => 8,
}
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum FloatType {
F16,
@ -178,7 +206,7 @@ pub enum Instruction<P: ArgParams> {
SetpBool(SetpBoolData, Arg5<P>),
Not(NotType, Arg2<P>),
Bra(BraData, Arg1<P>),
Cvt(CvtData, Arg2<P>),
Cvt(CvtDetails, Arg2<P>),
Shl(ShlType, Arg3<P>),
St(StData, Arg2St<P>),
Ret(RetData),
@ -398,7 +426,88 @@ pub struct BraData {
pub uniform: bool,
}
pub struct CvtData {}
pub enum CvtDetails {
IntFromInt(CvtIntToIntDesc),
FloatFromFloat(CvtDesc<FloatType, FloatType>),
IntFromFloat(CvtDesc<IntType, FloatType>),
FloatFromInt(CvtDesc<FloatType, IntType>),
}
pub struct CvtIntToIntDesc {
pub dst: IntType,
pub src: IntType,
pub saturate: bool,
}
pub struct CvtDesc<Dst, Src> {
pub rounding: Option<RoundingMode>,
pub flush_to_zero: bool,
pub saturate: bool,
pub dst: Dst,
pub src: Src,
}
impl CvtDetails {
pub fn new_int_from_int_checked(
saturate: bool,
dst: IntType,
src: IntType,
err: &mut Vec<PtxError>,
) -> Self {
if saturate {
if src.is_signed() {
if dst.is_signed() && dst.width() >= src.width() {
err.push(PtxError::SyntaxError);
}
} else {
if dst == src || dst.width() >= src.width() {
err.push(PtxError::SyntaxError);
}
}
}
CvtDetails::IntFromInt(CvtIntToIntDesc { dst, src, saturate })
}
pub fn new_float_from_int_checked(
rounding: RoundingMode,
flush_to_zero: bool,
saturate: bool,
dst: FloatType,
src: IntType,
err: &mut Vec<PtxError>,
) -> Self {
if flush_to_zero && dst != FloatType::F32 {
err.push(PtxError::NonF32Ftz);
}
CvtDetails::FloatFromInt(CvtDesc {
dst,
src,
saturate,
flush_to_zero,
rounding: Some(rounding),
})
}
pub fn new_int_from_float_checked(
rounding: RoundingMode,
flush_to_zero: bool,
saturate: bool,
dst: IntType,
src: FloatType,
err: &mut Vec<PtxError>,
) -> Self {
if flush_to_zero && src != FloatType::F32 {
err.push(PtxError::NonF32Ftz);
}
CvtDetails::IntFromFloat(CvtDesc {
dst,
src,
saturate,
flush_to_zero,
rounding: Some(rounding),
})
}
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum ShlType {

View file

@ -403,13 +403,13 @@ InstMulMode: ast::MulDetails = {
typ: t,
control: ctr
}),
<r:RoundingMode?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulDetails::Float(ast::MulFloatDesc {
<r:RoundingModeFloat?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulDetails::Float(ast::MulFloatDesc {
typ: ast::FloatType::F32,
rounding: r,
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
}),
<r:RoundingMode?> ".f64" => ast::MulDetails::Float(ast::MulFloatDesc {
<r:RoundingModeFloat?> ".f64" => ast::MulDetails::Float(ast::MulFloatDesc {
typ: ast::FloatType::F64,
rounding: r,
flush_to_zero: false,
@ -436,13 +436,20 @@ MulIntControl: ast::MulIntControl = {
};
#[inline]
RoundingMode : ast::RoundingMode = {
RoundingModeFloat : ast::RoundingMode = {
".rn" => ast::RoundingMode::NearestEven,
".rz" => ast::RoundingMode::Zero,
".rm" => ast::RoundingMode::NegativeInf,
".rp" => ast::RoundingMode::PositiveInf,
};
RoundingModeInt : ast::RoundingMode = {
".rni" => ast::RoundingMode::NearestEven,
".rzi" => ast::RoundingMode::Zero,
".rmi" => ast::RoundingMode::NegativeInf,
".rpi" => ast::RoundingMode::PositiveInf,
};
IntType : ast::IntType = {
".u16" => ast::IntType::U16,
".u32" => ast::IntType::U32,
@ -468,13 +475,13 @@ InstAddMode: ast::AddDetails = {
typ: ast::IntType::S32,
saturate: true,
}),
<rn:RoundingMode?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::AddDetails::Float(ast::AddFloatDesc {
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::AddDetails::Float(ast::AddFloatDesc {
typ: ast::FloatType::F32,
rounding: rn,
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
}),
<rn:RoundingMode?> ".f64" => ast::AddDetails::Float(ast::AddFloatDesc {
<rn:RoundingModeFloat?> ".f64" => ast::AddDetails::Float(ast::AddFloatDesc {
typ: ast::FloatType::F64,
rounding: rn,
flush_to_zero: false,
@ -580,28 +587,153 @@ InstBra: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"cvt" CvtRnd? ".ftz"? ".sat"? CvtType CvtType <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtData{}, a)
}
"cvt" <s:".sat"?> <dst_t:CvtTypeInt> <src_t:CvtTypeInt> <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::new_int_from_int_checked(
s.is_some(),
dst_t,
src_t,
errors
),
a)
},
"cvt" <r:RoundingModeFloat> <f:".ftz"?> <s:".sat"?> <dst_t:CvtTypeFloat> <src_t:CvtTypeInt> <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::new_float_from_int_checked(
r,
f.is_some(),
s.is_some(),
dst_t,
src_t,
errors
),
a)
},
"cvt" <r:RoundingModeInt> <f:".ftz"?> <s:".sat"?> <dst_t:CvtTypeInt> <src_t:CvtTypeFloat> <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::new_int_from_float_checked(
r,
f.is_some(),
s.is_some(),
dst_t,
src_t,
errors
),
a)
},
"cvt" <r:RoundingModeInt?> <s:".sat"?> ".f16" ".f16" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc {
rounding: r,
flush_to_zero: false,
saturate: s.is_some(),
dst: ast::FloatType::F16,
src: ast::FloatType::F16
}
), a)
},
"cvt" <f:".ftz"?> <s:".sat"?> ".f32" ".f16" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc {
rounding: None,
flush_to_zero: f.is_some(),
saturate: s.is_some(),
dst: ast::FloatType::F32,
src: ast::FloatType::F16
}
), a)
},
"cvt" <s:".sat"?> ".f64" ".f16" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc {
rounding: None,
flush_to_zero: false,
saturate: s.is_some(),
dst: ast::FloatType::F64,
src: ast::FloatType::F16
}
), a)
},
"cvt" <r:RoundingModeFloat> <f:".ftz"?> <s:".sat"?> ".f16" ".f32" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc {
rounding: Some(r),
flush_to_zero: f.is_some(),
saturate: s.is_some(),
dst: ast::FloatType::F16,
src: ast::FloatType::F32
}
), a)
},
"cvt" <r:RoundingModeFloat?> <f:".ftz"?> <s:".sat"?> ".f32" ".f32" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc {
rounding: r,
flush_to_zero: f.is_some(),
saturate: s.is_some(),
dst: ast::FloatType::F32,
src: ast::FloatType::F32
}
), a)
},
"cvt" <s:".sat"?> ".f64" ".f32" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc {
rounding: None,
flush_to_zero: false,
saturate: s.is_some(),
dst: ast::FloatType::F64,
src: ast::FloatType::F32
}
), a)
},
"cvt" <r:RoundingModeFloat> <s:".sat"?> ".f16" ".f64" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc {
rounding: Some(r),
flush_to_zero: false,
saturate: s.is_some(),
dst: ast::FloatType::F16,
src: ast::FloatType::F64
}
), a)
},
"cvt" <r:RoundingModeFloat> <f:".ftz"?> <s:".sat"?> ".f32" ".f64" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc {
rounding: Some(r),
flush_to_zero: s.is_some(),
saturate: s.is_some(),
dst: ast::FloatType::F32,
src: ast::FloatType::F64
}
), a)
},
"cvt" <r:RoundingModeFloat?> <s:".sat"?> ".f64" ".f64" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc {
rounding: r,
flush_to_zero: false,
saturate: s.is_some(),
dst: ast::FloatType::F64,
src: ast::FloatType::F64
}
), a)
},
};
CvtRnd = {
CvtIrnd,
CvtFrnd
}
CvtIrnd = {
".rni", ".rzi", ".rmi", ".rpi"
CvtTypeInt: ast::IntType = {
".u8" => ast::IntType::U8,
".u16" => ast::IntType::U16,
".u32" => ast::IntType::U32,
".u64" => ast::IntType::U64,
".s8" => ast::IntType::S8,
".s16" => ast::IntType::S16,
".s32" => ast::IntType::S32,
".s64" => ast::IntType::S64,
};
CvtFrnd = {
".rn", ".rz", ".rm", ".rp"
};
CvtType = {
".u8", ".u16", ".u32", ".u64",
".s8", ".s16", ".s32", ".s64",
".f16", ".f32", ".f64"
CvtTypeFloat: ast::FloatType = {
".f16" => ast::FloatType::F16,
".f32" => ast::FloatType::F32,
".f64" => ast::FloatType::F64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl

View file

@ -1580,13 +1580,6 @@ impl ast::MulDetails {
}
impl ast::IntType {
fn is_signed(self) -> bool {
match self {
ast::IntType::S16 | ast::IntType::S32 | ast::IntType::S64 => true,
ast::IntType::U16 | ast::IntType::U32 | ast::IntType::U64 => false,
}
}
fn try_new(t: ast::ScalarType) -> Option<Self> {
match t {
ast::ScalarType::U16 => Some(ast::IntType::U16),

36
ptx/tools/cvt.py Normal file
View file

@ -0,0 +1,36 @@
import os
import subprocess
import tempfile
types = ["u8", "u16", "u32", "u64", "s8", "s16", "s32", "s64", "f16", "f32", "f64"]
rnd = ["", ".rn", ".rni"]
ftz_all = ["", ".ftz"]
sat = ["", ".sat"]
for in_type in types:
for out_type in types:
for r in rnd:
for ftz in ftz_all:
for s in sat:
with tempfile.TemporaryDirectory() as dir:
f_name = os.path.join(dir, 'ptx')
out_name = os.path.join(dir, 'out')
with open(f_name, 'w') as f:
f.write(
f"""
.version 6.5
.target sm_30
.address_size 64
.visible .entry VecAdd_kernel()
{{
.reg.{in_type} r1;
.reg.{out_type} r2;
cvt{r}{ftz}{s}.{out_type}.{in_type} r2, r1;
ret;
}}
""")
err = subprocess.run(f"ptxas {f_name} -o {out_name}", capture_output = True)
if err.returncode == 0:
print(f"cvt{r}{ftz}{s}.{out_type}.{in_type}")
#else:
# print(f"[INVALID] cvt{r}{ftz}{s}.{out_type}.{in_type}")