From 1e6a93b7cae61c777ec4ce06a5f6d7d2b81af0ea Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Fri, 13 Sep 2019 18:18:21 -0700 Subject: [PATCH] add missing memref cast fold pattern for dim op - add missing canonicalization pattern to fold memref_cast + dim to dim (needed to propagate constant when folding a dynamic shape to a static one) - also fix an outdated/inconsistent comment in StandardOps/Ops.td Signed-off-by: Uday Bondhugula Closes tensorflow/mlir#126 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/126 from bondhugula:quickfix 4566e75e49685c532faffff91d64c5d83d4da524 PiperOrigin-RevId: 269020058 --- mlir/include/mlir/Dialect/StandardOps/Ops.td | 5 +++-- mlir/lib/Dialect/StandardOps/Ops.cpp | 6 ++++++ mlir/test/Transforms/canonicalize.mlir | 12 ++++++++---- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index 4f4165349fa3..629bddfe2830 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -556,6 +556,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> { }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def DivFOp : FloatArithmeticOp<"divf"> { @@ -580,9 +581,9 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { with the same type as the elements of the tensor or vector. The arity of indices matches the rank of the accessed value (i.e., if a tensor is of rank 3, then 3 indices are required for the extract). The indices should all be - of affine_int type. For example: + of index type. For example: - %0 = extract_element %0[%1, %2] : vector<4x4xi32> + %3 = extract_element %0[%1, %2] : vector<4x4xi32> }]; let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate, diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index ef7b795d5f0e..3b6800058910 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1381,6 +1381,12 @@ OpFoldResult DimOp::fold(ArrayRef operands) { return {}; } +void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + /// dim(memrefcast) -> dim + results.insert(getOperationName(), context); +} + //===----------------------------------------------------------------------===// // DivISOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 680d6feda16e..b954b697d4cd 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -256,21 +256,25 @@ func @xor_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { // CHECK-LABEL: func @memref_cast_folding func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 { + // CHECK-NOT: memref_cast %1 = memref_cast %arg0 : memref<4xf32> to memref - - // CHECK-NEXT: %c0 = constant 0 : index %c0 = constant 0 : index + // CHECK-NOT: dim + %dim = dim %1, 0 : memref + + // CHECK: affine.load %arg0[%c4 - 1] + affine.load %1[%dim - 1] : memref // CHECK-NEXT: store %arg1, %arg0[%c0] : memref<4xf32> store %arg1, %1[%c0] : memref - // CHECK-NEXT: %0 = load %arg0[%c0] : memref<4xf32> + // CHECK-NEXT: %{{.*}} = load %arg0[%c0] : memref<4xf32> %0 = load %1[%c0] : memref // CHECK-NEXT: dealloc %arg0 : memref<4xf32> dealloc %1: memref - // CHECK-NEXT: return %0 + // CHECK-NEXT: return %{{.*}} return %0 : f32 }