[mlir][Vector] Add linalg.copy-based pattern for splitting vector.transfer_read into full and partial copies.

This revision adds a transformation and a pattern that rewrites a "maybe masked" `vector.transfer_read %view[...], %pad `into a pattern resembling:

```
   %1:3 = scf.if (%inBounds) {
      scf.yield %view : memref<A...>, index, index
    } else {
      %2 = linalg.fill(%extra_alloc, %pad)
      %3 = subview %view [...][...][...]
      linalg.copy(%3, %alloc)
      memref_cast %extra_alloc: memref<B...> to memref<A...>
      scf.yield %4 : memref<A...>, index, index
   }
   %res= vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
```
where `extra_alloc` is a top of the function alloca'ed buffer of one vector.

This rewrite makes it possible to realize the "always full tile" abstraction where vector.transfer_read operations are guaranteed to read from a padded full buffer.
The extra work only occurs on the boundary tiles.
This commit is contained in:
Nicolas Vasilache 2020-08-03 05:34:07 -04:00
parent 04e45ae1c6
commit 1a4263d394
6 changed files with 348 additions and 74 deletions

View file

@ -56,22 +56,48 @@ enum class VectorContractLowering {
};
/// Enum to control the lowering of `vector.transpose` operations.
enum class VectorTransposeLowering {
// Lower transpose into element-wise extract and inserts.
/// Lower transpose into element-wise extract and inserts.
EltWise = 0,
/// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
/// intrinsics.
Flat = 1,
};
/// Enum to control the splitting of `vector.transfer` operations into masked
/// and unmasked variants.
enum class VectorTransferSplit {
/// Do not split vector transfer operations.
None = 0,
/// Split using masked + unmasked vector.transfer operations.
VectorTransfer = 1,
/// Split using a unmasked vector.transfer + linalg.fill + linalg.copy
/// operations.
LinalgCopy = 2,
/// Do not split vector transfer operation but instead mark it as "unmasked".
ForceUnmasked = 3
};
/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
/// Option to control the lowering of vector.contract.
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
VectorTransposeLowering vectorTransposeLowering =
VectorTransposeLowering::EltWise;
VectorTransformsOptions &
setVectorTransformsOptions(VectorContractLowering opt) {
vectorContractLowering = opt;
return *this;
}
/// Option to control the lowering of vector.transpose.
VectorTransposeLowering vectorTransposeLowering =
VectorTransposeLowering::EltWise;
VectorTransformsOptions &
setVectorTransposeLowering(VectorTransposeLowering opt) {
vectorTransposeLowering = opt;
return *this;
}
/// Option to control the splitting of vector transfers.
VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
vectorTransferSplit = opt;
return *this;
}
};
/// Collect a set of transformation patterns that are related to contracting

View file

@ -109,13 +109,13 @@ private:
FilterConstraintType filter;
};
/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
/// is `success, the `ifOp` points to the newly created conditional upon
/// function return. To accomodate for the fact that the original
/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
/// in the temporary buffer, the scf.if op returns a view and values of type
/// index. At this time, only vector.transfer_read is implemented.
/// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
/// newly created conditional upon function return.
/// To accomodate for the fact that the original vector.transfer indexing may be
/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
/// scf.if op returns a view and values of type index.
/// At this time, only vector.transfer_read case is implemented.
///
/// Example (a 2-D vector.transfer_read):
/// ```
@ -124,17 +124,17 @@ private:
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// scf.yield %0 : memref<A...>, index, index
/// } else {
/// %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
/// %3 = vector.type_cast %extra_alloc : memref<...> to
/// memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
/// memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
/// memref<A...>, index, index
/// // fastpath, direct cast
/// memref_cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view : compatibleMemRefType, index, index
/// } else {
/// // slowpath, masked vector.transfer or linalg.copy.
/// memref_cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
/// ```
/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
@ -143,9 +143,10 @@ private:
/// rank-reducing subviews.
LogicalResult
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
LogicalResult splitFullAndPartialTransfer(OpBuilder &b,
VectorTransferOpInterface xferOp,
scf::IfOp *ifOp = nullptr);
LogicalResult splitFullAndPartialTransfer(
OpBuilder &b, VectorTransferOpInterface xferOp,
VectorTransformsOptions options = VectorTransformsOptions(),
scf::IfOp *ifOp = nullptr);
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
/// may take an extra filter to perform selection at a finer granularity.
@ -155,16 +156,19 @@ struct VectorTransferFullPartialRewriter : public RewritePattern {
explicit VectorTransferFullPartialRewriter(
MLIRContext *context,
VectorTransformsOptions options = VectorTransformsOptions(),
FilterConstraintType filter =
[](VectorTransferOpInterface op) { return success(); },
PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter) {}
: RewritePattern(benefit, MatchAnyOpTypeTag()), options(options),
filter(filter) {}
/// Performs the rewrite.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
VectorTransformsOptions options;
FilterConstraintType filter;
};

