[mlir] Add additional Canonicalization of shape.cstr_broadcastable.
Summary: Added canonicalization and folding was: - Folding when either input is an attribute indicating a scalar input which can always be broadcasted. - Canonicalization where it can be determined that either input shape is a scalar. - Canonicalization where the partially specified input shapes can be proven to be broadcastable always. Differential Revision: https://reviews.llvm.org/D83194
This commit is contained in:
parent
e4ec6d0afe
commit
2ef71cb7fd
|
@ -47,6 +47,21 @@ namespace util {
|
||||||
bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
|
bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
|
||||||
SmallVectorImpl<int64_t> &resultShape);
|
SmallVectorImpl<int64_t> &resultShape);
|
||||||
|
|
||||||
|
/// Returns true if a broadcast between the 2 shapes is guaranteed to be
|
||||||
|
/// successful and not result in an error. False does not guarantee that the
|
||||||
|
/// shapes are not broadcastable; it might guarantee that they are not
|
||||||
|
/// broadcastable or it might mean that this function does not have enough
|
||||||
|
/// information to know.
|
||||||
|
///
|
||||||
|
/// Conceptually, this returns true if getBroadcastedShape would have returned
|
||||||
|
/// true and vice versa, with one exception. If a dimension is unknown in both
|
||||||
|
/// shapes, getBroadcastedShape would return true and have a result with unknown
|
||||||
|
/// dimension, while this function will return false because it's possible for
|
||||||
|
/// both shapes to have a dimension greater than 1 and different which would
|
||||||
|
/// fail to broadcast.
|
||||||
|
bool staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
|
||||||
|
ArrayRef<int64_t> shape2);
|
||||||
|
|
||||||
/// Returns the result broadcast composition type from the two given types by
|
/// Returns the result broadcast composition type from the two given types by
|
||||||
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
|
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
|
||||||
/// either of the input types has dynamic shape. Returns null type if the two
|
/// either of the input types has dynamic shape. Returns null type if the two
|
||||||
|
|
|
@ -317,21 +317,101 @@ OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
|
||||||
// CstrBroadcastableOp
|
// CstrBroadcastableOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Given an input shape Value, try to obtain the shape's values.
|
||||||
|
LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
|
||||||
|
if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
|
||||||
|
auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
|
||||||
|
if (!type.hasRank())
|
||||||
|
return failure();
|
||||||
|
shapeValues = llvm::to_vector<6>(type.getShape());
|
||||||
|
return success();
|
||||||
|
} else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
|
||||||
|
shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
|
||||||
|
return success();
|
||||||
|
} else {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For shapes that were created by some operations, we can obtain partial
|
||||||
|
// information on the shapes and sometimes determine if they will be
|
||||||
|
// broadcastable with that.
|
||||||
|
struct CstrBroadcastablePartialInfo
|
||||||
|
: public OpRewritePattern<CstrBroadcastableOp> {
|
||||||
|
using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(CstrBroadcastableOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<int64_t, 6> lhsShape, rhsShape;
|
||||||
|
if (failed(getShapeVec(op.lhs(), lhsShape)))
|
||||||
|
return failure();
|
||||||
|
if (failed(getShapeVec(op.rhs(), rhsShape)))
|
||||||
|
return failure();
|
||||||
|
if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Scalars are always broadcastable.
|
||||||
|
struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> {
|
||||||
|
using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(CstrBroadcastableOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<int64_t, 6> shape;
|
||||||
|
if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0)
|
||||||
|
return failure();
|
||||||
|
if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void CstrBroadcastableOp::getCanonicalizationPatterns(
|
void CstrBroadcastableOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||||
// If inputs are equal, return passing witness
|
// Canonicalization patterns have overlap with the considerations during
|
||||||
patterns.insert<CstrBroadcastableEqOps>(context);
|
// folding in case additional shape information is inferred at some point that
|
||||||
|
// does not result in folding.
|
||||||
|
patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo,
|
||||||
|
CstrBroadcastableScalar>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (!operands[0] || !operands[1])
|
// Both operands are not needed if one is a scalar.
|
||||||
|
if (operands[0] &&
|
||||||
|
operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
|
||||||
|
return BoolAttr::get(true, getContext());
|
||||||
|
if (operands[1] &&
|
||||||
|
operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
|
||||||
|
return BoolAttr::get(true, getContext());
|
||||||
|
|
||||||
|
if (operands[0] && operands[1]) {
|
||||||
|
auto lhsShape = llvm::to_vector<6>(
|
||||||
|
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||||
|
auto rhsShape = llvm::to_vector<6>(
|
||||||
|
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||||
|
SmallVector<int64_t, 6> resultShape;
|
||||||
|
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
|
||||||
|
return BoolAttr::get(true, getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lastly, see if folding can be completed based on what constraints are known
|
||||||
|
// on the input shapes.
|
||||||
|
SmallVector<int64_t, 6> lhsShape, rhsShape;
|
||||||
|
if (failed(getShapeVec(lhs(), lhsShape)))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto lhsShape = llvm::to_vector<6>(
|
if (failed(getShapeVec(rhs(), rhsShape)))
|
||||||
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
|
return nullptr;
|
||||||
auto rhsShape = llvm::to_vector<6>(
|
|
||||||
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
|
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
|
||||||
SmallVector<int64_t, 6> resultShape;
|
|
||||||
if (OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
|
|
||||||
return BoolAttr::get(true, getContext());
|
return BoolAttr::get(true, getContext());
|
||||||
|
|
||||||
// Because a failing witness result here represents an eventual assertion
|
// Because a failing witness result here represents an eventual assertion
|
||||||
|
|
|
@ -13,6 +13,23 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
|
||||||
|
ArrayRef<int64_t> shape2) {
|
||||||
|
// Two dimensions are compatible when
|
||||||
|
// 1. they are defined and equal, or
|
||||||
|
// 2. one of them is 1
|
||||||
|
return llvm::all_of(llvm::zip(llvm::reverse(shape1), llvm::reverse(shape2)),
|
||||||
|
[](auto dimensions) {
|
||||||
|
auto dim1 = std::get<0>(dimensions);
|
||||||
|
auto dim2 = std::get<1>(dimensions);
|
||||||
|
if (dim1 == 1 || dim2 == 1)
|
||||||
|
return true;
|
||||||
|
if (dim1 == dim2 && !ShapedType::isDynamic(dim1))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
||||||
ArrayRef<int64_t> shape2,
|
ArrayRef<int64_t> shape2,
|
||||||
SmallVectorImpl<int64_t> &resultShape) {
|
SmallVectorImpl<int64_t> &resultShape) {
|
||||||
|
|
|
@ -403,8 +403,8 @@ func @f() {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// Broadcastable with non-broadcastable constant shapes is always false
|
// Broadcastable with non-broadcastable constant shapes is always false
|
||||||
// CHECK-LABEL: func @f
|
// CHECK-LABEL: func @static_non_broadcastable
|
||||||
func @f() {
|
func @static_non_broadcastable() {
|
||||||
// CHECK-NEXT: shape.const_shape
|
// CHECK-NEXT: shape.const_shape
|
||||||
// CHECK-NEXT: shape.const_shape
|
// CHECK-NEXT: shape.const_shape
|
||||||
// CHECK-NEXT: shape.cstr_broadcastable
|
// CHECK-NEXT: shape.cstr_broadcastable
|
||||||
|
@ -515,3 +515,49 @@ func @size_to_index_to_size(%size : !shape.size) -> !shape.size {
|
||||||
return %result : !shape.size
|
return %result : !shape.size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Canonicalize scalar cstr_broadcastable checks
|
||||||
|
// CHECK-LABEL: @cstr_broadcastable_scalar
|
||||||
|
func @cstr_broadcastable_scalar(%arg0 : tensor<?xf32>) {
|
||||||
|
// CHECK-NEXT: shape.const_witness true
|
||||||
|
// CHECK-NEXT: consume.witness
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
%0 = shape.const_shape []
|
||||||
|
%1 = shape.shape_of %arg0 : tensor<?xf32>
|
||||||
|
%2 = shape.cstr_broadcastable %0, %1
|
||||||
|
"consume.witness"(%2) : (!shape.witness) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Do not canonicalize cstr_broadcastable checks with 2 unknowns
|
||||||
|
// CHECK-LABEL: @cstr_broadcastable_unknown
|
||||||
|
func @cstr_broadcastable_unknown(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) {
|
||||||
|
// CHECK-NEXT: shape.shape_of %arg0
|
||||||
|
// CHECK-NEXT: shape.shape_of %arg1
|
||||||
|
// CHECK-NEXT: shape.cstr_broadcastable
|
||||||
|
// CHECK-NEXT: consume.witness
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
%0 = shape.shape_of %arg0 : tensor<?xf32>
|
||||||
|
%1 = shape.shape_of %arg1 : tensor<?xf32>
|
||||||
|
%2 = shape.cstr_broadcastable %0, %1
|
||||||
|
"consume.witness"(%2) : (!shape.witness) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Scalars are safe to broadcast to unranked sizes.
|
||||||
|
// CHECK-LABEL: @cstr_broadcastable_scalar_unranked
|
||||||
|
func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<index>) {
|
||||||
|
// CHECK-NEXT: shape.const_witness true
|
||||||
|
// CHECK-NEXT: consume.witness
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
%0 = shape.shape_of %arg1 : tensor<index>
|
||||||
|
%1 = shape.shape_of %arg0 : tensor<*xf32>
|
||||||
|
%2 = shape.cstr_broadcastable %0, %1
|
||||||
|
"consume.witness"(%2) : (!shape.witness) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue