[mlir][Shape] Make shape_eq nary

This gets rid of a dubious shape_eq %a, %a fold, that folds shape_eq
even if %a is not an Attribute.

Differential Revision: https://reviews.llvm.org/D97728
This commit is contained in:
Benjamin Kramer 2021-03-01 20:34:17 +01:00
parent 64f5d7e972
commit 24acadef8a
5 changed files with 128 additions and 64 deletions

View file

@ -168,20 +168,38 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
let hasFolder = 1;
}
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> {
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative,
InferTypeOpInterface]> {
let summary = "Returns whether the input shapes or extent tensors are equal";
let description = [{
Takes two shape or extent tensor operands and determines whether they are
equal. When extent tensors are compared to shapes they are regarded as their
equivalent non-error shapes. Error shapes can be tested for equality like
any other shape value, meaning that the error value is equal to itself.
Takes one or more shape or extent tensor operands and determines whether
they are equal. When extent tensors are compared to shapes they are regarded
as their equivalent non-error shapes. Error shapes can be tested for
equality like any other shape value, meaning that the error value is equal
to itself.
}];
let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
Shape_ShapeOrExtentTensorType:$rhs);
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
let results = (outs I1:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
// Convenience builder alias for the binary version.
let builders = [
OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
[{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
];
let extraClassDeclaration = [{
// TODO: This should really be automatic. Figure out how to not need this defined.
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
inferredReturnTypes.push_back(::mlir::IntegerType::get(context,
/*width=*/1));
return success();
};
}];
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
let hasFolder = 1;
}

View file

@ -474,46 +474,56 @@ struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
LogicalResult
ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering is only defined on `tensor<?xindex>` operands, not
// on shapes.
if (op.lhs().getType().isa<ShapeType>() ||
op.rhs().getType().isa<ShapeType>()) {
if (!llvm::all_of(op.shapes(),
[](Value v) { return !v.getType().isa<ShapeType>(); }))
return failure();
Type i1Ty = rewriter.getI1Type();
if (op.shapes().size() <= 1) {
rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty,
rewriter.getBoolAttr(true));
return success();
}
ShapeEqOp::Adaptor transformed(operands);
auto loc = op.getLoc();
Type indexTy = rewriter.getIndexType();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
Value eqRank =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
Type i1Ty = rewriter.getI1Type();
rewriter.replaceOpWithNewOp<IfOp>(
op, i1Ty, eqRank,
[&](OpBuilder &b, Location loc) {
Value one = b.create<ConstantIndexOp>(loc, 1);
Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
auto loop = b.create<scf::ForOp>(
loc, zero, lhsRank, one, ValueRange{init},
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
Value conj = args[0];
Value lhsExtent =
b.create<tensor::ExtractOp>(loc, transformed.lhs(), iv);
Value rhsExtent =
b.create<tensor::ExtractOp>(loc, transformed.rhs(), iv);
Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
lhsExtent, rhsExtent);
Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
});
b.create<scf::YieldOp>(loc, loop.getResults());
},
[&](OpBuilder &b, Location loc) {
Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
b.create<scf::YieldOp>(loc, result);
});
Value firstShape = transformed.shapes().front();
Value firstRank = rewriter.create<DimOp>(loc, indexTy, firstShape, zero);
Value result = nullptr;
// Generate a linear sequence of compares, all with firstShape as lhs.
for (Value shape : transformed.shapes().drop_front(1)) {
Value rank = rewriter.create<DimOp>(loc, indexTy, shape, zero);
Value eqRank =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
auto same = rewriter.create<IfOp>(
loc, i1Ty, eqRank,
[&](OpBuilder &b, Location loc) {
Value one = b.create<ConstantIndexOp>(loc, 1);
Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
auto loop = b.create<scf::ForOp>(
loc, zero, firstRank, one, ValueRange{init},
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
Value conj = args[0];
Value lhsExtent =
b.create<tensor::ExtractOp>(loc, firstShape, iv);
Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
lhsExtent, rhsExtent);
Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
});
b.create<scf::YieldOp>(loc, loop.getResults());
},
[&](OpBuilder &b, Location loc) {
Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
b.create<scf::YieldOp>(loc, result);
});
result = !result ? same.getResult(0)
: rewriter.create<AndOp>(loc, result, same.getResult(0));
}
rewriter.replaceOp(op, result);
return success();
}

