[mlir][Vector] Modernize default lowering of vector transpose

This patch removes an old recursive implementation to lower vector.transpose to extract/insert operations
and replaces it with a iterative approach that leverages newer linearization/delinearization utilities.
The patch should be NFC except by the order in which the extract/insert ops are generated.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D121321
This commit is contained in:
Diego Caballero 2022-03-10 20:44:24 +00:00
parent 3c9e849943
commit f71f9958b9
8 changed files with 62 additions and 67 deletions

View file

@ -32,19 +32,6 @@ class LinalgDependenceGraph;
/// `[0, permutation.size())`.
bool isPermutation(ArrayRef<int64_t> permutation);
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
template <typename T, unsigned N>
void applyPermutationToVector(SmallVector<T, N> &inVec,
ArrayRef<int64_t> permutation) {
SmallVector<T, N> auxVec(inVec.size());
for (const auto &en : enumerate(permutation))
auxVec[en.index()] = inVec[en.value()];
inVec = auxVec;
}
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);

View file

@ -30,6 +30,19 @@ int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> strides,
int64_t linearIndex);
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
template <typename T, unsigned N>
void applyPermutationToVector(SmallVector<T, N> &inVec,
ArrayRef<int64_t> permutation) {
SmallVector<T, N> auxVec(inVec.size());
for (const auto &en : enumerate(permutation))
auxVec[en.index()] = inVec[en.value()];
inVec = auxVec;
}
/// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront = 0,

View file

@ -18,6 +18,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LLVM.h"

View file

@ -17,6 +17,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AsmState.h"

View file

@ -14,6 +14,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"

View file

@ -20,6 +20,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Transforms/FoldUtils.h"

View file

@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
@ -300,16 +301,18 @@ public:
}
};
/// Return the number of leftmost dimensions from the first rightmost transposed
/// dimension found in 'transpose'.
size_t getNumDimsFromFirstTransposedDim(ArrayRef<int64_t> transpose) {
/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
/// transposed.
void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
SmallVectorImpl<int64_t> &result) {
size_t numTransposedDims = transpose.size();
for (size_t transpDim : llvm::reverse(transpose)) {
if (transpDim != numTransposedDims - 1)
break;
numTransposedDims--;
}
return numTransposedDims;
result.append(transpose.begin(), transpose.begin() + numTransposedDims);
}
/// Progressive lowering of TransposeOp.
@ -334,6 +337,8 @@ public:
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value input = op.vector();
VectorType inputType = op.getVectorType();
VectorType resType = op.getResultType();
// Set up convenience transposition table.
@ -354,7 +359,7 @@ public:
Type flattenedType =
VectorType::get(resType.getNumElements(), resType.getElementType());
auto matrix =
rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
Value trans = rewriter.create<vector::FlatTransposeOp>(
@ -365,54 +370,40 @@ public:
// Generate unrolled extract/insert ops. We do not unroll the rightmost
// (i.e., highest-order) dimensions that are not transposed and leave them
// in vector form to improve performance.
size_t numLeftmostTransposedDims = getNumDimsFromFirstTransposedDim(transp);
// The type of the extract operation will be scalar if all the dimensions
// are unrolled. Otherwise, it will be a vector with the shape of the
// dimensions that are not transposed.
Type extractType =
numLeftmostTransposedDims == transp.size()
? resType.getElementType()
: VectorType::Builder(resType).setShape(
resType.getShape().drop_front(numLeftmostTransposedDims));
// in vector form to improve performance. Therefore, we prune those
// dimensions from the shape/transpose data structures used to generate the
// extract/insert ops.
SmallVector<int64_t, 4> prunedTransp;
pruneNonTransposedDims(transp, prunedTransp);
size_t numPrunedDims = transp.size() - prunedTransp.size();
auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
auto prunedInStrides = computeStrides(prunedInShape, ones);
// Generates the extract/insert operations for every scalar/vector element
// of the leftmost transposed dimensions. We traverse every transpose
// element using a linearized index that we delinearize to generate the
// appropriate indices for the extract/insert operations.
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
SmallVector<int64_t, 4> lhs(numLeftmostTransposedDims, 0);
SmallVector<int64_t, 4> rhs(numLeftmostTransposedDims, 0);
rewriter.replaceOp(op, expandIndices(loc, resType, extractType, 0,
numLeftmostTransposedDims, transp, lhs,
rhs, op.vector(), result, rewriter));
int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
++linearIdx) {
auto extractIdxs = delinearize(prunedInStrides, linearIdx);
SmallVector<int64_t, 4> insertIdxs(extractIdxs);
applyPermutationToVector(insertIdxs, prunedTransp);
Value extractOp =
rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
result =
rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
}
rewriter.replaceOp(op, result);
return success();
}
private:
// Builds the indices arrays for the lhs and rhs. Generates the extract/insert
// operations when all the ranks go over the last dimension being transposed.
Value expandIndices(Location loc, VectorType resType, Type extractType,
int64_t pos, int64_t numLeftmostTransposedDims,
SmallVector<int64_t, 4> &transp,
SmallVector<int64_t, 4> &lhs,
SmallVector<int64_t, 4> &rhs, Value input, Value result,
PatternRewriter &rewriter) const {
if (pos >= numLeftmostTransposedDims) {
auto ridx = rewriter.getI64ArrayAttr(rhs);
auto lidx = rewriter.getI64ArrayAttr(lhs);
Value e =
rewriter.create<vector::ExtractOp>(loc, extractType, input, ridx);
return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
}
for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
lhs[pos] = d;
rhs[transp[pos]] = d;
result = expandIndices(loc, resType, extractType, pos + 1,
numLeftmostTransposedDims, transp, lhs, rhs, input,
result, rewriter);
}
return result;
}
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
};

View file

@ -8,14 +8,14 @@
// ELTWISE: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
// ELTWISE: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
// ELTWISE: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : f32 into vector<3x2xf32>
// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2, 0] : f32 into vector<3x2xf32>
// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [0, 1] : f32 into vector<3x2xf32>
// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<3x2xf32>
// ELTWISE: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
// ELTWISE: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
// ELTWISE: return %[[T11]] : vector<3x2xf32>