diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h index 6dd4b9aaf555..c6b63a949f64 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -67,6 +67,13 @@ void populateShapeCastFoldingPatterns(RewritePatternSet &patterns); /// pairs or forward write-read pairs. void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns); +/// Collect a set of leading one dimension removal patterns. +/// +/// These patterns insert rank-reducing memref.subview ops to remove one +/// dimensions. With them, there are more chances that we can avoid +/// potentially exensive vector.shape_cast operations. +void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns); + /// Collect a set of patterns that bubble up/down bitcast ops. /// /// These patterns move vector.bitcast ops to be before insert ops or after diff --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp index ae6f3949c399..c9438c4a28f4 100644 --- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp @@ -10,12 +10,14 @@ // transfer_write ops. // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" @@ -209,6 +211,128 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { opToErase.push_back(read.getOperation()); } +/// Drops unit dimensions from the input MemRefType. +static MemRefType dropUnitDims(MemRefType inputType) { + ArrayRef none{}; + Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( + 0, inputType, none, none, none); + return canonicalizeStridedLayout(rankReducedType.cast()); +} + +/// Creates a rank-reducing memref.subview op that drops unit dims from its +/// input. Or just returns the input if it was already without unit dims. +static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, + mlir::Location loc, + Value input) { + MemRefType inputType = input.getType().cast(); + assert(inputType.hasStaticShape()); + MemRefType resultType = dropUnitDims(inputType); + if (resultType == inputType) + return input; + SmallVector subviewOffsets(inputType.getRank(), 0); + SmallVector subviewStrides(inputType.getRank(), 1); + return rewriter.create( + loc, resultType, input, subviewOffsets, inputType.getShape(), + subviewStrides); +} + +/// Returns the number of dims that aren't unit dims. +static int getReducedRank(ArrayRef shape) { + return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); +} + +/// Returns true if all values are `arith.constant 0 : index` +static bool isZero(Value v) { + auto cst = v.getDefiningOp(); + return cst && cst.value() == 0; +} + +/// Rewrites vector.transfer_read ops where the source has unit dims, by +/// inserting a memref.subview dropping those unit dims. +class TransferReadDropUnitDimsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, + PatternRewriter &rewriter) const override { + auto loc = transferReadOp.getLoc(); + Value vector = transferReadOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferReadOp.source(); + MemRefType sourceType = source.getType().dyn_cast(); + // TODO: support tensor types. + if (!sourceType || !sourceType.hasStaticShape()) + return failure(); + if (sourceType.getNumElements() != vectorType.getNumElements()) + return failure(); + // TODO: generalize this pattern, relax the requirements here. + if (transferReadOp.hasOutOfBoundsDim()) + return failure(); + if (!transferReadOp.permutation_map().isMinorIdentity()) + return failure(); + int reducedRank = getReducedRank(sourceType.getShape()); + if (reducedRank == sourceType.getRank()) + return failure(); // The source shape can't be further reduced. + if (reducedRank != vectorType.getRank()) + return failure(); // This pattern requires the vector shape to match the + // reduced source shape. + if (llvm::any_of(transferReadOp.indices(), + [](Value v) { return !isZero(v); })) + return failure(); + Value reducedShapeSource = + rankReducingSubviewDroppingUnitDims(rewriter, loc, source); + Value c0 = rewriter.create(loc, 0); + SmallVector zeros(reducedRank, c0); + auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); + rewriter.replaceOpWithNewOp( + transferReadOp, vectorType, reducedShapeSource, zeros, identityMap); + return success(); + } +}; + +/// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has +/// unit dims, by inserting a memref.subview dropping those unit dims. +class TransferWriteDropUnitDimsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, + PatternRewriter &rewriter) const override { + auto loc = transferWriteOp.getLoc(); + Value vector = transferWriteOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferWriteOp.source(); + MemRefType sourceType = source.getType().dyn_cast(); + // TODO: support tensor type. + if (!sourceType || !sourceType.hasStaticShape()) + return failure(); + if (sourceType.getNumElements() != vectorType.getNumElements()) + return failure(); + // TODO: generalize this pattern, relax the requirements here. + if (transferWriteOp.hasOutOfBoundsDim()) + return failure(); + if (!transferWriteOp.permutation_map().isMinorIdentity()) + return failure(); + int reducedRank = getReducedRank(sourceType.getShape()); + if (reducedRank == sourceType.getRank()) + return failure(); // The source shape can't be further reduced. + if (reducedRank != vectorType.getRank()) + return failure(); // This pattern requires the vector shape to match the + // reduced source shape. + if (llvm::any_of(transferWriteOp.indices(), + [](Value v) { return !isZero(v); })) + return failure(); + Value reducedShapeSource = + rankReducingSubviewDroppingUnitDims(rewriter, loc, source); + Value c0 = rewriter.create(loc, 0); + SmallVector zeros(reducedRank, c0); + auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); + rewriter.replaceOpWithNewOp( + transferWriteOp, vector, reducedShapeSource, zeros, identityMap); + return success(); + } +}; + } // namespace void mlir::vector::transferOpflowOpt(FuncOp func) { @@ -226,3 +350,11 @@ void mlir::vector::transferOpflowOpt(FuncOp func) { }); opt.removeDeadOp(); } + +void mlir::vector::populateVectorTransferDropUnitDimsPatterns( + RewritePatternSet &patterns) { + patterns + .add( + patterns.getContext()); + populateShapeCastFoldingPatterns(patterns); +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir new file mode 100644 index 000000000000..a3d34a646c2f --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -test-vector-transfer-drop-unit-dims-patterns -split-input-file | FileCheck %s + +// ----- + +func @transfer_read_rank_reducing( + %arg : memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>) -> vector<3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>, vector<3x2xi8> + return %v : vector<3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] +// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> +// CHECK: vector.transfer_read %[[SUBVIEW]] + +// ----- + +func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>, %vec : vector<3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : + vector<3x2xi8>, memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]> + return +} + +// CHECK-LABEL: func @transfer_write_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] +// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> +// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]] \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f7e13bc330d4..cf33b0d7117d 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -583,6 +583,21 @@ struct TestVectorReduceToContractPatternsPatterns } }; +struct TestVectorTransferDropUnitDimsPatterns + : public PassWrapper { + StringRef getArgument() const final { + return "test-vector-transfer-drop-unit-dims-patterns"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + populateVectorTransferDropUnitDimsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + } // namespace namespace mlir { @@ -613,6 +628,8 @@ void registerTestVectorLowerings() { PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir