Rollup merge of #73757 - oli-obk:const_prop_hardening, r=wesleywiser
Const prop: erase all block-only locals at the end of every block I messed up this erasure in https://github.com/rust-lang/rust/pull/73656#discussion_r446040140. I think it is too fragile to have the previous scheme. Let's benchmark the new scheme and see what happens. r? @wesleywiser cc @felix91gr
This commit is contained in:
commit
ccc1bf79c8
5 changed files with 90 additions and 16 deletions
|
@ -132,6 +132,10 @@ pub enum LocalValue<Tag = ()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {
|
impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {
|
||||||
|
/// Read the local's value or error if the local is not yet live or not live anymore.
|
||||||
|
///
|
||||||
|
/// Note: This may only be invoked from the `Machine::access_local` hook and not from
|
||||||
|
/// anywhere else. You may be invalidating machine invariants if you do!
|
||||||
pub fn access(&self) -> InterpResult<'tcx, Operand<Tag>> {
|
pub fn access(&self) -> InterpResult<'tcx, Operand<Tag>> {
|
||||||
match self.value {
|
match self.value {
|
||||||
LocalValue::Dead => throw_ub!(DeadLocal),
|
LocalValue::Dead => throw_ub!(DeadLocal),
|
||||||
|
@ -144,6 +148,9 @@ impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {
|
||||||
|
|
||||||
/// Overwrite the local. If the local can be overwritten in place, return a reference
|
/// Overwrite the local. If the local can be overwritten in place, return a reference
|
||||||
/// to do so; otherwise return the `MemPlace` to consult instead.
|
/// to do so; otherwise return the `MemPlace` to consult instead.
|
||||||
|
///
|
||||||
|
/// Note: This may only be invoked from the `Machine::access_local_mut` hook and not from
|
||||||
|
/// anywhere else. You may be invalidating machine invariants if you do!
|
||||||
pub fn access_mut(
|
pub fn access_mut(
|
||||||
&mut self,
|
&mut self,
|
||||||
) -> InterpResult<'tcx, Result<&mut LocalValue<Tag>, MemPlace<Tag>>> {
|
) -> InterpResult<'tcx, Result<&mut LocalValue<Tag>, MemPlace<Tag>>> {
|
||||||
|
|
|
@ -11,7 +11,7 @@ use rustc_span::def_id::DefId;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
AllocId, Allocation, AllocationExtra, CheckInAllocMsg, Frame, ImmTy, InterpCx, InterpResult,
|
AllocId, Allocation, AllocationExtra, CheckInAllocMsg, Frame, ImmTy, InterpCx, InterpResult,
|
||||||
Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar,
|
LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Data returned by Machine::stack_pop,
|
/// Data returned by Machine::stack_pop,
|
||||||
|
@ -192,6 +192,8 @@ pub trait Machine<'mir, 'tcx>: Sized {
|
||||||
) -> InterpResult<'tcx>;
|
) -> InterpResult<'tcx>;
|
||||||
|
|
||||||
/// Called to read the specified `local` from the `frame`.
|
/// Called to read the specified `local` from the `frame`.
|
||||||
|
/// Since reading a ZST is not actually accessing memory or locals, this is never invoked
|
||||||
|
/// for ZST reads.
|
||||||
#[inline]
|
#[inline]
|
||||||
fn access_local(
|
fn access_local(
|
||||||
_ecx: &InterpCx<'mir, 'tcx, Self>,
|
_ecx: &InterpCx<'mir, 'tcx, Self>,
|
||||||
|
@ -201,6 +203,21 @@ pub trait Machine<'mir, 'tcx>: Sized {
|
||||||
frame.locals[local].access()
|
frame.locals[local].access()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Called to write the specified `local` from the `frame`.
|
||||||
|
/// Since writing a ZST is not actually accessing memory or locals, this is never invoked
|
||||||
|
/// for ZST reads.
|
||||||
|
#[inline]
|
||||||
|
fn access_local_mut<'a>(
|
||||||
|
ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
|
||||||
|
frame: usize,
|
||||||
|
local: mir::Local,
|
||||||
|
) -> InterpResult<'tcx, Result<&'a mut LocalValue<Self::PointerTag>, MemPlace<Self::PointerTag>>>
|
||||||
|
where
|
||||||
|
'tcx: 'mir,
|
||||||
|
{
|
||||||
|
ecx.stack_mut()[frame].locals[local].access_mut()
|
||||||
|
}
|
||||||
|
|
||||||
/// Called before a basic block terminator is executed.
|
/// Called before a basic block terminator is executed.
|
||||||
/// You can use this to detect endlessly running programs.
|
/// You can use this to detect endlessly running programs.
|
||||||
#[inline]
|
#[inline]
|
||||||
|
|
|
@ -432,7 +432,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This is used by [priroda](https://github.com/oli-obk/priroda) to get an OpTy from a local
|
/// Read from a local. Will not actually access the local if reading from a ZST.
|
||||||
|
/// Will not access memory, instead an indirect `Operand` is returned.
|
||||||
|
///
|
||||||
|
/// This is public because it is used by [priroda](https://github.com/oli-obk/priroda) to get an
|
||||||
|
/// OpTy from a local
|
||||||
pub fn access_local(
|
pub fn access_local(
|
||||||
&self,
|
&self,
|
||||||
frame: &super::Frame<'mir, 'tcx, M::PointerTag, M::FrameExtra>,
|
frame: &super::Frame<'mir, 'tcx, M::PointerTag, M::FrameExtra>,
|
||||||
|
|
|
@ -741,7 +741,7 @@ where
|
||||||
// but not factored as a separate function.
|
// but not factored as a separate function.
|
||||||
let mplace = match dest.place {
|
let mplace = match dest.place {
|
||||||
Place::Local { frame, local } => {
|
Place::Local { frame, local } => {
|
||||||
match self.stack_mut()[frame].locals[local].access_mut()? {
|
match M::access_local_mut(self, frame, local)? {
|
||||||
Ok(local) => {
|
Ok(local) => {
|
||||||
// Local can be updated in-place.
|
// Local can be updated in-place.
|
||||||
*local = LocalValue::Live(Operand::Immediate(src));
|
*local = LocalValue::Live(Operand::Immediate(src));
|
||||||
|
@ -974,7 +974,7 @@ where
|
||||||
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, Option<Size>)> {
|
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, Option<Size>)> {
|
||||||
let (mplace, size) = match place.place {
|
let (mplace, size) = match place.place {
|
||||||
Place::Local { frame, local } => {
|
Place::Local { frame, local } => {
|
||||||
match self.stack_mut()[frame].locals[local].access_mut()? {
|
match M::access_local_mut(self, frame, local)? {
|
||||||
Ok(&mut local_val) => {
|
Ok(&mut local_val) => {
|
||||||
// We need to make an allocation.
|
// We need to make an allocation.
|
||||||
|
|
||||||
|
@ -998,7 +998,7 @@ where
|
||||||
}
|
}
|
||||||
// Now we can call `access_mut` again, asserting it goes well,
|
// Now we can call `access_mut` again, asserting it goes well,
|
||||||
// and actually overwrite things.
|
// and actually overwrite things.
|
||||||
*self.stack_mut()[frame].locals[local].access_mut().unwrap().unwrap() =
|
*M::access_local_mut(self, frame, local).unwrap().unwrap() =
|
||||||
LocalValue::Live(Operand::Indirect(mplace));
|
LocalValue::Live(Operand::Indirect(mplace));
|
||||||
(mplace, Some(size))
|
(mplace, Some(size))
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
use std::cell::Cell;
|
use std::cell::Cell;
|
||||||
|
|
||||||
use rustc_ast::ast::Mutability;
|
use rustc_ast::ast::Mutability;
|
||||||
|
use rustc_data_structures::fx::FxHashSet;
|
||||||
use rustc_hir::def::DefKind;
|
use rustc_hir::def::DefKind;
|
||||||
use rustc_hir::HirId;
|
use rustc_hir::HirId;
|
||||||
use rustc_index::bit_set::BitSet;
|
use rustc_index::bit_set::BitSet;
|
||||||
|
@ -28,7 +29,7 @@ use rustc_trait_selection::traits;
|
||||||
use crate::const_eval::error_to_const_error;
|
use crate::const_eval::error_to_const_error;
|
||||||
use crate::interpret::{
|
use crate::interpret::{
|
||||||
self, compile_time_machine, AllocId, Allocation, Frame, ImmTy, Immediate, InterpCx, LocalState,
|
self, compile_time_machine, AllocId, Allocation, Frame, ImmTy, Immediate, InterpCx, LocalState,
|
||||||
LocalValue, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer,
|
LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer,
|
||||||
ScalarMaybeUninit, StackPopCleanup,
|
ScalarMaybeUninit, StackPopCleanup,
|
||||||
};
|
};
|
||||||
use crate::transform::{MirPass, MirSource};
|
use crate::transform::{MirPass, MirSource};
|
||||||
|
@ -151,11 +152,19 @@ impl<'tcx> MirPass<'tcx> for ConstProp {
|
||||||
struct ConstPropMachine<'mir, 'tcx> {
|
struct ConstPropMachine<'mir, 'tcx> {
|
||||||
/// The virtual call stack.
|
/// The virtual call stack.
|
||||||
stack: Vec<Frame<'mir, 'tcx, (), ()>>,
|
stack: Vec<Frame<'mir, 'tcx, (), ()>>,
|
||||||
|
/// `OnlyInsideOwnBlock` locals that were written in the current block get erased at the end.
|
||||||
|
written_only_inside_own_block_locals: FxHashSet<Local>,
|
||||||
|
/// Locals that need to be cleared after every block terminates.
|
||||||
|
only_propagate_inside_block_locals: BitSet<Local>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'mir, 'tcx> ConstPropMachine<'mir, 'tcx> {
|
impl<'mir, 'tcx> ConstPropMachine<'mir, 'tcx> {
|
||||||
fn new() -> Self {
|
fn new(only_propagate_inside_block_locals: BitSet<Local>) -> Self {
|
||||||
Self { stack: Vec::new() }
|
Self {
|
||||||
|
stack: Vec::new(),
|
||||||
|
written_only_inside_own_block_locals: Default::default(),
|
||||||
|
only_propagate_inside_block_locals,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -227,6 +236,18 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx>
|
||||||
l.access()
|
l.access()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn access_local_mut<'a>(
|
||||||
|
ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
|
||||||
|
frame: usize,
|
||||||
|
local: Local,
|
||||||
|
) -> InterpResult<'tcx, Result<&'a mut LocalValue<Self::PointerTag>, MemPlace<Self::PointerTag>>>
|
||||||
|
{
|
||||||
|
if frame == 0 && ecx.machine.only_propagate_inside_block_locals.contains(local) {
|
||||||
|
ecx.machine.written_only_inside_own_block_locals.insert(local);
|
||||||
|
}
|
||||||
|
ecx.machine.stack[frame].locals[local].access_mut()
|
||||||
|
}
|
||||||
|
|
||||||
fn before_access_global(
|
fn before_access_global(
|
||||||
_memory_extra: &(),
|
_memory_extra: &(),
|
||||||
_alloc_id: AllocId,
|
_alloc_id: AllocId,
|
||||||
|
@ -274,8 +295,6 @@ struct ConstPropagator<'mir, 'tcx> {
|
||||||
// Because we have `MutVisitor` we can't obtain the `SourceInfo` from a `Location`. So we store
|
// Because we have `MutVisitor` we can't obtain the `SourceInfo` from a `Location`. So we store
|
||||||
// the last known `SourceInfo` here and just keep revisiting it.
|
// the last known `SourceInfo` here and just keep revisiting it.
|
||||||
source_info: Option<SourceInfo>,
|
source_info: Option<SourceInfo>,
|
||||||
// Locals we need to forget at the end of the current block
|
|
||||||
locals_of_current_block: BitSet<Local>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'mir, 'tcx> LayoutOf for ConstPropagator<'mir, 'tcx> {
|
impl<'mir, 'tcx> LayoutOf for ConstPropagator<'mir, 'tcx> {
|
||||||
|
@ -313,8 +332,20 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
|
||||||
let param_env = tcx.param_env(def_id).with_reveal_all();
|
let param_env = tcx.param_env(def_id).with_reveal_all();
|
||||||
|
|
||||||
let span = tcx.def_span(def_id);
|
let span = tcx.def_span(def_id);
|
||||||
let mut ecx = InterpCx::new(tcx, span, param_env, ConstPropMachine::new(), ());
|
|
||||||
let can_const_prop = CanConstProp::check(body);
|
let can_const_prop = CanConstProp::check(body);
|
||||||
|
let mut only_propagate_inside_block_locals = BitSet::new_empty(can_const_prop.len());
|
||||||
|
for (l, mode) in can_const_prop.iter_enumerated() {
|
||||||
|
if *mode == ConstPropMode::OnlyInsideOwnBlock {
|
||||||
|
only_propagate_inside_block_locals.insert(l);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut ecx = InterpCx::new(
|
||||||
|
tcx,
|
||||||
|
span,
|
||||||
|
param_env,
|
||||||
|
ConstPropMachine::new(only_propagate_inside_block_locals),
|
||||||
|
(),
|
||||||
|
);
|
||||||
|
|
||||||
let ret = ecx
|
let ret = ecx
|
||||||
.layout_of(body.return_ty().subst(tcx, substs))
|
.layout_of(body.return_ty().subst(tcx, substs))
|
||||||
|
@ -345,7 +376,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
|
||||||
//FIXME(wesleywiser) we can't steal this because `Visitor::super_visit_body()` needs it
|
//FIXME(wesleywiser) we can't steal this because `Visitor::super_visit_body()` needs it
|
||||||
local_decls: body.local_decls.clone(),
|
local_decls: body.local_decls.clone(),
|
||||||
source_info: None,
|
source_info: None,
|
||||||
locals_of_current_block: BitSet::new_empty(body.local_decls.len()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -900,7 +930,6 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> {
|
||||||
Will remove it from const-prop after block is finished. Local: {:?}",
|
Will remove it from const-prop after block is finished. Local: {:?}",
|
||||||
place.local
|
place.local
|
||||||
);
|
);
|
||||||
self.locals_of_current_block.insert(place.local);
|
|
||||||
}
|
}
|
||||||
ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => {
|
ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => {
|
||||||
trace!("can't propagate into {:?}", place);
|
trace!("can't propagate into {:?}", place);
|
||||||
|
@ -1089,10 +1118,27 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// We remove all Locals which are restricted in propagation to their containing blocks.
|
|
||||||
for local in self.locals_of_current_block.iter() {
|
// We remove all Locals which are restricted in propagation to their containing blocks and
|
||||||
|
// which were modified in the current block.
|
||||||
|
// Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`
|
||||||
|
let mut locals = std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals);
|
||||||
|
for &local in locals.iter() {
|
||||||
Self::remove_const(&mut self.ecx, local);
|
Self::remove_const(&mut self.ecx, local);
|
||||||
}
|
}
|
||||||
self.locals_of_current_block.clear();
|
locals.clear();
|
||||||
|
// Put it back so we reuse the heap of the storage
|
||||||
|
self.ecx.machine.written_only_inside_own_block_locals = locals;
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
// Ensure we are correctly erasing locals with the non-debug-assert logic.
|
||||||
|
for local in self.ecx.machine.only_propagate_inside_block_locals.iter() {
|
||||||
|
assert!(
|
||||||
|
self.get_const(local.into()).is_none()
|
||||||
|
|| self
|
||||||
|
.layout_of(self.local_decls[local].ty)
|
||||||
|
.map_or(true, |layout| layout.is_zst())
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue