Promote transpose from linalg to standard dialect

While affine maps are part of the builtin memref type, there is very
limited support for manipulating them in the standard dialect. Add
transpose to the set of ops to complement the existing view/subview ops.
This is a metadata transformation that encodes the transpose into the
strides of a memref.

I'm planning to use this when lowering operations on strided memrefs,
using the transpose to remove the stride without adding a dependency on
linalg dialect.

Differential Revision: https://reviews.llvm.org/D88651
This commit is contained in:
Benjamin Kramer 2020-09-30 19:19:04 +02:00
parent 16778b19f2
commit 6e2b267d1c
15 changed files with 237 additions and 228 deletions

View file

@ -554,9 +554,9 @@ are:
* `std.view`,
* `std.subview`,
* `std.transpose`.
* `linalg.range`,
* `linalg.slice`,
* `linalg.transpose`.
* `linalg.reshape`,
Future ops are added on a per-need basis but should include:

View file

@ -287,36 +287,6 @@ def Linalg_SliceOp : Linalg_Op<"slice", [
let hasFolder = 1;
}
def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
Arguments<(ins AnyStridedMemRef:$view, AffineMapAttr:$permutation)>,
Results<(outs AnyStridedMemRef)> {
let summary = "`transpose` produces a new strided memref (metadata-only)";
let description = [{
The `linalg.transpose` op produces a strided memref whose sizes and strides
are a permutation of the original `view`. This is a pure metadata
transformation.
Example:
```mlir
%1 = linalg.transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, stride_spec>
```
}];
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, Value view, "
"AffineMapAttr permutation, ArrayRef<NamedAttribute> attrs = {}">];
let verifier = [{ return ::verify(*this); }];
let extraClassDeclaration = [{
static StringRef getPermutationAttrName() { return "permutation"; }
ShapedType getShapedType() { return view().getType().cast<ShapedType>(); }
}];
let hasFolder = 1;
}
def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, Terminator]>,
Arguments<(ins Variadic<AnyType>:$values)> {
let summary = "Linalg yield operation";

View file

@ -3428,6 +3428,38 @@ def TensorStoreOp : Std_Op<"tensor_store",
let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
def TransposeOp : Std_Op<"transpose", [NoSideEffect]>,
Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
Results<(outs AnyStridedMemRef)> {
let summary = "`transpose` produces a new strided memref (metadata-only)";
let description = [{
The `transpose` op produces a strided memref whose sizes and strides
are a permutation of the original `in` memref. This is purely a metadata
transformation.
Example:
```mlir
%1 = transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
```
}];
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, Value in, "
"AffineMapAttr permutation, ArrayRef<NamedAttribute> attrs = {}">];
let extraClassDeclaration = [{
static StringRef getPermutationAttrName() { return "permutation"; }
ShapedType getShapedType() { return in().getType().cast<ShapedType>(); }
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// TruncateIOp
//===----------------------------------------------------------------------===//

View file

@ -284,57 +284,6 @@ public:
}
};
/// Conversion pattern that transforms a linalg.transpose op into:
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
/// and stride. Size and stride are permutations of the original values.
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
/// The linalg.transpose op is replaced by the alloca'ed pointer.
class TransposeOpConversion : public ConvertToLLVMPattern {
public:
explicit TransposeOpConversion(MLIRContext *context,
LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(TransposeOp::getOperationName(), context,
lowering_) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Initialize the common boilerplate and alloca at the top of the FuncOp.
edsc::ScopedContext context(rewriter, op->getLoc());
TransposeOpAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.view());
auto transposeOp = cast<TransposeOp>(op);
// No permutation, early exit.
if (transposeOp.permutation().isIdentity())
return rewriter.replaceOp(op, {baseDesc}), success();
BaseViewConversionHelper desc(
typeConverter.convertType(transposeOp.getShapedType()));
// Copy the base and aligned pointers from the old descriptor to the new
// one.
desc.setAllocatedPtr(baseDesc.allocatedPtr());
desc.setAlignedPtr(baseDesc.alignedPtr());
// Copy the offset pointer from the old descriptor to the new one.
desc.setOffset(baseDesc.offset());
// Iterate over the dimensions and apply size/stride permutation.
for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
int sourcePos = en.index();
int targetPos = en.value().cast<AffineDimExpr>().getPosition();
desc.setSize(targetPos, baseDesc.size(sourcePos));
desc.setStride(targetPos, baseDesc.stride(sourcePos));
}
rewriter.replaceOp(op, {desc});
return success();
}
};
// YieldOp produces and LLVM::ReturnOp.
class YieldOpConversion : public ConvertToLLVMPattern {
public:
@ -356,7 +305,7 @@ void mlir::populateLinalgToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
TransposeOpConversion, YieldOpConversion>(ctx, converter);
YieldOpConversion>(ctx, converter);
// Populate the type conversions for the linalg types.
converter.addConversion(

View file

@ -206,12 +206,12 @@ public:
// If either inputPerm or outputPerm are non-identities, insert transposes.
auto inputPerm = op.inputPermutation();
if (inputPerm.hasValue() && !inputPerm->isIdentity())
in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
AffineMapAttr::get(*inputPerm));
in = rewriter.create<TransposeOp>(op.getLoc(), in,
AffineMapAttr::get(*inputPerm));
auto outputPerm = op.outputPermutation();
if (outputPerm.hasValue() && !outputPerm->isIdentity())
out = rewriter.create<linalg::TransposeOp>(
op.getLoc(), out, AffineMapAttr::get(*outputPerm));
out = rewriter.create<TransposeOp>(op.getLoc(), out,
AffineMapAttr::get(*outputPerm));
// If nothing was transposed, fail and let the conversion kick in.
if (in == op.input() && out == op.output())
@ -270,7 +270,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect>();
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
target.addLegalOp<linalg::TransposeOp, linalg::ReshapeOp, linalg::RangeOp>();
target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
OwningRewritePatternList patterns;
populateLinalgToStandardConversionPatterns(patterns, &getContext());
if (failed(applyFullConversion(module, target, patterns)))

View file

@ -3011,6 +3011,57 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
}
};
/// Conversion pattern that transforms a transpose op into:
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
/// and stride. Size and stride are permutations of the original values.
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
/// The transpose op is replaced by the alloca'ed pointer.
class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> {
public:
using ConvertOpToLLVMPattern<TransposeOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
TransposeOpAdaptor adaptor(operands);
MemRefDescriptor viewMemRef(adaptor.in());
auto transposeOp = cast<TransposeOp>(op);
// No permutation, early exit.
if (transposeOp.permutation().isIdentity())
return rewriter.replaceOp(op, {viewMemRef}), success();
auto targetMemRef = MemRefDescriptor::undef(
rewriter, loc, typeConverter.convertType(transposeOp.getShapedType()));
// Copy the base and aligned pointers from the old descriptor to the new
// one.
targetMemRef.setAllocatedPtr(rewriter, loc,
viewMemRef.allocatedPtr(rewriter, loc));
targetMemRef.setAlignedPtr(rewriter, loc,
viewMemRef.alignedPtr(rewriter, loc));
// Copy the offset pointer from the old descriptor to the new one.
targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
// Iterate over the dimensions and apply size/stride permutation.
for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
int sourcePos = en.index();
int targetPos = en.value().cast<AffineDimExpr>().getPosition();
targetMemRef.setSize(rewriter, loc, targetPos,
viewMemRef.size(rewriter, loc, sourcePos));
targetMemRef.setStride(rewriter, loc, targetPos,
viewMemRef.stride(rewriter, loc, sourcePos));
}
rewriter.replaceOp(op, {targetMemRef});
return success();
}
};
/// Conversion pattern that transforms an op into:
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
@ -3425,6 +3476,7 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
RankOpLowering,
StoreOpLowering,
SubViewOpLowering,
TransposeOpLowering,
ViewOpLowering,
AllocOpLowering>(converter);
// clang-format on

