[MLIR] Move eraseArguments and eraseResults to FunctionLike

Previously, they were only defined for `FuncOp`.

To support this, `FunctionLike` needs a way to get an updated type
from the concrete operation. This adds a new hook for that purpose,
called `getTypeWithoutArgsAndResults`.

For now, `FunctionLike` continues to assume the type is
`FunctionType`, and concrete operations that use another type can hide
the `getType`, `setType`, and `getTypeWithoutArgsAndResults` methods.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D90363
This commit is contained in:
mikeurbach 2020-10-22 11:39:39 -06:00
parent 50c2f2b6f0
commit 2e36e0dad5
11 changed files with 255 additions and 89 deletions

View file

@ -255,11 +255,17 @@ particular:
- they can have argument and result attributes that are stored in dictionary
attributes on the operation itself.
This trait does *NOT* provide type support for the functions, meaning that
concrete Ops must handle the type of the declared or defined function.
`getTypeAttrName()` is a convenience function that returns the name of the
attribute that can be used to store the function type, but the trait makes no
assumption based on it.
This trait provides limited type support for the declared or defined functions.
The convenience function `getTypeAttrName()` returns the name of an attribute
that can be used to store the function type. In addition, this trait provides
`getType` and `setType` helpers to store a `FunctionType` in the attribute named
by `getTypeAttrName()`.
In general, this trait assumes concrete ops use `FunctionType` under the hood.
If this is not the case, in order to use the function type support, concrete ops
must define the following methods, using the same name, to hide the ones defined
for `FunctionType`: `addBodyBlock`, `getType`, `getTypeWithoutArgsAndResults`
and `setType`.
### HasParent

View file

@ -16,6 +16,10 @@
#include "mlir/IR/BlockSupport.h"
#include "mlir/IR/Visitors.h"
namespace llvm {
class BitVector;
} // end namespace llvm
namespace mlir {
class TypeRange;
template <typename ValueRangeT> class ValueTypeRange;
@ -98,6 +102,13 @@ public:
/// Erase the argument at 'index' and remove it from the argument list.
void eraseArgument(unsigned index);
/// Erases the arguments listed in `argIndices` and removes them from the
/// argument list.
/// `argIndices` is allowed to have duplicates and can be in any order.
void eraseArguments(ArrayRef<unsigned> argIndices);
/// Erases the arguments that have their corresponding bit set in
/// `eraseIndices` and removes them from the argument list.
void eraseArguments(llvm::BitVector eraseIndices);
unsigned getNumArguments() { return arguments.size(); }
BlockArgument getArgument(unsigned i) { return arguments[i]; }

View file

@ -59,18 +59,6 @@ public:
void print(OpAsmPrinter &p);
LogicalResult verify();
/// Erase a single argument at `argIndex`.
void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
/// Erases the arguments listed in `argIndices`.
/// `argIndices` is allowed to have duplicates and can be in any order.
void eraseArguments(ArrayRef<unsigned> argIndices);
/// Erase a single result at `resultIndex`.
void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); }
/// Erases the results listed in `resultIndices`.
/// `resultIndices` is allowed to have duplicates and can be in any order.
void eraseResults(ArrayRef<unsigned> resultIndices);
/// Create a deep copy of this function and all of its blocks, remapping
/// any operands that use values outside of the function using the map that is
/// provided (leaving them alone if no entry is present). If the mapper

View file

