[mlir][Vector] Simplify code a bit. NFCI.
This commit is contained in:
parent
04b99a4d18
commit
eb41f9edde
|
@ -184,9 +184,9 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
|
|||
auto lhsType = types[0].cast<VectorType>();
|
||||
auto rhsType = types[1].cast<VectorType>();
|
||||
auto maskElementType = parser.getBuilder().getI1Type();
|
||||
SmallVector<Type, 2> maskTypes;
|
||||
maskTypes.push_back(VectorType::get(lhsType.getShape(), maskElementType));
|
||||
maskTypes.push_back(VectorType::get(rhsType.getShape(), maskElementType));
|
||||
std::array<Type, 2> maskTypes = {
|
||||
VectorType::get(lhsType.getShape(), maskElementType),
|
||||
VectorType::get(rhsType.getShape(), maskElementType)};
|
||||
if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
|
@ -462,12 +462,10 @@ std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
|
|||
}
|
||||
|
||||
SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
|
||||
SmallVector<AffineMap, 4> res;
|
||||
auto mapAttrs = indexing_maps().getValue();
|
||||
res.reserve(mapAttrs.size());
|
||||
for (auto mapAttr : mapAttrs)
|
||||
res.push_back(mapAttr.cast<AffineMapAttr>().getValue());
|
||||
return res;
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
|
||||
return mapAttr.cast<AffineMapAttr>().getValue();
|
||||
}));
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
|
||||
|
@ -1854,8 +1852,7 @@ LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
|
|||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
|
||||
auto s = getVectorType().getShape();
|
||||
return SmallVector<int64_t, 4>{s.begin(), s.end()};
|
||||
return llvm::to_vector<4>(getVectorType().getShape());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2014,11 +2011,8 @@ static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
|
|||
auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
|
||||
SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
|
||||
memRefType.getShape().end());
|
||||
if (vectorType) {
|
||||
res.reserve(memRefType.getRank() + vectorType.getRank());
|
||||
for (auto s : vectorType.getShape())
|
||||
res.push_back(s);
|
||||
}
|
||||
if (vectorType)
|
||||
res.append(vectorType.getShape().begin(), vectorType.getShape().end());
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
@ -1707,7 +1707,7 @@ void ContractionOpToOuterProductOpLowering::rewrite(
|
|||
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
|
||||
AffineExpr m, n, k;
|
||||
bindDims(rewriter.getContext(), m, n, k);
|
||||
SmallVector<int64_t, 2> perm{1, 0};
|
||||
static constexpr std::array<int64_t, 2> perm = {1, 0};
|
||||
auto iteratorTypes = op.iterator_types().getValue();
|
||||
SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
|
||||
if (isParallelIterator(iteratorTypes[0]) &&
|
||||
|
@ -1911,10 +1911,10 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
|
|||
assert(lookup.hasValue() && "parallel index not listed in reduction");
|
||||
int64_t resIndex = lookup.getValue();
|
||||
// Construct new iterator types and affine map array attribute.
|
||||
SmallVector<AffineMap, 4> lowIndexingMaps;
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
|
||||
std::array<AffineMap, 3> lowIndexingMaps = {
|
||||
adjustMap(iMap[0], iterIndex, rewriter),
|
||||
adjustMap(iMap[1], iterIndex, rewriter),
|
||||
adjustMap(iMap[2], iterIndex, rewriter)};
|
||||
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
|
||||
auto lowIter =
|
||||
rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
|
||||
|
@ -1962,10 +1962,10 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
|
|||
op.acc());
|
||||
}
|
||||
// Construct new iterator types and affine map array attribute.
|
||||
SmallVector<AffineMap, 4> lowIndexingMaps;
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
|
||||
std::array<AffineMap, 3> lowIndexingMaps = {
|
||||
adjustMap(iMap[0], iterIndex, rewriter),
|
||||
adjustMap(iMap[1], iterIndex, rewriter),
|
||||
adjustMap(iMap[2], iterIndex, rewriter)};
|
||||
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
|
||||
auto lowIter =
|
||||
rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
|
||||
|
|
Loading…
Reference in a new issue