View file

@ -973,86 +973,6 @@ static LogicalResult verify(SliceOp op) {
Value SliceOp::getViewSource() { return view(); }
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
static MemRefType inferTransposeResultType(MemRefType memRefType,
AffineMap permutationMap) {
auto rank = memRefType.getRank();
auto originalSizes = memRefType.getShape();
// Compute permuted sizes.
SmallVector<int64_t, 4> sizes(rank, 0);
for (auto en : llvm::enumerate(permutationMap.getResults()))
sizes[en.index()] =
originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
// Compute permuted strides.
int64_t offset;
SmallVector<int64_t, 4> strides;
auto res = getStridesAndOffset(memRefType, strides, offset);
assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
(void)res;
auto map =
makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
map = permutationMap ? map.compose(permutationMap) : map;
return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
}
void mlir::linalg::TransposeOp::build(OpBuilder &b, OperationState &result,
Value view, AffineMapAttr permutation,
ArrayRef<NamedAttribute> attrs) {
auto permutationMap = permutation.getValue();
assert(permutationMap);
auto memRefType = view.getType().cast<MemRefType>();
// Compute result type.
MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
build(b, result, resultType, view, attrs);
result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
}
static void print(OpAsmPrinter &p, TransposeOp op) {
p << op.getOperationName() << " " << op.view() << " " << op.permutation();
p.printOptionalAttrDict(op.getAttrs(),
{TransposeOp::getPermutationAttrName()});
p << " : " << op.view().getType() << " to " << op.getType();
}
static ParseResult parseTransposeOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType view;
AffineMap permutation;
MemRefType srcType, dstType;
if (parser.parseOperand(view) || parser.parseAffineMap(permutation) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(srcType) ||
parser.resolveOperand(view, srcType, result.operands) ||
parser.parseKeywordType("to", dstType) ||
parser.addTypeToList(dstType, result.types))
return failure();
result.addAttribute(TransposeOp::getPermutationAttrName(),
AffineMapAttr::get(permutation));
return success();
}
static LogicalResult verify(TransposeOp op) {
if (!op.permutation().isPermutation())
return op.emitOpError("expected a permutation map");
if (op.permutation().getNumDims() != op.getShapedType().getRank())
return op.emitOpError(
"expected a permutation map of same rank as the view");
auto srcType = op.view().getType().cast<MemRefType>();
auto dstType = op.getType().cast<MemRefType>();
if (dstType != inferTransposeResultType(srcType, op.permutation()))
return op.emitOpError("output type ")
<< dstType << " does not match transposed input type " << srcType;
return success();
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
@ -1359,11 +1279,6 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
return foldReshapeOp(*this, operands);
}
OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return {};
}
//===----------------------------------------------------------------------===//
// Auto-generated Linalg named ops.