@ -71,6 +71,14 @@ inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
return resultDict ? resultDict.getValue() : llvm::None;
}
/// Erase the specified arguments and update the function type attribute.
void eraseFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
unsigned originalNumArgs, Type newType);
/// Erase the specified results and update the function type attribute.
void eraseFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
unsigned originalNumResults, Type newType);
} // namespace impl
namespace OpTrait {
@ -84,12 +92,21 @@ namespace OpTrait {
/// arguments;
/// - they can have argument attributes that are stored in a dictionary
/// attribute on the Op itself.
/// This trait does *NOT* provide type support for the functions, meaning that
/// concrete Ops must handle the type of the declared or defined function.
/// `getTypeAttrName()` is a convenience function that returns the name of the
/// attribute that can be used to store the function type, but the trait makes
/// no assumption based on it.
///
/// This trait provides limited type support for the declared or defined
/// functions. The convenience function `getTypeAttrName()` returns the name of
/// an attribute that can be used to store the function type. In addition, this
/// trait provides `getType` and `setType` helpers to store a `FunctionType` in
/// the attribute named by `getTypeAttrName()`.
///
/// In general, this trait assumes concrete ops use `FunctionType` under the
/// hood. If this is not the case, in order to use the function type support,
/// concrete ops must define the following methods, using the same name, to hide
/// the ones defined for `FunctionType`: `addBodyBlock`, `getType`,
/// `getTypeWithoutArgsAndResults` and `setType`.
///
/// Besides the requirements above, concrete ops must interact with this trait
/// using the following functions:
/// - Concrete ops *must* define a member function `getNumFuncArguments()` that
/// returns the number of function arguments based exclusively on type (so
/// that it can be called on function declarations).
@ -183,6 +200,19 @@ public:
return getTypeAttr().getValue().template cast<FunctionType>();
}
/// Return the type of this function without the specified arguments and
/// results. This is used to update the function's signature in the
/// `eraseArguments` and `eraseResults` methods. The arrays of indices are
/// allowed to have duplicates and can be in any order.
///
/// Note that the concrete class must define a method with the same name to
/// hide this one if the concrete class does not use FunctionType for the
/// function type under the hood.
FunctionType getTypeWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
ArrayRef<unsigned> resultIndices) {
return getType().getWithoutArgsAndResults(argIndices, resultIndices);
}
bool isTypeAttrValid() {
auto typeAttr = getTypeAttr();
if (!typeAttr)
@ -204,7 +234,7 @@ public:
void setType(FunctionType newType);
//===--------------------------------------------------------------------===//
// Argument Handling
// Argument and Result Handling
//===--------------------------------------------------------------------===//
using BlockArgListType = Region::BlockArgListType;
@ -229,6 +259,30 @@ public:
return getBody().getArgumentTypes();
}
/// Erase a single argument at `argIndex`.
void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
/// Erases the arguments listed in `argIndices`.
/// `argIndices` is allowed to have duplicates and can be in any order.
void eraseArguments(ArrayRef<unsigned> argIndices) {
unsigned originalNumArgs = getNumArguments();
Type newType = getTypeWithoutArgsAndResults(argIndices, {});
::mlir::impl::eraseFunctionArguments(this->getOperation(), argIndices,
originalNumArgs, newType);
}
/// Erase a single result at `resultIndex`.
void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); }
/// Erases the results listed in `resultIndices`.
/// `resultIndices` is allowed to have duplicates and can be in any order.
void eraseResults(ArrayRef<unsigned> resultIndices) {
unsigned originalNumResults = getNumResults();
Type newType = getTypeWithoutArgsAndResults({}, resultIndices);
::mlir::impl::eraseFunctionResults(this->getOperation(), resultIndices,
originalNumResults, newType);
}
//===--------------------------------------------------------------------===//
// Argument Attributes
//===--------------------------------------------------------------------===//

View file

@ -238,15 +238,19 @@ public:
static FunctionType get(TypeRange inputs, TypeRange results,
MLIRContext *context);
// Input types.
/// Input types.
unsigned getNumInputs() const;
Type getInput(unsigned i) const { return getInputs()[i]; }
ArrayRef<Type> getInputs() const;
// Result types.
/// Result types.
unsigned getNumResults() const;
Type getResult(unsigned i) const { return getResults()[i]; }
ArrayRef<Type> getResults() const;
/// Returns a new function type without the specified arguments and results.
FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
ArrayRef<unsigned> resultIndices);
};
//===----------------------------------------------------------------------===//

View file

@ -25,8 +25,6 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include <set>
#define DEBUG_TYPE "linalg-drop-unit-dims"
using namespace mlir;
@ -166,9 +164,8 @@ LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
for (unsigned unitDimLoop : unitDims) {
entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
}
std::set<unsigned> orderedUnitDims(unitDims.begin(), unitDims.end());
for (unsigned i : llvm::reverse(orderedUnitDims))
entryBlock->eraseArgument(i);
SmallVector<unsigned, 8> unitDimsToErase(unitDims.begin(), unitDims.end());
entryBlock->eraseArguments(unitDimsToErase);
return success();
}

View file