View file

@ -629,15 +629,15 @@ OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs())
return BoolAttr::get(getContext(), true);
auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (lhs == nullptr)
bool allSame = true;
if (!operands.empty() && !operands[0])
return {};
auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
if (rhs == nullptr)
return {};
return BoolAttr::get(getContext(), lhs == rhs);
for (Attribute operand : operands.drop_front(1)) {
if (!operand)
return {};
allSame = allSame && operand == operands[0];
}
return BoolAttr::get(getContext(), allSame);
}
//===----------------------------------------------------------------------===//

View file

@ -295,6 +295,53 @@ func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
// -----
// CHECK-LABEL: @shape_eq
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> i1
func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>, %c : tensor<?xindex>) -> i1 {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor<?xindex>
// CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor<?xindex>
// CHECK: %[[RANK_EQ:.*]] = cmpi eq, %[[RANK_A]], %[[RANK_B]]
// CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[INIT:.*]] = constant true
// CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
// CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
// CHECK: %[[EXTENT_B:.*]] = tensor.extract %[[B]][%[[I]]] : tensor<?xindex>
// CHECK: %[[EXTENT_EQ:.*]] = cmpi eq, %[[EXTENT_A]], %[[EXTENT_B]]
// CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
// CHECK: scf.yield %[[CONJ_NEXT]] : i1
// CHECK: }
// CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
// CHECK: } else {
// CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false
// CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
// CHECK: }
// CHECK: %[[RANK_C:.*]] = dim %[[C]], %[[C0]] : tensor<?xindex>
// CHECK: %[[RANK_EQ:.*]] = cmpi eq, %[[RANK_A]], %[[RANK_C]]
// CHECK: %[[SHAPE_EQ2:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[INIT:.*]] = constant true
// CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
// CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
// CHECK: %[[EXTENT_C:.*]] = tensor.extract %[[C]][%[[I]]] : tensor<?xindex>
// CHECK: %[[EXTENT_EQ:.*]] = cmpi eq, %[[EXTENT_A]], %[[EXTENT_C]]
// CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
// CHECK: scf.yield %[[CONJ_NEXT]] : i1
// CHECK: }
// CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
// CHECK: } else {
// CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false
// CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
// CHECK: }
// CHECK: %[[RESULT:.*]] = and %[[SHAPE_EQ]], %[[SHAPE_EQ2]] : i1
// CHECK: return %[[RESULT]] : i1
%result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
return %result : i1
}
// -----
// Don't lower `shape.broadcast` if a `shape.shape` type is involved.
// CHECK-LABEL: @broadcast
func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {

View file

@ -864,7 +864,8 @@ func @shape_eq_fold_1() -> i1 {
// CHECK: return %[[RESULT]] : i1
%a = shape.const_shape [1, 2, 3] : !shape.shape
%b = shape.const_shape [1, 2, 3] : tensor<?xindex>
%result = shape.shape_eq %a, %b : !shape.shape, tensor<?xindex>
%c = shape.const_shape [1, 2, 3] : tensor<?xindex>
%result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<?xindex>, tensor<?xindex>
return %result : i1
}
@ -877,7 +878,8 @@ func @shape_eq_fold_0() -> i1 {
// CHECK: return %[[RESULT]] : i1
%a = shape.const_shape [1, 2, 3] : tensor<?xindex>
%b = shape.const_shape [4, 5, 6] : tensor<?xindex>
%result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
%c = shape.const_shape [4, 5, 6] : tensor<?xindex>
%result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
return %result : i1
}
@ -908,19 +910,6 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
return %result : i1
}
// -----
// Fold `shape_eq` for non-constant but same shapes.
// CHECK-LABEL: @shape_eq_do_fold
// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1
func @shape_eq_do_fold(%a : !shape.shape) -> i1 {
// CHECK: %[[RESULT:.*]] = constant true
// CHECK: return %[[RESULT]] : i1
%result = shape.shape_eq %a, %a : !shape.shape, !shape.shape
return %result : i1
}
// -----
// Fold `mul` for constant sizes.