Split InferShapedTypeOpInterface to create ReifyRankedShapedTypeInterface.

The `reifyReturnTypeShapesPerResultDim` method supports shape
inference for rsults that are ranked types. These are used lower in
the codegeneration stack than its counter part `reifyReturnTypeShapes`
which also supports unranked types, and is more suited for use higher
up the compilation stack. To have separation of concerns, this method
is split into its own interface.
See discussion : https://llvm.discourse.group/t/better-layering-for-infershapedtypeopinterface/3823

Differential Revision: https://reviews.llvm.org/D106133
This commit is contained in:
MaheshRavishankar 2021-07-19 14:35:20 -07:00
parent 49289bd943
commit 9afc065743
17 changed files with 181 additions and 200 deletions

View file

@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {

View file

@ -928,8 +928,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/// Returns the value that expresses the shape of the output in terms of
/// shape of the input operands where possible
LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes);
LogicalResult reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes);
//========================================================================//
// Helper functions to mutate the `operand_segment_sizes` attribute.

View file

@ -36,8 +36,7 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
[NoSideEffect,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>]> {
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "operation to define a tensor of particular value";
let description = [{
@ -130,10 +129,8 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
}
def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>,
NoSideEffect]> {
[AttrSizedOperandSegments, NoSideEffect,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "tensor pad operation";
let description = [{
`linalg.pad_tensor` is an operation that pads the `source` tensor
@ -398,8 +395,7 @@ def IndexListArrayAttr :
class Linalg_TensorReshapeOp<string mnemonic> : Linalg_ReshapeLikeOp<
mnemonic,
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>]>,
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>,
Arguments<(ins AnyTensor:$src,
IndexListArrayAttr:$reassociation)>,
Results<(outs AnyTensor:$result)> {

View file

@ -26,7 +26,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
// depending on the specific Linalg op.
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
: Op<Linalg_Dialect, mnemonic, !listconcat(props, [
LinalgStructuredInterface, InferShapedTypeOpInterface])> {
LinalgStructuredInterface, ReifyRankedShapedTypeOpInterface])> {
code structuredOpsBaseDecls = [{
// Return whether the op accesses the iteration indices.
bool hasIndexSemantics() {
@ -36,9 +36,9 @@ class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
return !op->getRegion(0).front().getOps<IndexOp>().empty();
}
LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
return cast<LinalgOp>(getOperation()).reifyReturnTypeShapesPerResultDim(b,
LogicalResult reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgOp>(getOperation()).reifyResultShapes(b,
reifiedReturnShapes);
}
}];

View file

@ -35,6 +35,13 @@ namespace memref {
/// into `patterns`.
void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the
/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
/// operands.
void populateResolveRankedShapeTypeResultDimsPatterns(
RewritePatternSet &patterns);
/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the `InferShapedTypeOpInterface`, in
/// terms of shapes of its input operands.
@ -50,7 +57,14 @@ std::unique_ptr<Pass> createFoldSubViewOpsPass();
/// Creates an operation pass to resolve `memref.dim` operations with values
/// that are defined by operations that implement the
/// `InferShapedTypeOpInterface`, in terms of shapes of its input operands.
/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
/// operands.
std::unique_ptr<Pass> createResolveRankedShapeTypeResultDimsPass();
/// Creates an operation pass to resolve `memref.dim` operations with values
/// that are defined by operations that implement the
/// `InferShapedTypeOpInterface` or the `ReifyRankedShapeTypeShapeOpInterface`,
/// in terms of shapes of its input operands.
std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
//===----------------------------------------------------------------------===//

View file

@ -23,12 +23,28 @@ def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
];
}
def ResolveRankedShapeTypeResultDims :
Pass<"resolve-ranked-shaped-type-result-dims"> {
let summary = "Resolve memref.dim of result values of ranked shape type";
let description = [{
The pass resolves memref.dim of result of operations that
implement the `ReifyRankedShapedTypeOpInterface` in terms of
shapes of its operands.
}];
let constructor =
"mlir::memref::createResolveRankedShapeTypeResultDimsPass()";
let dependentDialects = [
"memref::MemRefDialect", "tensor::TensorDialect"
];
}
def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
let summary = "Resolve memref.dim of result values";
let description = [{
The pass resolves memref.dim of result of operations that
implement the `InferShapedTypeOpInterface` in terms of shapes of
its operands.
implement the `InferShapedTypeOpInterface` or
`ReifyRankedShapedTypeOpInterface` in terms of shapes of its
operands.
}];
let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()";
let dependentDialects = [

View file

@ -432,8 +432,7 @@ def Tensor_InsertOp : Tensor_Op<"insert",
def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
Tensor_Dialect, "insert_slice",
[NoSideEffect, AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
TypesMatchWith<"expected result type to match dest type",
"dest", "result", "$_self">]> {
let summary = "insert_slice operation";

View file

@ -23,6 +23,8 @@
namespace mlir {
using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<Value>>;
/// ShapedTypeComponents that represents the components of a ShapedType.
/// The components consist of
/// - A ranked or unranked shape with the dimension specification match those

View file

@ -105,9 +105,7 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
/*desc=*/[{Reify the shape computation for the operation.
Insert operations using the given OpBuilder that computes the
result shape. Only one of this method or
`reifyReturnTypeShapesPerResultDim` needs to be overriden by the
operation. This interface is supposed to be workable during dialect
result shape. This interface is supposed to be workable during dialect
conversion (e.g. convert from tensor world to buffer world),
where `getOperand` may be invalid. For example, some ops (e.g.
dynamic_reshape(input, target_shape)) may depend on their operands
@ -127,34 +125,6 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
"::mlir::SmallVectorImpl<::mlir::Value> &":$reifiedReturnShapes),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
>,
InterfaceMethod<
/*desc=*/[{Reify the shape computation for the operation.
Insert operations using the given OpBuilder that computes the
result shape. The `reifiedReturnShapes` is expected to be
populated with as many vectors as the number of results of the
op (empty if the shape of a result value cannot be computed). If
the returned shape for a result is not empty, its size must
match the rank of the shaped type returned. Consequently, this
interface can only be overridden if the return types are ranked.
If both this method and `reifyReturnTypeShapes` are overridden
by the operation, `reifyReturnTypeShapes` takes precedence. This
method is intended to be used when the shape of each result, dim
pair can be computed independently. Using this method avoids
adding additional instructions to aggregate individual dimension
of a result shape into an single `Value` (and consequently
avoids the need to extract the value from the shape on the
client side).
}],
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"reifyReturnTypeShapesPerResultDim",
/*args=*/(ins "::mlir::OpBuilder&":$builder,
"::mlir::SmallVectorImpl<::mlir::SmallVector<::mlir::Value>>&"
:$reifiedReturnShapes),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
>
];
}
@ -176,4 +146,35 @@ class InferTensorType<list<string> overridenMethods = []> {
defvar InferTensorTypeWithReify = InferTensorType<[
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
def ReifyRankedShapedTypeOpInterface :
OpInterface<"ReifyRankedShapedTypeOpInterface"> {
let description = [{
Interface to compute the shape of the result of an operation when
the result is a ranked shape type, i.e. `RankedTensorType` or
`MemRefType`.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<
/*desc=*/[{
Reify the shape of the result of an operation (typically in
terms of shape of its operands)
Insert operations using the given `OpBuilder` that computes
the result shape. The `reifiedReturnShapes` is expected to be
populated with as many vectors as the number of results of the
op. Each of these vectors is expected to be of size equal to
rank of the corresponding result. If the shape of a particular
result cannot be computed it must be empty.
}],
/*retTy=*/"LogicalResult",
/*methodName=*/"reifyResultShapes",
/*args=*/(ins "::mlir::OpBuilder &":$builder,
"ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
>
];
}
#endif // MLIR_INFERTYPEOPINTERFACE

View file

@ -274,8 +274,9 @@ private:
llvm::SmallSet<unsigned, 4> positions;
};
LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
LogicalResult
LinalgOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
// An example that helps understand the logic below.
// Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
// We want to express the shape of dim 0 of O in terms of shape of the inputs.

View file

@ -779,9 +779,8 @@ struct FoldInitTensorWithTensorReshapeOp
if (!reshapeOp.src().template getDefiningOp<InitTensorOp>())
return failure();
Location loc = reshapeOp.getLoc();
SmallVector<SmallVector<Value>, 4> resultShapes;
if (failed(reshapeOp.reifyReturnTypeShapesPerResultDim(rewriter,
resultShapes)) ||
ReifiedRankedShapedTypeDims resultShapes;
if (failed(reshapeOp.reifyResultShapes(rewriter, resultShapes)) ||
!llvm::hasSingleElement(resultShapes))
return failure();
Value initTensor = rewriter.create<InitTensorOp>(
@ -825,9 +824,8 @@ void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
ReplaceStaticShapeDims>(context);
}
LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &builder,
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
LogicalResult InitTensorOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
auto shapes = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
if (isDynamicSize(dim))
@ -1003,8 +1001,8 @@ PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
builder);
}
LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
LogicalResult PadTensorOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
Location loc = getLoc();
auto lowPad = getMixedLowPad();
auto highPad = getMixedHighPad();
@ -1429,8 +1427,8 @@ void TensorCollapseShapeOp::getCanonicalizationPatterns(
FoldReshapeWithConstant<TensorCollapseShapeOp>>(context);
}
LogicalResult TensorExpandShapeOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
LogicalResult TensorExpandShapeOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
auto resultShape =
getAsValues(b, getLoc(),
getReshapeOutputShapeFromInputShape(
@ -1440,8 +1438,8 @@ LogicalResult TensorExpandShapeOp::reifyReturnTypeShapesPerResultDim(
return success();
}
LogicalResult TensorCollapseShapeOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
LogicalResult TensorCollapseShapeOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
auto resultShape =
getAsValues(b, getLoc(),
getReshapeOutputShapeFromInputShape(

View file

@ -1,5 +1,4 @@
//===- ResolveShapedTypeResultDims.cpp - Resolve memref.dim ops of result values
//-------===//
//===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -21,52 +20,6 @@
using namespace mlir;
/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
/// TODO(ravishankarm): This is better put as a interface utility method
/// somewhere, but that would imply the interface will depend on the `tensor`
/// dialect. Ideally maybe a utility method in the `tensor` dialect.
static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
int64_t dimIndex) {
unsigned resultNumber = result.getResultNumber();
auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
Location loc = result.getOwner()->getLoc();
if (!shapedTypeOp)
return nullptr;
// The interface exposes two methods, one that returns the shape of all the
// results as `Value` and other that returns the shape as a list of
// `SmallVector<Value>`. The former takes precedence over the latter. So first
// check if the op implements the first interface method or the second, and
// get the value to use appropriately.
SmallVector<Value> reifiedResultShapes;
if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
if (reifiedResultShapes.size() <= resultNumber)
return nullptr;
Value resultShape = reifiedResultShapes[resultNumber];
auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
return nullptr;
return builder.create<tensor::ExtractOp>(
loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
}
SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
builder, reifiedResultShapesPerDim)))
return nullptr;
if (reifiedResultShapesPerDim.size() <= resultNumber ||
reifiedResultShapesPerDim[resultNumber].size() !=
static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
return nullptr;
OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
if (auto attr = valueOrAttr.dyn_cast<Attribute>())
return builder.createOrFold<ConstantIndexOp>(
loc, attr.cast<IntegerAttr>().getInt());
return valueOrAttr.get<Value>();
}
namespace {
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
template <typename OpTy>
@ -86,11 +39,62 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
if (!dimIndex)
return failure();
Value replacement =
getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
if (!replacement)
SmallVector<Value> reifiedResultShapes;
if (failed(shapedTypeOp.reifyReturnTypeShapes(
rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
return failure();
rewriter.replaceOp(dimOp, replacement);
if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
return failure();
Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
return failure();
Location loc = dimOp->getLoc();
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
dimOp, resultShape,
rewriter.createOrFold<ConstantIndexOp>(loc, *dimIndex));
return success();
}
};
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
template <typename OpTy>
struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
OpResult dimValue = dimOp.source().template dyn_cast<OpResult>();
if (!dimValue)
return failure();
auto rankedShapeTypeOp =
dyn_cast<ReifyRankedShapedTypeOpInterface>(dimValue.getOwner());
if (!rankedShapeTypeOp)
return failure();
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
if (!dimIndex)
return failure();
SmallVector<SmallVector<Value>> reifiedResultShapes;
if (failed(
rankedShapeTypeOp.reifyResultShapes(rewriter, reifiedResultShapes)))
return failure();
if (reifiedResultShapes.size() != rankedShapeTypeOp->getNumResults())
return failure();
unsigned resultNumber = dimValue.getResultNumber();
auto sourceType = dimValue.getType().dyn_cast<RankedTensorType>();
if (reifiedResultShapes[resultNumber].size() !=
static_cast<size_t>(sourceType.getRank()))
return failure();
rewriter.replaceOp(dimOp, reifiedResultShapes[resultNumber][*dimIndex]);
return success();
}
};
@ -104,12 +108,26 @@ namespace {
#define GEN_PASS_CLASSES
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
struct ResolveRankedShapeTypeResultDimsPass final
: public ResolveRankedShapeTypeResultDimsBase<
ResolveRankedShapeTypeResultDimsPass> {
void runOnOperation() override;
};
struct ResolveShapedTypeResultDimsPass final
: public ResolveShapedTypeResultDimsBase<ResolveShapedTypeResultDimsPass> {
void runOnOperation() override;
};
} // namespace
void memref::populateResolveRankedShapeTypeResultDimsPatterns(
RewritePatternSet &patterns) {
patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
patterns.getContext());
}
void memref::populateResolveShapedTypeResultDimsPatterns(
RewritePatternSet &patterns) {
// TODO: Move tensor::DimOp pattern to the Tensor dialect.
@ -118,8 +136,17 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
patterns.getContext());
}
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns))))
return signalPassFailure();
}
void ResolveShapedTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns))))
@ -129,3 +156,7 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
return std::make_unique<ResolveShapedTypeResultDimsPass>();
}
std::unique_ptr<Pass> memref::createResolveRankedShapeTypeResultDimsPass() {
return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
}

