[mlir] Fold shape.eq %a, %a to true

Differential Revision: https://reviews.llvm.org/D95430
This commit is contained in:
Tres Popp 2021-01-26 11:18:24 +01:00
parent 7b3ba8dd02
commit bc8d8e69a6
2 changed files with 16 additions and 1 deletions

View file

@ -572,6 +572,8 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs())
return BoolAttr::get(true, getContext());
auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (lhs == nullptr)
return {};

View file

@ -787,7 +787,7 @@ func @shape_eq_fold_0() -> i1 {
// -----
// Do not fold `shape_eq` for non-constant shapes.
// Do not fold `shape_eq` for non-constant different shapes.
// CHECK-LABEL: @shape_eq_do_not_fold
// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1
func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
@ -799,6 +799,19 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
return %result : i1
}
// -----
// Fold `shape_eq` for non-constant but same shapes.
// CHECK-LABEL: @shape_eq_do_fold
// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1
func @shape_eq_do_fold(%a : !shape.shape) -> i1 {
// CHECK: %[[RESULT:.*]] = constant true
// CHECK: return %[[RESULT]] : i1
%result = shape.shape_eq %a, %a : !shape.shape, !shape.shape
return %result : i1
}
// -----
// Fold `mul` for constant sizes.