Support symbolic operands for memref replacement; fix memrefNormalize

- allow symbols in index remapping provided for memref replacement
- fix memref normalize crash on cases with layout maps with symbols

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>
Reported by: Alex Zinenko

Closes tensorflow/mlir#139

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/139 from bondhugula:memref-rep-symbols 2f48c1fdb5d4c58915bbddbd9f07b18541819233
PiperOrigin-RevId: 269851182
This commit is contained in:
Uday Bondhugula 2019-09-18 11:25:33 -07:00 committed by A. Unique TensorFlower
parent 1c73be76d8
commit 727a50ae2d
7 changed files with 55 additions and 16 deletions

View file

@ -162,6 +162,12 @@ def AllocOp : Std_Op<"alloc"> {
unsigned getNumSymbolicOperands() {
return getNumOperands() - getType().getNumDynamicDims();
}
/// Returns the symbolic operands (the ones in square brackets), which bind
/// to the symbols of the memref's layout map.
operand_range getSymbolicOperands() {
return {operand_begin() + getType().getNumDynamicDims(), operand_end()};
}
}];
let hasCanonicalizer = 1;

View file

@ -40,15 +40,15 @@ class OpBuilder;
/// Replaces all "dereferencing" uses of `oldMemRef` with `newMemRef` while
/// optionally remapping the old memref's indices using the supplied affine map,
/// `indexRemap`. The new memref could be of a different shape or rank.
/// `extraIndices` provides additional access indices to be added to the start.
/// `extraIndices` provides any additional access indices to be added to the
/// start.
///
/// `indexRemap` remaps indices of the old memref access to a new set of indices
/// that are used to index the memref. Additional input operands to indexRemap
/// can be optionally provided, and they are added at the start of its input
/// list. `indexRemap` is expected to have only dimensional inputs, and the
/// number of its inputs equal to extraOperands.size() plus rank of the memref.
/// 'extraOperands' is an optional argument that corresponds to additional
/// operands (inputs) for indexRemap at the beginning of its input list.
/// can be optionally provided in `extraOperands`, and they occupy the start
/// of its input list. `indexRemap`'s dimensional inputs are expected to
/// correspond to memref's indices, and its symbolic inputs if any should be
/// provided in `symbolOperands`.
///
/// `domInstFilter`, if non-null, restricts the replacement to only those
/// operations that are dominated by the former; similarly, `postDomInstFilter`
@ -70,6 +70,7 @@ LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices = {},
AffineMap indexRemap = AffineMap(),
ArrayRef<Value *> extraOperands = {},
ArrayRef<Value *> symbolOperands = {},
Operation *domInstFilter = nullptr,
Operation *postDomInstFilter = nullptr);
@ -79,7 +80,8 @@ LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
Operation *op,
ArrayRef<Value *> extraIndices = {},
AffineMap indexRemap = AffineMap(),
ArrayRef<Value *> extraOperands = {});
ArrayRef<Value *> extraOperands = {},
ArrayRef<Value *> symbolOperands = {});
/// Rewrites the memref defined by this alloc op to have an identity layout map
/// and updates all its indexing uses. Returns failure if any of its uses

View file

@ -955,6 +955,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
LogicalResult res =
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
/*extraOperands=*/outerIVs,
/*symbolOperands=*/{},
/*domInstFilter=*/&*forOp.getBody()->begin());
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");

View file

@ -122,6 +122,7 @@ static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) {
/*extraIndices=*/{ivModTwoOp},
/*indexRemap=*/AffineMap(),
/*extraOperands=*/{},
/*symbolOperands=*/{},
/*domInstFilter=*/&*forOp.getBody()->begin()))) {
LLVM_DEBUG(
forOp.emitError("memref replacement for double buffering failed"));

View file

@ -1548,6 +1548,7 @@ static LogicalResult generateCopy(
replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/regionSymbols,
/*symbolOperands=*/{},
/*domInstFilter=*/&*begin,
/*postDomInstFilter=*/&*postDomFilter);

View file

@ -62,14 +62,17 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
Operation *op,
ArrayRef<Value *> extraIndices,
AffineMap indexRemap,
ArrayRef<Value *> extraOperands) {
ArrayRef<Value *> extraOperands,
ArrayRef<Value *> symbolOperands) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
(void)oldMemRefRank;
(void)oldMemRefRank; // unused in opt mode
if (indexRemap) {
assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
"symbolic operand count mistmatch");
assert(indexRemap.getNumInputs() ==
extraOperands.size() + oldMemRefRank + symbolOperands.size());
assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
} else {
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
@ -131,9 +134,11 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
// provided. The indices of a memref come right after it, i.e.,
// at position memRefOperandPos + 1.
SmallVector<Value *, 4> remapOperands;
remapOperands.reserve(extraOperands.size() + oldMemRefRank);
remapOperands.reserve(extraOperands.size() + oldMemRefRank +
symbolOperands.size());
remapOperands.append(extraOperands.begin(), extraOperands.end());
remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
remapOperands.append(symbolOperands.begin(), symbolOperands.end());
SmallVector<Value *, 4> remapOutputs;
remapOutputs.reserve(oldMemRefRank);
@ -226,6 +231,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices,
AffineMap indexRemap,
ArrayRef<Value *> extraOperands,
ArrayRef<Value *> symbolOperands,
Operation *domInstFilter,
Operation *postDomInstFilter) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
@ -233,8 +239,10 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
(void)oldMemRefRank;
if (indexRemap) {
assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
"symbol operand count mismatch");
assert(indexRemap.getNumInputs() ==
extraOperands.size() + oldMemRefRank + symbolOperands.size());
assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
} else {
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
@ -287,7 +295,8 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
for (auto *op : opsToReplace) {
if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
indexRemap, extraOperands)))
indexRemap, extraOperands,
symbolOperands)))
llvm_unreachable("memref replacement guaranteed to succeed here");
}
@ -446,6 +455,8 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
}
auto *oldMemRef = allocOp.getResult();
SmallVector<Value *, 4> symbolOperands(allocOp.getSymbolicOperands());
auto newMemRefType = b.getMemRefType(newShape, memrefType.getElementType(),
b.getMultiDimIdentityMap(newRank));
auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType);
@ -453,7 +464,9 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
// Replace all uses of the old memref.
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
/*extraIndices=*/{},
/*indexRemap=*/layoutMap))) {
/*indexRemap=*/layoutMap,
/*extraOperands=*/{},
/*symbolOperands=*/symbolOperands))) {
// If it failed (due to escapes for example), bail out.
newAlloc.erase();
return failure();

View file

@ -96,6 +96,21 @@ func @strided_cumulative() {
return
}
// Symbolic operand for alloc, although unused. Tests replaceAllMemRefUsesWith
// when the index remap has symbols.
// CHECK-LABEL: func @symbolic_operands
func @symbolic_operands(%s : index) {
// CHECK: alloc() : memref<100xf32>
%A = alloc()[%s] : memref<10x10xf32, (d0,d1)[s0] -> (10*d0 + d1)>
affine.for %i = 0 to 10 {
affine.for %j = 0 to 10 {
// CHECK: affine.load %{{.*}}[%{{.*}} * 10 + %{{.*}}] : memref<100xf32>
affine.load %A[%i, %j] : memref<10x10xf32, (d0,d1)[s0] -> (10*d0 + d1)>
}
}
return
}
// Memref escapes; no normalization.
// CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}>
func @escaping() -> memref<64xf32, (d0) -> (d0 + 2)> {