View file

@ -1042,9 +1042,8 @@ OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
return OpFoldResult();
}
LogicalResult InsertSliceOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &builder,
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
LogicalResult InsertSliceOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
reifiedReturnShapes[0][dim] =

View file

@ -55,34 +55,3 @@ func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
// -----
func @result_shape_and_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-> (index, index, index, index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%0:2 = "test.op_with_result_shape_and_per_dim_interface"(%arg0, %arg1)
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
%1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
%2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
%3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32>
%4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32>
%5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
return %1, %2, %3, %4, %5 : index, index, index, index, index
}
// CHECK-LABEL: func @result_shape_and_per_dim(
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]

View file

@ -822,46 +822,8 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
return success();
}
LogicalResult
OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
OpBuilder &builder,
llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
Location loc = getLoc();
shapes.reserve(getNumOperands());
for (Value operand : llvm::reverse(getOperands())) {
auto currShape = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(
0, operand.getType().cast<RankedTensorType>().getRank()),
[&](int64_t dim) -> Value {
return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
}));
shapes.emplace_back(std::move(currShape));
}
return success();
}
LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes(
OpBuilder &builder, ValueRange operands,
llvm::SmallVectorImpl<Value> &shapes) {
Location loc = getLoc();
shapes.reserve(operands.size());
for (Value operand : llvm::reverse(operands)) {
auto currShape = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(
0, operand.getType().cast<RankedTensorType>().getRank()),
[&](int64_t dim) -> Value {
return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
}));
shapes.push_back(builder.create<tensor::FromElementsOp>(
getLoc(), builder.getIndexType(), currShape));
}
return success();
}
LogicalResult
OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
OpBuilder &builder,
llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
Location loc = getLoc();
shapes.reserve(getNumOperands());
for (Value operand : llvm::reverse(getOperands())) {

View file

@ -579,16 +579,7 @@ def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
def OpWithResultShapePerDimInterfaceOp :
TEST_Op<"op_with_result_shape_per_dim_interface",
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>]> {
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}
def OpWithResultShapeAndPerDimInterfaceOp :
TEST_Op<"op_with_result_shape_and_per_dim_interface",
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes", "reifyReturnTypeShapesPerResultDim"]>]> {
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}

View file

@ -2046,6 +2046,7 @@ cc_library(
":Affine",
":DialectUtils",
":IR",
":InferTypeOpInterface",
":LinalgInterfacesIncGen",
":LinalgStructuredOpsIncGen",
":MemRefDialect",