View file

@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVector
MLIRIR
MLIRStandardOps
MLIRAffineOps
MLIRLinalgOps
MLIRSCF
MLIRLoopAnalysis
MLIRSideEffectInterfaces

View file

@ -14,6 +14,7 @@
#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -2056,7 +2057,16 @@ LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
return success();
}
MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
/// be cast. If the MemRefTypes don't have the same rank or are not strided,
/// return null; otherwise:
/// 1. if `aT` and `bT` are cast-compatible, return `aT`.
/// 2. else return a new MemRefType obtained by iterating over the shape and
/// strides and:
/// a. keeping the ones that are static and equal across `aT` and `bT`.
/// b. using a dynamic shape and/or stride for the dimeniosns that don't
/// agree.
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
if (MemRefCastOp::areCastCompatible(aT, bT))
return aT;
if (aT.getRank() != bT.getRank())
@ -2086,13 +2096,154 @@ MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
}
/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
/// is `success, the `ifOp` points to the newly created conditional upon
/// function return. To accomodate for the fact that the original
/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
/// in the temporary buffer, the scf.if op returns a view and values of type
/// index. At this time, only vector.transfer_read is implemented.
/// Operates under a scoped context to build the intersection between the
/// view `xferOp.memref()` @ `xferOp.indices()` and the view `alloc`.
// TODO: view intersection/union/differences should be a proper std op.
static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
Value alloc) {
using namespace edsc::intrinsics;
int64_t memrefRank = xferOp.getMemRefType().getRank();
// TODO: relax this precondition, will require rank-reducing subviews.
assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
"Expected memref rank to match the alloc rank");
Value one = std_constant_index(1);
ValueRange leadingIndices =
xferOp.indices().take_front(xferOp.getLeadingMemRefRank());
SmallVector<Value, 4> sizes;
sizes.append(leadingIndices.begin(), leadingIndices.end());
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
Value dimMemRef = std_dim(xferOp.memref(), indicesIdx);
Value dimAlloc = std_dim(alloc, resultIdx);
Value index = xferOp.indices()[indicesIdx];
AffineExpr i, j, k;
bindDims(xferOp.getContext(), i, j, k);
SmallVector<AffineMap, 4> maps =
AffineMap::inferFromExprList(MapList{{i - j, k}});
// affine_min(%dimMemRef - %index, %dimAlloc)
Value affineMin = affine_min(index.getType(), maps[0],
ValueRange{dimMemRef, index, dimAlloc});
sizes.push_back(affineMin);
});
return std_sub_view(xferOp.memref(), xferOp.indices(), sizes,
SmallVector<Value, 4>(memrefRank, one));
}
/// Given an `xferOp` for which:
/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
/// 2. a memref of single vector `alloc` has been allocated.
/// Produce IR resembling:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// memref_cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view, ... : compatibleMemRefType, index, index
/// } else {
/// %2 = linalg.fill(%alloc, %pad)
/// %3 = subview %view [...][...][...]
/// linalg.copy(%3, %alloc)
/// memref_cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4, ... : compatibleMemRefType, index, index
/// }
/// ```
/// Return the produced scf::IfOp.
static scf::IfOp createScopedFullPartialLinalgCopy(
vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
MemRefType compatibleMemRefType, Value alloc) {
using namespace edsc;
using namespace edsc::intrinsics;
scf::IfOp fullPartialIfOp;
Value zero = std_constant_index(0);
Value memref = xferOp.memref();
conditionBuilder(
returnTypes, inBoundsCond,
[&]() -> scf::ValueVector {
Value res = memref;
if (compatibleMemRefType != xferOp.getMemRefType())
res = std_memref_cast(memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
xferOp.indices().end());
return viewAndIndices;
},
[&]() -> scf::ValueVector {
linalg_fill(alloc, xferOp.padding());
// Take partial subview of memref which guarantees no dimension
// overflows.
Value memRefSubView = createScopedSubViewIntersection(
cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
linalg_copy(memRefSubView, alloc);
Value casted = std_memref_cast(alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
return viewAndIndices;
},
&fullPartialIfOp);
return fullPartialIfOp;
}
/// Given an `xferOp` for which:
/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
/// 2. a memref of single vector `alloc` has been allocated.
/// Produce IR resembling:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// memref_cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view, ... : compatibleMemRefType, index, index
/// } else {
/// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
/// %3 = vector.type_cast %extra_alloc :
/// memref<...> to memref<vector<...>>
/// store %2, %3[] : memref<vector<...>>
/// %4 = memref_cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4, ... : compatibleMemRefType, index, index
/// }
/// ```
/// Return the produced scf::IfOp.
static scf::IfOp createScopedFullPartialVectorTransferRead(
vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond,
MemRefType compatibleMemRefType, Value alloc) {
using namespace edsc;
using namespace edsc::intrinsics;
scf::IfOp fullPartialIfOp;
Value zero = std_constant_index(0);
Value memref = xferOp.memref();
conditionBuilder(
returnTypes, inBoundsCond,
[&]() -> scf::ValueVector {
Value res = memref;
if (compatibleMemRefType != xferOp.getMemRefType())
res = std_memref_cast(memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
xferOp.indices().end());
return viewAndIndices;
},
[&]() -> scf::ValueVector {
Operation *newXfer =
ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
std_store(vector, vector_type_cast(
MemRefType::get({}, vector.getType()), alloc));
Value casted = std_memref_cast(alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
return viewAndIndices;
},
&fullPartialIfOp);
return fullPartialIfOp;
}
/// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
/// newly created conditional upon function return.
/// To accomodate for the fact that the original vector.transfer indexing may be
/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
/// scf.if op returns a view and values of type index.
/// At this time, only vector.transfer_read case is implemented.
///
/// Example (a 2-D vector.transfer_read):
/// ```
@ -2101,17 +2252,17 @@ MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// scf.yield %0 : memref<A...>, index, index
/// } else {
/// %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
/// %3 = vector.type_cast %extra_alloc : memref<...> to
/// memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
/// memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
/// memref<A...>, index, index
/// // fastpath, direct cast
/// memref_cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view : compatibleMemRefType, index, index
/// } else {
/// // slowpath, masked vector.transfer or linalg.copy.
/// memref_cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
/// ```
/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
@ -2119,10 +2270,21 @@ MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
LogicalResult mlir::vector::splitFullAndPartialTransfer(
OpBuilder &b, VectorTransferOpInterface xferOp, scf::IfOp *ifOp) {
OpBuilder &b, VectorTransferOpInterface xferOp,
VectorTransformsOptions options, scf::IfOp *ifOp) {
using namespace edsc;
using namespace edsc::intrinsics;
if (options.vectorTransferSplit == VectorTransferSplit::None)
return failure();
SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
auto unmaskedAttr = b.getBoolArrayAttr(bools);
if (options.vectorTransferSplit == VectorTransferSplit::ForceUnmasked) {
xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);
return success();
}
assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
"Expected splitFullAndPartialTransferPrecondition to hold");
auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
@ -2154,45 +2316,21 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
b.getI64IntegerAttr(32));
}
Value memref = xferOp.memref();
SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
auto unmaskedAttr = b.getBoolArrayAttr(bools);
MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());
// Read case: full fill + partial copy -> unmasked vector.xfer_read.
Value zero = std_constant_index(0);
SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
b.getIndexType());
returnTypes[0] = compatibleMemRefType;
scf::IfOp fullPartialIfOp;
conditionBuilder(
returnTypes, inBoundsCond,
[&]() -> scf::ValueVector {
Value res = memref;
if (compatibleMemRefType != xferOp.getMemRefType())
res = std_memref_cast(memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
xferOp.indices().end());
return viewAndIndices;
},
[&]() -> scf::ValueVector {
Operation *newXfer =
ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
std_store(vector, vector_type_cast(
MemRefType::get({}, vector.getType()), alloc));
Value casted = std_memref_cast(alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
return viewAndIndices;
},
&fullPartialIfOp);
scf::IfOp fullPartialIfOp =
options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
? createScopedFullPartialVectorTransferRead(
xferReadOp, returnTypes, inBoundsCond, compatibleMemRefType,
alloc)
: createScopedFullPartialLinalgCopy(xferReadOp, returnTypes,
inBoundsCond,
compatibleMemRefType, alloc);
if (ifOp)
*ifOp = fullPartialIfOp;
@ -2211,7 +2349,7 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
failed(filter(xferOp)))
return failure();
rewriter.startRootUpdate(xferOp);
if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp))) {
if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
rewriter.finalizeRootUpdate(xferOp);
return success();
}

