Upgrade/fix/simplify store to load forwarding

- fix store to load forwarding for a certain set of cases (where
  forwarding shouldn't have happened); use AffineValueMap difference
  based MemRefAccess equality checking; utility logic is also greatly
  simplified

- add missing equality/inequality operators for AffineExpr ==/!= ints

- add == != operators on MemRefAccess

Closes tensorflow/mlir#136

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/136 from bondhugula:store-load-forwarding d79fd1add8bcfbd9fa71d841a6a9905340dcd792
PiperOrigin-RevId: 270457011
This commit is contained in:
Uday Bondhugula 2019-09-21 10:08:32 -07:00 committed by A. Unique TensorFlower
parent 4d880d09e0
commit f559c38c28
8 changed files with 119 additions and 130 deletions

View file

@ -71,8 +71,17 @@ struct MemRefAccess {
bool isStore() const;
/// Populates 'accessMap' with composition of AffineApplyOps reachable from
// 'indices'.
/// 'indices'.
void getAccessMap(AffineValueMap *accessMap) const;
/// Equal if both affine accesses can be proved to be equivalent at compile
/// time (considering the memrefs, their respective affine access maps and
/// operands). The equality of access functions + operands is checked by
/// subtracting fully composed value maps, and then simplifying the difference
/// using the expression flattener.
/// TODO: this does not account for aliasing of memrefs.
bool operator==(const MemRefAccess &rhs) const;
bool operator!=(const MemRefAccess &rhs) const { return !(*this == rhs); }
};
// DependenceComponent contains state about the direction of a dependence as an

View file

@ -532,12 +532,6 @@ public:
/// 'num' identifiers starting at position 'pos'.
void constantFoldIdRange(unsigned pos, unsigned num);
/// Returns true if all the identifiers in the specified range [start, limit)
/// can only take a single value each if the remaining identifiers are treated
/// as symbols/parameters, i.e., for given values of the latter, there only
/// exists a unique value for each of the dimensions in the specified range.
bool isRangeOneToOne(unsigned start, unsigned limit) const;
/// Updates the constraints to be the smallest bounding (enclosing) box that
/// contains the points of 'this' set and that of 'other', with the symbols
/// being treated specially. For each of the dimensions, the min of the lower

View file

@ -88,6 +88,8 @@ public:
bool operator==(AffineExpr other) const { return expr == other.expr; }
bool operator!=(AffineExpr other) const { return !(*this == other); }
bool operator==(int64_t v) const;
bool operator!=(int64_t v) const { return !(*this == v); }
explicit operator bool() const { return expr; }
bool operator!() const { return expr == nullptr; }

View file

@ -2717,51 +2717,6 @@ void FlatAffineConstraints::projectOut(Value *id) {
FourierMotzkinEliminate(pos);
}
bool FlatAffineConstraints::isRangeOneToOne(unsigned start,
unsigned limit) const {
assert(start <= getNumIds() - 1 && "invalid start position");
assert(limit > start && limit <= getNumIds() && "invalid limit");
FlatAffineConstraints tmpCst(*this);
if (start != 0) {
// Move [start, limit) to the left.
for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) {
for (unsigned c = 0, f = getNumCols(); c < f; ++c) {
if (c >= start && c < limit)
tmpCst.atIneq(r, c - start) = atIneq(r, c);
else if (c < start)
tmpCst.atIneq(r, c + limit - start) = atIneq(r, c);
else
tmpCst.atIneq(r, c) = atIneq(r, c);
}
}
for (unsigned r = 0, e = getNumEqualities(); r < e; ++r) {
for (unsigned c = 0, f = getNumCols(); c < f; ++c) {
if (c >= start && c < limit)
tmpCst.atEq(r, c - start) = atEq(r, c);
else if (c < start)
tmpCst.atEq(r, c + limit - start) = atEq(r, c);
else
tmpCst.atEq(r, c) = atEq(r, c);
}
}
}
// Mark everything to the right as symbols so that we can check the extents in
// a symbolic way below.
tmpCst.setDimSymbolSeparation(getNumIds() - (limit - start));
// Check if the extents of all the specified dimensions are just one (when
// treating the rest as symbols).
for (unsigned pos = 0, e = tmpCst.getNumDimIds(); pos < e; ++pos) {
auto extent = tmpCst.getConstantBoundOnDimSize(pos);
if (!extent.hasValue() || extent.getValue() != 1)
return false;
}
return true;
}
void FlatAffineConstraints::clearConstraints() {
equalities.clear();
inequalities.clear();

View file

@ -23,11 +23,8 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Builders.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
@ -881,6 +878,24 @@ unsigned mlir::getNestingDepth(Operation &op) {
return depth;
}
/// Equal if both affine accesses are provably equivalent (at compile
/// time) when considering the memref, the affine maps and their respective
/// operands. The equality of access functions + operands is checked by
/// subtracting fully composed value maps, and then simplifying the difference
/// using the expression flattener.
/// TODO: this does not account for aliasing of memrefs.
bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
if (memref != rhs.memref)
return false;
AffineValueMap diff, thisMap, rhsMap;
getAccessMap(&thisMap);
rhs.getAccessMap(&rhsMap);
AffineValueMap::difference(thisMap, rhsMap, &diff);
return llvm::all_of(diff.getAffineMap().getResults(),
[](AffineExpr e) { return e == 0; });
}
/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
/// where each lists loops from outer-most to inner-most in loop nest.
unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) {

View file

@ -279,6 +279,10 @@ int64_t AffineConstantExpr::getValue() const {
return static_cast<ImplType *>(expr)->constant;
}
bool AffineExpr::operator==(int64_t v) const {
return *this == getAffineConstantExpr(v, getContext());
}
AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
auto assignCtx = [context](AffineConstantExprStorage *storage) {
storage->context = context;

View file

@ -40,19 +40,19 @@ namespace {
// The store to load forwarding relies on three conditions:
//
// 1) there has to be a dependence from the store to the load satisfied at the
// block* immediately within the innermost loop enclosing both the load op and
// the store op,
// 1) they need to have mathematically equivalent affine access functions
// (checked after full composition of load/store operands); this implies that
// they access the same single memref element for all iterations of the common
// surrounding loop,
//
// 2) the store op should dominate the load op,
//
// 3) among all candidate store op's that satisfy (1) and (2), if there exists a
// store op that postdominates all those that satisfy (1), such a store op is
// provably the last writer to the particular memref location being loaded from
// by the load op, and its store value can be forwarded to the load.
//
// 4) the load should touch a single location in the memref for a given
// iteration of the innermost loop enclosing both the store op and the load op.
// 3) among all op's that satisfy both (1) and (2), the one that postdominates
// all store op's that have a dependence into the load, is provably the last
// writer to the particular memref location being loaded at the load op, and its
// store value can be forwarded to the load. Note that the only dependences
// that are to be considered are those that are satisifed at the block* of the
// innermost common surrounding loop of the <store, load> being considered.
//
// (* A dependence being satisfied at a block: a dependence that is satisfied by
// virtue of the destination operation appearing textually / lexically after
@ -60,9 +60,9 @@ namespace {
// dependence is always either satisfied by a loop or by a block).
//
// The above conditions are simple to check, sufficient, and powerful for most
// cases in practice - condition (1) and (3) are precise and necessary, while
// condition (2) is a sufficient one but not necessary (since it doesn't reason
// about loops that are guaranteed to execute at least once).
// cases in practice - they are sufficient, but not necessary --- since they
// don't reason about loops that are guaranteed to execute at least once or
// multiple sources to forward from.
//
// TODO(mlir-team): more forwarding can be done when support for
// loop/conditional live-out SSA values is available.
@ -78,7 +78,7 @@ struct MemRefDataFlowOpt : public FunctionPass<MemRefDataFlowOpt> {
// A list of memref's that are potentially dead / could be eliminated.
SmallPtrSet<Value *, 4> memrefsToErase;
// Load op's whose results were replaced by those forwarded from stores.
std::vector<Operation *> loadOpsToErase;
SmallVector<Operation *, 8> loadOpsToErase;
DominanceInfo *domInfo = nullptr;
PostDominanceInfo *postDomInfo = nullptr;
@ -93,9 +93,8 @@ std::unique_ptr<OpPassBase<FuncOp>> mlir::createMemRefDataFlowOptPass() {
}
// This is a straightforward implementation not optimized for speed. Optimize
// this in the future if needed.
// if needed.
void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
Operation *lastWriteStoreOp = nullptr;
Operation *loadOpInst = loadOp.getOperation();
// First pass over the use list to get minimum number of surrounding
@ -113,81 +112,63 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
storeOps.push_back(storeOpInst);
}
unsigned loadOpDepth = getNestingDepth(*loadOpInst);
// 1. Check if there is a dependence satisfied at depth equal to the depth
// of the loop body of the innermost common surrounding loop of the storeOp
// and loadOp.
// The list of store op candidates for forwarding - need to satisfy the
// conditions listed at the top.
// The list of store op candidates for forwarding that satisfy conditions
// (1) and (2) above - they will be filtered later when checking (3).
SmallVector<Operation *, 8> fwdingCandidates;
// Store ops that have a dependence into the load (even if they aren't
// forwarding candidates). Each forwarding candidate will be checked for a
// post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores.
SmallVector<Operation *, 8> depSrcStores;
for (auto *storeOpInst : storeOps) {
MemRefAccess srcAccess(storeOpInst);
MemRefAccess destAccess(loadOpInst);
// Find stores that may be reaching the load.
FlatAffineConstraints dependenceConstraints;
unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst);
unsigned d;
// Dependences at loop depth <= minSurroundingLoops do NOT matter.
for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) {
for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
DependenceResult result = checkMemrefAccessDependence(
srcAccess, destAccess, d, &dependenceConstraints,
/*dependenceComponents=*/nullptr);
if (!hasDependence(result))
continue;
depSrcStores.push_back(storeOpInst);
// Check if this store is a candidate for forwarding; we only forward if
// the dependence from the store is carried by the *body* of innermost
// common surrounding loop. As an example this filters out cases like:
// affine.for %i0
// affine.for %i1
// %idx = affine.apply (d0) -> (d0 + 1) (%i0)
// store %A[%idx]
// load %A[%i0]
//
if (d != nsLoops + 1)
if (hasDependence(result))
break;
// 2. The store has to dominate the load op to be candidate. This is not
// strictly a necessary condition since dominance isn't a prerequisite for
// a memref element store to reach a load, but this is sufficient and
// reasonably powerful in practice.
if (!domInfo->dominates(storeOpInst, loadOpInst))
break;
// Finally, forwarding is only possible if the load touches a single
// location in the memref across the enclosing loops *not* common with the
// store. This is filtering out cases like:
// for (i ...)
// a [i] = ...
// for (j ...)
// ... = a[j]
// If storeOpInst and loadOpDepth at the same nesting depth, the load Op
// is trivially loading from a single location at that depth; so there
// isn't a need to call isRangeOneToOne.
if (getNestingDepth(*storeOpInst) < loadOpDepth) {
MemRefRegion region(loadOpInst->getLoc());
region.compute(loadOpInst, nsLoops);
if (!region.getConstraints()->isRangeOneToOne(
/*start=*/0, /*limit=*/loadOp.getMemRefType().getRank()))
break;
}
// After all these conditions, we have a candidate for forwarding!
fwdingCandidates.push_back(storeOpInst);
break;
}
if (d == minSurroundingLoops)
continue;
// Stores that *may* be reaching the load.
depSrcStores.push_back(storeOpInst);
// 1. Check if the store and the load have mathematically equivalent
// affine access functions; this implies that they statically refer to the
// same single memref element. As an example this filters out cases like:
// store %A[%i0 + 1]
// load %A[%i0]
// store %A[%M]
// load %A[%N]
// Use the AffineValueMap difference based memref access equality checking.
if (srcAccess != destAccess)
continue;
// 2. The store has to dominate the load op to be candidate.
if (!domInfo->dominates(storeOpInst, loadOpInst))
continue;
// We now have a candidate for forwarding.
fwdingCandidates.push_back(storeOpInst);
}
// Note: this can implemented in a cleaner way with postdominator tree
// 3. Of all the store op's that meet the above criteria, the store that
// postdominates all 'depSrcStores' (if one exists) is the unique store
// providing the value to the load, i.e., provably the last writer to that
// memref loc.
// Note: this can be implemented in a cleaner way with postdominator tree
// traversals. Consider this for the future if needed.
Operation *lastWriteStoreOp = nullptr;
for (auto *storeOpInst : fwdingCandidates) {
// 3. Of all the store op's that meet the above criteria, the store
// that postdominates all 'depSrcStores' (if such a store exists) is the
// unique store providing the value to the load, i.e., provably the last
// writer to that memref loc.
if (llvm::all_of(depSrcStores, [&](Operation *depStore) {
return postDomInfo->postDominates(storeOpInst, depStore);
})) {
@ -195,10 +176,6 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
break;
}
}
// TODO: optimization for future: those store op's that are determined to be
// postdominated above can actually be recorded and skipped on the 'i' loop
// iteration above --- since they can never post dominate everything.
if (!lastWriteStoreOp)
return;

View file

@ -247,3 +247,36 @@ func @store_load_store_nested_fwd(%N : index) -> f32 {
// CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: return %{{.*}} : f32
}
// CHECK-LABEL: func @should_not_fwd
func @should_not_fwd(%A: memref<100xf32>, %M : index, %N : index) -> f32 {
%cf = constant 0.0 : f32
affine.store %cf, %A[%M] : memref<100xf32>
// CHECK: affine.load %{{.*}}[%{{.*}}]
%v = affine.load %A[%N] : memref<100xf32>
return %v : f32
}
// Can store forward to A[%j, %i], but no forwarding to load on %A[%i, %j]
// CHECK-LABEL: func @refs_not_known_to_be_equal
func @refs_not_known_to_be_equal(%A : memref<100 x 100 x f32>, %M : index) {
%N = affine.apply (d0) -> (d0 + 1) (%M)
%cf1 = constant 1.0 : f32
affine.for %i = 0 to 100 {
// CHECK: affine.for %[[I:.*]] =
affine.for %j = 0 to 100 {
// CHECK: affine.for %[[J:.*]] =
// CHECK: affine.load %{{.*}}[%[[I]], %[[J]]]
%u = affine.load %A[%i, %j] : memref<100x100xf32>
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[J]], %[[I]]]
affine.store %cf1, %A[%j, %i] : memref<100x100xf32>
// CHECK-NEXT: affine.load %{{.*}}[%[[I]], %[[J]]]
%v = affine.load %A[%i, %j] : memref<100x100xf32>
// This load should disappear.
%w = affine.load %A[%j, %i] : memref<100x100xf32>
// CHECK-NEXT: "foo"
"foo" (%u, %v, %w) : (f32, f32, f32) -> ()
}
}
return
}