diff --git a/Core/include/Acts/Geometry/SurfaceArrayCreator.hpp b/Core/include/Acts/Geometry/SurfaceArrayCreator.hpp index 7ac78342408..c14a3582c4b 100644 --- a/Core/include/Acts/Geometry/SurfaceArrayCreator.hpp +++ b/Core/include/Acts/Geometry/SurfaceArrayCreator.hpp @@ -20,7 +20,6 @@ #include "Acts/Utilities/BinningType.hpp" #include "Acts/Utilities/Logger.hpp" -#include #include #include #include diff --git a/Core/include/Acts/Utilities/Axis.hpp b/Core/include/Acts/Utilities/Axis.hpp index 69407e01f6d..a347495a5a8 100644 --- a/Core/include/Acts/Utilities/Axis.hpp +++ b/Core/include/Acts/Utilities/Axis.hpp @@ -218,7 +218,7 @@ class Axis : public IAxis { /// get bin width /// @return constant width for all bins - double getBinWidth(std::size_t /*bin*/ = 0) const { return m_width; } + double getBinWidth(std::size_t /*bin*/ = 0) const final { return m_width; } /// get lower bound of bin /// @param bin index of bin @@ -229,7 +229,7 @@ class Axis : public IAxis { /// /// @note Bin intervals have a closed lower bound, i.e. the lower boundary /// belongs to the bin with the given bin index. - double getBinLowerBound(std::size_t bin) const { + double getBinLowerBound(std::size_t bin) const final { return getMin() + (bin - 1) * getBinWidth(); } @@ -240,7 +240,7 @@ class Axis : public IAxis { /// i.e. \f$0 \le \text{bin} \le \text{nBins}\f$ /// @note Bin intervals have an open upper bound, i.e. the upper boundary /// does @b not belong to the bin with the given bin index. - double getBinUpperBound(std::size_t bin) const { + double getBinUpperBound(std::size_t bin) const final { return getMin() + bin * getBinWidth(); } @@ -249,7 +249,7 @@ class Axis : public IAxis { /// @return bin center position /// @pre @c bin must be a valid bin index (excluding under-/overflow bins), /// i.e. \f$1 \le \text{bin} \le \text{nBins}\f$ - double getBinCenter(std::size_t bin) const { + double getBinCenter(std::size_t bin) const final { return getMin() + (bin - 0.5) * getBinWidth(); } @@ -271,7 +271,7 @@ class Axis : public IAxis { /// @c false /// @post If @c true is returned, the bin containing the given value is a /// valid bin, i.e. it is neither the underflow nor the overflow bin. - bool isInside(double x) const { return (m_min <= x) && (x < m_max); } + bool isInside(double x) const final { return (m_min <= x) && (x < m_max); } /// Return a vector of bin edges /// @return Vector which contains the bin edges @@ -503,7 +503,7 @@ class Axis : public IAxis { /// @return width of given bin /// @pre @c bin must be a valid bin index (excluding under-/overflow bins), /// i.e. \f$1 \le \text{bin} \le \text{nBins}\f$ - double getBinWidth(std::size_t bin) const { + double getBinWidth(std::size_t bin) const final { return m_binEdges.at(bin) - m_binEdges.at(bin - 1); } @@ -514,7 +514,7 @@ class Axis : public IAxis { /// i.e. \f$1 \le \text{bin} \le \text{nBins} + 1\f$ /// @note Bin intervals have a closed lower bound, i.e. the lower boundary /// belongs to the bin with the given bin index. - double getBinLowerBound(std::size_t bin) const { + double getBinLowerBound(std::size_t bin) const final { return m_binEdges.at(bin - 1); } @@ -525,14 +525,16 @@ class Axis : public IAxis { /// i.e. \f$0 \le \text{bin} \le \text{nBins}\f$ /// @note Bin intervals have an open upper bound, i.e. the upper boundary /// does @b not belong to the bin with the given bin index. - double getBinUpperBound(std::size_t bin) const { return m_binEdges.at(bin); } + double getBinUpperBound(std::size_t bin) const final { + return m_binEdges.at(bin); + } /// get bin center /// @param bin index of bin /// @return bin center position /// @pre @c bin must be a valid bin index (excluding under-/overflow bins), /// i.e. \f$1 \le \text{bin} \le \text{nBins}\f$ - double getBinCenter(std::size_t bin) const { + double getBinCenter(std::size_t bin) const final { return 0.5 * (getBinLowerBound(bin) + getBinUpperBound(bin)); } @@ -553,7 +555,7 @@ class Axis : public IAxis { /// @return @c true if \f$\text{xmin} \le x < \text{xmax}\f$, otherwise @c false /// @post If @c true is returned, the bin containing the given value is a /// valid bin, i.e. it is neither the underflow nor the overflow bin. - bool isInside(double x) const { + bool isInside(double x) const final { return (m_binEdges.front() <= x) && (x < m_binEdges.back()); } diff --git a/Core/include/Acts/Utilities/Grid.hpp b/Core/include/Acts/Utilities/Grid.hpp index 89c0834250f..7708edaa101 100644 --- a/Core/include/Acts/Utilities/Grid.hpp +++ b/Core/include/Acts/Utilities/Grid.hpp @@ -12,6 +12,7 @@ #include "Acts/Utilities/IAxis.hpp" #include "Acts/Utilities/IGrid.hpp" #include "Acts/Utilities/Interpolation.hpp" +#include "Acts/Utilities/MultiAxis.hpp" #include "Acts/Utilities/TypeTag.hpp" #include "Acts/Utilities/detail/MultiAxisHelper.hpp" @@ -47,6 +48,8 @@ class Grid final : public IGrid { /// number of dimensions of the grid static constexpr std::size_t DIM = sizeof...(Axes); + /// multi axis type + using multi_axis_t = MultiAxis; /// type of values stored using value_type = T; /// reference type to values stored @@ -103,6 +106,34 @@ class Grid final : public IGrid { m_values.resize(size()); } + /// @brief Constructor from const axis tuple, this will allow + /// creating a grid with a different value type from a template + /// grid object. + /// + /// @param axes + explicit Grid(const multi_axis_t& axes) : m_axes(axes) { + m_values.resize(size()); + } + + /// @brief Move constructor from axis tuple + /// @param axes + explicit Grid(multi_axis_t&& axes) : m_axes(std::move(axes)) { + m_values.resize(size()); + } + + /// @brief constructor from parameters pack of axes and type tag + /// @param axes + explicit Grid(TypeTag /*tag*/, const multi_axis_t& axes) : m_axes(axes) { + m_values.resize(size()); + } + + /// @brief constructor from parameters pack of axes and type tag + /// @param axes + explicit Grid(TypeTag /*tag*/, multi_axis_t&& axes) + : m_axes(std::move(axes)) { + m_values.resize(size()); + } + // Grid(TypeTag /*tag*/, Axes&... axes) = delete; /// @brief access value stored in bin for a given point @@ -224,7 +255,7 @@ class Grid final : public IGrid { /// @pre All local bin indices must be a valid index for the corresponding /// axis (excluding the under-/overflow bins for each axis). point_t binCenter(const index_t& localBins) const { - return detail::MultiAxisHelper::getBinCenter(localBins, m_axes); + return m_axes.getBinCenter(localBins); } AnyPointType binCenterAny(AnyIndexType indices) const override { @@ -255,8 +286,7 @@ class Grid final : public IGrid { /// @pre All local bin indices must be a valid index for the corresponding /// axis (including the under-/overflow bin for this axis). std::size_t globalBinFromLocalBins(const index_t& localBins) const { - return detail::MultiAxisHelper::getFlatIndexFromMultiIndex(localBins, - m_axes); + return m_axes.getFlatIndexFromMultiIndex(localBins); } /// @brief determine global bin index of the bin with the lower left edge @@ -290,7 +320,8 @@ class Grid final : public IGrid { /// @note This could be a under-/overflow bin along one or more axes. template index_t localBinsFromPosition(const Point& point) const { - return detail::MultiAxisHelper::getMultiIndexFromPoint(point, m_axes); + return detail::MultiAxisHelper::getMultiIndexFromPoint( + point, m_axes.getAxesTuple()); } /// @brief determine local bin index for each axis from global bin index @@ -302,7 +333,7 @@ class Grid final : public IGrid { /// @note Local bin indices can contain under-/overflow bins along the /// corresponding axis. index_t localBinsFromGlobalBin(std::size_t bin) const { - return detail::MultiAxisHelper::getMultiIndexFromFlatIndex(bin, m_axes); + return m_axes.getMultiIndexFromFlatIndex(bin); } /// @brief determine local bin index of the bin with the lower left edge @@ -321,12 +352,12 @@ class Grid final : public IGrid { template index_t localBinsFromLowerLeftEdge(const Point& point) const { Point shiftedPoint; - point_t width = detail::MultiAxisHelper::getWidth(m_axes); + point_t width = detail::MultiAxisHelper::getWidth(m_axes.getAxesTuple()); for (std::size_t i = 0; i < DIM; i++) { shiftedPoint[i] = point[i] + width[i] / 2; } - return detail::MultiAxisHelper::getMultiIndexFromPoint(shiftedPoint, - m_axes); + return detail::MultiAxisHelper::getMultiIndexFromPoint( + shiftedPoint, m_axes.getAxesTuple()); } /// @brief retrieve lower-left bin edge from set of local bin indices @@ -337,7 +368,7 @@ class Grid final : public IGrid { /// @pre @c localBins must only contain valid bin indices (excluding /// underflow bins). point_t lowerLeftBinEdge(const index_t& localBins) const { - return detail::MultiAxisHelper::getLowerLeftBinCorner(localBins, m_axes); + return m_axes.getLowerLeftBinCorner(localBins); } /// @copydoc Acts::IGrid::lowerLeftBinEdgeAny @@ -353,7 +384,7 @@ class Grid final : public IGrid { /// @pre @c localBins must only contain valid bin indices (excluding /// overflow bins). point_t upperRightBinEdge(const index_t& localBins) const { - return detail::MultiAxisHelper::getUpperRightBinCorner(localBins, m_axes); + return m_axes.getUpperRightBinCorner(localBins); } /// @copydoc Acts::IGrid::upperRightBinEdgeAny @@ -364,16 +395,16 @@ class Grid final : public IGrid { /// @brief get bin width along each specific axis /// /// @return array giving the bin width alonf all axes - point_t binWidth() const { return detail::MultiAxisHelper::getWidth(m_axes); } + point_t binWidth() const { + return detail::MultiAxisHelper::getWidth(m_axes.getAxesTuple()); + } /// @brief get number of bins along each specific axis /// /// @return array giving the number of bins along all axes /// /// @note Not including under- and overflow bins - index_t numLocalBins() const { - return detail::MultiAxisHelper::getNBins(m_axes); - } + index_t numLocalBins() const { return m_axes.getNBins(); } /// @copydoc Acts::IGrid::numLocalBinsAny AnyIndexType numLocalBinsAny() const override { @@ -383,16 +414,12 @@ class Grid final : public IGrid { /// @brief get the minimum value of all axes of one grid /// /// @return array returning the minima of all given axes - point_t minPosition() const { - return detail::MultiAxisHelper::getMin(m_axes); - } + point_t minPosition() const { return m_axes.getMinPoint(); } /// @brief get the maximum value of all axes of one grid /// /// @return array returning the maxima of all given axes - point_t maxPosition() const { - return detail::MultiAxisHelper::getMax(m_axes); - } + point_t maxPosition() const { return m_axes.getMaxPoint(); } /// @brief set all overflow and underflow bins to a certain value /// @@ -401,7 +428,7 @@ class Grid final : public IGrid { /// void setExteriorBins(const value_type& value) { for (std::size_t index : - detail::MultiAxisHelper::exteriorBinIndices(m_axes)) { + detail::MultiAxisHelper::exteriorBinIndices(m_axes.getAxesTuple())) { at(index) = value; } } @@ -471,7 +498,7 @@ class Grid final : public IGrid { /// along any axis. template bool isInside(const Point& position) const { - return detail::MultiAxisHelper::isInside(position, m_axes); + return detail::MultiAxisHelper::isInside(position, m_axes.getAxesTuple()); } /// @brief get global bin indices for neighborhood @@ -492,7 +519,7 @@ class Grid final : public IGrid { detail::FlatNeighborHoodIndices neighborHoodIndices( const index_t& localBins, std::size_t size = 1u) const { return detail::MultiAxisHelper::neighborHoodIndices(localBins, size, - m_axes); + m_axes.getAxesTuple()); } /// @brief get global bin indices for neighborhood @@ -515,7 +542,7 @@ class Grid final : public IGrid { const index_t& localBins, std::array, DIM>& sizePerAxis) const { return detail::MultiAxisHelper::neighborHoodIndices(localBins, sizePerAxis, - m_axes); + m_axes.getAxesTuple()); } /// @brief total number of bins @@ -525,21 +552,7 @@ class Grid final : public IGrid { /// /// @note This number contains under-and overflow bins along all axes. std::size_t size(bool fullCounter = true) const { - index_t nBinsArray = numLocalBins(); - std::size_t current_size = 1; - // add under-and overflow bins for each axis and multiply all bins - if (fullCounter) { - for (const auto& value : nBinsArray) { - current_size *= value + 2; - } - } - // ignore under-and overflow bins for each axis and multiply all bins - else { - for (const auto& value : nBinsArray) { - current_size *= value; - } - } - return current_size; + return m_axes.getNTotalBins(fullCounter); } /// @brief Convenience function to convert the type of the grid @@ -579,15 +592,12 @@ class Grid final : public IGrid { /// @brief get the axes as a tuple /// @return Reference to the tuple containing all grid axes - const std::tuple& axesTuple() const { return m_axes; } + const std::tuple& axesTuple() const { return m_axes.getAxesTuple(); } /// @brief get the axes as an array of IAxis pointers /// @return Vector containing pointers to all grid axes boost::container::small_vector axes() const override { - boost::container::small_vector result; - auto axes = detail::MultiAxisHelper::getAxes(m_axes); - std::ranges::copy(axes, std::back_inserter(result)); - return result; + return m_axes.getAnyAxesVector(); } /// begin iterator for global bins @@ -622,13 +632,11 @@ class Grid final : public IGrid { } protected: - void toStream(std::ostream& os) const override { - printAxes(os, std::make_index_sequence()); - } + void toStream(std::ostream& os) const override { os << m_axes; } private: - /// set of axis defining the multi-dimensional grid - std::tuple m_axes; + /// multi axis for the grid + multi_axis_t m_axes; /// linear value store for each bin std::vector m_values; @@ -637,19 +645,7 @@ class Grid final : public IGrid { // doesn't make that much sense from an API design standpoint. detail::FlatNeighborHoodIndices rawClosestPointsIndices( const index_t& localBins) const { - return detail::MultiAxisHelper::closestPointsIndices(localBins, m_axes); - } - - template - void printAxes(std::ostream& os, std::index_sequence /*s*/) const { - auto printOne = [&os, this]( - std::integral_constant) { - if constexpr (index > 0) { - os << ", "; - } - os << std::get(m_axes); - }; - (printOne(std::integral_constant()), ...); + return m_axes.getClosestPointsIndices(localBins); } static AnyIndexType toAnyIndexType(const index_t& indices) { diff --git a/Core/include/Acts/Utilities/IAxis.hpp b/Core/include/Acts/Utilities/IAxis.hpp index 8ba6dfedbe6..b80d1e62b35 100644 --- a/Core/include/Acts/Utilities/IAxis.hpp +++ b/Core/include/Acts/Utilities/IAxis.hpp @@ -17,7 +17,7 @@ namespace Acts { -/// Common base class for all Axis instance. This allows generice handling +/// Common base class for all Axis instances. This allows generice handling /// such as for inspection. class IAxis { public: @@ -73,6 +73,37 @@ class IAxis { /// while the index nBins + 1 indicates the overflow bin . virtual std::size_t getBin(double x) const = 0; + /// Check whether value is inside axis limits + /// @param x The value to check + /// @return @c true if the value is within the axis range, otherwise @c false + /// @post If @c true is returned, the bin containing the given value is a + /// valid bin, i.e. it is neither the underflow nor the overflow bin. + virtual bool isInside(double x) const = 0; + + /// Get bin width + /// @param bin index of bin + /// @return width of given bin + virtual double getBinWidth(std::size_t bin) const = 0; + + /// Get lower bound of bin + /// @param bin index of bin + /// @return lower bin boundary + /// @note Bin intervals have a closed lower bound, i.e. the lower boundary + /// belongs to the bin with the given bin index. + virtual double getBinLowerBound(std::size_t bin) const = 0; + + /// Get upper bound of bin + /// @param bin index of bin + /// @return upper bin boundary + /// @note Bin intervals have an open upper bound, i.e. the upper boundary + /// does @b not belong to the bin with the given bin index. + virtual double getBinUpperBound(std::size_t bin) const = 0; + + /// Get bin center + /// @param bin index of bin + /// @return bin center position + virtual double getBinCenter(std::size_t bin) const = 0; + /// Centralized axis factory for equidistant binning /// @param aBoundaryType the axis boundary type /// @param min the minimum edge of the axis diff --git a/Core/include/Acts/Utilities/IMultiAxis.hpp b/Core/include/Acts/Utilities/IMultiAxis.hpp new file mode 100644 index 00000000000..e424bdf0fa0 --- /dev/null +++ b/Core/include/Acts/Utilities/IMultiAxis.hpp @@ -0,0 +1,511 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Utilities/IAxis.hpp" +#include "Acts/Utilities/detail/MultiAxisHelper.hpp" + +#include + +#include + +namespace Acts { + +/// @brief Common base class for all MultiAxis instances. This allows generic +/// handling such as for inspection. +/// +/// A multi-axis describes the binning of a multi-dimensional grid as the +/// product of several one-dimensional @c IAxis objects. The number of axes +/// (i.e. the dimension of the grid) is only known at runtime through this +/// interface; the dimension-aware variant is exposed by the derived +/// @c IMultiAxisND template. +/// +/// This base class exposes a type-erased, dynamically sized API (the @c *Any +/// methods, using small vectors) so that grids of differing dimension can be +/// handled through a common pointer. Bins are addressed either via a +/// multi-index (one local bin index per axis) or via a single flattened global +/// bin index. As for @c IAxis, local bin indices start at @c 1, with index +/// @c 0 and nBins + 1 denoting the underflow and overflow bins of an +/// axis; flattened global indices include these under-/overflow bins. +class IMultiAxis { + public: + /// Small vector type used to hold per-axis values without heap allocation + /// for the common low-dimensional cases. + template + using SmallVector = boost::container::small_vector; + + /// Flattened global bin index across all axes + using FlatIndex = std::size_t; + /// Dynamically sized multi-index holding one local bin index per axis + using AnyMultiIndex = SmallVector; + /// Dynamically sized point holding one coordinate per axis + using AnyPoint = SmallVector; + /// Dynamically sized vector of (non-owning) pointers to the contained axes + using AnyAxesVector = SmallVector; + + virtual ~IMultiAxis() = default; + + /// Get the number of axes spanning the grid + /// @return number of axes (i.e. the dimension of the grid) + virtual std::size_t getNAxes() const = 0; + + /// Get the axis at the given dimension + /// @param i index of the axis + /// @return const reference to the requested axis + virtual const IAxis& getAxis(std::size_t i) const = 0; + + /// Get the number of bins along each axis + /// @return per-axis number of bins (excluding under-/overflow bins) + virtual AnyMultiIndex getNBinsAny() const { + AnyMultiIndex result; + result.reserve(getNAxes()); + for (const IAxis& axis : *this) { + result.push_back(axis.getNBins()); + } + return result; + } + + /// Get the total number of bins in the grid + /// @param fullCounter if @c true the under-/overflow bins of every axis are + /// included in the count, otherwise only the regular bins are counted + /// @return product of the per-axis bin counts + virtual std::size_t getNTotalBins(bool fullCounter = true) const { + std::size_t result = 1; + for (const IAxis& axis : *this) { + result *= axis.getNBins() + (fullCounter ? 2 : 0); + } + return result; + } + + /// Get (non-owning) pointers to all contained axes + /// @return vector of pointers to the axes, in axis order + virtual AnyAxesVector getAnyAxesVector() const { + AnyAxesVector result; + std::ranges::transform(*this, std::back_inserter(result), + [](const IAxis& axis) { return &axis; }); + return result; + } + + /// Get the lower boundary of the grid range along each axis + /// @return point holding the minimum of each axis + virtual AnyPoint getMinPointAny() const { + AnyPoint result; + result.reserve(getNAxes()); + for (const IAxis& axis : *this) { + result.push_back(axis.getMin()); + } + return result; + } + + /// Get the upper boundary of the grid range along each axis + /// @return point holding the maximum of each axis + virtual AnyPoint getMaxPointAny() const { + AnyPoint result; + result.reserve(getNAxes()); + for (const IAxis& axis : *this) { + result.push_back(axis.getMax()); + } + return result; + } + + /// Check whether a point lies inside the grid limits + /// @param point coordinates to check, one per axis + /// @return @c true if the point is within range along every axis + /// @throws std::invalid_argument if the number of coordinates does not match + /// the number of axes + virtual bool isInsideAny(const AnyPoint& point) const { + if (point.size() != getNAxes()) { + throw std::invalid_argument("Invalid number of coordinates"); + } + for (std::size_t i = 0; i < point.size(); ++i) { + const IAxis& axis = getAxis(i); + if (!axis.isInside(point[i])) { + return false; + } + } + return true; + } + + /// Get the lower-left corner of the bin given by a multi-index + /// @param indices local bin indices along each axis + /// @return point holding the lower bin boundary of each axis + virtual AnyPoint getLowerLeftBinCornerAny( + const AnyMultiIndex& indices) const { + AnyPoint result; + result.reserve(getNAxes()); + for (std::size_t i = 0; i < indices.size(); ++i) { + const IAxis& axis = getAxis(i); + result.push_back(axis.getBinLowerBound(indices[i])); + } + return result; + } + + /// Get the upper-right corner of the bin given by a multi-index + /// @param indices local bin indices along each axis + /// @return point holding the upper bin boundary of each axis + virtual AnyPoint getUpperRightBinCornerAny( + const AnyMultiIndex& indices) const { + AnyPoint result; + result.reserve(getNAxes()); + for (std::size_t i = 0; i < indices.size(); ++i) { + const IAxis& axis = getAxis(i); + result.push_back(axis.getBinUpperBound(indices[i])); + } + return result; + } + + /// Get the center of the bin given by a multi-index + /// @param indices local bin indices along each axis + /// @return point holding the bin center coordinate of each axis + virtual AnyPoint getBinCenterAny(const AnyMultiIndex& indices) const { + AnyPoint result; + result.reserve(getNAxes()); + for (std::size_t i = 0; i < indices.size(); ++i) { + const IAxis& axis = getAxis(i); + result.push_back(axis.getBinCenter(indices[i])); + } + return result; + } + + /// Random-access iterator over the contained axes. Dereferencing yields a + /// const reference to the @c IAxis at the current dimension. + class iterator { + public: + /// The type of the values the iterator points to + using value_type = const IAxis; + /// The type used to represent the distance between two iterators + using difference_type = std::ptrdiff_t; + /// The pointer type of the values the iterator points to + using pointer = const IAxis*; + /// The reference type of the values the iterator points to + using reference = const IAxis&; + + /// The iterator category (random-access) + using iterator_category = std::random_access_iterator_tag; + /// The iterator concept (random-access) + using iterator_concept = std::random_access_iterator_tag; + + constexpr iterator() noexcept = default; + /// Construct an iterator pointing at the given axis dimension + /// @param multiAxis the multi-axis to iterate over + /// @param index the axis dimension to point at + constexpr iterator(const IMultiAxis& multiAxis, std::size_t index) noexcept + : m_multiAxis(&multiAxis), m_index(index) {} + + /// Dereference the iterator + /// @return a const reference to the axis at the current dimension + constexpr reference operator*() const { + return m_multiAxis->getAxis(m_index); + } + /// Pre-increment the iterator + /// @return a reference to the incremented iterator + constexpr iterator& operator++() noexcept { + ++m_index; + return *this; + } + /// Post-increment the iterator + /// @return a copy of the iterator before incrementing + constexpr iterator operator++(int) noexcept { + auto tmp = *this; + ++(*this); + return tmp; + } + /// Pre-decrement the iterator + /// @return a reference to the decremented iterator + constexpr iterator& operator--() noexcept { + --m_index; + return *this; + } + /// Post-decrement the iterator + /// @return a copy of the iterator before decrementing + constexpr iterator operator--(int) noexcept { + auto tmp = *this; + --(*this); + return tmp; + } + /// Advance the iterator by @p n positions + /// @param n the number of positions to advance + /// @return a reference to the advanced iterator + constexpr iterator& operator+=(difference_type n) noexcept { + m_index += n; + return *this; + } + /// Move the iterator back by @p n positions + /// @param n the number of positions to move back + /// @return a reference to the moved iterator + constexpr iterator& operator-=(difference_type n) noexcept { + m_index -= n; + return *this; + } + + private: + const IMultiAxis* m_multiAxis{}; + std::size_t m_index{}; + + friend constexpr iterator operator+(iterator it, + difference_type n) noexcept { + return it += n; + } + + friend constexpr iterator operator+(difference_type n, + iterator it) noexcept { + return it += n; + } + + friend constexpr iterator operator-(iterator it, + difference_type n) noexcept { + return it -= n; + } + + friend constexpr difference_type operator-(const iterator& lhs, + const iterator& rhs) noexcept { + return lhs.m_index - rhs.m_index; + } + + friend constexpr auto operator<=>(const iterator& a, + const iterator& b) noexcept { + return a.m_index <=> b.m_index; + } + + friend constexpr bool operator==(const iterator& a, + const iterator& b) noexcept { + return a.m_index == b.m_index; + } + }; + + /// @return iterator to the first axis + iterator begin() const { return iterator(*this, 0); } + + /// @return iterator past the last axis + iterator end() const { return iterator(*this, getNAxes()); } + + protected: + /// Print the contained axes to the given stream + /// @param os output stream + virtual void toStream(std::ostream& os) const { + for (std::size_t i = 0; i < getNAxes(); ++i) { + os << getAxis(i); + if (i < getNAxes() - 1) { + os << ", "; + } + } + } + + private: + /// Check if two multi-axes are equal + /// @param lhs first multi-axis + /// @param rhs second multi-axis + /// @return @c true if both have the same number of axes and all axes compare + /// equal + friend bool operator==(const IMultiAxis& lhs, const IMultiAxis& rhs) { + if (lhs.getNAxes() != rhs.getNAxes()) { + return false; + } + return std::ranges::equal(lhs, rhs); + } + + /// Output stream operator + /// @param os output stream + /// @param multiAxis the multi-axis to be printed + /// @return the output stream + friend std::ostream& operator<<(std::ostream& os, + const IMultiAxis& multiAxis) { + multiAxis.toStream(os); + return os; + } +}; + +/// @brief Common base class for all multi-axes of a fixed, compile-time +/// dimension. +/// +/// On top of the dynamically sized @c IMultiAxis API this adds a statically +/// sized API (using @c std::array of fixed length @c DIM) and the grid index +/// conversions between points, multi-indices and flattened global indices. +/// The actual axis storage is provided by the concrete @c MultiAxis derived +/// class. +/// +/// @tparam _DIM number of axes (dimension of the grid) +template +class IMultiAxisND : public IMultiAxis { + public: + /// Dimension of the grid (number of axes) + static constexpr std::size_t DIM = _DIM; + + /// Statically sized multi-index holding one local bin index per axis + using MultiIndex = std::array; + /// Statically sized point holding one coordinate per axis + using Point = std::array; + /// Statically sized array of (non-owning) pointers to the contained axes + using AnyAxesArray = std::array; + /// Tuple of const references to the contained axes + using AnyAxesTuple = decltype(std::apply( + [](auto&&... xs) { return std::tie(*xs...); }, AnyAxesArray{})); + + /// Get the number of axes spanning the grid + /// @return the compile-time dimension @c DIM + std::size_t getNAxes() const override { return DIM; } + + /// Get (non-owning) pointers to all contained axes + /// @return fixed-size array of pointers to the axes, in axis order + virtual AnyAxesArray getAnyAxesArray() const { + AnyAxesArray result{}; + std::ranges::transform(*this, result.begin(), + [](const IAxis& axis) { return &axis; }); + return result; + } + + /// Get const references to all contained axes as a tuple + /// @return tuple of references to the axes, in axis order + virtual AnyAxesTuple getAnyAxesTuple() const { + return std::apply([](auto&&... xs) { return std::tie(*xs...); }, + getAnyAxesArray()); + } + + /// Get the number of bins along each axis + /// @return per-axis number of bins (excluding under-/overflow bins) + virtual MultiIndex getNBins() const { + MultiIndex result{}; + for (std::size_t i = 0; i < DIM; ++i) { + result[i] = getAxis(i).getNBins(); + } + return result; + } + + /// Get the lower boundary of the grid range along each axis + /// @return point holding the minimum of each axis + virtual Point getMinPoint() const { + Point result{}; + for (std::size_t i = 0; i < DIM; ++i) { + result[i] = getAxis(i).getMin(); + } + return result; + } + + /// Get the upper boundary of the grid range along each axis + /// @return point holding the maximum of each axis + virtual Point getMaxPoint() const { + Point result{}; + for (std::size_t i = 0; i < DIM; ++i) { + result[i] = getAxis(i).getMax(); + } + return result; + } + + /// Check whether a point lies inside the grid limits + /// @param point coordinates to check, one per axis + /// @return @c true if the point is within range along every axis + virtual bool isInside(const Point& point) const { + for (std::size_t i = 0; i < DIM; ++i) { + if (!getAxis(i).isInside(point[i])) { + return false; + } + } + return true; + } + + /// Get the lower-left corner of the bin given by a multi-index + /// @param multiIndex local bin indices along each axis + /// @return point holding the lower bin boundary of each axis + virtual Point getLowerLeftBinCorner(const MultiIndex& multiIndex) const { + Point result{}; + for (std::size_t i = 0; i < DIM; ++i) { + result[i] = getAxis(i).getBinLowerBound(multiIndex[i]); + } + return result; + } + + /// Get the upper-right corner of the bin given by a multi-index + /// @param multiIndex local bin indices along each axis + /// @return point holding the upper bin boundary of each axis + virtual Point getUpperRightBinCorner(const MultiIndex& multiIndex) const { + Point result{}; + for (std::size_t i = 0; i < DIM; ++i) { + result[i] = getAxis(i).getBinUpperBound(multiIndex[i]); + } + return result; + } + + /// Get the center of the bin given by a multi-index + /// @param multiIndex local bin indices along each axis + /// @return point holding the bin center coordinate of each axis + virtual Point getBinCenter(const MultiIndex& multiIndex) const { + Point result{}; + for (std::size_t i = 0; i < DIM; ++i) { + result[i] = getAxis(i).getBinCenter(multiIndex[i]); + } + return result; + } + + /// Determine the flattened global bin index for a given point + /// @param point coordinates to look up, one per axis + /// @return global bin index of the bin containing the point + virtual FlatIndex getFlatIndexFromPoint(const Point& point) const { + return getFlatIndexFromMultiIndex(getMultiIndexFromPoint(point)); + } + + /// Determine the flattened global bin index from a multi-index + /// @param multiIndex local bin indices along each axis (under-/overflow bins + /// are allowed) + /// @return global bin index of the bin + virtual FlatIndex getFlatIndexFromMultiIndex( + const MultiIndex& multiIndex) const { + return detail::MultiAxisHelper::getFlatIndexFromMultiIndex( + multiIndex, getAnyAxesTuple()); + } + + /// Determine the multi-index of local bin indices for a given point + /// @param point coordinates to look up, one per axis + /// @return local bin indices along each axis (may be under-/overflow bins) + virtual MultiIndex getMultiIndexFromPoint(const Point& point) const { + return detail::MultiAxisHelper::getMultiIndexFromPoint(point, + getAnyAxesTuple()); + } + + /// Determine the multi-index of local bin indices from a flattened global + /// bin index + /// @param flatIndex global bin index + /// @return local bin indices along each axis (may be under-/overflow bins) + virtual MultiIndex getMultiIndexFromFlatIndex(FlatIndex flatIndex) const { + return detail::MultiAxisHelper::getMultiIndexFromFlatIndex( + flatIndex, getAnyAxesTuple()); + } + + /// Get the global bin indices of the bins in the neighborhood of a bin + /// @param multiIndex local bin indices of the bin of interest + /// @param size number of adjacent bins to include along each axis (symmetric) + /// @return sorted collection of global bin indices in the neighborhood + virtual detail::FlatNeighborHoodIndices getNeighborHoodIndices( + const MultiIndex& multiIndex, std::size_t size = 1u) const = 0; + + /// Get the global bin indices of the bins in the neighborhood of a bin, with + /// a separate neighborhood size per axis + /// @param multiIndex local bin indices of the bin of interest + /// @param sizePerAxis per-axis lower/upper number of adjacent bins to include + /// @return sorted collection of global bin indices in the neighborhood + virtual detail::FlatNeighborHoodIndices getNeighborHoodIndices( + const MultiIndex& multiIndex, + std::array, DIM>& sizePerAxis) const = 0; + + /// Get the global bin indices of the grid points closest to the given bin + /// @param multiIndex local bin indices of the bin of interest + /// @return sorted collection of global bin indices whose lower-left corners + /// are the closest grid points to every point in the given bin + virtual detail::FlatNeighborHoodIndices getClosestPointsIndices( + const MultiIndex& multiIndex) const = 0; + + /// Get the global bin indices of the grid points closest to the given point + /// @param position coordinates to look up, one per axis + /// @return sorted collection of global bin indices of the closest grid points + virtual detail::FlatNeighborHoodIndices getClosestPointsIndices( + const Point& position) const { + return getClosestPointsIndices(getMultiIndexFromPoint(position)); + } +}; + +} // namespace Acts diff --git a/Core/include/Acts/Utilities/MultiAxis.hpp b/Core/include/Acts/Utilities/MultiAxis.hpp new file mode 100644 index 00000000000..98ed9e545f1 --- /dev/null +++ b/Core/include/Acts/Utilities/MultiAxis.hpp @@ -0,0 +1,197 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Utilities/Helpers.hpp" +#include "Acts/Utilities/IMultiAxis.hpp" + +#include + +namespace Acts { + +/// @brief Multi-dimensional binning defined by a product of one-dimensional +/// axes +/// +/// This class stores a fixed set of concrete @c Axis objects in a tuple and +/// implements the @c IMultiAxisND interface for the resulting grid. The grid +/// dimension and the concrete axis types (binning and boundary types) are +/// fixed at compile time, while the @c IMultiAxis base allows handling +/// different multi-axes through a common pointer. The grid index conventions +/// (per-axis local bin indices starting at @c 1, under-/overflow bins at @c 0 +/// and nBins + 1, and flattened global bin indices including those +/// under-/overflow bins) are described on @c IMultiAxis. +/// +/// @tparam Axes parameter pack of concrete @c Axis types spanning the grid +template +class MultiAxis final : public IMultiAxisND { + public: + /// Base interface for this multi-axis' dimension + using Base = IMultiAxisND; + + /// Dimension of the grid (number of axes) + static constexpr std::size_t DIM = Base::DIM; + /// Flattened global bin index across all axes + using FlatIndex = typename Base::FlatIndex; + /// Statically sized multi-index holding one local bin index per axis + using MultiIndex = typename Base::MultiIndex; + /// Statically sized point holding one coordinate per axis + using Point = typename Base::Point; + + /// Tuple type holding the concrete axes + using AxesTuple = std::tuple; + + /// Construct from a tuple of axes (copy) + /// @param axes tuple of axes spanning the grid + explicit MultiAxis(const std::tuple& axes) : m_axes(axes) {} + + /// Construct from a tuple of axes (move) + /// @param axes tuple of axes spanning the grid + explicit MultiAxis(std::tuple&& axes) : m_axes(std::move(axes)) {} + + /// Construct from individual axes (forwarding) + /// @param axes axes spanning the grid + explicit MultiAxis(Axes&&... axes) : m_axes(std::forward_as_tuple(axes...)) {} + + /// Construct from individual axes (copy) + /// @param axes axes spanning the grid + explicit MultiAxis(const Axes&... axes) : m_axes(std::tuple(axes...)) {} + + /// Get the axis at the given dimension + /// @param i index of the axis + /// @return const reference to the requested axis + const IAxis& getAxis(std::size_t i) const override { + return template_switch_lambda<0, DIM - 1>( + i, [this](auto iType) -> const IAxis& { + constexpr std::size_t iValue = decltype(iType)::value; + return std::get(m_axes); + }); + } + + /// Get the tuple of concrete axes + /// @return const reference to the stored axes tuple + const AxesTuple& getAxesTuple() const { return m_axes; } + + /// Get the number of bins along each axis + /// @return per-axis number of bins (excluding under-/overflow bins) + MultiIndex getNBins() const override { + return detail::MultiAxisHelper::getNBins(m_axes); + } + + /// Get the lower boundary of the grid range along each axis + /// @return point holding the minimum of each axis + Point getMinPoint() const override { + return detail::MultiAxisHelper::getMin(m_axes); + } + + /// Get the upper boundary of the grid range along each axis + /// @return point holding the maximum of each axis + Point getMaxPoint() const override { + return detail::MultiAxisHelper::getMax(m_axes); + } + + /// Check whether a point lies inside the grid limits + /// @param point coordinates to check, one per axis + /// @return @c true if the point is within range along every axis + bool isInside(const Point& point) const override { + return detail::MultiAxisHelper::isInside(point, m_axes); + } + + /// Get the lower-left corner of the bin given by a multi-index + /// @param multiIndex local bin indices along each axis + /// @return point holding the lower bin boundary of each axis + Point getLowerLeftBinCorner(const MultiIndex& multiIndex) const override { + return detail::MultiAxisHelper::getLowerLeftBinCorner(multiIndex, m_axes); + } + + /// Get the upper-right corner of the bin given by a multi-index + /// @param multiIndex local bin indices along each axis + /// @return point holding the upper bin boundary of each axis + Point getUpperRightBinCorner(const MultiIndex& multiIndex) const override { + return detail::MultiAxisHelper::getUpperRightBinCorner(multiIndex, m_axes); + } + + /// Get the center of the bin given by a multi-index + /// @param multiIndex local bin indices along each axis + /// @return point holding the bin center coordinate of each axis + Point getBinCenter(const MultiIndex& multiIndex) const override { + return detail::MultiAxisHelper::getBinCenter(multiIndex, m_axes); + } + + /// Determine the flattened global bin index for a given point + /// @param point coordinates to look up, one per axis + /// @return global bin index of the bin containing the point + FlatIndex getFlatIndexFromPoint(const Point& point) const override { + return getFlatIndexFromMultiIndex(getMultiIndexFromPoint(point)); + } + + /// Determine the flattened global bin index from a multi-index + /// @param multiIndex local bin indices along each axis (under-/overflow bins + /// are allowed) + /// @return global bin index of the bin + FlatIndex getFlatIndexFromMultiIndex( + const MultiIndex& multiIndex) const override { + return detail::MultiAxisHelper::getFlatIndexFromMultiIndex(multiIndex, + m_axes); + } + + /// Determine the multi-index of local bin indices for a given point + /// @param point coordinates to look up, one per axis + /// @return local bin indices along each axis (may be under-/overflow bins) + MultiIndex getMultiIndexFromPoint(const Point& point) const override { + return detail::MultiAxisHelper::getMultiIndexFromPoint(point, m_axes); + } + + /// Determine the multi-index of local bin indices from a flattened global + /// bin index + /// @param flatIndex global bin index + /// @return local bin indices along each axis (may be under-/overflow bins) + MultiIndex getMultiIndexFromFlatIndex(FlatIndex flatIndex) const override { + return detail::MultiAxisHelper::getMultiIndexFromFlatIndex(flatIndex, + m_axes); + } + + /// Get the global bin indices of the bins in the neighborhood of a bin + /// @param multiIndex local bin indices of the bin of interest + /// @param size number of adjacent bins to include along each axis (symmetric) + /// @return sorted collection of global bin indices in the neighborhood + detail::FlatNeighborHoodIndices getNeighborHoodIndices( + const MultiIndex& multiIndex, std::size_t size = 1u) const override { + return detail::MultiAxisHelper::neighborHoodIndices(multiIndex, size, + m_axes); + } + + /// Get the global bin indices of the bins in the neighborhood of a bin, with + /// a separate neighborhood size per axis + /// @param multiIndex local bin indices of the bin of interest + /// @param sizePerAxis per-axis lower/upper number of adjacent bins to include + /// @return sorted collection of global bin indices in the neighborhood + detail::FlatNeighborHoodIndices getNeighborHoodIndices( + const MultiIndex& multiIndex, + std::array, DIM>& sizePerAxis) const override { + return detail::MultiAxisHelper::neighborHoodIndices(multiIndex, sizePerAxis, + m_axes); + } + + /// Get the global bin indices of the grid points closest to the given bin + /// @param multiIndex local bin indices of the bin of interest + /// @return sorted collection of global bin indices whose lower-left corners + /// are the closest grid points to every point in the given bin + detail::FlatNeighborHoodIndices getClosestPointsIndices( + const MultiIndex& multiIndex) const override { + return detail::MultiAxisHelper::closestPointsIndices(multiIndex, m_axes); + } + + using Base::getClosestPointsIndices; + + private: + /// tuple of concrete axes spanning the grid + std::tuple m_axes; +}; + +} // namespace Acts diff --git a/Tests/UnitTests/Core/Utilities/CMakeLists.txt b/Tests/UnitTests/Core/Utilities/CMakeLists.txt index f61d045092c..02b1ad45c5c 100644 --- a/Tests/UnitTests/Core/Utilities/CMakeLists.txt +++ b/Tests/UnitTests/Core/Utilities/CMakeLists.txt @@ -76,3 +76,4 @@ add_unittest(Histogram HistogramTests.cpp) add_unittest(ProfileEfficiency ProfileEfficiencyTests.cpp) add_unittest(ContainerHelpers ContainerHelpersTests.cpp) add_unittest(Ranges RangesTests.cpp) +add_unittest(MultiAxis MultiAxisTests.cpp) diff --git a/Tests/UnitTests/Core/Utilities/GridTests.cpp b/Tests/UnitTests/Core/Utilities/GridTests.cpp index 196e442f6bf..865b1bd7e2e 100644 --- a/Tests/UnitTests/Core/Utilities/GridTests.cpp +++ b/Tests/UnitTests/Core/Utilities/GridTests.cpp @@ -14,7 +14,6 @@ #include "ActsTests/CommonHelpers/FloatComparisons.hpp" #include -#include #include #include #include diff --git a/Tests/UnitTests/Core/Utilities/MultiAxisTests.cpp b/Tests/UnitTests/Core/Utilities/MultiAxisTests.cpp new file mode 100644 index 00000000000..8b10798120f --- /dev/null +++ b/Tests/UnitTests/Core/Utilities/MultiAxisTests.cpp @@ -0,0 +1,1220 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#include + +#include "Acts/Utilities/Axis.hpp" +#include "Acts/Utilities/AxisDefinitions.hpp" +#include "Acts/Utilities/MultiAxis.hpp" +#include "ActsTests/CommonHelpers/FloatComparisons.hpp" + +#include +#include +#include +#include + +using namespace Acts; +using namespace Acts::detail; + +namespace ActsTests { + +BOOST_AUTO_TEST_SUITE(MultiAxisTests) + +BOOST_AUTO_TEST_CASE(test_1d_equidistant) { + using Point = std::array; + using MultiIndex = std::array; + + Axis a(0.0, 4.0, 4u); + + MultiAxis ma(a); + + // test general properties + BOOST_CHECK_EQUAL(ma.getNAxes(), 1u); + BOOST_CHECK_EQUAL(ma.getNBins().at(0), 4u); + BOOST_CHECK_EQUAL(ma.getNTotalBins(), 6u); + + // flat bin index + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-0.3}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-0.}), 1u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.}), 1u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.7}), 1u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1.2}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2.}), 3u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2.7}), 3u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3.}), 4u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3.9999}), 4u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({4.}), 5u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({4.98}), 5u); + + // flat bin index -> multi bin indices + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(0) == MultiIndex{0}); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(1) == MultiIndex{1}); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(2) == MultiIndex{2}); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(3) == MultiIndex{3}); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(4) == MultiIndex{4}); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(5) == MultiIndex{5}); + + // multi bin indices -> flat bin index + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1}), 1u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3}), 3u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({4}), 4u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({5}), 5u); + + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(ma.getFlatIndexFromPoint({2.7})) == + MultiIndex{3}); + + // inside checks + BOOST_CHECK(!ma.isInside({-2.})); + BOOST_CHECK(ma.isInside({0.})); + BOOST_CHECK(ma.isInside({2.5})); + BOOST_CHECK(!ma.isInside({4.})); + BOOST_CHECK(!ma.isInside({6.})); + + // test some bin centers + CHECK_CLOSE_ABS(ma.getBinCenter({1}), Point{0.5}, 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({2}), Point{1.5}, 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({3}), Point{2.5}, 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({4}), Point{3.5}, 1e-6); + + // test some lower-left bin edges + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1}), Point{0.}, 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({2}), Point{1.}, 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({3}), Point{2.}, 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({4}), Point{3.}, 1e-6); + + // test some upper right-bin edges + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({1}), Point{1.}, 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({2}), Point{2.}, 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({3}), Point{3.}, 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({4}), Point{4.}, 1e-6); +} + +BOOST_AUTO_TEST_CASE(test_2d_equidistant) { + using Point = std::array; + using MultiIndex = std::array; + + Axis a(0.0, 4.0, 4u); + Axis b(0.0, 3.0, 3u); + + MultiAxis ma(a, b); + + // test general properties + BOOST_CHECK_EQUAL(ma.getNAxes(), 2u); + BOOST_CHECK_EQUAL(ma.getNBins().at(0), 4u); + BOOST_CHECK_EQUAL(ma.getNBins().at(1), 3u); + BOOST_CHECK_EQUAL(ma.getNTotalBins(), 30u); + + // flat bin index + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-1, -1}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-1, 0}), 1u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-1, 1}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-1, 2}), 3u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-1, 3}), 4u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, -1}), 5u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0}), 6u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 1}), 7u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 2}), 8u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 3}), 9u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, -1}), 10u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0}), 11u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 1}), 12u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 2}), 13u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 3}), 14u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, -1}), 15u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 0}), 16u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 1}), 17u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 2}), 18u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 3}), 19u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3, -1}), 20u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3, 0}), 21u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3, 1}), 22u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3, 2}), 23u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3, 3}), 24u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({4, -1}), 25u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({4, 0}), 26u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({4, 1}), 27u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({4, 2}), 28u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({4, 3}), 29u); + + // test some arbitrary points + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1.2, 0.3}), 11u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2.2, 3.3}), 19u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.9, 1.8}), 7u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3.7, 3.1}), 24u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1.4, 2.3}), 13u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-3, 3}), 4u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({8, 1}), 27u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, -3}), 10u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3, 11}), 24u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-2, -3}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-2, 7}), 04u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({12, -1}), 25u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({12, 11}), 29u); + + // flat bin index -> multi bin indices + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(0) == (MultiIndex{0, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(1) == (MultiIndex{0, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(2) == (MultiIndex{0, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(3) == (MultiIndex{0, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(4) == (MultiIndex{0, 4})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(5) == (MultiIndex{1, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(6) == (MultiIndex{1, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(7) == (MultiIndex{1, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(8) == (MultiIndex{1, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(9) == (MultiIndex{1, 4})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(10) == (MultiIndex{2, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(11) == (MultiIndex{2, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(12) == (MultiIndex{2, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(13) == (MultiIndex{2, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(14) == (MultiIndex{2, 4})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(15) == (MultiIndex{3, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(16) == (MultiIndex{3, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(17) == (MultiIndex{3, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(18) == (MultiIndex{3, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(19) == (MultiIndex{3, 4})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(20) == (MultiIndex{4, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(21) == (MultiIndex{4, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(22) == (MultiIndex{4, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(23) == (MultiIndex{4, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(24) == (MultiIndex{4, 4})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(25) == (MultiIndex{5, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(26) == (MultiIndex{5, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(27) == (MultiIndex{5, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(28) == (MultiIndex{5, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(29) == (MultiIndex{5, 4})); + + // local bin indices -> global bin index + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 0}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 1}), 1u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 2}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 3}), 3u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 4}), 4u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 0}), 5u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 1}), 6u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 2}), 7u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 3}), 8u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 4}), 9u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 0}), 10u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 1}), 11u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 2}), 12u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 3}), 13u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 4}), 14u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 0}), 15u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 1}), 16u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 2}), 17u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 3}), 18u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 4}), 19u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({4, 0}), 20u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({4, 1}), 21u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({4, 2}), 22u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({4, 3}), 23u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({4, 4}), 24u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({5, 0}), 25u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({5, 1}), 26u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({5, 2}), 27u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({5, 3}), 28u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({5, 4}), 29u); + + BOOST_CHECK(ma.getMultiIndexFromFlatIndex( + ma.getFlatIndexFromPoint({1.2, 0.7})) == (MultiIndex{2, 1})); + + // inside checks + BOOST_CHECK(!ma.isInside({-2., -1})); + BOOST_CHECK(!ma.isInside({-2., 1.})); + BOOST_CHECK(!ma.isInside({-2., 5.})); + BOOST_CHECK(!ma.isInside({1., -1.})); + BOOST_CHECK(!ma.isInside({6., -1.})); + BOOST_CHECK(ma.isInside({0.5, 1.3})); + BOOST_CHECK(!ma.isInside({4., -1.})); + BOOST_CHECK(!ma.isInside({4., 0.3})); + BOOST_CHECK(!ma.isInside({4., 3.})); + BOOST_CHECK(!ma.isInside({-1., 3.})); + BOOST_CHECK(!ma.isInside({2., 3.})); + BOOST_CHECK(!ma.isInside({5., 3.})); + + // test some bin centers + CHECK_CLOSE_ABS(ma.getBinCenter({1, 1}), (Point{0.5, 0.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({2, 3}), (Point{1.5, 2.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({3, 1}), (Point{2.5, 0.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({4, 2}), (Point{3.5, 1.5}), 1e-6); + + // test some lower-left bin edges + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1, 1}), (Point{0., 0.}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({2, 3}), (Point{1., 2.}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({3, 1}), (Point{2., 0.}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({4, 2}), (Point{3., 1.}), 1e-6); + + // test some upper right-bin edges + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({1, 1}), (Point{1., 1.}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({2, 3}), (Point{2., 3.}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({3, 1}), (Point{3., 1.}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({4, 2}), (Point{4., 2.}), 1e-6); +} + +BOOST_AUTO_TEST_CASE(test_3d_equidistant) { + using Point = std::array; + using MultiIndex = std::array; + + Axis a(0.0, 2.0, 2u); + Axis b(0.0, 3.0, 3u); + Axis c(0.0, 2.0, 2u); + + MultiAxis ma(a, b, c); + + // test general properties + BOOST_CHECK_EQUAL(ma.getNAxes(), 3u); + BOOST_CHECK_EQUAL(ma.getNBins().at(0), 2u); + BOOST_CHECK_EQUAL(ma.getNBins().at(1), 3u); + BOOST_CHECK_EQUAL(ma.getNBins().at(2), 2u); + BOOST_CHECK_EQUAL(ma.getNTotalBins(), 80u); + + // test grid points + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0, 0}), 25u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0, 1}), 26u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0, 2}), 27u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 1, 0}), 29u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 1, 1}), 30u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 1, 2}), 31u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 2, 0}), 33u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 2, 1}), 34u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 2, 2}), 35u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 3, 0}), 37u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 3, 1}), 38u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 3, 2}), 39u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0, 0}), 45u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0, 1}), 46u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0, 2}), 47u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 1, 0}), 49u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 1, 1}), 50u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 2, 0}), 53u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 2, 1}), 54u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 2, 2}), 55u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 3, 0}), 57u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 3, 1}), 58u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 3, 2}), 59u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 0, 0}), 65u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 0, 1}), 66u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 0, 2}), 67u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 1, 0}), 69u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 1, 1}), 70u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 1, 2}), 71u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 2, 0}), 73u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 2, 1}), 74u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 2, 2}), 75u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 3, 0}), 77u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 3, 1}), 78u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, 3, 2}), 79u); + + // flat bin index -> multi bin indices + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(0) == (MultiIndex{0, 0, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(1) == (MultiIndex{0, 0, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(2) == (MultiIndex{0, 0, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(3) == (MultiIndex{0, 0, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(4) == (MultiIndex{0, 1, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(5) == (MultiIndex{0, 1, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(6) == (MultiIndex{0, 1, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(7) == (MultiIndex{0, 1, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(24) == (MultiIndex{1, 1, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(25) == (MultiIndex{1, 1, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(26) == (MultiIndex{1, 1, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(27) == (MultiIndex{1, 1, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(52) == (MultiIndex{2, 3, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(53) == (MultiIndex{2, 3, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(54) == (MultiIndex{2, 3, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(55) == (MultiIndex{2, 3, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(60) == (MultiIndex{3, 0, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(61) == (MultiIndex{3, 0, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(62) == (MultiIndex{3, 0, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(63) == (MultiIndex{3, 0, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(76) == (MultiIndex{3, 4, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(77) == (MultiIndex{3, 4, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(78) == (MultiIndex{3, 4, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(79) == (MultiIndex{3, 4, 3})); + + // multi bin indices -> flat bin index + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 0, 0}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 0, 1}), 1u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 0, 2}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 0, 3}), 3u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 1, 0}), 4u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 1, 1}), 5u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 1, 2}), 6u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 1, 3}), 7u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 1, 0}), 24u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 1, 1}), 25u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 1, 2}), 26u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 1, 3}), 27u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 3, 0}), 52u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 3, 1}), 53u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 3, 2}), 54u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 3, 3}), 55u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 0, 0}), 60u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 0, 1}), 61u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 0, 2}), 62u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 0, 3}), 63u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 4, 0}), 76u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 4, 1}), 77u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 4, 2}), 78u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 4, 3}), 79u); + + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(ma.getFlatIndexFromPoint( + {1.2, 0.7, 1.4})) == (MultiIndex{2, 1, 2})); + + // inside checks + BOOST_CHECK(!ma.isInside({-2., -1, -2})); + BOOST_CHECK(!ma.isInside({-2., 1., 0.})); + BOOST_CHECK(!ma.isInside({-2., 5., -1})); + BOOST_CHECK(!ma.isInside({1., -1., 1.})); + BOOST_CHECK(!ma.isInside({6., -1., 4.})); + BOOST_CHECK(ma.isInside({0.5, 1.3, 1.7})); + BOOST_CHECK(!ma.isInside({2., -1., -0.4})); + BOOST_CHECK(!ma.isInside({2., 0.3, 3.4})); + BOOST_CHECK(!ma.isInside({2., 3., 0.8})); + BOOST_CHECK(!ma.isInside({-1., 3., 5.})); + BOOST_CHECK(!ma.isInside({2., 3., -1.})); + BOOST_CHECK(!ma.isInside({5., 3., 0.5})); + + // test some bin centers + CHECK_CLOSE_ABS(ma.getBinCenter({1, 1, 1}), (Point{0.5, 0.5, 0.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({2, 3, 2}), (Point{1.5, 2.5, 1.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({1, 1, 2}), (Point{0.5, 0.5, 1.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({2, 2, 1}), (Point{1.5, 1.5, 0.5}), 1e-6); + + // test some lower-left bin edges + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1, 1, 1}), (Point{0., 0., 0.}), + 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({2, 3, 2}), (Point{1., 2., 1.}), + 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1, 1, 2}), (Point{0., 0., 1.}), + 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({2, 2, 1}), (Point{1., 1., 0.}), + 1e-6); + + // test some upper right-bin edges + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({{1, 1, 1}}), (Point{{1., 1., 1.}}), + 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({{2, 3, 2}}), (Point{{2., 3., 2.}}), + 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({{1, 1, 2}}), (Point{{1., 1., 2.}}), + 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({{2, 2, 1}}), (Point{{2., 2., 1.}}), + 1e-6); +} + +BOOST_AUTO_TEST_CASE(test_1d_variable) { + using Point = std::array; + using MultiIndex = std::array; + + Axis a({0.0, 1.0, 4.0}); + + MultiAxis ma(a); + + // test general properties + BOOST_CHECK_EQUAL(ma.getNAxes(), 1u); + BOOST_CHECK_EQUAL(ma.getNBins().at(0), 2u); + BOOST_CHECK_EQUAL(ma.getNTotalBins(), 4u); + + // flat bin index + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-0.3}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.}), 1u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.7}), 1u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1.2}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2.7}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({4.}), 3u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({4.98}), 3u); + + // flat bin index -> multi bin indices + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(0) == MultiIndex{0}); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(1) == MultiIndex{1}); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(2) == MultiIndex{2}); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(3) == MultiIndex{3}); + + // multi bin indices -> flat bin index + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1}), 1u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3}), 3u); + + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(ma.getFlatIndexFromPoint({0.8})) == + MultiIndex{1}); + + // inside checks + BOOST_CHECK(!ma.isInside({-2.})); + BOOST_CHECK(ma.isInside({0.})); + BOOST_CHECK(ma.isInside({2.5})); + BOOST_CHECK(!ma.isInside({4.})); + BOOST_CHECK(!ma.isInside({6.})); + + // test some bin centers + CHECK_CLOSE_ABS(ma.getBinCenter({1}), Point{0.5}, 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({2}), Point{2.5}, 1e-6); + + // test some lower-left bin edges + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1}), Point{0.}, 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({2}), Point{1.}, 1e-6); + + // test some upper right-bin edges + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({1}), Point{1.}, 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({2}), Point{4.}, 1e-6); +} + +BOOST_AUTO_TEST_CASE(test_2d_variable) { + using Point = std::array; + using MultiIndex = std::array; + + Axis a({0.0, 0.5, 3.0}); + Axis b({0.0, 1.0, 4.0}); + + MultiAxis ma(a, b); + + // test general properties + BOOST_CHECK_EQUAL(ma.getNAxes(), 2u); + BOOST_CHECK_EQUAL(ma.getNBins().at(0), 2u); + BOOST_CHECK_EQUAL(ma.getNBins().at(1), 2u); + BOOST_CHECK_EQUAL(ma.getNTotalBins(), 16u); + + // test grid points + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0}), 5u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 1}), 6u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 4}), 7u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.5, 0}), 9u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.5, 1}), 10u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.5, 4}), 11u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3, 0}), 13u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3, 1}), 14u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3, 4}), 15u); + + // test some arbitrary points + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.3, 1.2}), 6u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3.3, 2.2}), 14u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1.8, 0.9}), 9u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({3.1, 0.7}), 13u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2.3, 1.4}), 10u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({2, -3}), 8u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 8}), 11u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-3, 1}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({11, 3}), 14u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-3, -2}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({7, -2}), 12u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-1, 12}), 3u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({11, 12}), 15u); + + // flat bin index -> multi bin indices + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(0) == (MultiIndex{0, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(1) == (MultiIndex{0, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(2) == (MultiIndex{0, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(3) == (MultiIndex{0, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(4) == (MultiIndex{1, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(5) == (MultiIndex{1, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(6) == (MultiIndex{1, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(7) == (MultiIndex{1, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(8) == (MultiIndex{2, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(9) == (MultiIndex{2, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(10) == (MultiIndex{2, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(11) == (MultiIndex{2, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(12) == (MultiIndex{3, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(13) == (MultiIndex{3, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(14) == (MultiIndex{3, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(15) == (MultiIndex{3, 3})); + + // multi bin indices -> flat bin index + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 0}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 1}), 1u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 2}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 3}), 3u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 0}), 4u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 1}), 5u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 2}), 6u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 3}), 7u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 0}), 8u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 1}), 9u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 2}), 10u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 3}), 11u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 0}), 12u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 1}), 13u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 2}), 14u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 3}), 15u); + + BOOST_CHECK(ma.getMultiIndexFromFlatIndex( + ma.getFlatIndexFromPoint({3.2, 1.8})) == (MultiIndex{3, 2})); + + // inside checks + BOOST_CHECK(!ma.isInside({-2., -1})); + BOOST_CHECK(!ma.isInside({-2., 1.})); + BOOST_CHECK(!ma.isInside({-2., 5.})); + BOOST_CHECK(!ma.isInside({1., -1.})); + BOOST_CHECK(!ma.isInside({6., -1.})); + BOOST_CHECK(ma.isInside({0.5, 1.3})); + BOOST_CHECK(!ma.isInside({3., -1.})); + BOOST_CHECK(!ma.isInside({3., 0.3})); + BOOST_CHECK(!ma.isInside({3., 4.})); + BOOST_CHECK(!ma.isInside({-1., 4.})); + BOOST_CHECK(!ma.isInside({2., 4.})); + BOOST_CHECK(!ma.isInside({5., 4.})); + + // test some bin centers + CHECK_CLOSE_ABS(ma.getBinCenter({{1, 1}}), (Point{0.25, 0.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({{2, 1}}), (Point{1.75, 0.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({{1, 2}}), (Point{0.25, 2.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({{2, 2}}), (Point{1.75, 2.5}), 1e-6); + + // test some lower-left bin edges + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({{1, 1}}), (Point{0., 0.}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({{2, 1}}), (Point{0.5, 0.}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({{1, 2}}), (Point{0., 1.}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({{2, 2}}), (Point{0.5, 1.}), 1e-6); + + // test some upper right-bin edges + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({{1, 1}}), (Point{0.5, 1.}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({{2, 1}}), (Point{3., 1.}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({{1, 2}}), (Point{0.5, 4.}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({{2, 2}}), (Point{3., 4.}), 1e-6); +} + +BOOST_AUTO_TEST_CASE(test_3d_variable) { + using Point = std::array; + using MultiIndex = std::array; + + Axis a({0.0, 1.0}); + Axis b({0.0, 0.5, 3.0}); + Axis c({0.0, 0.5, 3.0, 3.3}); + + MultiAxis ma(a, b, c); + + // test general properties + BOOST_CHECK_EQUAL(ma.getNAxes(), 3u); + BOOST_CHECK_EQUAL(ma.getNBins().at(0), 1u); + BOOST_CHECK_EQUAL(ma.getNBins().at(1), 2u); + BOOST_CHECK_EQUAL(ma.getNBins().at(2), 3u); + BOOST_CHECK_EQUAL(ma.getNTotalBins(), 60u); + + // test grid points + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0, 0}), 26u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0, 0}), 46u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0.5, 0}), 31u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0.5, 0}), 51u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 3, 0}), 36u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 3, 0}), 56u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0, 0.5}), 27u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0, 0.5}), 47u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0.5, 0.5}), 32u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0.5, 0.5}), 52u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 3, 0.5}), 37u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 3, 0.5}), 57u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0, 3}), 28u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0, 3}), 48u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0.5, 3}), 33u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0.5, 3}), 53u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 3, 3}), 38u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 3, 3}), 58u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0, 3.3}), 29u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0, 3.3}), 49u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0.5, 3.3}), 34u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0.5, 3.3}), 54u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 3, 3.3}), 39u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 3, 3.3}), 59u); + + // flat bin index -> multi bin indices + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(0) == (MultiIndex{0, 0, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(1) == (MultiIndex{0, 0, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(2) == (MultiIndex{0, 0, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(3) == (MultiIndex{0, 0, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(4) == (MultiIndex{0, 0, 4})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(5) == (MultiIndex{0, 1, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(21) == (MultiIndex{1, 0, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(22) == (MultiIndex{1, 0, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(23) == (MultiIndex{1, 0, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(24) == (MultiIndex{1, 0, 4})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(25) == (MultiIndex{1, 1, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(26) == (MultiIndex{1, 1, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(57) == (MultiIndex{2, 3, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(58) == (MultiIndex{2, 3, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(59) == (MultiIndex{2, 3, 4})); + + // multi bin indices -> flat bin index + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 0, 0}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 0, 0}), 20u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 0, 0}), 40u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 1, 0}), 5u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 1, 0}), 25u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 1, 0}), 45u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 3, 1}), 16u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 3, 1}), 36u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 3, 1}), 56u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 0, 2}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 0, 2}), 22u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 0, 2}), 42u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 3, 4}), 19u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 3, 4}), 39u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 3, 4}), 59u); + + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(ma.getFlatIndexFromPoint( + {1.8, 0.7, 3.2})) == (MultiIndex{2, 2, 3})); + + // inside checks + BOOST_CHECK(!ma.isInside({-2., -1, -2})); + BOOST_CHECK(!ma.isInside({-2., 1., 0.})); + BOOST_CHECK(!ma.isInside({-2., 5., -1})); + BOOST_CHECK(!ma.isInside({1., -1., 1.})); + BOOST_CHECK(!ma.isInside({6., -1., 4.})); + BOOST_CHECK(ma.isInside({0.5, 1.3, 1.7})); + BOOST_CHECK(!ma.isInside({1., -1., -0.4})); + BOOST_CHECK(!ma.isInside({1., 0.3, 3.4})); + BOOST_CHECK(!ma.isInside({1., 3., 0.8})); + BOOST_CHECK(!ma.isInside({-1., 3., 5.})); + BOOST_CHECK(!ma.isInside({2., 3., -1.})); + BOOST_CHECK(!ma.isInside({5., 3., 0.5})); + + // test some bin centers + CHECK_CLOSE_ABS(ma.getBinCenter({1, 1, 1}), (Point{0.5, 0.25, 0.25}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({1, 1, 2}), (Point{0.5, 0.25, 1.75}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({1, 1, 3}), (Point{0.5, 0.25, 3.15}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({1, 2, 1}), (Point{0.5, 1.75, 0.25}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({1, 2, 2}), (Point{0.5, 1.75, 1.75}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({1, 2, 3}), (Point{0.5, 1.75, 3.15}), 1e-6); + + // test some lower-left bin edges + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1, 1, 1}), (Point{0., 0., 0.}), + 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1, 1, 2}), (Point{0., 0., 0.5}), + 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1, 1, 3}), (Point{0., 0., 3.}), + 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1, 2, 1}), (Point{0., 0.5, 0.}), + 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1, 2, 2}), (Point{0., 0.5, 0.5}), + 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1, 2, 3}), (Point{0., 0.5, 3.}), + 1e-6); + + // test some upper right-bin edges + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({1, 1, 1}), (Point{1., 0.5, 0.5}), + 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({1, 1, 2}), (Point{1., 0.5, 3.}), + 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({1, 1, 3}), (Point{1., 0.5, 3.3}), + 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({1, 2, 1}), (Point{1., 3., 0.5}), + 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({1, 2, 2}), (Point{1., 3., 3.}), + 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({1, 2, 3}), (Point{1., 3., 3.3}), + 1e-6); +} + +BOOST_AUTO_TEST_CASE(test_2d_mixed) { + using Point = std::array; + using MultiIndex = std::array; + + Axis a(0.0, 1.0, 4u); + Axis b({0.0, 0.5, 3.0}); + + MultiAxis ma(a, b); + + // test general properties + BOOST_CHECK_EQUAL(ma.getNAxes(), 2u); + BOOST_CHECK_EQUAL(ma.getNBins().at(0), 4u); + BOOST_CHECK_EQUAL(ma.getNBins().at(1), 2u); + BOOST_CHECK_EQUAL(ma.getNTotalBins(), 24u); + + // test grid points + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0}), 5u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.25, 0}), 9u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.5, 0}), 13u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.75, 0}), 17u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0}), 21u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 0.5}), 6u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.25, 0.5}), 10u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.5, 0.5}), 14u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.75, 0.5}), 18u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 0.5}), 22u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0, 3}), 7u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.25, 3}), 11u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.5, 3}), 15u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.75, 3}), 19u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1, 3}), 23u); + + // test some arbitrary points + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({1.2, 0.3}), 21u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.2, 1.3}), 6u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.9, 1.8}), 18u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.7, 2.1}), 14u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.4, 0.3}), 9u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-3, 2}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({8, 1}), 22u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.1, -3}), 4u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({0.8, 11}), 19u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-2, -3}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({-2, 7}), 3u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({12, -1}), 20u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromPoint({12, 11}), 23u); + + // flat bin index -> multi bin indices + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(0) == (MultiIndex{0, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(1) == (MultiIndex{0, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(2) == (MultiIndex{0, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(3) == (MultiIndex{0, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(4) == (MultiIndex{1, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(5) == (MultiIndex{1, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(6) == (MultiIndex{1, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(7) == (MultiIndex{1, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(8) == (MultiIndex{2, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(9) == (MultiIndex{2, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(10) == (MultiIndex{2, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(11) == (MultiIndex{2, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(12) == (MultiIndex{3, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(13) == (MultiIndex{3, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(14) == (MultiIndex{3, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(15) == (MultiIndex{3, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(16) == (MultiIndex{4, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(17) == (MultiIndex{4, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(18) == (MultiIndex{4, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(19) == (MultiIndex{4, 3})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(20) == (MultiIndex{5, 0})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(21) == (MultiIndex{5, 1})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(22) == (MultiIndex{5, 2})); + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(23) == (MultiIndex{5, 3})); + + // multi bin indices -> flat bin index + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 0}), 0u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 1}), 1u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 2}), 2u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({0, 3}), 3u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 0}), 4u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 1}), 5u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 2}), 6u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({1, 3}), 7u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 0}), 8u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 1}), 9u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 2}), 10u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({2, 3}), 11u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 0}), 12u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 1}), 13u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 2}), 14u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({3, 3}), 15u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({4, 0}), 16u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({4, 1}), 17u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({4, 2}), 18u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({4, 3}), 19u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({5, 0}), 20u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({5, 1}), 21u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({5, 2}), 22u); + BOOST_CHECK_EQUAL(ma.getFlatIndexFromMultiIndex({5, 3}), 23u); + + BOOST_CHECK(ma.getMultiIndexFromFlatIndex(ma.getFlatIndexFromPoint( + Point({{1.1, 1.7}}))) == MultiIndex({{5, 2}})); + + // inside checks + BOOST_CHECK(!ma.isInside({-2., -1})); + BOOST_CHECK(!ma.isInside({-2., 1.})); + BOOST_CHECK(!ma.isInside({-2., 5.})); + BOOST_CHECK(!ma.isInside({0.1, -1.})); + BOOST_CHECK(!ma.isInside({6., -1.})); + BOOST_CHECK(ma.isInside({0.5, 1.3})); + BOOST_CHECK(!ma.isInside({1., -1.})); + BOOST_CHECK(!ma.isInside({1., 0.3})); + BOOST_CHECK(!ma.isInside({1., 3.})); + BOOST_CHECK(!ma.isInside({-1., 3.})); + BOOST_CHECK(!ma.isInside({0.2, 3.})); + BOOST_CHECK(!ma.isInside({5., 3.})); + + // test some bin centers + CHECK_CLOSE_ABS(ma.getBinCenter({1, 1}), (Point{0.125, 0.25}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({1, 2}), (Point{0.125, 1.75}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({2, 1}), (Point{0.375, 0.25}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({2, 2}), (Point{0.375, 1.75}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({3, 1}), (Point{0.625, 0.25}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({3, 2}), (Point{0.625, 1.75}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({4, 1}), (Point{0.875, 0.25}), 1e-6); + CHECK_CLOSE_ABS(ma.getBinCenter({4, 2}), (Point{0.875, 1.75}), 1e-6); + + // test some lower-left bin edges + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1, 1}), (Point{0., 0.}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({1, 2}), (Point{0., 0.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({2, 1}), (Point{0.25, 0.}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({2, 2}), (Point{0.25, 0.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({3, 1}), (Point{0.5, 0.}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({3, 2}), (Point{0.5, 0.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({4, 1}), (Point{0.75, 0.}), 1e-6); + CHECK_CLOSE_ABS(ma.getLowerLeftBinCorner({4, 2}), (Point{0.75, 0.5}), 1e-6); + + // test some upper-right bin edges + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({1, 1}), (Point{0.25, 0.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({1, 2}), (Point{0.25, 3.}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({2, 1}), (Point{0.5, 0.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({2, 2}), (Point{0.5, 3.}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({3, 1}), (Point{0.75, 0.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({3, 2}), (Point{0.75, 3.}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({4, 1}), (Point{1., 0.5}), 1e-6); + CHECK_CLOSE_ABS(ma.getUpperRightBinCorner({4, 2}), (Point{1., 3.}), 1e-6); +} + +BOOST_AUTO_TEST_CASE(neighborhood) { + using bins_t = std::vector; + + Axis a(0.0, 1.0, 10u); + Axis b(0.0, 1.0, 10u); + Axis c(0.0, 1.0, 10u); + + MultiAxis ma1(a); + MultiAxis ma2(a, b); + MultiAxis ma3(a, b, c); + + // 1D case + BOOST_CHECK(ma1.getNeighborHoodIndices({0}, 1).collectVector() == + (bins_t{0, 1})); + BOOST_CHECK(ma1.getNeighborHoodIndices({0}, 2).collectVector() == + (bins_t{0, 1, 2})); + BOOST_CHECK(ma1.getNeighborHoodIndices({1}, 1).collectVector() == + (bins_t{0, 1, 2})); + BOOST_CHECK(ma1.getNeighborHoodIndices({1}, 3).collectVector() == + (bins_t{0, 1, 2, 3, 4})); + BOOST_CHECK(ma1.getNeighborHoodIndices({4}, 2).collectVector() == + (bins_t{2, 3, 4, 5, 6})); + BOOST_CHECK(ma1.getNeighborHoodIndices({9}, 2).collectVector() == + (bins_t{7, 8, 9, 10, 11})); + BOOST_CHECK(ma1.getNeighborHoodIndices({10}, 2).collectVector() == + (bins_t{8, 9, 10, 11})); + BOOST_CHECK(ma1.getNeighborHoodIndices({11}, 2).collectVector() == + (bins_t{9, 10, 11})); + + // 2D case + BOOST_CHECK(ma2.getNeighborHoodIndices({0, 0}, 1).collectVector() == + (bins_t{0, 1, 12, 13})); + BOOST_CHECK(ma2.getNeighborHoodIndices({0, 1}, 1).collectVector() == + (bins_t{0, 1, 2, 12, 13, 14})); + BOOST_CHECK(ma2.getNeighborHoodIndices({1, 0}, 1).collectVector() == + (bins_t{0, 1, 12, 13, 24, 25})); + BOOST_CHECK(ma2.getNeighborHoodIndices({1, 1}, 1).collectVector() == + (bins_t{0, 1, 2, 12, 13, 14, 24, 25, 26})); + BOOST_CHECK(ma2.getNeighborHoodIndices({5, 5}, 1).collectVector() == + (bins_t{52, 53, 54, 64, 65, 66, 76, 77, 78})); + BOOST_CHECK(ma2.getNeighborHoodIndices({9, 10}, 2).collectVector() == + (bins_t{92, 93, 94, 95, 104, 105, 106, 107, 116, 117, + 118, 119, 128, 129, 130, 131, 140, 141, 142, 143})); + + // 3D case + BOOST_CHECK(ma3.getNeighborHoodIndices({0, 0, 0}, 1).collectVector() == + (bins_t{0, 1, 12, 13, 144, 145, 156, 157})); + BOOST_CHECK(ma3.getNeighborHoodIndices({0, 0, 1}, 1).collectVector() == + (bins_t{0, 1, 2, 12, 13, 14, 144, 145, 146, 156, 157, 158})); + BOOST_CHECK(ma3.getNeighborHoodIndices({0, 1, 0}, 1).collectVector() == + (bins_t{0, 1, 12, 13, 24, 25, 144, 145, 156, 157, 168, 169})); + BOOST_CHECK(ma3.getNeighborHoodIndices({1, 0, 0}, 1).collectVector() == + (bins_t{0, 1, 12, 13, 144, 145, 156, 157, 288, 289, 300, 301})); + BOOST_CHECK(ma3.getNeighborHoodIndices({0, 1, 1}, 1).collectVector() == + (bins_t{0, 1, 2, 12, 13, 14, 24, 25, 26, 144, 145, 146, 156, 157, + 158, 168, 169, 170})); + BOOST_CHECK(ma3.getNeighborHoodIndices({1, 1, 1}, 1).collectVector() == + (bins_t{0, 1, 2, 12, 13, 14, 24, 25, 26, + 144, 145, 146, 156, 157, 158, 168, 169, 170, + 288, 289, 290, 300, 301, 302, 312, 313, 314})); + BOOST_CHECK(ma3.getNeighborHoodIndices({11, 10, 9}, 1).collectVector() == + (bins_t{1556, 1557, 1558, 1568, 1569, 1570, 1580, 1581, 1582, + 1700, 1701, 1702, 1712, 1713, 1714, 1724, 1725, 1726})); + + // Neighbors array + std::array, 1> a1; + a1.at(0) = std::make_pair(-1, 1); + BOOST_CHECK(ma1.getNeighborHoodIndices({0}, a1).collectVector() == + (bins_t{0, 1})); + BOOST_CHECK(ma1.getNeighborHoodIndices({2}, a1).collectVector() == + (bins_t{1, 2, 3})); + + a1.at(0) = std::make_pair(2, 3); + BOOST_CHECK(ma1.getNeighborHoodIndices({2}, a1).collectVector() == + (bins_t{4, 5})); + + a1.at(0) = std::make_pair(-2, -1); + BOOST_CHECK(ma1.getNeighborHoodIndices({2}, a1).collectVector() == + (bins_t{0, 1})); + + a1.at(0) = std::make_pair(-3, -1); + BOOST_CHECK(ma1.getNeighborHoodIndices({2}, a1).collectVector() == + (bins_t{0, 1})); + + Axis d(AxisClosed, 0.0, 1.0, 10u); + + MultiAxis ma1Cl(d); + + BOOST_CHECK(ma1Cl.getNeighborHoodIndices({0}, 1) + .collectVector() + .empty()); // underflow, makes no sense + BOOST_CHECK(ma1Cl.getNeighborHoodIndices({11}, 1) + .collectVector() + .empty()); // overflow, makes no sense + BOOST_CHECK(ma1Cl.getNeighborHoodIndices({1}, 1).collectVector() == + (bins_t{10, 1, 2})); + BOOST_CHECK(ma1Cl.getNeighborHoodIndices({5}, 1).collectVector() == + (bins_t{4, 5, 6})); + + Axis f(AxisClosed, 0.0, 1.0, 5u); + Axis e(AxisClosed, 0.0, 1.0, 5u); + + MultiAxis ma2Cl(e, f); + + BOOST_CHECK(ma2Cl.getNeighborHoodIndices({3, 3}, 1).collectVector() == + (bins_t{16, 17, 18, 23, 24, 25, 30, 31, 32})); + BOOST_CHECK(ma2Cl.getNeighborHoodIndices({1, 1}, 1).collectVector() == + (bins_t{40, 36, 37, 12, 8, 9, 19, 15, 16})); + BOOST_CHECK(ma2Cl.getNeighborHoodIndices({1, 5}, 1).collectVector() == + (bins_t{39, 40, 36, 11, 12, 8, 18, 19, 15})); + BOOST_CHECK(ma2Cl.getNeighborHoodIndices({5, 1}, 1).collectVector() == + (bins_t{33, 29, 30, 40, 36, 37, 12, 8, 9})); + BOOST_CHECK(ma2Cl.getNeighborHoodIndices({5, 5}, 1).collectVector() == + (bins_t{32, 33, 29, 39, 40, 36, 11, 12, 8})); + + BOOST_CHECK(ma2Cl.getNeighborHoodIndices({3, 3}, 2).collectVector() == + (bins_t{8, 9, 10, 11, 12, 15, 16, 17, 18, 19, 22, 23, 24, + 25, 26, 29, 30, 31, 32, 33, 36, 37, 38, 39, 40})); + BOOST_CHECK(ma2Cl.getNeighborHoodIndices({1, 1}, 2).collectVector() == + (bins_t{32, 33, 29, 30, 31, 39, 40, 36, 37, 38, 11, 12, 8, + 9, 10, 18, 19, 15, 16, 17, 25, 26, 22, 23, 24})); + BOOST_CHECK(ma2Cl.getNeighborHoodIndices({1, 5}, 2).collectVector() == + (bins_t{31, 32, 33, 29, 30, 38, 39, 40, 36, 37, 10, 11, 12, + 8, 9, 17, 18, 19, 15, 16, 24, 25, 26, 22, 23})); + BOOST_CHECK(ma2Cl.getNeighborHoodIndices({5, 1}, 2).collectVector() == + (bins_t{25, 26, 22, 23, 24, 32, 33, 29, 30, 31, 39, 40, 36, + 37, 38, 11, 12, 8, 9, 10, 18, 19, 15, 16, 17})); + BOOST_CHECK(ma2Cl.getNeighborHoodIndices({5, 5}, 2).collectVector() == + (bins_t{24, 25, 26, 22, 23, 31, 32, 33, 29, 30, 38, 39, 40, + 36, 37, 10, 11, 12, 8, 9, 17, 18, 19, 15, 16})); + + std::array, 2> a2; + a2.at(0) = + std::make_pair(-2, -1); // only 2 bins left of requested bin + // (not including the requested bin) + a2.at(1) = std::make_pair( + -1, 2); // one bin left of requested bin, the requested bin itself and 2 + // bins right of requested bin + std::set returnedBins; + + auto returnedBinsVec = + ma2Cl.getNeighborHoodIndices({3, 2}, a2).collectVector(); + returnedBins.insert(returnedBinsVec.begin(), returnedBinsVec.end()); + std::set expectedBins{{8, 9, 10, 11, 15, 16, 17, 18}}; + BOOST_CHECK(returnedBins == expectedBins); + + returnedBinsVec = ma2Cl.getNeighborHoodIndices({1, 5}, a2).collectVector(); + returnedBins.clear(); + returnedBins.insert(returnedBinsVec.begin(), returnedBinsVec.end()); + expectedBins = {{29, 30, 32, 33, 36, 37, 39, 40}}; + BOOST_CHECK(returnedBins == expectedBins); + + a2.at(0) = {-6, 7}; + a2.at(1) = {0, 0}; + returnedBinsVec = ma2Cl.getNeighborHoodIndices({1, 5}, a2).collectVector(); + returnedBins.clear(); + returnedBins.insert(returnedBinsVec.begin(), returnedBinsVec.end()); + expectedBins = {{12, 19, 26, 33, 40}}; + BOOST_CHECK(returnedBins == expectedBins); + + // @TODO 3D test would be nice, but should essentially not be a problem if + // 2D works. + + // clang-format off + /* + * 1 2 3 4 5 + * |------------------------| + * 1 | 8 | 9 | 10 | 11 | 12 | + * |----|----|----|----|----| + * 2 | 15 | 16 | 17 | 18 | 19 | + * |----|----|----|----|----| + * 3 | 22 | 23 | 24 | 25 | 26 | + * |----|----|----|----|----| + * 4 | 29 | 30 | 31 | 32 | 33 | + * |----|----|----|----|----| + * 5 | 36 | 37 | 38 | 39 | 40 | + * |------------------------| + */ + // clang-format on +} + +BOOST_AUTO_TEST_CASE(closestPoints) { + using Point1 = std::array; + using Point2 = std::array; + using Point3 = std::array; + using bins_t = std::vector; + + Axis a(0.0, 1.0, 10u); + Axis b(0.0, 1.0, 5u); + Axis c(0.0, 1.0, 3u); + + MultiAxis g1(a); + MultiAxis g2(a, b); + MultiAxis g3(a, b, c); + + // 1D case + BOOST_CHECK(g1.getClosestPointsIndices(Point1{0.52}).collectVector() == + (bins_t{6, 7})); + BOOST_CHECK(g1.getClosestPointsIndices(Point1{0.98}).collectVector() == + (bins_t{10, 11})); + + // 2D case + BOOST_CHECK(g2.getClosestPointsIndices(Point2{0.52, 0.08}).collectVector() == + (bins_t{43, 44, 50, 51})); + BOOST_CHECK(g2.getClosestPointsIndices(Point2{0.05, 0.08}).collectVector() == + (bins_t{8, 9, 15, 16})); + + // 3D case + BOOST_CHECK( + g3.getClosestPointsIndices(Point3{0.23, 0.13, 0.61}).collectVector() == + (bins_t{112, 113, 117, 118, 147, 148, 152, 153})); + BOOST_CHECK( + g3.getClosestPointsIndices(Point3{0.52, 0.35, 0.71}).collectVector() == + (bins_t{223, 224, 228, 229, 258, 259, 263, 264})); + + using EAxisClosed = Axis; + + using MultiAxis1Cl_t = MultiAxis; + using MultiAxis2Cl_t = MultiAxis; + // using MultiAxis3Cl_t = MultiAxis; + + EAxisClosed aCl(0.0, 1.0, 10u); + EAxisClosed bCl(0.0, 1.0, 5u); + // EAxisClosed cCl(0.0, 1.0, 3u); + + MultiAxis1Cl_t ma1Cl(std::make_tuple(aCl)); + MultiAxis2Cl_t ma2Cl(std::make_tuple(aCl, bCl)); + + // 1D case + BOOST_CHECK(ma1Cl.getClosestPointsIndices(Point1{0.52}).collectVector() == + (bins_t{6, 7})); + BOOST_CHECK(ma1Cl.getClosestPointsIndices(Point1{0.98}).collectVector() == + (bins_t{10, 1})); + + // 2D case + BOOST_CHECK( + ma2Cl.getClosestPointsIndices(Point2{0.52, 0.08}).collectVector() == + (bins_t{43, 44, 50, 51})); + BOOST_CHECK( + ma2Cl.getClosestPointsIndices(Point2{0.52, 0.68}).collectVector() == + (bins_t{46, 47, 53, 54})); + BOOST_CHECK( + ma2Cl.getClosestPointsIndices(Point2{0.52, 0.88}).collectVector() == + (bins_t{47, 43, 54, 50})); + BOOST_CHECK( + ma2Cl.getClosestPointsIndices(Point2{0.05, 0.08}).collectVector() == + (bins_t{8, 9, 15, 16})); + BOOST_CHECK( + ma2Cl.getClosestPointsIndices(Point2{0.9, 0.95}).collectVector() == + (bins_t{75, 71, 12, 8})); + + // @TODO: 3D checks would also be nice + + Axis aOp(AxisBound, 0.0, 1.0, 10u); + Axis bOp(AxisBound, 0.0, 1.0, 5u); + + MultiAxis ma1Op(aOp); + MultiAxis ma2Op(aOp, bOp); + + // 1D case + BOOST_CHECK(ma1Op.getClosestPointsIndices(Point1{0.52}).collectVector() == + (bins_t{6, 7})); + BOOST_CHECK(ma1Op.getClosestPointsIndices(Point1{0.98}).collectVector() == + (bins_t{10})); + BOOST_CHECK(ma1Op.getClosestPointsIndices(Point1{0.88}).collectVector() == + (bins_t{9, 10})); + + // 2D case + BOOST_CHECK( + ma2Op.getClosestPointsIndices(Point2{0.52, 0.08}).collectVector() == + (bins_t{43, 44, 50, 51})); + BOOST_CHECK( + ma2Op.getClosestPointsIndices(Point2{0.52, 0.68}).collectVector() == + (bins_t{46, 47, 53, 54})); + BOOST_CHECK( + ma2Op.getClosestPointsIndices(Point2{0.52, 0.88}).collectVector() == + (bins_t{47, 54})); + BOOST_CHECK( + ma2Op.getClosestPointsIndices(Point2{0.05, 0.1}).collectVector() == + (bins_t{8, 9, 15, 16})); + BOOST_CHECK( + ma2Op.getClosestPointsIndices(Point2{0.95, 0.95}).collectVector() == + (bins_t{75})); + + // @TODO: 3D checks would also be nice + + // clang-format off + /* + * 1 2 3 4 5 + * |------------------------| + * 1 | 8 | 9 | 10 | 11 | 12 | + * |----|----|----|----|----| + * 2 | 15 | 16 | 17 | 18 | 19 | + * |----|----|----|----|----| + * 3 | 22 | 23 | 24 | 25 | 26 | + * |----|----|----|----|----| + * 4 | 29 | 30 | 31 | 32 | 33 | + * |----|----|----|----|----| + * 5 | 36 | 37 | 38 | 39 | 40 | + * |------------------------| + * 6 | 43 | 44 | 45 | 46 | 47 | + * |------------------------| + * 7 | 50 | 51 | 52 | 53 | 54 | + * |------------------------| + * 8 | 57 | 58 | 59 | 60 | 61 | + * |------------------------| + * 9 | 64 | 65 | 66 | 67 | 68 | + * |------------------------| + * 10 | 71 | 72 | 73 | 74 | 75 | + * |------------------------| + * 77 78 79 80 81 82 83 + */ + // clang-format on +} + +BOOST_AUTO_TEST_CASE(Output) { + Axis a{AxisOpen, 0.0, 1.0, 10u}; + Axis b{AxisBound, {1, 2, 3}}; + + MultiAxis ma(a, b); + + std::stringstream ss; + ss << ma; + BOOST_CHECK_EQUAL( + ss.str(), + "Axis(0, 1, 10), Axis(1, 2, 3)"); + + const IMultiAxis& ima = ma; + + ss.str(""); + + ss << ima; + + BOOST_CHECK_EQUAL( + ss.str(), + "Axis(0, 1, 10), Axis(1, 2, 3)"); +} + +BOOST_AUTO_TEST_CASE(Equality) { + Axis a{AxisOpen, 0.0, 1.0, 10u}; + Axis b{AxisBound, {1, 2, 3}}; + Axis c{AxisClosed, {1, 2, 5}}; + + MultiAxis ma_ab(a, b); + MultiAxis ma_ac(a, c); + + BOOST_CHECK_EQUAL(ma_ab, ma_ab); + BOOST_CHECK_EQUAL(ma_ac, ma_ac); + BOOST_CHECK_NE(ma_ab, ma_ac); + + const IMultiAxis& ima_ab = ma_ab; + const IMultiAxis& ima_ac = ma_ac; + + BOOST_CHECK_EQUAL(ima_ab, ima_ab); + BOOST_CHECK_EQUAL(ima_ac, ima_ac); +} + +BOOST_AUTO_TEST_SUITE_END() + +} // namespace ActsTests