[MLIR][Shape] Allow shape.any to operate on extent tensors

Differential Revision: https://reviews.llvm.org/D84433
This commit is contained in:
Frederik Gossen 2020-07-24 11:01:23 +00:00
parent 274db1d21a
commit 7f600da828
4 changed files with 43 additions and 19 deletions

View file

@ -509,11 +509,14 @@ def Shape_ConcatOp : Shape_Op<"concat", []> {
//===----------------------------------------------------------------------===//
// TODO: Move the code below and witnesses to a different file.
def Shape_AnyOp : Shape_Op<"any", [Commutative, NoSideEffect]> {
def Shape_AnyOp : Shape_Op<"any", [Commutative,
NoSideEffect,
SameOperandsAndResultType]> {
let summary = "Return any combination of the input shapes";
let description = [{
This operation takes multiple input shapes and returns some combination of
their dimensions. This can be best seen with examples below.
This operation takes multiple input shapes or extent tensors and returns
some combination of their dimensions. This can be best seen with examples
below.
The result is undefined, but still side-effect free, in cases where the
inputs have differing ranks or differ in extents of shared dimensions.
@ -525,11 +528,10 @@ def Shape_AnyOp : Shape_Op<"any", [Commutative, NoSideEffect]> {
```
}];
let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
let results = (outs Shape_ShapeType:$result);
let assemblyFormat = "$inputs attr-dict";
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs);
let results = (outs Shape_ShapeOrExtentTensorType:$result);
let assemblyFormat = "$inputs `:` type($result) attr-dict";
let hasFolder = 1;
}

View file

@ -165,11 +165,12 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
// Lower `any` to its first operand.
// CHECK-LABEL: @any_of_three
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
func @any_of_three(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
-> !shape.shape {
func @any_of_three(%a : tensor<?xindex>,
%b : tensor<?xindex>,
%c : tensor<?xindex>) -> tensor<?xindex> {
// CHECK: return %[[A]] : tensor<?xindex>
%result = shape.any %a, %b, %c
return %result : !shape.shape
%result = shape.any %a, %b, %c : tensor<?xindex>
return %result : tensor<?xindex>
}
// -----
@ -177,9 +178,9 @@ func @any_of_three(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
// Lower `any` to its first operand.
// CHECK-LABEL: @any_of_one
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex>
func @any_of_one(%a : !shape.shape) -> !shape.shape {
func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> {
// CHECK: return %[[A]] : tensor<?xindex>
%result = shape.any %a
return %result : !shape.shape
%result = shape.any %a : tensor<?xindex>
return %result : tensor<?xindex>
}

View file

@ -364,14 +364,25 @@ func @f() {
// any can be replaced with a constant input if it has one.
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) -> !shape.shape {
func @f(%arg : !shape.shape) -> !shape.shape {
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape
// CHECK-NEXT: return %[[CS]]
%0 = shape.const_shape [2, 3, 4] : !shape.shape
%1 = shape.any %0, %arg0
%1 = shape.any %0, %arg : !shape.shape
return %1 : !shape.shape
}
// -----
// any can be replaced with a constant input if it has one.
// CHECK-LABEL: func @f
func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
// CHECK-NEXT: return %[[CS]] : tensor<?xindex>
%0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
%1 = shape.any %0, %arg : tensor<?xindex>
return %1 : tensor<?xindex>
}
// -----
@ -380,7 +391,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
// CHECK-NEXT: %[[CS:.*]] = shape.any
// CHECK-NEXT: return %[[CS]]
%1 = shape.any %arg0, %arg1
%1 = shape.any %arg0, %arg1 : !shape.shape
return %1 : !shape.shape
}

View file

@ -1,4 +1,3 @@
// RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s
// Verify the printed output can be parsed.
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// Verify the generic form can be parsed.
@ -99,7 +98,7 @@ func @test_constraints() {
%w3 = shape.const_witness false
%w4 = shape.assuming_all %w0, %w1, %w2, %w3
shape.assuming %w4 -> !shape.shape {
%2 = shape.any %0, %1
%2 = shape.any %0, %1 : !shape.shape
shape.assuming_yield %2 : !shape.shape
}
return
@ -173,3 +172,14 @@ func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
%result = shape.get_extent %arg, %c0 : tensor<?xindex>
return %result : !shape.size
}
func @any() {
%0 = shape.const_shape [1, 2, 3] : !shape.shape
%1 = shape.const_shape [4, 5, 6] : !shape.shape
%2 = shape.any %0, %1 : !shape.shape
%3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
%4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
%5 = shape.any %3, %4 : tensor<?xindex>
return
}