From 0aea49a7308322e6987c7b45e4e0d7ab15609e78 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Mon, 13 Dec 2021 20:00:28 +0000 Subject: [PATCH] [mlir][Vector] Patterns flattening vector transfers to 1D This is the first part of https://reviews.llvm.org/D114993 which has been split into small independent commits. This is needed at the moment to get good codegen from 2d vector.transfer ops that aim to compile to SIMD load/store instructions but that can only do so if the whole 2d transfer shape is handled in one piece, in particular taking advantage of the memref being contiguous rowmajor. For instance, if the target architecture has 128bit SIMD then we would expect that contiguous row-major transfers of <4x4xi8> map to one SIMD load/store instruction each. The current generic lowering of multi-dimensional vector.transfer ops can't achieve that because it peels dimensions one by one, so a transfer of <4x4xi8> becomes 4 transfers of <4xi8>. The new patterns here are only enabled for now by -test-vector-transfer-flatten-patterns. Reviewed By: nicolasvasilache --- mlir/include/mlir/Dialect/Vector/VectorOps.h | 7 + .../Vector/VectorTransferOpTransforms.cpp | 132 ++++++++++++++++++ ...ctor-transfer-drop-unit-dims-patterns.mlir | 33 +++++ .../Dialect/Vector/TestVectorTransforms.cpp | 17 +++ 4 files changed, 189 insertions(+) create mode 100644 mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h index 6dd4b9aaf555..c6b63a949f64 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -67,6 +67,13 @@ void populateShapeCastFoldingPatterns(RewritePatternSet &patterns); /// pairs or forward write-read pairs. void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns); +/// Collect a set of leading one dimension removal patterns. +/// +/// These patterns insert rank-reducing memref.subview ops to remove one +/// dimensions. With them, there are more chances that we can avoid +/// potentially exensive vector.shape_cast operations. +void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns); + /// Collect a set of patterns that bubble up/down bitcast ops. /// /// These patterns move vector.bitcast ops to be before insert ops or after diff --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp index ae6f3949c399..c9438c4a28f4 100644 --- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp @@ -10,12 +10,14 @@ // transfer_write ops. // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" @@ -209,6 +211,128 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { opToErase.push_back(read.getOperation()); } +/// Drops unit dimensions from the input MemRefType. +static MemRefType dropUnitDims(MemRefType inputType) { + ArrayRef none{}; + Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( + 0, inputType, none, none, none); + return canonicalizeStridedLayout(rankReducedType.cast()); +} + +/// Creates a rank-reducing memref.subview op that drops unit dims from its +/// input. Or just returns the input if it was already without unit dims. +static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, + mlir::Location loc, + Value input) { + MemRefType inputType = input.getType().cast(); + assert(inputType.hasStaticShape()); + MemRefType resultType = dropUnitDims(inputType); + if (resultType == inputType) + return input; + SmallVector subviewOffsets(inputType.getRank(), 0); + SmallVector subviewStrides(inputType.getRank(), 1); + return rewriter.create( + loc, resultType, input, subviewOffsets, inputType.getShape(), + subviewStrides); +} + +/// Returns the number of dims that aren't unit dims. +static int getReducedRank(ArrayRef shape) { + return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); +} + +/// Returns true if all values are `arith.constant 0 : index` +static bool isZero(Value v) { + auto cst = v.getDefiningOp(); + return cst && cst.value() == 0; +} + +/// Rewrites vector.transfer_read ops where the source has unit dims, by +/// inserting a memref.subview dropping those unit dims. +class TransferReadDropUnitDimsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, + PatternRewriter &rewriter) const override { + auto loc = transferReadOp.getLoc(); + Value vector = transferReadOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferReadOp.source(); + MemRefType sourceType = source.getType().dyn_cast(); + // TODO: support tensor types. + if (!sourceType || !sourceType.hasStaticShape()) + return failure(); + if (sourceType.getNumElements() != vectorType.getNumElements()) + return failure(); + // TODO: generalize this pattern, relax the requirements here. + if (transferReadOp.hasOutOfBoundsDim()) + return failure(); + if (!transferReadOp.permutation_map().isMinorIdentity()) + return failure(); + int reducedRank = getReducedRank(sourceType.getShape()); + if (reducedRank == sourceType.getRank()) + return failure(); // The source shape can't be further reduced. + if (reducedRank != vectorType.getRank()) + return failure(); // This pattern requires the vector shape to match the + // reduced source shape. + if (llvm::any_of(transferReadOp.indices(), + [](Value v) { return !isZero(v); })) + return failure(); + Value reducedShapeSource = + rankReducingSubviewDroppingUnitDims(rewriter, loc, source); + Value c0 = rewriter.create(loc, 0); + SmallVector zeros(reducedRank, c0); + auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); + rewriter.replaceOpWithNewOp( + transferReadOp, vectorType, reducedShapeSource, zeros, identityMap); + return success(); + } +}; + +/// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has +/// unit dims, by inserting a memref.subview dropping those unit dims. +class TransferWriteDropUnitDimsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, + PatternRewriter &rewriter) const override { + auto loc = transferWriteOp.getLoc(); + Value vector = transferWriteOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferWriteOp.source(); + MemRefType sourceType = source.getType().dyn_cast(); + // TODO: support tensor type. + if (!sourceType || !sourceType.hasStaticShape()) + return failure(); + if (sourceType.getNumElements() != vectorType.getNumElements()) + return failure(); + // TODO: generalize this pattern, relax the requirements here. + if (transferWriteOp.hasOutOfBoundsDim()) + return failure(); + if (!transferWriteOp.permutation_map().isMinorIdentity()) + return failure(); + int reducedRank = getReducedRank(sourceType.getShape()); + if (reducedRank == sourceType.getRank()) + return failure(); // The source shape can't be further reduced. + if (reducedRank != vectorType.getRank()) + return failure(); // This pattern requires the vector shape to match the + // reduced source shape. + if (llvm::any_of(transferWriteOp.indices(), + [](Value v) { return !isZero(v); })) + return failure(); + Value reducedShapeSource = + rankReducingSubviewDroppingUnitDims(rewriter, loc, source); + Value c0 = rewriter.create(loc, 0); + SmallVector zeros(reducedRank, c0); + auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); + rewriter.replaceOpWithNewOp( + transferWriteOp, vector, reducedShapeSource, zeros, identityMap); + return success(); + } +}; + } // namespace void mlir::vector::transferOpflowOpt(FuncOp func) { @@ -226,3 +350,11 @@ void mlir::vector::transferOpflowOpt(FuncOp func) { }); opt.removeDeadOp(); } + +void mlir::vector::populateVectorTransferDropUnitDimsPatterns( + RewritePatternSet &patterns) { + patterns + .add( + patterns.getContext()); + populateShapeCastFoldingPatterns(patterns); +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir new file mode 100644 index 000000000000..a3d34a646c2f --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -test-vector-transfer-drop-unit-dims-patterns -split-input-file | FileCheck %s + +// ----- + +func @transfer_read_rank_reducing( + %arg : memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>) -> vector<3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>, vector<3x2xi8> + return %v : vector<3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] +// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> +// CHECK: vector.transfer_read %[[SUBVIEW]] + +// ----- + +func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>, %vec : vector<3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : + vector<3x2xi8>, memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]> + return +} + +// CHECK-LABEL: func @transfer_write_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] +// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> +// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]] \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f7e13bc330d4..cf33b0d7117d 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -583,6 +583,21 @@ struct TestVectorReduceToContractPatternsPatterns } }; +struct TestVectorTransferDropUnitDimsPatterns + : public PassWrapper { + StringRef getArgument() const final { + return "test-vector-transfer-drop-unit-dims-patterns"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + populateVectorTransferDropUnitDimsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + } // namespace namespace mlir { @@ -613,6 +628,8 @@ void registerTestVectorLowerings() { PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir