[MLIR][Shape] Simplify shape lowering
Differential Revision: https://reviews.llvm.org/D84161
This commit is contained in:
parent
d4e4d5d780
commit
a85ca6be2a
|
@ -172,39 +172,37 @@ LogicalResult
|
|||
ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ShapeOfOp::Adaptor transformed(operands);
|
||||
auto tensorVal = transformed.arg();
|
||||
auto tensorTy = tensorVal.getType();
|
||||
Value arg = transformed.arg();
|
||||
Type argTy = arg.getType();
|
||||
|
||||
// For ranked tensors `shape_of` lowers to `std` and the pattern can be
|
||||
// found in the corresponding pass.
|
||||
if (tensorTy.isa<RankedTensorType>())
|
||||
if (argTy.isa<RankedTensorType>())
|
||||
return failure();
|
||||
|
||||
// Allocate stack memory.
|
||||
auto loc = op.getLoc();
|
||||
auto rankVal = rewriter.create<mlir::RankOp>(loc, tensorVal);
|
||||
auto i64Ty = rewriter.getI64Type();
|
||||
auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
|
||||
auto memVal = rewriter.create<AllocaOp>(loc, memTy, ValueRange({rankVal}));
|
||||
Value rank = rewriter.create<mlir::RankOp>(loc, arg);
|
||||
Type i64Ty = rewriter.getI64Type();
|
||||
Type memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
|
||||
Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{rank});
|
||||
|
||||
// Copy shape extents to stack-allocated memory.
|
||||
auto zeroVal = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
auto oneVal = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
rewriter.create<scf::ForOp>(
|
||||
loc, zeroVal, rankVal, oneVal, llvm::None,
|
||||
[&](OpBuilder &b, Location loc, Value iVal, ValueRange args) {
|
||||
auto dimVal = rewriter.create<DimOp>(loc, tensorVal, iVal);
|
||||
auto dimIntVal = rewriter.create<IndexCastOp>(loc, dimVal, i64Ty);
|
||||
rewriter.create<StoreOp>(loc, dimIntVal, memVal, ValueRange{iVal});
|
||||
loc, zero, rank, one, llvm::None,
|
||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
|
||||
Value dim = rewriter.create<DimOp>(loc, arg, iv);
|
||||
Value dimInt = rewriter.create<IndexCastOp>(loc, dim, i64Ty);
|
||||
rewriter.create<StoreOp>(loc, dimInt, mem, ValueRange{iv});
|
||||
rewriter.create<scf::YieldOp>(loc);
|
||||
});
|
||||
|
||||
// Load extents to tensor value.
|
||||
auto shapeIntVal = rewriter.create<TensorLoadOp>(loc, memVal);
|
||||
auto indexTy = rewriter.getIndexType();
|
||||
auto shapeTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
|
||||
rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), shapeIntVal,
|
||||
shapeTy);
|
||||
Value extentTensorInt = rewriter.create<TensorLoadOp>(loc, mem);
|
||||
rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), extentTensorInt,
|
||||
op.getType());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue