[MLIR][Presburger] Remove inheritence in MultiAffineFunction

This patch removes inheritence of MultiAffineFunction from IntegerPolyhedron
and instead makes IntegerPolyhedron as a member.

This patch removes virtualization in MultiAffineFunction and also removes
unnecessary functions inherited from IntegerPolyhedron.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D123921
This commit is contained in:
Groverkss 2022-04-19 01:14:18 +05:30
parent ef34442232
commit 15650b320b
6 changed files with 106 additions and 85 deletions

View file

@ -49,7 +49,6 @@ public:
enum class Kind {
FlatAffineConstraints,
FlatAffineValueConstraints,
MultiAffineFunction,
IntegerRelation,
IntegerPolyhedron,
};

View file

@ -42,54 +42,35 @@ namespace presburger {
///
/// Checking equality of two such functions is supported, as well as finding the
/// value of the function at a specified point.
class MultiAffineFunction : protected IntegerPolyhedron {
class MultiAffineFunction {
public:
/// We use protected inheritance to avoid inheriting the whole public
/// interface of IntegerPolyhedron. These using declarations explicitly make
/// only the relevant functions part of the public interface.
using IntegerPolyhedron::getNumDimAndSymbolIds;
using IntegerPolyhedron::getNumDimIds;
using IntegerPolyhedron::getNumIds;
using IntegerPolyhedron::getNumLocalIds;
using IntegerPolyhedron::getNumSymbolIds;
using IntegerPolyhedron::getSpace;
MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output)
: IntegerPolyhedron(domain), output(output) {}
: domainSet(domain), output(output) {}
MultiAffineFunction(const Matrix &output, const PresburgerSpace &space)
: IntegerPolyhedron(space), output(output) {}
: domainSet(space), output(output) {}
~MultiAffineFunction() override = default;
Kind getKind() const override { return Kind::MultiAffineFunction; }
bool classof(const IntegerRelation *rel) const {
return rel->getKind() == Kind::MultiAffineFunction;
}
unsigned getNumInputs() const { return getNumDimAndSymbolIds(); }
unsigned getNumInputs() const { return domainSet.getNumDimAndSymbolIds(); }
unsigned getNumOutputs() const { return output.getNumRows(); }
bool isConsistent() const {
return output.getNumColumns() == getNumIds() + 1;
return output.getNumColumns() == domainSet.getNumIds() + 1;
}
const IntegerPolyhedron &getDomain() const { return *this; }
const IntegerPolyhedron &getDomain() const { return domainSet; }
const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); }
/// Insert `num` identifiers of the specified kind at position `pos`.
/// Positions are relative to the kind of identifier. The coefficient columns
/// corresponding to the added identifiers are initialized to zero. Return the
/// absolute column position (i.e., not relative to the kind of identifier)
/// of the first added identifier.
unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
/// Swap the posA^th identifier with the posB^th identifier.
void swapId(unsigned posA, unsigned posB) override;
unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1);
/// Remove the specified range of ids.
void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit) override;
using IntegerRelation::removeIdRange;
void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit);
/// Eliminate the `posB^th` local identifier, replacing every instance of it
/// with the `posA^th` local identifier. This should be used when the two
/// local variables are known to always take the same values.
void eliminateRedundantLocalId(unsigned posA, unsigned posB) override;
/// Given a MAF `other`, merges local identifiers such that both funcitons
/// have union of local ids, without changing the set of points in domain or
/// the output.
void mergeLocalIds(MultiAffineFunction &other);
/// Return whether the outputs of `this` and `other` agree wherever both
/// functions are defined, i.e., the outputs should be equal for all points in
@ -114,6 +95,10 @@ public:
void dump() const;
private:
/// The IntegerPolyhedron representing the domain over which the function is
/// defined.
IntegerPolyhedron domainSet;
/// The function's output is a tuple of integers, with the ith element of the
/// tuple defined by the affine expression given by the ith row of this output
/// matrix.

View file

@ -130,6 +130,21 @@ void removeDuplicateDivs(
SmallVectorImpl<unsigned> &denoms, unsigned localOffset,
llvm::function_ref<bool(unsigned i, unsigned j)> merge);
/// Given two relations, A and B, add additional local ids to the sets such
/// that both have the union of the local ids in each set, without changing
/// the set of points that lie in A and B.
///
/// While taking union, if a local id in any set has a division representation
/// which is a duplicate of division representation, of another local id in any
/// set, it is not added to the final union of local ids and is instead merged.
///
/// On every possible merge, `merge(i, j)` is called. `i`, `j` are position
/// of local identifiers in both sets which are being merged. If `merge(i, j)`
/// returns true, the divisions are merged, otherwise the divisions are not
/// merged.
void mergeLocalIds(IntegerRelation &relA, IntegerRelation &relB,
llvm::function_ref<bool(unsigned i, unsigned j)> merge);
/// Compute the gcd of the range.
int64_t gcdRange(ArrayRef<int64_t> range);

