Rollup merge of #73757 - oli-obk:const_prop_hardening, r=wesleywiser
Const prop: erase all block-only locals at the end of every block I messed up this erasure in https://github.com/rust-lang/rust/pull/73656#discussion_r446040140. I think it is too fragile to have the previous scheme. Let's benchmark the new scheme and see what happens. r? @wesleywiser cc @felix91gr
This commit is contained in:
commit
ccc1bf79c8
5 changed files with 90 additions and 16 deletions
|
@ -132,6 +132,10 @@ pub enum LocalValue<Tag = ()> {
|
|||
}
|
||||
|
||||
impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {
|
||||
/// Read the local's value or error if the local is not yet live or not live anymore.
|
||||
///
|
||||
/// Note: This may only be invoked from the `Machine::access_local` hook and not from
|
||||
/// anywhere else. You may be invalidating machine invariants if you do!
|
||||
pub fn access(&self) -> InterpResult<'tcx, Operand<Tag>> {
|
||||
match self.value {
|
||||
LocalValue::Dead => throw_ub!(DeadLocal),
|
||||
|
@ -144,6 +148,9 @@ impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {
|
|||
|
||||
/// Overwrite the local. If the local can be overwritten in place, return a reference
|
||||
/// to do so; otherwise return the `MemPlace` to consult instead.
|
||||
///
|
||||
/// Note: This may only be invoked from the `Machine::access_local_mut` hook and not from
|
||||
/// anywhere else. You may be invalidating machine invariants if you do!
|
||||
pub fn access_mut(
|
||||
&mut self,
|
||||
) -> InterpResult<'tcx, Result<&mut LocalValue<Tag>, MemPlace<Tag>>> {
|
||||
|
|
|
@ -11,7 +11,7 @@ use rustc_span::def_id::DefId;
|
|||
|
||||
use super::{
|
||||
AllocId, Allocation, AllocationExtra, CheckInAllocMsg, Frame, ImmTy, InterpCx, InterpResult,
|
||||
Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar,
|
||||
LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar,
|
||||
};
|
||||
|
||||
/// Data returned by Machine::stack_pop,
|
||||
|
@ -192,6 +192,8 @@ pub trait Machine<'mir, 'tcx>: Sized {
|
|||
) -> InterpResult<'tcx>;
|
||||
|
||||
/// Called to read the specified `local` from the `frame`.
|
||||
/// Since reading a ZST is not actually accessing memory or locals, this is never invoked
|
||||
/// for ZST reads.
|
||||
#[inline]
|
||||
fn access_local(
|
||||
_ecx: &InterpCx<'mir, 'tcx, Self>,
|
||||
|
@ -201,6 +203,21 @@ pub trait Machine<'mir, 'tcx>: Sized {
|
|||
frame.locals[local].access()
|
||||
}
|
||||
|
||||
/// Called to write the specified `local` from the `frame`.
|
||||
/// Since writing a ZST is not actually accessing memory or locals, this is never invoked
|
||||
/// for ZST reads.
|
||||
#[inline]
|
||||
fn access_local_mut<'a>(
|
||||
ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
|
||||
frame: usize,
|
||||
local: mir::Local,
|
||||
) -> InterpResult<'tcx, Result<&'a mut LocalValue<Self::PointerTag>, MemPlace<Self::PointerTag>>>
|
||||
where
|
||||
'tcx: 'mir,
|
||||
{
|
||||
ecx.stack_mut()[frame].locals[local].access_mut()
|
||||
}
|
||||
|
||||
/// Called before a basic block terminator is executed.
|
||||
/// You can use this to detect endlessly running programs.
|
||||
#[inline]
|
||||
|
|
|
@ -432,7 +432,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
|
|||
})
|
||||
}
|
||||
|
||||
/// This is used by [priroda](https://github.com/oli-obk/priroda) to get an OpTy from a local
|
||||
/// Read from a local. Will not actually access the local if reading from a ZST.
|
||||
/// Will not access memory, instead an indirect `Operand` is returned.
|
||||
///
|
||||
/// This is public because it is used by [priroda](https://github.com/oli-obk/priroda) to get an
|
||||
/// OpTy from a local
|
||||
pub fn access_local(
|
||||
&self,
|
||||
frame: &super::Frame<'mir, 'tcx, M::PointerTag, M::FrameExtra>,
|
||||
|
|
|
@ -741,7 +741,7 @@ where
|
|||
// but not factored as a separate function.
|
||||
let mplace = match dest.place {
|
||||
Place::Local { frame, local } => {
|
||||
match self.stack_mut()[frame].locals[local].access_mut()? {
|
||||
match M::access_local_mut(self, frame, local)? {
|
||||
Ok(local) => {
|
||||
// Local can be updated in-place.
|
||||
*local = LocalValue::Live(Operand::Immediate(src));
|
||||
|
@ -974,7 +974,7 @@ where
|
|||
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, Option<Size>)> {
|
||||
let (mplace, size) = match place.place {
|
||||
Place::Local { frame, local } => {
|
||||
match self.stack_mut()[frame].locals[local].access_mut()? {
|
||||
match M::access_local_mut(self, frame, local)? {
|
||||
Ok(&mut local_val) => {
|
||||
// We need to make an allocation.
|
||||
|
||||
|
@ -998,7 +998,7 @@ where
|
|||
}
|
||||
// Now we can call `access_mut` again, asserting it goes well,
|
||||
// and actually overwrite things.
|
||||
*self.stack_mut()[frame].locals[local].access_mut().unwrap().unwrap() =
|
||||
*M::access_local_mut(self, frame, local).unwrap().unwrap() =
|
||||
LocalValue::Live(Operand::Indirect(mplace));
|
||||
(mplace, Some(size))
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
use std::cell::Cell;
|
||||
|
||||
use rustc_ast::ast::Mutability;
|
||||
use rustc_data_structures::fx::FxHashSet;
|
||||
use rustc_hir::def::DefKind;
|
||||
use rustc_hir::HirId;
|
||||
use rustc_index::bit_set::BitSet;
|
||||
|
@ -28,7 +29,7 @@ use rustc_trait_selection::traits;
|
|||
use crate::const_eval::error_to_const_error;
|
||||
use crate::interpret::{
|
||||
self, compile_time_machine, AllocId, Allocation, Frame, ImmTy, Immediate, InterpCx, LocalState,
|
||||
LocalValue, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer,
|
||||
LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer,
|
||||
ScalarMaybeUninit, StackPopCleanup,
|
||||
};
|
||||
use crate::transform::{MirPass, MirSource};
|
||||
|
@ -151,11 +152,19 @@ impl<'tcx> MirPass<'tcx> for ConstProp {
|
|||
struct ConstPropMachine<'mir, 'tcx> {
|
||||
/// The virtual call stack.
|
||||
stack: Vec<Frame<'mir, 'tcx, (), ()>>,
|
||||
/// `OnlyInsideOwnBlock` locals that were written in the current block get erased at the end.
|
||||
written_only_inside_own_block_locals: FxHashSet<Local>,
|
||||
/// Locals that need to be cleared after every block terminates.
|
||||
only_propagate_inside_block_locals: BitSet<Local>,
|
||||
}
|
||||
|
||||
impl<'mir, 'tcx> ConstPropMachine<'mir, 'tcx> {
|
||||
fn new() -> Self {
|
||||
Self { stack: Vec::new() }
|
||||
fn new(only_propagate_inside_block_locals: BitSet<Local>) -> Self {
|
||||
Self {
|
||||
stack: Vec::new(),
|
||||
written_only_inside_own_block_locals: Default::default(),
|
||||
only_propagate_inside_block_locals,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -227,6 +236,18 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx>
|
|||
l.access()
|
||||
}
|
||||
|
||||
fn access_local_mut<'a>(
|
||||
ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
|
||||
frame: usize,
|
||||
local: Local,
|
||||
) -> InterpResult<'tcx, Result<&'a mut LocalValue<Self::PointerTag>, MemPlace<Self::PointerTag>>>
|
||||
{
|
||||
if frame == 0 && ecx.machine.only_propagate_inside_block_locals.contains(local) {
|
||||
ecx.machine.written_only_inside_own_block_locals.insert(local);
|
||||
}
|
||||
ecx.machine.stack[frame].locals[local].access_mut()
|
||||
}
|
||||
|
||||
fn before_access_global(
|
||||
_memory_extra: &(),
|
||||
_alloc_id: AllocId,
|
||||
|
@ -274,8 +295,6 @@ struct ConstPropagator<'mir, 'tcx> {
|
|||
// Because we have `MutVisitor` we can't obtain the `SourceInfo` from a `Location`. So we store
|
||||
// the last known `SourceInfo` here and just keep revisiting it.
|
||||
source_info: Option<SourceInfo>,
|
||||
// Locals we need to forget at the end of the current block
|
||||
locals_of_current_block: BitSet<Local>,
|
||||
}
|
||||
|
||||
impl<'mir, 'tcx> LayoutOf for ConstPropagator<'mir, 'tcx> {
|
||||
|
@ -313,8 +332,20 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
|
|||
let param_env = tcx.param_env(def_id).with_reveal_all();
|
||||
|
||||
let span = tcx.def_span(def_id);
|
||||
let mut ecx = InterpCx::new(tcx, span, param_env, ConstPropMachine::new(), ());
|
||||
let can_const_prop = CanConstProp::check(body);
|
||||
let mut only_propagate_inside_block_locals = BitSet::new_empty(can_const_prop.len());
|
||||
for (l, mode) in can_const_prop.iter_enumerated() {
|
||||
if *mode == ConstPropMode::OnlyInsideOwnBlock {
|
||||
only_propagate_inside_block_locals.insert(l);
|
||||
}
|
||||
}
|
||||
let mut ecx = InterpCx::new(
|
||||
tcx,
|
||||
span,
|
||||
param_env,
|
||||
ConstPropMachine::new(only_propagate_inside_block_locals),
|
||||
(),
|
||||
);
|
||||
|
||||
let ret = ecx
|
||||
.layout_of(body.return_ty().subst(tcx, substs))
|
||||
|
@ -345,7 +376,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
|
|||
//FIXME(wesleywiser) we can't steal this because `Visitor::super_visit_body()` needs it
|
||||
local_decls: body.local_decls.clone(),
|
||||
source_info: None,
|
||||
locals_of_current_block: BitSet::new_empty(body.local_decls.len()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -900,7 +930,6 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> {
|
|||
Will remove it from const-prop after block is finished. Local: {:?}",
|
||||
place.local
|
||||
);
|
||||
self.locals_of_current_block.insert(place.local);
|
||||
}
|
||||
ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => {
|
||||
trace!("can't propagate into {:?}", place);
|
||||
|
@ -1089,10 +1118,27 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> {
|
|||
}
|
||||
}
|
||||
}
|
||||
// We remove all Locals which are restricted in propagation to their containing blocks.
|
||||
for local in self.locals_of_current_block.iter() {
|
||||
|
||||
// We remove all Locals which are restricted in propagation to their containing blocks and
|
||||
// which were modified in the current block.
|
||||
// Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`
|
||||
let mut locals = std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals);
|
||||
for &local in locals.iter() {
|
||||
Self::remove_const(&mut self.ecx, local);
|
||||
}
|
||||
self.locals_of_current_block.clear();
|
||||
locals.clear();
|
||||
// Put it back so we reuse the heap of the storage
|
||||
self.ecx.machine.written_only_inside_own_block_locals = locals;
|
||||
if cfg!(debug_assertions) {
|
||||
// Ensure we are correctly erasing locals with the non-debug-assert logic.
|
||||
for local in self.ecx.machine.only_propagate_inside_block_locals.iter() {
|
||||
assert!(
|
||||
self.get_const(local.into()).is_none()
|
||||
|| self
|
||||
.layout_of(self.local_decls[local].ty)
|
||||
.map_or(true, |layout| layout.is_zst())
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue