[MLIR] Support for ReturnOps in memref map layout normalization
-- This commit handles the returnOp in memref map layout normalization. -- An initial filter is applied on FuncOps which helps us know which functions can be a suitable candidate for memref normalization which doesn't lead to invalid IR. -- Handles memref map normalization for external function assuming the external function is normalizable. Differential Revision: https://reviews.llvm.org/D85226
This commit is contained in:
parent
fc7f004b88
commit
6d4f7801b1
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
||||
#define DEBUG_TYPE "normalize-memrefs"
|
||||
|
||||
|
@ -24,39 +25,45 @@ namespace {
|
|||
|
||||
/// All memrefs passed across functions with non-trivial layout maps are
|
||||
/// converted to ones with trivial identity layout ones.
|
||||
|
||||
// Input :-
|
||||
// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
|
||||
// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) ->
|
||||
// (memref<16xf64, #tile>) {
|
||||
// affine.for %arg3 = 0 to 16 {
|
||||
// %a = affine.load %A[%arg3] : memref<16xf64, #tile>
|
||||
// %p = mulf %a, %a : f64
|
||||
// affine.store %p, %A[%arg3] : memref<16xf64, #tile>
|
||||
// }
|
||||
// %c = alloc() : memref<16xf64, #tile>
|
||||
// %d = affine.load %c[0] : memref<16xf64, #tile>
|
||||
// return %A: memref<16xf64, #tile>
|
||||
// }
|
||||
|
||||
// Output :-
|
||||
// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
|
||||
// -> memref<4x4xf64> {
|
||||
// affine.for %arg3 = 0 to 16 {
|
||||
// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64>
|
||||
// %3 = mulf %2, %2 : f64
|
||||
// affine.store %3, %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64>
|
||||
// }
|
||||
// %0 = alloc() : memref<16xf64, #map0>
|
||||
// %1 = affine.load %0[0] : memref<16xf64, #map0>
|
||||
// return %arg0 : memref<4x4xf64>
|
||||
// }
|
||||
|
||||
/// If all the memref types/uses in a function are normalizable, we treat
|
||||
/// such functions as normalizable. Also, if a normalizable function is known
|
||||
/// to call a non-normalizable function, we treat that function as
|
||||
/// non-normalizable as well. We assume external functions to be normalizable.
|
||||
///
|
||||
/// Input :-
|
||||
/// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
|
||||
/// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) ->
|
||||
/// (memref<16xf64, #tile>) {
|
||||
/// affine.for %arg3 = 0 to 16 {
|
||||
/// %a = affine.load %A[%arg3] : memref<16xf64, #tile>
|
||||
/// %p = mulf %a, %a : f64
|
||||
/// affine.store %p, %A[%arg3] : memref<16xf64, #tile>
|
||||
/// }
|
||||
/// %c = alloc() : memref<16xf64, #tile>
|
||||
/// %d = affine.load %c[0] : memref<16xf64, #tile>
|
||||
/// return %A: memref<16xf64, #tile>
|
||||
/// }
|
||||
///
|
||||
/// Output :-
|
||||
/// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
|
||||
/// -> memref<4x4xf64> {
|
||||
/// affine.for %arg3 = 0 to 16 {
|
||||
/// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] :
|
||||
/// memref<4x4xf64> %3 = mulf %2, %2 : f64 affine.store %3, %arg0[%arg3
|
||||
/// floordiv 4, %arg3 mod 4] : memref<4x4xf64>
|
||||
/// }
|
||||
/// %0 = alloc() : memref<16xf64, #map0>
|
||||
/// %1 = affine.load %0[0] : memref<16xf64, #map0>
|
||||
/// return %arg0 : memref<4x4xf64>
|
||||
/// }
|
||||
///
|
||||
struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
|
||||
void runOnOperation() override;
|
||||
void runOnFunction(FuncOp funcOp);
|
||||
void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp);
|
||||
bool areMemRefsNormalizable(FuncOp funcOp);
|
||||
void updateFunctionSignature(FuncOp funcOp);
|
||||
void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp);
|
||||
void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp,
|
||||
DenseSet<FuncOp> &normalizableFuncs);
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
@ -67,41 +74,109 @@ std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
|
|||
|
||||
void NormalizeMemRefs::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
// We traverse each function within the module in order to normalize the
|
||||
// memref type arguments.
|
||||
// TODO: Handle external functions.
|
||||
// We maintain all normalizable FuncOps in a DenseSet. It is initialized
|
||||
// with all the functions within a module and then functions which are not
|
||||
// normalizable are removed from this set.
|
||||
// TODO: Change this to work on FuncLikeOp once there is an operation
|
||||
// interface for it.
|
||||
DenseSet<FuncOp> normalizableFuncs;
|
||||
// Initialize `normalizableFuncs` with all the functions within a module.
|
||||
moduleOp.walk([&](FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
|
||||
|
||||
// Traverse through all the functions applying a filter which determines
|
||||
// whether that function is normalizable or not. All callers/callees of
|
||||
// a non-normalizable function will also become non-normalizable even if
|
||||
// they aren't passing any or specific non-normalizable memrefs. So,
|
||||
// functions which calls or get called by a non-normalizable becomes non-
|
||||
// normalizable functions themselves.
|
||||
moduleOp.walk([&](FuncOp funcOp) {
|
||||
if (areMemRefsNormalizable(funcOp))
|
||||
runOnFunction(funcOp);
|
||||
if (normalizableFuncs.contains(funcOp)) {
|
||||
if (!areMemRefsNormalizable(funcOp)) {
|
||||
// Since this function is not normalizable, we set all the caller
|
||||
// functions and the callees of this function as not normalizable.
|
||||
// TODO: Drop this conservative assumption in the future.
|
||||
setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
|
||||
normalizableFuncs);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Those functions which can be normalized are subjected to normalization.
|
||||
for (FuncOp &funcOp : normalizableFuncs)
|
||||
normalizeFuncOpMemRefs(funcOp, moduleOp);
|
||||
}
|
||||
|
||||
// Return true if this operation dereferences one or more memref's.
|
||||
// TODO: Temporary utility, will be replaced when this is modeled through
|
||||
// side-effects/op traits.
|
||||
/// Return true if this operation dereferences one or more memref's.
|
||||
/// TODO: Temporary utility, will be replaced when this is modeled through
|
||||
/// side-effects/op traits.
|
||||
static bool isMemRefDereferencingOp(Operation &op) {
|
||||
return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
|
||||
AffineDmaWaitOp>(op);
|
||||
}
|
||||
|
||||
// Check whether all the uses of oldMemRef are either dereferencing uses or the
|
||||
// op is of type : DeallocOp, CallOp. Only if these constraints are satisfied
|
||||
// will the value become a candidate for replacement.
|
||||
/// Check whether all the uses of oldMemRef are either dereferencing uses or the
|
||||
/// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
|
||||
/// are satisfied will the value become a candidate for replacement.
|
||||
/// TODO: Extend this for DimOps.
|
||||
static bool isMemRefNormalizable(Value::user_range opUsers) {
|
||||
if (llvm::any_of(opUsers, [](Operation *op) {
|
||||
if (isMemRefDereferencingOp(*op))
|
||||
return false;
|
||||
return !isa<DeallocOp, CallOp>(*op);
|
||||
return !isa<DeallocOp, CallOp, ReturnOp>(*op);
|
||||
}))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check whether all the uses of AllocOps, CallOps and function arguments of a
|
||||
// function are either of dereferencing type or of type: DeallocOp, CallOp. Only
|
||||
// if these constraints are satisfied will the function become a candidate for
|
||||
// normalization.
|
||||
/// Set all the calling functions and the callees of the function as not
|
||||
/// normalizable.
|
||||
void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
|
||||
FuncOp funcOp, ModuleOp moduleOp, DenseSet<FuncOp> &normalizableFuncs) {
|
||||
if (!normalizableFuncs.contains(funcOp))
|
||||
return;
|
||||
|
||||
normalizableFuncs.erase(funcOp);
|
||||
// Caller of the function.
|
||||
Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
|
||||
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
|
||||
// TODO: Extend this for ops that are FunctionLike. This would require
|
||||
// creating an OpInterface for FunctionLike ops.
|
||||
FuncOp parentFuncOp = symbolUse.getUser()->getParentOfType<FuncOp>();
|
||||
for (FuncOp &funcOp : normalizableFuncs) {
|
||||
if (parentFuncOp == funcOp) {
|
||||
setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
|
||||
normalizableFuncs);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Functions called by this function.
|
||||
funcOp.walk([&](CallOp callOp) {
|
||||
StringRef callee = callOp.getCallee();
|
||||
for (FuncOp &funcOp : normalizableFuncs) {
|
||||
// We compare FuncOp and callee's name.
|
||||
if (callee == funcOp.getName()) {
|
||||
setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
|
||||
normalizableFuncs);
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Check whether all the uses of AllocOps, CallOps and function arguments of a
|
||||
/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
|
||||
/// or ReturnOp. Only if these constraints are satisfied will the function
|
||||
/// become a candidate for normalization. We follow a conservative approach here
|
||||
/// wherein even if the non-normalizable memref is not a part of the function's
|
||||
/// argument or return type, we still label the entire function as
|
||||
/// non-normalizable. We assume external functions to be normalizable.
|
||||
bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
|
||||
// We assume external functions to be normalizable.
|
||||
if (funcOp.isExternal())
|
||||
return true;
|
||||
|
||||
if (funcOp
|
||||
.walk([&](AllocOp allocOp) -> WalkResult {
|
||||
Value oldMemRef = allocOp.getResult();
|
||||
|
@ -136,28 +211,138 @@ bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
|
|||
return true;
|
||||
}
|
||||
|
||||
// Fetch the updated argument list and result of the function and update the
|
||||
// function signature.
|
||||
void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp) {
|
||||
/// Fetch the updated argument list and result of the function and update the
|
||||
/// function signature. This updates the function's return type at the caller
|
||||
/// site and in case the return type is a normalized memref then it updates
|
||||
/// the calling function's signature.
|
||||
/// TODO: An update to the calling function signature is required only if the
|
||||
/// returned value is in turn used in ReturnOp of the calling function.
|
||||
void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
|
||||
ModuleOp moduleOp) {
|
||||
FunctionType functionType = funcOp.getType();
|
||||
SmallVector<Type, 8> argTypes;
|
||||
SmallVector<Type, 4> resultTypes;
|
||||
|
||||
for (const auto &arg : llvm::enumerate(funcOp.getArguments()))
|
||||
argTypes.push_back(arg.value().getType());
|
||||
|
||||
FunctionType newFuncType;
|
||||
resultTypes = llvm::to_vector<4>(functionType.getResults());
|
||||
// We create a new function type and modify the function signature with this
|
||||
// new type.
|
||||
FunctionType newFuncType = FunctionType::get(/*inputs=*/argTypes,
|
||||
/*results=*/resultTypes,
|
||||
/*context=*/&getContext());
|
||||
|
||||
// TODO: Handle ReturnOps to update function results the caller site.
|
||||
funcOp.setType(newFuncType);
|
||||
// External function's signature was already updated in
|
||||
// 'normalizeFuncOpMemRefs()'.
|
||||
if (!funcOp.isExternal()) {
|
||||
SmallVector<Type, 8> argTypes;
|
||||
for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
|
||||
argTypes.push_back(argEn.value().getType());
|
||||
|
||||
// Traverse ReturnOps to check if an update to the return type in the
|
||||
// function signature is required.
|
||||
funcOp.walk([&](ReturnOp returnOp) {
|
||||
for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
|
||||
Type opType = operandEn.value().getType();
|
||||
MemRefType memrefType = opType.dyn_cast<MemRefType>();
|
||||
// If type is not memref or if the memref type is same as that in
|
||||
// function's return signature then no update is required.
|
||||
if (!memrefType || memrefType == resultTypes[operandEn.index()])
|
||||
continue;
|
||||
// Update function's return type signature.
|
||||
// Return type gets normalized either as a result of function argument
|
||||
// normalization, AllocOp normalization or an update made at CallOp.
|
||||
// There can be many call flows inside a function and an update to a
|
||||
// specific ReturnOp has not yet been made. So we check that the result
|
||||
// memref type is normalized.
|
||||
// TODO: When selective normalization is implemented, handle multiple
|
||||
// results case where some are normalized, some aren't.
|
||||
if (memrefType.getAffineMaps().empty())
|
||||
resultTypes[operandEn.index()] = memrefType;
|
||||
}
|
||||
});
|
||||
|
||||
// We create a new function type and modify the function signature with this
|
||||
// new type.
|
||||
newFuncType = FunctionType::get(/*inputs=*/argTypes,
|
||||
/*results=*/resultTypes,
|
||||
/*context=*/&getContext());
|
||||
}
|
||||
|
||||
// Since we update the function signature, it might affect the result types at
|
||||
// the caller site. Since this result might even be used by the caller
|
||||
// function in ReturnOps, the caller function's signature will also change.
|
||||
// Hence we record the caller function in 'funcOpsToUpdate' to update their
|
||||
// signature as well.
|
||||
llvm::SmallDenseSet<FuncOp, 8> funcOpsToUpdate;
|
||||
// We iterate over all symbolic uses of the function and update the return
|
||||
// type at the caller site.
|
||||
Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
|
||||
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
|
||||
Operation *callOp = symbolUse.getUser();
|
||||
OpBuilder builder(callOp);
|
||||
StringRef callee = cast<CallOp>(callOp).getCallee();
|
||||
Operation *newCallOp = builder.create<CallOp>(
|
||||
callOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee),
|
||||
callOp->getOperands());
|
||||
bool replacingMemRefUsesFailed = false;
|
||||
bool returnTypeChanged = false;
|
||||
for (unsigned resIndex : llvm::seq<unsigned>(0, callOp->getNumResults())) {
|
||||
OpResult oldResult = callOp->getResult(resIndex);
|
||||
OpResult newResult = newCallOp->getResult(resIndex);
|
||||
// This condition ensures that if the result is not of type memref or if
|
||||
// the resulting memref was already having a trivial map layout then we
|
||||
// need not perform any use replacement here.
|
||||
if (oldResult.getType() == newResult.getType())
|
||||
continue;
|
||||
AffineMap layoutMap =
|
||||
oldResult.getType().dyn_cast<MemRefType>().getAffineMaps().front();
|
||||
if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
|
||||
/*extraIndices=*/{},
|
||||
/*indexRemap=*/layoutMap,
|
||||
/*extraOperands=*/{},
|
||||
/*symbolOperands=*/{},
|
||||
/*domInstFilter=*/nullptr,
|
||||
/*postDomInstFilter=*/nullptr,
|
||||
/*allowDereferencingOps=*/true,
|
||||
/*replaceInDeallocOp=*/true))) {
|
||||
// If it failed (due to escapes for example), bail out.
|
||||
// It should never hit this part of the code because it is called by
|
||||
// only those functions which are normalizable.
|
||||
newCallOp->erase();
|
||||
replacingMemRefUsesFailed = true;
|
||||
break;
|
||||
}
|
||||
returnTypeChanged = true;
|
||||
}
|
||||
if (replacingMemRefUsesFailed)
|
||||
continue;
|
||||
// Replace all uses for other non-memref result types.
|
||||
callOp->replaceAllUsesWith(newCallOp);
|
||||
callOp->erase();
|
||||
if (returnTypeChanged) {
|
||||
// Since the return type changed it might lead to a change in function's
|
||||
// signature.
|
||||
// TODO: If funcOp doesn't return any memref type then no need to update
|
||||
// signature.
|
||||
// TODO: Further optimization - Check if the memref is indeed part of
|
||||
// ReturnOp at the parentFuncOp and only then updation of signature is
|
||||
// required.
|
||||
// TODO: Extend this for ops that are FunctionLike. This would require
|
||||
// creating an OpInterface for FunctionLike ops.
|
||||
FuncOp parentFuncOp = newCallOp->getParentOfType<FuncOp>();
|
||||
funcOpsToUpdate.insert(parentFuncOp);
|
||||
}
|
||||
}
|
||||
// Because external function's signature is already updated in
|
||||
// 'normalizeFuncOpMemRefs()', we don't need to update it here again.
|
||||
if (!funcOp.isExternal())
|
||||
funcOp.setType(newFuncType);
|
||||
|
||||
// Updating the signature type of those functions which call the current
|
||||
// function. Only if the return type of the current function has a normalized
|
||||
// memref will the caller function become a candidate for signature update.
|
||||
for (FuncOp parentFuncOp : funcOpsToUpdate)
|
||||
updateFunctionSignature(parentFuncOp, moduleOp);
|
||||
}
|
||||
|
||||
void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
|
||||
/// Normalizes the memrefs within a function which includes those arising as a
|
||||
/// result of AllocOps, CallOps and function's argument. The ModuleOp argument
|
||||
/// is used to help update function's signature after normalization.
|
||||
void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
|
||||
ModuleOp moduleOp) {
|
||||
// Turn memrefs' non-identity layouts maps into ones with identity. Collect
|
||||
// alloc ops first and then process since normalizeMemRef replaces/erases ops
|
||||
// during memref rewriting.
|
||||
|
@ -169,22 +354,27 @@ void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
|
|||
// We use this OpBuilder to create new memref layout later.
|
||||
OpBuilder b(funcOp);
|
||||
|
||||
FunctionType functionType = funcOp.getType();
|
||||
SmallVector<Type, 8> inputTypes;
|
||||
// Walk over each argument of a function to perform memref normalization (if
|
||||
// any).
|
||||
for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
|
||||
Type argType = funcOp.getArgument(argIndex).getType();
|
||||
for (unsigned argIndex :
|
||||
llvm::seq<unsigned>(0, functionType.getNumInputs())) {
|
||||
Type argType = functionType.getInput(argIndex);
|
||||
MemRefType memrefType = argType.dyn_cast<MemRefType>();
|
||||
// Check whether argument is of MemRef type. Any other argument type can
|
||||
// simply be part of the final function signature.
|
||||
if (!memrefType)
|
||||
if (!memrefType) {
|
||||
inputTypes.push_back(argType);
|
||||
continue;
|
||||
}
|
||||
// Fetch a new memref type after normalizing the old memref to have an
|
||||
// identity map layout.
|
||||
MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
|
||||
/*numSymbolicOperands=*/0);
|
||||
if (newMemRefType == memrefType) {
|
||||
if (newMemRefType == memrefType || funcOp.isExternal()) {
|
||||
// Either memrefType already had an identity map or the map couldn't be
|
||||
// transformed to an identity map.
|
||||
inputTypes.push_back(newMemRefType);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -202,7 +392,7 @@ void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
|
|||
/*domInstFilter=*/nullptr,
|
||||
/*postDomInstFilter=*/nullptr,
|
||||
/*allowNonDereferencingOps=*/true,
|
||||
/*handleDeallocOp=*/true))) {
|
||||
/*replaceInDeallocOp=*/true))) {
|
||||
// If it failed (due to escapes for example), bail out. Removing the
|
||||
// temporary argument inserted previously.
|
||||
funcOp.front().eraseArgument(argIndex);
|
||||
|
@ -214,5 +404,36 @@ void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
|
|||
funcOp.front().eraseArgument(argIndex + 1);
|
||||
}
|
||||
|
||||
updateFunctionSignature(funcOp);
|
||||
// In a normal function, memrefs in the return type signature gets normalized
|
||||
// as a result of normalization of functions arguments, AllocOps or CallOps'
|
||||
// result types. Since an external function doesn't have a body, memrefs in
|
||||
// the return type signature can only get normalized by iterating over the
|
||||
// individual return types.
|
||||
if (funcOp.isExternal()) {
|
||||
SmallVector<Type, 4> resultTypes;
|
||||
for (unsigned resIndex :
|
||||
llvm::seq<unsigned>(0, functionType.getNumResults())) {
|
||||
Type resType = functionType.getResult(resIndex);
|
||||
MemRefType memrefType = resType.dyn_cast<MemRefType>();
|
||||
// Check whether result is of MemRef type. Any other argument type can
|
||||
// simply be part of the final function signature.
|
||||
if (!memrefType) {
|
||||
resultTypes.push_back(resType);
|
||||
continue;
|
||||
}
|
||||
// Computing a new memref type after normalizing the old memref to have an
|
||||
// identity map layout.
|
||||
MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
|
||||
/*numSymbolicOperands=*/0);
|
||||
resultTypes.push_back(newMemRefType);
|
||||
continue;
|
||||
}
|
||||
|
||||
FunctionType newFuncType = FunctionType::get(/*inputs=*/inputTypes,
|
||||
/*results=*/resultTypes,
|
||||
/*context=*/&getContext());
|
||||
// Setting the new function signature for this external function.
|
||||
funcOp.setType(newFuncType);
|
||||
}
|
||||
updateFunctionSignature(funcOp, moduleOp);
|
||||
}
|
||||
|
|
|
@ -274,12 +274,12 @@ LogicalResult mlir::replaceAllMemRefUsesWith(
|
|||
// for the memref to be used in a non-dereferencing way outside of the
|
||||
// region where this replacement is happening.
|
||||
if (!isMemRefDereferencingOp(*op)) {
|
||||
// Currently we support the following non-dereferencing types to be a
|
||||
// candidate for replacement: Dealloc and CallOp.
|
||||
// TODO: Add support for other kinds of ops.
|
||||
if (!allowNonDereferencingOps)
|
||||
return failure();
|
||||
if (!(isa<DeallocOp, CallOp>(*op)))
|
||||
// Currently we support the following non-dereferencing ops to be a
|
||||
// candidate for replacement: Dealloc, CallOp and ReturnOp.
|
||||
// TODO: Add support for other kinds of ops.
|
||||
if (!isa<DeallocOp, CallOp, ReturnOp>(*op))
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
|
|
@ -126,14 +126,6 @@ func @symbolic_operands(%s : index) {
|
|||
return
|
||||
}
|
||||
|
||||
// Memref escapes; no normalization.
|
||||
// CHECK-LABEL: func @escaping() -> memref<64xf32, #map{{[0-9]+}}>
|
||||
func @escaping() -> memref<64xf32, affine_map<(d0) -> (d0 + 2)>> {
|
||||
// CHECK: %{{.*}} = alloc() : memref<64xf32, #map{{[0-9]+}}>
|
||||
%A = alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 2)>>
|
||||
return %A : memref<64xf32, affine_map<(d0) -> (d0 + 2)>>
|
||||
}
|
||||
|
||||
// Semi-affine maps, normalization not implemented yet.
|
||||
// CHECK-LABEL: func @semi_affine_layout_map
|
||||
func @semi_affine_layout_map(%s0: index, %s1: index) {
|
||||
|
@ -205,9 +197,125 @@ func @non_memref_ret(%A: memref<8xf64, #tile>) -> i1 {
|
|||
return %d : i1
|
||||
}
|
||||
|
||||
// Test case 4: No normalization should take place because the function is returning the memref.
|
||||
// CHECK-LABEL: func @memref_used_in_return
|
||||
// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) -> memref<8xf64, #map{{[0-9]+}}>
|
||||
func @memref_used_in_return(%A: memref<8xf64, #tile>) -> (memref<8xf64, #tile>) {
|
||||
return %A : memref<8xf64, #tile>
|
||||
// Test cases here onwards deal with normalization of memref in function signature, caller site.
|
||||
|
||||
// Test case 4: Check successful memref normalization in case of inter/intra-recursive calls.
|
||||
// CHECK-LABEL: func @ret_multiple_argument_type
|
||||
// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<2x4xf64>, f64)
|
||||
func @ret_multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) {
|
||||
%a = affine.load %A[0] : memref<16xf64, #tile>
|
||||
%p = mulf %a, %a : f64
|
||||
%cond = constant 1 : i1
|
||||
cond_br %cond, ^bb1, ^bb2
|
||||
^bb1:
|
||||
%res1, %res2 = call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
|
||||
return %res2, %p: memref<8xf64, #tile>, f64
|
||||
^bb2:
|
||||
return %C, %p: memref<8xf64, #tile>, f64
|
||||
}
|
||||
|
||||
// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64>
|
||||
// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64
|
||||
// CHECK: %true = constant true
|
||||
// CHECK: cond_br %true, ^bb1, ^bb2
|
||||
// CHECK: ^bb1: // pred: ^bb0
|
||||
// CHECK: %[[res:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
|
||||
// CHECK: return %[[res]]#1, %[[p]] : memref<2x4xf64>, f64
|
||||
// CHECK: ^bb2: // pred: ^bb0
|
||||
// CHECK: return %{{.*}}, %{{.*}} : memref<2x4xf64>, f64
|
||||
|
||||
// CHECK-LABEL: func @ret_single_argument_type
|
||||
// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
|
||||
func @ret_single_argument_type(%C: memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>){
|
||||
%a = alloc() : memref<8xf64, #tile>
|
||||
%b = alloc() : memref<16xf64, #tile>
|
||||
%d = constant 23.0 : f64
|
||||
call @ret_single_argument_type(%a) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
|
||||
call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
|
||||
%res1, %res2 = call @ret_multiple_argument_type(%b, %d, %a) : (memref<16xf64, #tile>, f64, memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64)
|
||||
%res3, %res4 = call @ret_single_argument_type(%res1) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
|
||||
return %b, %a: memref<16xf64, #tile>, memref<8xf64, #tile>
|
||||
}
|
||||
|
||||
// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64>
|
||||
// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64>
|
||||
// CHECK: %cst = constant 2.300000e+01 : f64
|
||||
// CHECK: %[[resA:[0-9]+]]:2 = call @ret_single_argument_type(%[[a]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
|
||||
// CHECK: %[[resB:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
|
||||
// CHECK: %[[resC:[0-9]+]]:2 = call @ret_multiple_argument_type(%[[b]], %cst, %[[a]]) : (memref<4x4xf64>, f64, memref<2x4xf64>) -> (memref<2x4xf64>, f64)
|
||||
// CHECK: %[[resD:[0-9]+]]:2 = call @ret_single_argument_type(%[[resC]]#0) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
|
||||
// CHECK: return %{{.*}}, %{{.*}} : memref<4x4xf64>, memref<2x4xf64>
|
||||
|
||||
// Test case set #5: To check normalization in a chain of interconnected functions.
|
||||
// CHECK-LABEL: func @func_A
|
||||
// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
|
||||
func @func_A(%A: memref<8xf64, #tile>) {
|
||||
call @func_B(%A) : (memref<8xf64, #tile>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: call @func_B(%[[A]]) : (memref<2x4xf64>) -> ()
|
||||
|
||||
// CHECK-LABEL: func @func_B
|
||||
// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
|
||||
func @func_B(%A: memref<8xf64, #tile>) {
|
||||
call @func_C(%A) : (memref<8xf64, #tile>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: call @func_C(%[[A]]) : (memref<2x4xf64>) -> ()
|
||||
|
||||
// CHECK-LABEL: func @func_C
|
||||
// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
|
||||
func @func_C(%A: memref<8xf64, #tile>) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test case set #6: Checking if no normalization takes place in a scenario: A -> B -> C and B has an unsupported type.
|
||||
// CHECK-LABEL: func @some_func_A
|
||||
// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
|
||||
func @some_func_A(%A: memref<8xf64, #tile>) {
|
||||
call @some_func_B(%A) : (memref<8xf64, #tile>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: call @some_func_B(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> ()
|
||||
|
||||
// CHECK-LABEL: func @some_func_B
|
||||
// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
|
||||
func @some_func_B(%A: memref<8xf64, #tile>) {
|
||||
"test.test"(%A) : (memref<8xf64, #tile>) -> ()
|
||||
call @some_func_C(%A) : (memref<8xf64, #tile>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: call @some_func_C(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> ()
|
||||
|
||||
// CHECK-LABEL: func @some_func_C
|
||||
// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
|
||||
func @some_func_C(%A: memref<8xf64, #tile>) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test case set #7: Check normalization in case of external functions.
|
||||
// CHECK-LABEL: func @external_func_A
|
||||
// CHECK-SAME: (memref<4x4xf64>)
|
||||
func @external_func_A(memref<16xf64, #tile>) -> ()
|
||||
|
||||
// CHECK-LABEL: func @external_func_B
|
||||
// CHECK-SAME: (memref<4x4xf64>, f64) -> memref<2x4xf64>
|
||||
func @external_func_B(memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>)
|
||||
|
||||
// CHECK-LABEL: func @simply_call_external()
|
||||
func @simply_call_external() {
|
||||
%a = alloc() : memref<16xf64, #tile>
|
||||
call @external_func_A(%a) : (memref<16xf64, #tile>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: %[[a:[0-9]+]] = alloc() : memref<4x4xf64>
|
||||
// CHECK: call @external_func_A(%[[a]]) : (memref<4x4xf64>) -> ()
|
||||
|
||||
// CHECK-LABEL: func @use_value_of_external
|
||||
// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64) -> memref<2x4xf64>
|
||||
func @use_value_of_external(%A: memref<16xf64, #tile>, %B: f64) -> (memref<8xf64, #tile>) {
|
||||
%res = call @external_func_B(%A, %B) : (memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>)
|
||||
return %res : memref<8xf64, #tile>
|
||||
}
|
||||
// CHECK: %[[res:[0-9]+]] = call @external_func_B(%[[A]], %[[B]]) : (memref<4x4xf64>, f64) -> memref<2x4xf64>
|
||||
// CHECK: return %{{.*}} : memref<2x4xf64>
|
||||
|
|
Loading…
Reference in a new issue