Fix remaining bugs in SSA renaming

This commit is contained in:
Andrzej Janik 2020-05-02 01:08:44 +02:00
parent 6700f8bcc2
commit a69c12a387

View file

@ -105,63 +105,122 @@ fn emit_function<'a>(
spirv::FunctionControl::NONE,
map.fn_void(),
)?;
for arg in f.args.iter() {
let arg_type = map.get_or_add(builder, SpirvType::Base(arg.a_type));
builder.function_parameter(arg_type)?;
}
let (mut normalized_ids, max_id) = normalize_identifiers(f.body);
let mut contant_ids = HashMap::new();
collect_arg_ids(&mut contant_ids, &f.args);
collect_label_ids(&mut contant_ids, &f.body);
let (mut normalized_ids, unique_ids) = normalize_identifiers(f.body, &contant_ids);
let bbs = get_basic_blocks(&normalized_ids);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
let dom_fronts = dominance_frontiers(&bbs, &doms);
ssa_legalize(&mut normalized_ids, max_id, bbs, &doms, &dom_fronts);
ssa_legalize(
&mut normalized_ids,
contant_ids.len() as u32,
unique_ids,
&bbs,
&doms,
&dom_fronts,
);
emit_function_body_ops(builder);
builder.ret()?;
builder.end_function()?;
Ok(func_id)
}
fn collect_arg_ids<'a>(result: &mut HashMap<&'a str, spirv::Word>, args: &'a [ast::Argument<'a>]) {
let mut id = result.len() as u32;
for arg in args {
result.insert(arg.name, id);
id += 1;
}
}
fn collect_label_ids<'a>(
result: &mut HashMap<&'a str, spirv::Word>,
fn_body: &[ast::Statement<&'a str>],
) {
let mut id = result.len() as u32;
for s in fn_body {
match s {
ast::Statement::Label(name) => {
result.insert(name, id);
id += 1;
}
ast::Statement::Instruction(_, _) => (),
ast::Statement::Variable(_) => (),
}
}
}
fn emit_function_body_ops(builder: &mut dr::Builder) {
todo!()
}
// This functions converts string identifiers to numeric identifiers in a normalized form, where
// - identifiers in the range [0..constant_identifiers.len()) are arguments and labels
// - identifiers in the range [constant_identifiers.len()..result.1) are variables
// TODO: support scopes
fn normalize_identifiers<'a>(func: Vec<ast::Statement<&'a str>>) -> (Vec<Statement>, spirv::Word) {
fn normalize_identifiers<'a>(
func: Vec<ast::Statement<&'a str>>,
constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined
) -> (Vec<Statement>, spirv::Word) {
let mut result = Vec::with_capacity(func.len());
let mut id: u32 = 0;
let mut known_ids = HashMap::new();
let mut id: u32 = constant_identifiers.len() as u32;
let mut remapped_ids = HashMap::new();
let mut get_or_add = |key| {
*known_ids.entry(key).or_insert_with(|| {
let to_insert = id;
id += 1;
to_insert
})
constant_identifiers.get(key).map_or_else(
|| {
*remapped_ids.entry(key).or_insert_with(|| {
let to_insert = id;
id += 1;
to_insert
})
},
|id| *id,
)
};
for s in func {
if let Some(s) = Statement::from_ast(s, &mut get_or_add) {
result.push(s);
}
}
(result, id - 1)
(result, id)
}
fn ssa_legalize(
func: &mut [Statement],
max_id: spirv::Word,
bbs: Vec<BasicBlock>,
doms: &Vec<BBIndex>,
dom_fronts: &Vec<HashSet<BBIndex>>,
constant_ids: spirv::Word,
unique_ids: spirv::Word,
bbs: &[BasicBlock],
doms: &[BBIndex],
dom_fronts: &[HashSet<BBIndex>],
) -> Vec<Vec<PhiDef>> {
let phis = gather_phi_sets(&func, max_id, &bbs, dom_fronts);
apply_ssa_renaming(func, &bbs, doms, max_id, &phis)
let phis = gather_phi_sets(
&func,
constant_ids,
unique_ids,
&bbs,
dom_fronts,
);
apply_ssa_renaming(func, &bbs, doms, constant_ids, unique_ids, &phis)
}
// "Modern Compiler Implementation in Java" - Algorithm 19.7
/* "Modern Compiler Implementation in Java" - Algorithm 19.7
* This algorithm modifies passed function body in-place by renumbering ids,
* result ids can be divided into following categories
* - if id < constant_ids
* it's a non-redefinable id
* - if id >= constant_ids && id < all_ids
* then it's an undefined id (a0, b0, c0)
* - if id >= all_ids
* then it's a normally redefined id
*/
fn apply_ssa_renaming(
func: &mut [Statement],
bbs: &[BasicBlock],
doms: &[BBIndex],
max_id: spirv::Word,
constant_ids: spirv::Word,
all_ids: spirv::Word,
old_phi: &[HashSet<spirv::Word>],
) -> Vec<Vec<PhiDef>> {
let mut dom_tree = vec![Vec::new(); bbs.len()];
@ -182,7 +241,7 @@ fn apply_ssa_renaming(
.collect::<HashMap<_, _>>()
})
.collect::<Vec<_>>();
let mut ssa_state = SSARewriteState::new(max_id);
let mut ssa_state = SSARewriteState::new(constant_ids, all_ids);
// once again, we do explicit stack
let mut state = Vec::new();
state.push((BBIndex(0), 0));
@ -300,32 +359,46 @@ fn get_bb_body_idx(func: &[Statement], all_bb: &[BasicBlock], bb: BBIndex) -> (u
// We assume here that the variables are defined in the dense sequence 0..max
struct SSARewriteState {
next: spirv::Word,
constant_ids: spirv::Word,
stack: Vec<Vec<spirv::Word>>,
}
impl SSARewriteState {
fn new(max: spirv::Word) -> Self {
let len = max + 1;
let stack = (0..len).map(|x| vec![x + len]).collect::<Vec<_>>();
impl<'a> SSARewriteState {
fn new(constant_ids: spirv::Word, all_ids: spirv::Word) -> Self {
let to_redefine = all_ids - constant_ids;
let stack = (0..to_redefine)
.map(|x| vec![x + constant_ids])
.collect::<Vec<_>>();
SSARewriteState {
next: 2 * len,
next: all_ids,
constant_ids: constant_ids,
stack,
}
}
fn get(&self, x: spirv::Word) -> spirv::Word {
*self.stack[x as usize].last().unwrap()
if x < self.constant_ids {
x
} else {
*self.stack[(x - self.constant_ids) as usize].last().unwrap()
}
}
fn redefine(&mut self, x: spirv::Word) -> spirv::Word {
let result = self.next;
self.next += 1;
self.stack[x as usize].push(result);
return result;
if x < self.constant_ids {
x
} else {
let result = self.next;
self.next += 1;
self.stack[(x - self.constant_ids) as usize].push(result);
result
}
}
fn pop(&mut self, x: spirv::Word) {
self.stack[x as usize].pop();
if x >= self.constant_ids {
self.stack[(x - self.constant_ids) as usize].pop();
}
}
}
@ -333,24 +406,28 @@ impl SSARewriteState {
// Calculates semi-pruned phis
fn gather_phi_sets(
func: &[Statement],
max_id: spirv::Word,
constant_ids: spirv::Word,
all_ids: spirv::Word,
cfg: &[BasicBlock],
dom_fronts: &[HashSet<BBIndex>],
) -> Vec<HashSet<spirv::Word>> {
let mut result = vec![HashSet::new(); cfg.len()];
let mut globals = HashSet::new();
let mut blocks = vec![(Vec::new(), HashSet::new()); (max_id as usize) + 1];
let mut blocks = vec![(Vec::new(), HashSet::new()); (all_ids - constant_ids) as usize];
for bb in 0..cfg.len() {
let mut var_kill = HashSet::new();
let mut visitor = |is_dst, id: &u32| {
if is_dst {
var_kill.insert(*id);
let (ref mut stack, ref mut set) = blocks[*id as usize];
stack.push(BBIndex(bb));
set.insert(BBIndex(bb));
} else {
if !var_kill.contains(id) {
globals.insert(*id);
if *id >= constant_ids {
let id = id - constant_ids;
if is_dst {
var_kill.insert(id);
let (ref mut stack, ref mut set) = blocks[id as usize];
stack.push(BBIndex(bb));
set.insert(BBIndex(bb));
} else {
if !var_kill.contains(&id) {
globals.insert(id);
}
}
}
};
@ -360,6 +437,7 @@ fn gather_phi_sets(
pred.as_ref().map(|p| p.visit_id(&mut visitor));
inst.visit_id(&mut visitor);
}
// label redefinition is a compile-time error
Statement::Label(_) => (),
}
}
@ -370,7 +448,7 @@ fn gather_phi_sets(
if let Some(bb) = work_stack.pop() {
work_set.remove(&bb);
for d_bb in &dom_fronts[bb.0] {
if result[d_bb.0].insert(id) {
if result[d_bb.0].insert(id + constant_ids) {
if work_set.insert(*d_bb) {
work_stack.push(*d_bb);
}
@ -627,6 +705,8 @@ impl Statement {
}
}
// WARNING: It is very important to first visit src operands and then dst operands,
// otherwise SSA renaming will yield weird results
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
match self {
Statement::Label(id) => f(true, id),
@ -793,8 +873,8 @@ impl<T> ast::Arg2<T> {
}
fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
f(true, &mut self.dst);
self.src.visit_id_mut(f);
f(true, &mut self.dst);
}
}
@ -818,8 +898,8 @@ impl<T> ast::Arg2Mov<T> {
}
fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
f(true, &mut self.dst);
self.src.visit_id_mut(f);
f(true, &mut self.dst);
}
}
@ -845,9 +925,9 @@ impl<T> ast::Arg3<T> {
}
fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
f(true, &mut self.dst);
self.src1.visit_id_mut(f);
self.src2.visit_id_mut(f);
f(true, &mut self.dst);
}
}
@ -875,10 +955,10 @@ impl<T> ast::Arg4<T> {
}
fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
f(true, &mut self.dst1);
self.dst2.as_mut().map(|i| f(true, i));
self.src1.visit_id_mut(f);
self.src2.visit_id_mut(f);
f(true, &mut self.dst1);
self.dst2.as_mut().map(|i| f(true, i));
}
}
@ -909,11 +989,11 @@ impl<T> ast::Arg5<T> {
}
fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
f(true, &mut self.dst1);
self.dst2.as_mut().map(|i| f(true, i));
self.src1.visit_id_mut(f);
self.src2.visit_id_mut(f);
self.src3.visit_id_mut(f);
f(true, &mut self.dst1);
self.dst2.as_mut().map(|i| f(true, i));
}
}
@ -1335,7 +1415,9 @@ mod tests {
.parse(&mut errors, func)
.unwrap();
assert_eq!(errors.len(), 0);
let (normalized_ids, _) = normalize_identifiers(ast);
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &ast);
let (normalized_ids, _) = normalize_identifiers(ast, &constant_ids);
let mut bbs = get_basic_blocks(&normalized_ids);
bbs.iter_mut().for_each(sort_pred_succ);
assert_eq!(
@ -1479,21 +1561,30 @@ mod tests {
.parse(&mut errors, func)
.unwrap();
assert_eq!(errors.len(), 0);
let (normalized_ids, max_id) = normalize_identifiers(fn_ast);
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &fn_ast);
assert_eq!(constant_ids.len(), 6);
let (normalized_ids, max_id) = normalize_identifiers(fn_ast, &constant_ids);
let bbs = get_basic_blocks(&normalized_ids);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
let dom_fronts = dominance_frontiers(&bbs, &doms);
let phi = gather_phi_sets(&normalized_ids, max_id, &bbs, &dom_fronts);
let phi = gather_phi_sets(
&normalized_ids,
constant_ids.len() as u32,
max_id,
&bbs,
&dom_fronts,
);
assert_eq!(
phi,
vec![
HashSet::new(),
to_hashset(vec![1, 2]),
to_hashset(vec![7, 8]),
HashSet::new(),
HashSet::new(),
HashSet::new(),
to_hashset(vec![1, 2]),
to_hashset(vec![7, 8]),
HashSet::new()
]
);
@ -1503,12 +1594,157 @@ mod tests {
v.into_iter().collect::<HashSet<T>>()
}
fn assert_dst_unique(func: &[Statement]) {
#[test]
fn ssa_rename_19_4() {
let func = FIG_19_4;
let mut errors = Vec::new();
let fn_ast = ptx::FunctionBodyParser::new()
.parse(&mut errors, func)
.unwrap();
assert_eq!(errors.len(), 0);
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &fn_ast);
let (mut func, unique_ids) = normalize_identifiers(fn_ast, &constant_ids);
let bbs = get_basic_blocks(&func);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
let dom_fronts = dominance_frontiers(&bbs, &doms);
let mut ssa_phis = ssa_legalize(
&mut func,
constant_ids.len() as u32,
unique_ids,
&bbs,
&doms,
&dom_fronts,
);
assert_phi_dst_id(unique_ids, &ssa_phis);
assert_dst_unique(&func, &ssa_phis);
sort_phi(&mut ssa_phis);
let i1 = unique_ids;
let j1 = unique_ids + 1;
let j2 = get_dst_from_src(&ssa_phis[1], j1);
let j3 = get_dst(&func[10]);
let j4 = get_dst_from_src(&ssa_phis[5], j3);
let j5 = get_dst(&func[14]);
let k1 = unique_ids + 2;
let k2 = get_dst_from_src(&ssa_phis[1], k1);
let k3 = get_dst(&func[11]);
let k4 = get_dst_from_src(&ssa_phis[5], k3);
let k5 = get_dst(&func[15]);
let p1 = get_dst(&func[4]);
let q1 = get_dst(&func[7]);
let block_2 = get_dst(&func[3]);
let block_3 = get_dst(&func[6]);
let block_5 = get_dst(&func[9]);
let block_6 = get_dst(&func[13]);
let block_7 = get_dst(&func[16]);
let block_4 = get_dst(&func[18]);
{
assert_eq!(get_ids(&func[0]), vec![i1]);
assert_eq!(get_ids(&func[1]), vec![j1]);
assert_eq!(get_ids(&func[2]), vec![k1]);
assert_eq!(
ssa_phis[1],
to_phi(vec![(j2, vec![j4, j1]), (k2, vec![k4, k1])])
);
assert_eq!(get_ids(&func[3]), vec![block_2]);
assert_eq!(get_ids(&func[4]), vec![p1, k2]);
assert_eq!(get_ids(&func[5]), vec![p1, block_4]);
assert_eq!(get_ids(&func[6]), vec![block_3]);
assert_eq!(get_ids(&func[7]), vec![q1, j2]);
assert_eq!(get_ids(&func[8]), vec![q1, block_6]);
assert_eq!(get_ids(&func[9]), vec![block_5]);
assert_eq!(get_ids(&func[10]), vec![j3, i1]);
assert_eq!(get_ids(&func[11]), vec![k3, k2]);
assert_eq!(get_ids(&func[12]), vec![block_7]);
assert_eq!(get_ids(&func[13]), vec![block_6]);
assert_eq!(get_ids(&func[14]), vec![j5, k2]);
assert_eq!(get_ids(&func[15]), vec![k5, k2]);
assert_eq!(
ssa_phis[5],
to_phi(vec![(j4, vec![j3, j5]), (k4, vec![k3, k5])])
);
assert_eq!(get_ids(&func[16]), vec![block_7]);
assert_eq!(get_ids(&func[17]), vec![block_2]);
assert_eq!(get_ids(&func[18]), vec![block_4]);
assert_eq!(get_ids(&func[19]), vec![]);
}
}
fn assert_phi_dst_id(max_id: spirv::Word, phis: &[Vec<PhiDef>]) {
for phi_set in phis {
for phi in phi_set {
assert!(phi.dst > max_id);
}
}
}
fn assert_dst_unique(func: &[Statement], phis: &[Vec<PhiDef>]) {
let mut seen = HashSet::new();
for s in func {
s.for_dst_id(&mut |id| {
assert!(seen.insert(id));
});
}
for phi_set in phis {
for phi in phi_set {
assert!(seen.insert(phi.dst));
}
}
}
fn get_ids(s: &Statement) -> Vec<spirv::Word> {
let mut result = Vec::new();
s.visit_id(&mut |_, id| {
result.push(*id);
});
result
}
fn sort_phi(phis: &mut [Vec<PhiDef>]) {
for phi_set in phis {
phi_set.sort_by_key(|phi| phi.dst);
}
}
fn to_phi(raw: Vec<(spirv::Word, Vec<spirv::Word>)>) -> Vec<PhiDef> {
let result = raw
.into_iter()
.map(|(dst, src)| PhiDef {
dst: dst,
src: src.into_iter().collect::<HashSet<_>>(),
})
.collect::<Vec<_>>();
let mut result = [result];
sort_phi(&mut result);
let [result] = result;
result
}
fn get_dst(s: &Statement) -> spirv::Word {
let mut result = None;
s.visit_id(&mut |is_dst, id| {
if is_dst {
assert_eq!(result.replace(*id), None);
}
});
result.unwrap()
}
fn get_dst_from_src(phi: &[PhiDef], src: spirv::Word) -> spirv::Word {
for phi_set in phi {
if phi_set.src.contains(&src) {
return phi_set.dst;
}
}
panic!()
}
}