[mlir] Fix indexed_accessor_range to properly forward the derived class.

Summary: This fixes the return value of helper methods on the base range class.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D72127
This commit is contained in:
River Riddle 2020-01-03 13:12:25 -08:00
parent 21309eafde
commit 0d9ca98c1a
6 changed files with 70 additions and 25 deletions

View file

@ -598,8 +598,8 @@ public:
iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
private:
/// See `detail::indexed_accessor_range_base` for details.
static OpResult dereference_iterator(Operation *op, ptrdiff_t index);
/// See `indexed_accessor_range` for details.
static OpResult dereference(Operation *op, ptrdiff_t index);
/// Allow access to `dereference_iterator`.
friend indexed_accessor_range<ResultRange, Operation *, OpResult, OpResult,

View file

@ -222,6 +222,8 @@ public:
count(end.getIndex() - begin.getIndex()) {}
indexed_accessor_range_base(const iterator_range<iterator> &range)
: indexed_accessor_range_base(range.begin(), range.end()) {}
indexed_accessor_range_base(BaseT base, ptrdiff_t count)
: base(base), count(count) {}
iterator begin() const { return iterator(base, 0); }
iterator end() const { return iterator(base, count); }
@ -267,8 +269,6 @@ public:
}
protected:
indexed_accessor_range_base(BaseT base, ptrdiff_t count)
: base(base), count(count) {}
indexed_accessor_range_base(const indexed_accessor_range_base &) = default;
indexed_accessor_range_base(indexed_accessor_range_base &&) = default;
indexed_accessor_range_base &
@ -286,18 +286,20 @@ protected:
/// bases that are offsetable should derive from indexed_accessor_range_base
/// instead. Derived range classes are expected to implement the following
/// static method:
/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index)
/// * ReferenceT dereference(const BaseT &base, ptrdiff_t index)
/// - Derefence an iterator pointing to a parent base at the given index.
template <typename DerivedT, typename BaseT, typename T,
typename PointerT = T *, typename ReferenceT = T &>
class indexed_accessor_range
: public detail::indexed_accessor_range_base<
indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT> {
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT> {
public:
indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count)
: detail::indexed_accessor_range_base<
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>(
std::make_pair(base, startIndex), count) {}
using detail::indexed_accessor_range_base<
indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
std::pair<BaseT, ptrdiff_t>, T, PointerT,
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT,
ReferenceT>::indexed_accessor_range_base;
/// Returns the current base of the range.
@ -306,14 +308,6 @@ public:
/// Returns the current start index of the range.
ptrdiff_t getStartIndex() const { return this->base.second; }
protected:
indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count)
: detail::indexed_accessor_range_base<
indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>(
std::make_pair(base, startIndex), count) {}
private:
/// See `detail::indexed_accessor_range_base` for details.
static std::pair<BaseT, ptrdiff_t>
offset_base(const std::pair<BaseT, ptrdiff_t> &base, ptrdiff_t index) {
@ -325,13 +319,8 @@ private:
static ReferenceT
dereference_iterator(const std::pair<BaseT, ptrdiff_t> &base,
ptrdiff_t index) {
return DerivedT::dereference_iterator(base.first, base.second + index);
return DerivedT::dereference(base.first, base.second + index);
}
/// Allow access to `offset_base` and `dereference_iterator`.
friend detail::indexed_accessor_range_base<
indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>;
};
/// Given a container of pairs, return a range over the second elements.

View file

@ -152,8 +152,8 @@ OperandRange::OperandRange(Operation *op)
ResultRange::ResultRange(Operation *op)
: ResultRange(op, /*startIndex=*/0, op->getNumResults()) {}
/// See `detail::indexed_accessor_range_base` for details.
OpResult ResultRange::dereference_iterator(Operation *op, ptrdiff_t index) {
/// See `indexed_accessor_range` for details.
OpResult ResultRange::dereference(Operation *op, ptrdiff_t index) {
return op->getResult(index);
}

View file

@ -10,4 +10,5 @@ add_subdirectory(Dialect)
add_subdirectory(IR)
add_subdirectory(Pass)
add_subdirectory(SDBM)
add_subdirectory(Support)
add_subdirectory(TableGen)

View file

@ -0,0 +1,6 @@
add_mlir_unittest(MLIRSupportTests
IndexedAccessorTest.cpp
)
target_link_libraries(MLIRSupportTests
PRIVATE MLIRSupport)

View file

@ -0,0 +1,49 @@
//===- IndexedAccessorTest.cpp - Indexed Accessor Tests -------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/ArrayRef.h"
#include "gmock/gmock.h"
using namespace mlir;
using namespace mlir::detail;
namespace {
/// Simple indexed accessor range that wraps an array.
template <typename T>
struct ArrayIndexedAccessorRange
: public indexed_accessor_range<ArrayIndexedAccessorRange<T>, T *, T> {
ArrayIndexedAccessorRange(T *data, ptrdiff_t start, ptrdiff_t numElements)
: indexed_accessor_range<ArrayIndexedAccessorRange<T>, T *, T>(
data, start, numElements) {}
using indexed_accessor_range<ArrayIndexedAccessorRange<T>, T *,
T>::indexed_accessor_range;
/// See `indexed_accessor_range` for details.
static T &dereference(T *data, ptrdiff_t index) { return data[index]; }
};
} // end anonymous namespace
template <typename T>
static void compareData(ArrayIndexedAccessorRange<T> range,
ArrayRef<T> referenceData) {
ASSERT_TRUE(referenceData.size() == range.size());
ASSERT_TRUE(std::equal(range.begin(), range.end(), referenceData.begin()));
}
namespace {
TEST(AccessorRange, SliceTest) {
int rawData[] = {0, 1, 2, 3, 4};
ArrayRef<int> data = llvm::makeArrayRef(rawData);
ArrayIndexedAccessorRange<int> range(rawData, /*start=*/0, /*numElements=*/5);
compareData(range, data);
compareData(range.slice(2, 3), data.slice(2, 3));
compareData(range.slice(0, 5), data.slice(0, 5));
}
} // end anonymous namespace