View file

@ -3491,6 +3491,96 @@ static Type getTensorTypeFromMemRefType(Type type) {
return NoneType::get(type.getContext());
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
static MemRefType inferTransposeResultType(MemRefType memRefType,
AffineMap permutationMap) {
auto rank = memRefType.getRank();
auto originalSizes = memRefType.getShape();
// Compute permuted sizes.
SmallVector<int64_t, 4> sizes(rank, 0);
for (auto en : llvm::enumerate(permutationMap.getResults()))
sizes[en.index()] =
originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
// Compute permuted strides.
int64_t offset;
SmallVector<int64_t, 4> strides;
auto res = getStridesAndOffset(memRefType, strides, offset);
assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
(void)res;
auto map =
makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
map = permutationMap ? map.compose(permutationMap) : map;
return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
}
void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
AffineMapAttr permutation,
ArrayRef<NamedAttribute> attrs) {
auto permutationMap = permutation.getValue();
assert(permutationMap);
auto memRefType = in.getType().cast<MemRefType>();
// Compute result type.
MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
build(b, result, resultType, in, attrs);
result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
}
// transpose $in $permutation attr-dict : type($in) `to` type(results)
static void print(OpAsmPrinter &p, TransposeOp op) {
p << "transpose " << op.in() << " " << op.permutation();
p.printOptionalAttrDict(op.getAttrs(),
{TransposeOp::getPermutationAttrName()});
p << " : " << op.in().getType() << " to " << op.getType();
}
static ParseResult parseTransposeOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType in;
AffineMap permutation;
MemRefType srcType, dstType;
if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(srcType) ||
parser.resolveOperand(in, srcType, result.operands) ||
parser.parseKeywordType("to", dstType) ||
parser.addTypeToList(dstType, result.types))
return failure();
result.addAttribute(TransposeOp::getPermutationAttrName(),
AffineMapAttr::get(permutation));
return success();
}
static LogicalResult verify(TransposeOp op) {
if (!op.permutation().isPermutation())
return op.emitOpError("expected a permutation map");
if (op.permutation().getNumDims() != op.getShapedType().getRank())
return op.emitOpError(
"expected a permutation map of same rank as the input");
auto srcType = op.in().getType().cast<MemRefType>();
auto dstType = op.getType().cast<MemRefType>();
auto transposedType = inferTransposeResultType(srcType, op.permutation());
if (dstType != transposedType)
return op.emitOpError("output type ")
<< dstType << " does not match transposed input type " << srcType
<< ", " << transposedType;
return success();
}
OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return {};
}
//===----------------------------------------------------------------------===//
// TruncateIOp
//===----------------------------------------------------------------------===//

View file

@ -673,7 +673,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
/// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
auto transposeOp = extractOp.vector().getDefiningOp<TransposeOp>();
auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
if (!transposeOp)
return failure();
@ -2521,7 +2521,7 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
// Eliminates transpose operations, which produce values identical to their
// input values. This happens when the dimensions of the input vector remain in
// their original order after the transpose operation.
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
SmallVector<int64_t, 4> transp;
getTransp(transp);
@ -2535,7 +2535,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
return vector();
}
static LogicalResult verify(TransposeOp op) {
static LogicalResult verify(vector::TransposeOp op) {
VectorType vectorType = op.getVectorType();
VectorType resultType = op.getResultType();
int64_t rank = resultType.getRank();
@ -2563,14 +2563,14 @@ static LogicalResult verify(TransposeOp op) {
namespace {
// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
class TransposeFolder final : public OpRewritePattern<TransposeOp> {
class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern<TransposeOp>::OpRewritePattern;
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TransposeOp transposeOp,
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
// Wrapper around TransposeOp::getTransp() for cleaner code.
auto getPermutation = [](TransposeOp transpose) {
// Wrapper around vector::TransposeOp::getTransp() for cleaner code.
auto getPermutation = [](vector::TransposeOp transpose) {
SmallVector<int64_t, 4> permutation;
transpose.getTransp(permutation);
return permutation;
@ -2586,15 +2586,15 @@ public:
};
// Return if the input of 'transposeOp' is not defined by another transpose.
TransposeOp parentTransposeOp =
transposeOp.vector().getDefiningOp<TransposeOp>();
vector::TransposeOp parentTransposeOp =
transposeOp.vector().getDefiningOp<vector::TransposeOp>();
if (!parentTransposeOp)
return failure();
SmallVector<int64_t, 4> permutation = composePermutations(
getPermutation(parentTransposeOp), getPermutation(transposeOp));
// Replace 'transposeOp' with a new transpose operation.
rewriter.replaceOpWithNewOp<TransposeOp>(
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
parentTransposeOp.vector(),
vector::getVectorSubscriptAttr(rewriter, permutation));
@ -2604,12 +2604,12 @@ public:
} // end anonymous namespace
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
void vector::TransposeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<TransposeFolder>(context);
}
void TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
populateFromInt64AttrArray(transp(), results);
}

View file

@ -114,3 +114,20 @@ func @assert_test_function(%arg : i1) {
return
}
// -----
// CHECK-LABEL: func @transpose
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
%0 = transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d0 * s2 + d1)>>
return
}

View file

@ -33,27 +33,6 @@ func @store_number_of_indices(%v : memref<f32>) {
// -----
func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{expected a permutation map}}
linalg.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
// -----
func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{expected a permutation map of same rank as the view}}
linalg.transpose %v (i) -> (i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
// -----
func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
linalg.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
// -----
func @yield_parent(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
// expected-error @+1 {{op expected parent op with LinalgOp interface}}
linalg.yield %arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>

View file

@ -69,22 +69,6 @@ func @slice_with_range_and_index(%arg0: memref<?x?xf64, offset: ?, strides: [?,
// CHECK: llvm.insertvalue %{{.*}}[3, 0] : !llvm.struct<(ptr<double>, ptr<double>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: llvm.insertvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<double>, ptr<double>, i64, array<1 x i64>, array<1 x i64>)>
func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
%0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d0 * s2 + d1)>>
return
}
// CHECK-LABEL: func @transpose
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<3 x i64>, array<3 x i64>)>
func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
// Reshapes that expand a contiguous tensor with some 1's.
%0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,

View file

@ -126,11 +126,11 @@ func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
// CHECK-DAG: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>
func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
%0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>>
%0 = transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>>
return
}
// CHECK-LABEL: func @transpose
// CHECK: linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
// CHECK: transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
// CHECK-SAME: memref<?x?x?xf32, #[[$strided3D]]> to memref<?x?x?xf32, #[[$strided3DT]]>
// -----

View file

@ -55,9 +55,9 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
// CHECK-LABEL: func @copy_transpose(
// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?x?x?xf32, #[[$map1]]>,
// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?x?x?xf32, #[[$map1]]>) {
// CHECK: %[[t0:.*]] = linalg.transpose %[[arg0]]
// CHECK: %[[t0:.*]] = transpose %[[arg0]]
// CHECK-SAME: (d0, d1, d2) -> (d0, d2, d1) : memref<?x?x?xf32, #[[$map1]]>
// CHECK: %[[t1:.*]] = linalg.transpose %[[arg1]]
// CHECK: %[[t1:.*]] = transpose %[[arg1]]
// CHECK-SAME: (d0, d1, d2) -> (d2, d1, d0) : memref<?x?x?xf32, #[[$map1]]>
// CHECK: %[[o0:.*]] = memref_cast %[[t0]] :
// CHECK-SAME: memref<?x?x?xf32, #[[$map2]]> to memref<?x?x?xf32, #[[$map8]]>

View file

@ -81,3 +81,24 @@ func @dynamic_tensor_from_elements(%m : index, %n : index)
} : tensor<?x3x?xf32>
return %tnsr : tensor<?x3x?xf32>
}
// -----
func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{expected a permutation map}}
transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
// -----
func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{expected a permutation map of same rank as the input}}
transpose %v (i) -> (i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
// -----
func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}