View file

@ -1092,36 +1092,11 @@ void IntegerRelation::eliminateRedundantLocalId(unsigned posA, unsigned posB) {
/// obtained, and thus these local ids are not considered for detecting
/// duplicates.
unsigned IntegerRelation::mergeLocalIds(IntegerRelation &other) {
assert(space.isCompatible(other.getSpace()) &&
"Spaces should be compatible.");
IntegerRelation &relA = *this;
IntegerRelation &relB = other;
unsigned oldALocals = relA.getNumLocalIds();
// Merge local ids of relA and relB without using division information,
// i.e. append local ids of `relB` to `relA` and insert local ids of `relA`
// to `relB` at start of its local ids.
unsigned initLocals = relA.getNumLocalIds();
insertId(IdKind::Local, relA.getNumLocalIds(), relB.getNumLocalIds());
relB.insertId(IdKind::Local, 0, initLocals);
// Get division representations from each rel.
std::vector<SmallVector<int64_t, 8>> divsA, divsB;
SmallVector<unsigned, 4> denomsA, denomsB;
relA.getLocalReprs(divsA, denomsA);
relB.getLocalReprs(divsB, denomsB);
// Copy division information for relB into `divsA` and `denomsA`, so that
// these have the combined division information of both rels. Since newly
// added local variables in relA and relB have no constraints, they will not
// have any division representation.
std::copy(divsB.begin() + initLocals, divsB.end(),
divsA.begin() + initLocals);
std::copy(denomsB.begin() + initLocals, denomsB.end(),
denomsA.begin() + initLocals);
// Merge function that merges the local variables in both sets by treating
// them as the same identifier.
auto merge = [&relA, &relB, oldALocals](unsigned i, unsigned j) -> bool {
@ -1140,9 +1115,7 @@ unsigned IntegerRelation::mergeLocalIds(IntegerRelation &other) {
return true;
};
// Merge all divisions by removing duplicate divisions.
unsigned localOffset = getIdKindOffset(IdKind::Local);
presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge);
presburger::mergeLocalIds(*this, other, merge);
// Since we do not remove duplicate divisions in relA, this is guranteed to be
// non-negative.

View file

@ -35,7 +35,7 @@ PresburgerSet PWMAFunction::getDomain() const {
Optional<SmallVector<int64_t, 8>>
MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
assert(point.size() == getNumDimAndSymbolIds() &&
assert(point.size() == domainSet.getNumDimAndSymbolIds() &&
"Point has incorrect dimensionality!");
Optional<SmallVector<int64_t, 8>> maybeLocalValues =
@ -74,7 +74,7 @@ PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
void MultiAffineFunction::print(raw_ostream &os) const {
os << "Domain:";
IntegerPolyhedron::print(os);
domainSet.print(os);
os << "Output:\n";
output.print(os);
os << "\n";
@ -83,36 +83,24 @@ void MultiAffineFunction::print(raw_ostream &os) const {
void MultiAffineFunction::dump() const { print(llvm::errs()); }
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
return space.isCompatible(other.getSpace()) &&
return getDomainSpace().isCompatible(other.getDomainSpace()) &&
getDomain().isEqual(other.getDomain()) &&
isEqualWhereDomainsOverlap(other);
}
unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos,
unsigned num) {
assert((kind != IdKind::Domain || num == 0) &&
"Domain has to be zero in a set");
unsigned absolutePos = getIdKindOffset(kind) + pos;
assert(kind != IdKind::Domain && "Domain has to be zero in a set");
unsigned absolutePos = domainSet.getIdKindOffset(kind) + pos;
output.insertColumns(absolutePos, num);
return IntegerPolyhedron::insertId(kind, pos, num);
}
void MultiAffineFunction::swapId(unsigned posA, unsigned posB) {
output.swapColumns(posA, posB);
IntegerPolyhedron::swapId(posA, posB);
return domainSet.insertId(kind, pos, num);
}
void MultiAffineFunction::removeIdRange(IdKind kind, unsigned idStart,
unsigned idLimit) {
output.removeColumns(idStart + getIdKindOffset(kind), idLimit - idStart);
IntegerPolyhedron::removeIdRange(kind, idStart, idLimit);
}
void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
unsigned posB) {
unsigned localOffset = getIdKindOffset(IdKind::Local);
output.addToColumn(localOffset + posB, localOffset + posA, /*scale=*/1);
IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
output.removeColumns(idStart + domainSet.getIdKindOffset(kind),
idLimit - idStart);
domainSet.removeIdRange(kind, idStart, idLimit);
}
void MultiAffineFunction::truncateOutput(unsigned count) {
@ -127,9 +115,37 @@ void PWMAFunction::truncateOutput(unsigned count) {
numOutputs = count;
}
void MultiAffineFunction::mergeLocalIds(MultiAffineFunction &other) {
// Merge output local ids of both functions without using division
// information i.e. append local ids of `other` to `this` and insert
// local ids of `this` to `other` at the start of it's local ids.
output.insertColumns(domainSet.getIdKindEnd(IdKind::Local),
other.domainSet.getNumLocalIds());
other.output.insertColumns(other.domainSet.getIdKindOffset(IdKind::Local),
domainSet.getNumLocalIds());
auto merge = [this, &other](unsigned i, unsigned j) -> bool {
// Merge local at position j into local at position i in function domain.
domainSet.eliminateRedundantLocalId(i, j);
other.domainSet.eliminateRedundantLocalId(i, j);
unsigned localOffset = domainSet.getIdKindOffset(IdKind::Local);
// Merge local at position j into local at position i in output domain.
output.addToColumn(localOffset + j, localOffset + i, 1);
output.removeColumn(localOffset + j);
other.output.addToColumn(localOffset + j, localOffset + i, 1);
other.output.removeColumn(localOffset + j);
return true;
};
presburger::mergeLocalIds(domainSet, other.domainSet, merge);
}
bool MultiAffineFunction::isEqualWhereDomainsOverlap(
MultiAffineFunction other) const {
if (!space.isCompatible(other.getSpace()))
if (!getDomainSpace().isCompatible(other.getDomainSpace()))
return false;
// `commonFunc` has the same output as `this`.
@ -139,7 +155,7 @@ bool MultiAffineFunction::isEqualWhereDomainsOverlap(
commonFunc.mergeLocalIds(other);
// After this, the domain of `commonFunc` will be the intersection of the
// domains of `this` and `other`.
commonFunc.IntegerPolyhedron::append(other);
commonFunc.domainSet.append(other.domainSet);
// `commonDomainMatching` contains the subset of the common domain
// where the outputs of `this` and `other` match.
@ -180,7 +196,7 @@ bool PWMAFunction::isEqual(const PWMAFunction &other) const {
}
void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
assert(space.isCompatible(piece.getSpace()) &&
assert(space.isCompatible(piece.getDomainSpace()) &&
"Piece to be added is not compatible with this PWMAFunction!");
assert(piece.isConsistent() && "Piece is internally inconsistent!");
assert(this->getDomain()

View file

@ -304,6 +304,39 @@ void presburger::removeDuplicateDivs(
}
}
void presburger::mergeLocalIds(
IntegerRelation &relA, IntegerRelation &relB,
llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
assert(relA.getSpace().isCompatible(relB.getSpace()) &&
"Spaces should be compatible.");
// Merge local ids of relA and relB without using division information,
// i.e. append local ids of `relB` to `relA` and insert local ids of `relA`
// to `relB` at start of its local ids.
unsigned initLocals = relA.getNumLocalIds();
relA.insertId(IdKind::Local, relA.getNumLocalIds(), relB.getNumLocalIds());
relB.insertId(IdKind::Local, 0, initLocals);
// Get division representations from each rel.
std::vector<SmallVector<int64_t, 8>> divsA, divsB;
SmallVector<unsigned, 4> denomsA, denomsB;
relA.getLocalReprs(divsA, denomsA);
relB.getLocalReprs(divsB, denomsB);
// Copy division information for relB into `divsA` and `denomsA`, so that
// these have the combined division information of both rels. Since newly
// added local variables in relA and relB have no constraints, they will not
// have any division representation.
std::copy(divsB.begin() + initLocals, divsB.end(),
divsA.begin() + initLocals);
std::copy(denomsB.begin() + initLocals, denomsB.end(),
denomsA.begin() + initLocals);
// Merge all divisions by removing duplicate divisions.
unsigned localOffset = relA.getIdKindOffset(IdKind::Local);
presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge);
}
int64_t presburger::gcdRange(ArrayRef<int64_t> range) {
int64_t gcd = 0;
for (int64_t elem : range) {