View file

@ -1,13 +1,26 @@
// RUN: mlir-opt %s -test-vector-transfer-full-partial-split | FileCheck %s
// RUN: mlir-opt %s -test-vector-transfer-full-partial-split=use-linalg-copy | FileCheck %s --check-prefix=LINALG
// CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
// CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
// CHECK-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// LINALG-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
// LINALG-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
// LINALG-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// LINALG-DAG: #[[$map_2d_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
// LINALG-DAG: #[[$bounds_map_4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)>
// LINALG-DAG: #[[$bounds_map_8:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)>
// CHECK-LABEL: split_vector_transfer_read_2d(
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref
// CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index
// CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index
// LINALG-LABEL: split_vector_transfer_read_2d(
// LINALG-SAME: %[[A:[a-zA-Z0-9]*]]: memref
// LINALG-SAME: %[[i:[a-zA-Z0-9]*]]: index
// LINALG-SAME: %[[j:[a-zA-Z0-9]*]]: index
func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -> vector<4x8xf32> {
%c0 = constant 0 : index
%f0 = constant 0.0 : f32
@ -43,9 +56,45 @@ func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -
// CHECK: }
// CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]]
// CHECK_SAME: {masked = [false, false]} : memref<?x8xf32>, vector<4x8xf32>
// LINALG-DAG: %[[c0:.*]] = constant 0 : index
// LINALG-DAG: %[[c1:.*]] = constant 1 : index
// LINALG-DAG: %[[c4:.*]] = constant 4 : index
// LINALG-DAG: %[[c8:.*]] = constant 8 : index
// LINALG-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
// alloca for boundary full tile
// LINALG: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32>
// %i + 4 <= dim(%A, 0)
// LINALG: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
// LINALG: %[[d0:.*]] = dim %[[A]], %[[c0]] : memref<?x8xf32>
// LINALG: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[d0]] : index
// %j + 8 <= dim(%A, 1)
// LINALG: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
// LINALG: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index
// are both conds true
// LINALG: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1
// LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32>, index, index) {
// inBounds, just yield %A
// LINALG: scf.yield %[[A]], %[[i]], %[[j]] : memref<?x8xf32>, index, index
// LINALG: } else {
// slow path, fill tmp alloc and yield a memref_casted version of it
// LINALG: linalg.fill(%[[alloc]], %[[cst]]) : memref<4x8xf32>, f32
// LINALG: %[[d0:.*]] = dim %[[A]], %[[c0]] : memref<?x8xf32>
// LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[d0]], %[[i]], %[[c4]])
// LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
// LINALG: %[[sv:.*]] = subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [%[[c1]], %[[c1]]]
// LINALG-SAME: memref<?x8xf32> to memref<?x?xf32, #[[$map_2d_dynamic]]>
// LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_dynamic]]>, memref<4x8xf32>
// LINALG: %[[yielded:.*]] = memref_cast %[[alloc]] :
// LINALG-SAME: memref<4x8xf32> to memref<?x8xf32>
// LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
// LINALG-SAME: memref<?x8xf32>, index, index
// LINALG: }
// LINALG: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]]
// LINALG_SAME: {masked = [false, false]} : memref<?x8xf32>, vector<4x8xf32>
%1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32>, vector<4x8xf32>
// CHECK: return %[[res]] : vector<4x8xf32>
// LINALG: return %[[res]] : vector<4x8xf32>
return %1: vector<4x8xf32>
}
@ -53,6 +102,11 @@ func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref
// CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index
// CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index
// LINALG-LABEL: split_vector_transfer_read_strided_2d(
// LINALG-SAME: %[[A:[a-zA-Z0-9]*]]: memref
// LINALG-SAME: %[[i:[a-zA-Z0-9]*]]: index
// LINALG-SAME: %[[j:[a-zA-Z0-9]*]]: index
func @split_vector_transfer_read_strided_2d(
%A: memref<7x8xf32, offset:?, strides:[?, 1]>,
%i: index, %j: index) -> vector<4x8xf32> {
@ -94,6 +148,44 @@ func @split_vector_transfer_read_strided_2d(
// CHECK: }
// CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} :
// CHECK-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32>
// LINALG-DAG: %[[c0:.*]] = constant 0 : index
// LINALG-DAG: %[[c1:.*]] = constant 1 : index
// LINALG-DAG: %[[c4:.*]] = constant 4 : index
// LINALG-DAG: %[[c7:.*]] = constant 7 : index
// LINALG-DAG: %[[c8:.*]] = constant 8 : index
// LINALG-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
// alloca for boundary full tile
// LINALG: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32>
// %i + 4 <= dim(%A, 0)
// LINALG: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
// LINALG: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[c7]] : index
// %j + 8 <= dim(%A, 1)
// LINALG: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
// LINALG: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index
// are both conds true
// LINALG: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1
// LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index) {
// inBounds but not cast-compatible: yield a memref_casted form of %A
// LINALG: %[[casted:.*]] = memref_cast %arg0 :
// LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x8xf32, #[[$map_2d_stride_1]]>
// LINALG: scf.yield %[[casted]], %[[i]], %[[j]] :
// LINALG-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
// LINALG: } else {
// slow path, fill tmp alloc and yield a memref_casted version of it
// LINALG: linalg.fill(%[[alloc]], %[[cst]]) : memref<4x8xf32>, f32
// LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[c7]], %[[i]], %[[c4]])
// LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
// LINALG: %[[sv:.*]] = subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [%[[c1]], %[[c1]]]
// LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x?xf32, #[[$map_2d_dynamic]]>
// LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_dynamic]]>, memref<4x8xf32>
// LINALG: %[[yielded:.*]] = memref_cast %[[alloc]] :
// LINALG-SAME: memref<4x8xf32> to memref<?x8xf32, #[[$map_2d_stride_1]]>
// LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
// LINALG-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, index, index
// LINALG: }
// LINALG: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} :
// LINALG-SAME: memref<?x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32>
%1 = vector.transfer_read %A[%i, %j], %f0 :
memref<7x8xf32, offset:?, strides:[?, 1]>, vector<4x8xf32>

View file

@ -125,10 +125,23 @@ struct TestVectorUnrollingPatterns
struct TestVectorTransferFullPartialSplitPatterns
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
FunctionPass> {
TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns(
const TestVectorTransferFullPartialSplitPatterns &pass) {}
Option<bool> useLinalgOps{
*this, "use-linalg-copy",
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
"linalg.copy operations."),
llvm::cl::init(false)};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
patterns.insert<VectorTransferFullPartialRewriter>(ctx);
VectorTransformsOptions options;
if (useLinalgOps)
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
else
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
patterns.insert<VectorTransferFullPartialRewriter>(ctx, options);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};