@ -9,6 +9,7 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/BitVector.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
@ -176,6 +177,22 @@ void Block::eraseArgument(unsigned index) {
arguments.erase(arguments.begin() + index);
}
void Block::eraseArguments(ArrayRef<unsigned> argIndices) {
llvm::BitVector eraseIndices(getNumArguments());
for (unsigned i : argIndices)
eraseIndices.set(i);
eraseArguments(eraseIndices);
}
void Block::eraseArguments(llvm::BitVector eraseIndices) {
// We do this in reverse so that we erase later indices before earlier
// indices, to avoid shifting the later indices.
unsigned originalNumArgs = getNumArguments();
for (unsigned i = 0; i < originalNumArgs; ++i)
if (eraseIndices.test(originalNumArgs - i - 1))
eraseArgument(originalNumArgs - i - 1);
}
/// Insert one value to the given position of the argument list. The existing
/// arguments are shifted. The block is expected not to have predecessors.
BlockArgument Block::insertArgument(args_iterator it, Type type) {

View file

@ -10,6 +10,7 @@ add_mlir_library(MLIRIR
Dominance.cpp
Function.cpp
FunctionImplementation.cpp
FunctionSupport.cpp
IntegerSet.cpp
Location.cpp
MLIRContext.cpp

View file

@ -98,65 +98,6 @@ LogicalResult FuncOp::verify() {
return success();
}
void FuncOp::eraseArguments(ArrayRef<unsigned> argIndices) {
auto oldType = getType();
int originalNumArgs = oldType.getNumInputs();
llvm::BitVector eraseIndices(originalNumArgs);
for (auto index : argIndices)
eraseIndices.set(index);
auto shouldEraseArg = [&](int i) { return eraseIndices.test(i); };
// There are 3 things that need to be updated:
// - Function type.
// - Arg attrs.
// - Block arguments of entry block.
// Update the function type and arg attrs.
SmallVector<Type, 4> newInputTypes;
SmallVector<MutableDictionaryAttr, 4> newArgAttrs;
for (int i = 0; i < originalNumArgs; i++) {
if (shouldEraseArg(i))
continue;
newInputTypes.emplace_back(oldType.getInput(i));
newArgAttrs.emplace_back(getArgAttrDict(i));
}
setType(FunctionType::get(newInputTypes, oldType.getResults(), getContext()));
setAllArgAttrs(newArgAttrs);
// Update the entry block's arguments.
// We do this in reverse so that we erase later indices before earlier
// indices, to avoid shifting the later indices.
Block &entry = front();
for (int i = 0; i < originalNumArgs; i++)
if (shouldEraseArg(originalNumArgs - i - 1))
entry.eraseArgument(originalNumArgs - i - 1);
}
void FuncOp::eraseResults(ArrayRef<unsigned> resultIndices) {
auto oldType = getType();
int originalNumResults = oldType.getNumResults();
llvm::BitVector eraseIndices(originalNumResults);
for (auto index : resultIndices)
eraseIndices.set(index);
auto shouldEraseResult = [&](int i) { return eraseIndices.test(i); };
// There are 2 things that need to be updated:
// - Function type.
// - Result attrs.
// Update the function type and result attrs.
SmallVector<Type, 4> newResultTypes;
SmallVector<MutableDictionaryAttr, 4> newResultAttrs;
for (int i = 0; i < originalNumResults; i++) {
if (shouldEraseResult(i))
continue;
newResultTypes.emplace_back(oldType.getResult(i));
newResultAttrs.emplace_back(getResultAttrDict(i));
}
setType(FunctionType::get(oldType.getInputs(), newResultTypes, getContext()));
setAllResultAttrs(newResultAttrs);
}
/// Clone the internal blocks from this function into dest and all attributes
/// from this function to dest.
void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {

View file

@ -0,0 +1,103 @@
//===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/FunctionSupport.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/BitVector.h"
using namespace mlir;
/// Helper to call a callback once on each index in the range
/// [0, `totalIndices`), *except* for the indices given in `indices`.
/// `indices` is allowed to have duplicates and can be in any order.
inline void iterateIndicesExcept(unsigned totalIndices,
ArrayRef<unsigned> indices,
function_ref<void(unsigned)> callback) {
llvm::BitVector skipIndices(totalIndices);
for (unsigned i : indices)
skipIndices.set(i);
for (unsigned i = 0; i < totalIndices; ++i)
if (!skipIndices.test(i))
callback(i);
}
//===----------------------------------------------------------------------===//
// Function Arguments and Results.
//===----------------------------------------------------------------------===//
void mlir::impl::eraseFunctionArguments(Operation *op,
ArrayRef<unsigned> argIndices,
unsigned originalNumArgs,
Type newType) {
// There are 3 things that need to be updated:
// - Function type.
// - Arg attrs.
// - Block arguments of entry block.
Block &entry = op->getRegion(0).front();
SmallString<8> nameBuf;
// Collect arg attrs to set.
SmallVector<MutableDictionaryAttr, 4> newArgAttrs;
iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
newArgAttrs.emplace_back(getArgAttrDict(op, i));
});
// Remove any arg attrs that are no longer needed.
for (unsigned i = newArgAttrs.size(), e = originalNumArgs; i < e; ++i)
op->removeAttr(getArgAttrName(i, nameBuf));
// Set the function type.
op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
// Set the new arg attrs, or remove them if empty.
for (unsigned i = 0, e = newArgAttrs.size(); i != e; ++i) {
auto nameAttr = getArgAttrName(i, nameBuf);
auto argAttr = newArgAttrs[i];
if (argAttr.empty())
op->removeAttr(nameAttr);
else
op->setAttr(nameAttr, argAttr.getDictionary(op->getContext()));
}
// Update the entry block's arguments.
entry.eraseArguments(argIndices);
}
void mlir::impl::eraseFunctionResults(Operation *op,
ArrayRef<unsigned> resultIndices,
unsigned originalNumResults,
Type newType) {
// There are 2 things that need to be updated:
// - Function type.
// - Result attrs.
SmallString<8> nameBuf;
// Collect result attrs to set.
SmallVector<MutableDictionaryAttr, 4> newResultAttrs;
iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
newResultAttrs.emplace_back(getResultAttrDict(op, i));
});
// Remove any result attrs that are no longer needed.
for (unsigned i = newResultAttrs.size(), e = originalNumResults; i < e; ++i)
op->removeAttr(getResultAttrName(i, nameBuf));
// Set the function type.
op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
// Set the new result attrs, or remove them if empty.
for (unsigned i = 0, e = newResultAttrs.size(); i != e; ++i) {
auto nameAttr = getResultAttrName(i, nameBuf);
auto resultAttr = newResultAttrs[i];
if (resultAttr.empty())
op->removeAttr(nameAttr);
else
op->setAttr(nameAttr, resultAttr.getDictionary(op->getContext()));
}
}

View file

@ -10,6 +10,8 @@
#include "TypeDetail.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/Twine.h"
using namespace mlir;
@ -46,6 +48,48 @@ ArrayRef<Type> FunctionType::getResults() const {
return getImpl()->getResults();
}
/// Helper to call a callback once on each index in the range
/// [0, `totalIndices`), *except* for the indices given in `indices`.
/// `indices` is allowed to have duplicates and can be in any order.
inline void iterateIndicesExcept(unsigned totalIndices,
ArrayRef<unsigned> indices,
function_ref<void(unsigned)> callback) {
llvm::BitVector skipIndices(totalIndices);
for (unsigned i : indices)
skipIndices.set(i);
for (unsigned i = 0; i < totalIndices; ++i)
if (!skipIndices.test(i))
callback(i);
}
/// Returns a new function type without the specified arguments and results.
FunctionType
FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
ArrayRef<unsigned> resultIndices) {
ArrayRef<Type> newInputTypes = getInputs();
SmallVector<Type, 4> newInputTypesBuffer;
if (!argIndices.empty()) {
unsigned originalNumArgs = getNumInputs();
iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
newInputTypesBuffer.emplace_back(getInput(i));
});
newInputTypes = newInputTypesBuffer;
}
ArrayRef<Type> newResultTypes = getResults();
SmallVector<Type, 4> newResultTypesBuffer;
if (!resultIndices.empty()) {
unsigned originalNumResults = getNumResults();
iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
newResultTypesBuffer.emplace_back(getResult(i));
});
newResultTypes = newResultTypesBuffer;
}
return get(newInputTypes, newResultTypes, getContext());
}
//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//