llvm/flang/runtime/dot-product.cpp
peter klausler 5e1421b22f [flang] Implement MATMUL in the runtime
Define an API for the transformational intrinsic function MATMUL,
implement it, and add some basic unit tests.  The large number of
possible argument type combinations are covered by a set of
generalized templates that are instantiated for each valid
pair of possible argument types.

Places where BLAS-2/3 routines could be called for acceleration
are marked with TODOs.  Handling for other special cases (e.g.,
known-shape 3x3 matrices and vectors) are deferred.

Some minor tweaks were made to the recent related implementation
of DOT_PRODUCT to reflect lessons learned.

Differential Revision: https://reviews.llvm.org/D102652
2021-05-18 10:59:52 -07:00

181 lines
7.1 KiB
C++

//===-- runtime/dot-product.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "cpp-type.h"
#include "descriptor.h"
#include "reduction.h"
#include "terminator.h"
#include "tools.h"
#include <cinttypes>
namespace Fortran::runtime {
template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
class Accumulator {
public:
using Result = RESULT;
Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {}
void Accumulate(SubscriptValue xAt, SubscriptValue yAt) {
if constexpr (XCAT == TypeCategory::Complex) {
sum_ += std::conj(static_cast<Result>(*x_.Element<XT>(&xAt))) *
static_cast<Result>(*y_.Element<YT>(&yAt));
} else if constexpr (XCAT == TypeCategory::Logical) {
sum_ = sum_ ||
(IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt));
} else {
sum_ += static_cast<Result>(*x_.Element<XT>(&xAt)) *
static_cast<Result>(*y_.Element<YT>(&yAt));
}
}
Result GetResult() const { return sum_; }
private:
const Descriptor &x_, &y_;
Result sum_{};
};
template <typename RESULT, TypeCategory XCAT, typename XT, typename YT>
static inline RESULT DoDotProduct(
const Descriptor &x, const Descriptor &y, Terminator &terminator) {
RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1);
SubscriptValue n{x.GetDimension(0).Extent()};
if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) {
terminator.Crash(
"DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd",
static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN));
}
if constexpr (std::is_same_v<XT, YT>) {
if constexpr (std::is_same_v<XT, float>) {
// TODO: call BLAS-1 SDOT or SDSDOT
} else if constexpr (std::is_same_v<XT, double>) {
// TODO: call BLAS-1 DDOT
} else if constexpr (std::is_same_v<XT, std::complex<float>>) {
// TODO: call BLAS-1 CDOTC
} else if constexpr (std::is_same_v<XT, std::complex<float>>) {
// TODO: call BLAS-1 ZDOTC
}
}
SubscriptValue xAt{x.GetDimension(0).LowerBound()};
SubscriptValue yAt{y.GetDimension(0).LowerBound()};
Accumulator<RESULT, XCAT, XT, YT> accumulator{x, y};
for (SubscriptValue j{0}; j < n; ++j) {
accumulator.Accumulate(xAt++, yAt++);
}
return accumulator.GetResult();
}
template <TypeCategory RCAT, int RKIND> struct DotProduct {
using Result = CppTypeFor<RCAT, RKIND>;
template <TypeCategory XCAT, int XKIND> struct DP1 {
template <TypeCategory YCAT, int YKIND> struct DP2 {
Result operator()(const Descriptor &x, const Descriptor &y,
Terminator &terminator) const {
if constexpr (constexpr auto resultType{
GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
if constexpr (resultType->first == RCAT &&
resultType->second <= RKIND) {
return DoDotProduct<Result, XCAT, CppTypeFor<XCAT, XKIND>,
CppTypeFor<YCAT, YKIND>>(x, y, terminator);
}
}
terminator.Crash(
"DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))",
static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND,
static_cast<int>(YCAT), YKIND);
}
};
Result operator()(const Descriptor &x, const Descriptor &y,
Terminator &terminator, TypeCategory yCat, int yKind) const {
return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator);
}
};
Result operator()(const Descriptor &x, const Descriptor &y,
const char *source, int line) const {
Terminator terminator{source, line};
auto xCatKind{x.type().GetCategoryAndKind()};
auto yCatKind{y.type().GetCategoryAndKind()};
RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, terminator,
x, y, terminator, yCatKind->first, yCatKind->second);
}
};
extern "C" {
std::int8_t RTNAME(DotProductInteger1)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
std::int16_t RTNAME(DotProductInteger2)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
std::int32_t RTNAME(DotProductInteger4)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
std::int64_t RTNAME(DotProductInteger8)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line);
}
#ifdef __SIZEOF_INT128__
common::int128_t RTNAME(DotProductInteger16)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line);
}
#endif
// TODO: REAL/COMPLEX(2 & 3)
float RTNAME(DotProductReal4)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
}
double RTNAME(DotProductReal8)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line);
}
#if LONG_DOUBLE == 80
long double RTNAME(DotProductReal10)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line);
}
#elif LONG_DOUBLE == 128
long double RTNAME(DotProductReal16)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line);
}
#endif
void RTNAME(CppDotProductComplex4)(std::complex<float> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
auto z{DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line)};
result = std::complex<float>{
static_cast<float>(z.real()), static_cast<float>(z.imag())};
}
void RTNAME(CppDotProductComplex8)(std::complex<double> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line);
}
#if LONG_DOUBLE == 80
void RTNAME(CppDotProductComplex10)(std::complex<long double> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line);
}
#elif LONG_DOUBLE == 128
void RTNAME(CppDotProductComplex16)(std::complex<long double> &result,
const Descriptor &x, const Descriptor &y, const char *source, int line) {
result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line);
}
#endif
bool RTNAME(DotProductLogical)(
const Descriptor &x, const Descriptor &y, const char *source, int line) {
return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line);
}
} // extern "C"
} // namespace Fortran::runtime