add missing memref cast fold pattern for dim op

- add missing canonicalization pattern to fold memref_cast + dim to
  dim (needed to propagate constant when folding a dynamic shape to
  a static one)

- also fix an outdated/inconsistent comment in StandardOps/Ops.td

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>

Closes tensorflow/mlir#126

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/126 from bondhugula:quickfix 4566e75e49685c532faffff91d64c5d83d4da524
PiperOrigin-RevId: 269020058
This commit is contained in:
Uday Bondhugula 2019-09-13 18:18:21 -07:00 committed by A. Unique TensorFlower
parent d780bdef20
commit 1e6a93b7ca
3 changed files with 17 additions and 6 deletions

View file

@ -556,6 +556,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def DivFOp : FloatArithmeticOp<"divf"> {
@ -580,9 +581,9 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
with the same type as the elements of the tensor or vector. The arity of
indices matches the rank of the accessed value (i.e., if a tensor is of rank
3, then 3 indices are required for the extract). The indices should all be
of affine_int type. For example:
of index type. For example:
%0 = extract_element %0[%1, %2] : vector<4x4xi32>
%3 = extract_element %0[%1, %2] : vector<4x4xi32>
}];
let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate,

View file

@ -1381,6 +1381,12 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return {};
}
void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
/// dim(memrefcast) -> dim
results.insert<MemRefCastFolder>(getOperationName(), context);
}
//===----------------------------------------------------------------------===//
// DivISOp
//===----------------------------------------------------------------------===//

View file

@ -256,21 +256,25 @@ func @xor_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
// CHECK-LABEL: func @memref_cast_folding
func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 {
// CHECK-NOT: memref_cast
%1 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
// CHECK-NEXT: %c0 = constant 0 : index
%c0 = constant 0 : index
// CHECK-NOT: dim
%dim = dim %1, 0 : memref<? x f32>
// CHECK: affine.load %arg0[%c4 - 1]
affine.load %1[%dim - 1] : memref<?xf32>
// CHECK-NEXT: store %arg1, %arg0[%c0] : memref<4xf32>
store %arg1, %1[%c0] : memref<?xf32>
// CHECK-NEXT: %0 = load %arg0[%c0] : memref<4xf32>
// CHECK-NEXT: %{{.*}} = load %arg0[%c0] : memref<4xf32>
%0 = load %1[%c0] : memref<?xf32>
// CHECK-NEXT: dealloc %arg0 : memref<4xf32>
dealloc %1: memref<?xf32>
// CHECK-NEXT: return %0
// CHECK-NEXT: return %{{.*}}
return %0 : f32
}