[mlir][Shape] Make shape_eq nary
This gets rid of a dubious shape_eq %a, %a fold, that folds shape_eq even if %a is not an Attribute. Differential Revision: https://reviews.llvm.org/D97728
This commit is contained in:
parent
64f5d7e972
commit
24acadef8a
|
@ -168,20 +168,38 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> {
|
||||
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative,
|
||||
InferTypeOpInterface]> {
|
||||
let summary = "Returns whether the input shapes or extent tensors are equal";
|
||||
let description = [{
|
||||
Takes two shape or extent tensor operands and determines whether they are
|
||||
equal. When extent tensors are compared to shapes they are regarded as their
|
||||
equivalent non-error shapes. Error shapes can be tested for equality like
|
||||
any other shape value, meaning that the error value is equal to itself.
|
||||
Takes one or more shape or extent tensor operands and determines whether
|
||||
they are equal. When extent tensors are compared to shapes they are regarded
|
||||
as their equivalent non-error shapes. Error shapes can be tested for
|
||||
equality like any other shape value, meaning that the error value is equal
|
||||
to itself.
|
||||
}];
|
||||
|
||||
let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
|
||||
Shape_ShapeOrExtentTensorType:$rhs);
|
||||
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
|
||||
let results = (outs I1:$result);
|
||||
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
|
||||
// Convenience builder alias for the binary version.
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
|
||||
[{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
// TODO: This should really be automatic. Figure out how to not need this defined.
|
||||
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
|
||||
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
|
||||
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
|
||||
::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(::mlir::IntegerType::get(context,
|
||||
/*width=*/1));
|
||||
return success();
|
||||
};
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -474,46 +474,56 @@ struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
|
|||
LogicalResult
|
||||
ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// For now, this lowering is only defined on `tensor<?xindex>` operands, not
|
||||
// on shapes.
|
||||
if (op.lhs().getType().isa<ShapeType>() ||
|
||||
op.rhs().getType().isa<ShapeType>()) {
|
||||
if (!llvm::all_of(op.shapes(),
|
||||
[](Value v) { return !v.getType().isa<ShapeType>(); }))
|
||||
return failure();
|
||||
|
||||
Type i1Ty = rewriter.getI1Type();
|
||||
if (op.shapes().size() <= 1) {
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty,
|
||||
rewriter.getBoolAttr(true));
|
||||
return success();
|
||||
}
|
||||
|
||||
ShapeEqOp::Adaptor transformed(operands);
|
||||
auto loc = op.getLoc();
|
||||
Type indexTy = rewriter.getIndexType();
|
||||
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
|
||||
Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
|
||||
Value eqRank =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
|
||||
Type i1Ty = rewriter.getI1Type();
|
||||
rewriter.replaceOpWithNewOp<IfOp>(
|
||||
op, i1Ty, eqRank,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
Value one = b.create<ConstantIndexOp>(loc, 1);
|
||||
Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
|
||||
auto loop = b.create<scf::ForOp>(
|
||||
loc, zero, lhsRank, one, ValueRange{init},
|
||||
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
|
||||
Value conj = args[0];
|
||||
Value lhsExtent =
|
||||
b.create<tensor::ExtractOp>(loc, transformed.lhs(), iv);
|
||||
Value rhsExtent =
|
||||
b.create<tensor::ExtractOp>(loc, transformed.rhs(), iv);
|
||||
Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
|
||||
lhsExtent, rhsExtent);
|
||||
Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
|
||||
b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
|
||||
});
|
||||
b.create<scf::YieldOp>(loc, loop.getResults());
|
||||
},
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
|
||||
b.create<scf::YieldOp>(loc, result);
|
||||
});
|
||||
Value firstShape = transformed.shapes().front();
|
||||
Value firstRank = rewriter.create<DimOp>(loc, indexTy, firstShape, zero);
|
||||
Value result = nullptr;
|
||||
// Generate a linear sequence of compares, all with firstShape as lhs.
|
||||
for (Value shape : transformed.shapes().drop_front(1)) {
|
||||
Value rank = rewriter.create<DimOp>(loc, indexTy, shape, zero);
|
||||
Value eqRank =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
|
||||
auto same = rewriter.create<IfOp>(
|
||||
loc, i1Ty, eqRank,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
Value one = b.create<ConstantIndexOp>(loc, 1);
|
||||
Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
|
||||
auto loop = b.create<scf::ForOp>(
|
||||
loc, zero, firstRank, one, ValueRange{init},
|
||||
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
|
||||
Value conj = args[0];
|
||||
Value lhsExtent =
|
||||
b.create<tensor::ExtractOp>(loc, firstShape, iv);
|
||||
Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
|
||||
Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
|
||||
lhsExtent, rhsExtent);
|
||||
Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
|
||||
b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
|
||||
});
|
||||
b.create<scf::YieldOp>(loc, loop.getResults());
|
||||
},
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
|
||||
b.create<scf::YieldOp>(loc, result);
|
||||
});
|
||||
result = !result ? same.getResult(0)
|
||||
: rewriter.create<AndOp>(loc, result, same.getResult(0));
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -629,15 +629,15 @@ OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (lhs() == rhs())
|
||||
return BoolAttr::get(getContext(), true);
|
||||
auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (lhs == nullptr)
|
||||
bool allSame = true;
|
||||
if (!operands.empty() && !operands[0])
|
||||
return {};
|
||||
auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (rhs == nullptr)
|
||||
return {};
|
||||
return BoolAttr::get(getContext(), lhs == rhs);
|
||||
for (Attribute operand : operands.drop_front(1)) {
|
||||
if (!operand)
|
||||
return {};
|
||||
allSame = allSame && operand == operands[0];
|
||||
}
|
||||
return BoolAttr::get(getContext(), allSame);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -295,6 +295,53 @@ func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @shape_eq
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> i1
|
||||
func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>, %c : tensor<?xindex>) -> i1 {
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_EQ:.*]] = cmpi eq, %[[RANK_A]], %[[RANK_B]]
|
||||
// CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[INIT:.*]] = constant true
|
||||
// CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
|
||||
// CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
|
||||
// CHECK: %[[EXTENT_B:.*]] = tensor.extract %[[B]][%[[I]]] : tensor<?xindex>
|
||||
// CHECK: %[[EXTENT_EQ:.*]] = cmpi eq, %[[EXTENT_A]], %[[EXTENT_B]]
|
||||
// CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
|
||||
// CHECK: scf.yield %[[CONJ_NEXT]] : i1
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false
|
||||
// CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
|
||||
// CHECK: }
|
||||
// CHECK: %[[RANK_C:.*]] = dim %[[C]], %[[C0]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_EQ:.*]] = cmpi eq, %[[RANK_A]], %[[RANK_C]]
|
||||
// CHECK: %[[SHAPE_EQ2:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[INIT:.*]] = constant true
|
||||
// CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
|
||||
// CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
|
||||
// CHECK: %[[EXTENT_C:.*]] = tensor.extract %[[C]][%[[I]]] : tensor<?xindex>
|
||||
// CHECK: %[[EXTENT_EQ:.*]] = cmpi eq, %[[EXTENT_A]], %[[EXTENT_C]]
|
||||
// CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
|
||||
// CHECK: scf.yield %[[CONJ_NEXT]] : i1
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false
|
||||
// CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
|
||||
// CHECK: }
|
||||
// CHECK: %[[RESULT:.*]] = and %[[SHAPE_EQ]], %[[SHAPE_EQ2]] : i1
|
||||
// CHECK: return %[[RESULT]] : i1
|
||||
%result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
|
||||
return %result : i1
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Don't lower `shape.broadcast` if a `shape.shape` type is involved.
|
||||
// CHECK-LABEL: @broadcast
|
||||
func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
|
||||
|
|
|
@ -864,7 +864,8 @@ func @shape_eq_fold_1() -> i1 {
|
|||
// CHECK: return %[[RESULT]] : i1
|
||||
%a = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%b = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
%result = shape.shape_eq %a, %b : !shape.shape, tensor<?xindex>
|
||||
%c = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
%result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<?xindex>, tensor<?xindex>
|
||||
return %result : i1
|
||||
}
|
||||
|
||||
|
@ -877,7 +878,8 @@ func @shape_eq_fold_0() -> i1 {
|
|||
// CHECK: return %[[RESULT]] : i1
|
||||
%a = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
%b = shape.const_shape [4, 5, 6] : tensor<?xindex>
|
||||
%result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
|
||||
%c = shape.const_shape [4, 5, 6] : tensor<?xindex>
|
||||
%result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
|
||||
return %result : i1
|
||||
}
|
||||
|
||||
|
@ -908,19 +910,6 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
|
|||
return %result : i1
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Fold `shape_eq` for non-constant but same shapes.
|
||||
// CHECK-LABEL: @shape_eq_do_fold
|
||||
// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1
|
||||
func @shape_eq_do_fold(%a : !shape.shape) -> i1 {
|
||||
// CHECK: %[[RESULT:.*]] = constant true
|
||||
// CHECK: return %[[RESULT]] : i1
|
||||
%result = shape.shape_eq %a, %a : !shape.shape, !shape.shape
|
||||
return %result : i1
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Fold `mul` for constant sizes.
|
||||
|
|
Loading…
Reference in a new issue