diff --git a/dx2/reflection.cxx b/dx2/reflection.cxx index 369efcc..1e1cf3a 100644 --- a/dx2/reflection.cxx +++ b/dx2/reflection.cxx @@ -129,6 +129,8 @@ std::vector ReflectionTable::get_column_names() const { } return names; } + +size_t ReflectionTable::size() const { return get_row_count(); } #pragma endregion #pragma region Private Helper Methods @@ -184,14 +186,11 @@ void ReflectionTable::write(std::string_view filename, // Suppress errors when opening non-existent files, groups, datasets.. H5ErrorSilencer silencer; - // 🗂️ Ensure the file exists or create it before writing - h5utils::H5File file(H5Fopen(fname.c_str(), H5F_ACC_RDWR, H5P_DEFAULT)); + // 🗂️ Create (or truncate) the file + h5utils::H5File file( + H5Fcreate(fname.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT)); if (!file) { - file = h5utils::H5File( - H5Fcreate(fname.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT)); - if (!file) { - throw std::runtime_error("Failed to create or open file: " + fname); - } + throw std::runtime_error("Failed to create or open file: " + fname); } // Open or create group diff --git a/include/dx2/reflection.hpp b/include/dx2/reflection.hpp index 4a9c468..df44bbb 100644 --- a/include/dx2/reflection.hpp +++ b/include/dx2/reflection.hpp @@ -257,6 +257,20 @@ class ReflectionTable { return col.get_name() == name && col.get_type() == typeid(T); } + /// Throw if type T is not supported by the HDF5 backend. + template void ensure_supported_type() const { + const auto ®istry = h5dispatch::get_supported_types(); + bool supported = std::any_of(registry.begin(), registry.end(), + [](const h5dispatch::H5TypeInfo &info) { + return info.cpp_type == typeid(T); + }); + if (!supported) { + throw std::runtime_error( + "Attempted to add column with unsupported type: " + + std::string(typeid(T).name())); + } + } + /** * @brief Merges a vector of row indices into a set. * @@ -378,6 +392,14 @@ class ReflectionTable { * @brief Get a list of all column names in the table. */ std::vector get_column_names() const; + + /** + * @brief Returns the number of rows (reflections) in the table. + * + * All columns are required to share the same row count, so this is the + * length of any column. Returns 0 if the table has no columns. + */ + size_t size() const; #pragma endregion #pragma region Column Access @@ -559,17 +581,7 @@ class ReflectionTable { } // Check if type T is supported - const auto ®istry = h5dispatch::get_supported_types(); - bool supported = std::any_of(registry.begin(), registry.end(), - [](const h5dispatch::H5TypeInfo &info) { - return info.cpp_type == typeid(T); - }); - - if (!supported) { - throw std::runtime_error( - "Attempted to add column with unsupported type: " + - std::string(typeid(T).name())); - } + ensure_supported_type(); // Ensure row count consistency if (!data.empty() && col->get_shape()[0] != get_row_count()) { @@ -616,6 +628,82 @@ class ReflectionTable { const std::vector &column_data) { add_column(name, std::vector{rows, cols}, column_data); } + + /** + * @brief Replaces the data of an existing column. + * + * Unlike `add_column`, this requires a column of the given name to + * already exist and replaces it in place (preserving column order). + * The replacement is wholesale, so the element type and shape may + * change, but the new row count must match the rest of the table. + * + * @tparam T The data type of the column. + * @param name The name of the (existing) column to update. + * @param shape A vector describing the shape of the new column data. + * @param column_data A flat vector of `T` values for the column. + * + * @throws std::runtime_error if no column with `name` exists, the type + * is unsupported, or the row count does not match the other columns. + */ + template + void update_column(const std::string &name, const std::vector &shape, + const std::vector &column_data) { + // Check if the type T is a bool. If so, convert to BoolEnum and update. + if constexpr (std::is_same_v) { + std::vector converted(column_data.size()); + for (size_t i = 0; i < column_data.size(); ++i) { + converted[i] = column_data[i] ? h5dispatch::BoolEnum::TRUE + : h5dispatch::BoolEnum::FALSE; + } + update_column(name, shape, converted); + } else { + // Locate the existing column (strict: it must already exist) + auto it = std::find_if(data.begin(), data.end(), + [&](const std::unique_ptr &c) { + return c->get_name() == name; + }); + if (it == data.end()) { + throw std::runtime_error("Column not found for update: " + name); + } + + auto col = std::make_unique>(name, shape, column_data); + + // Check if type T is supported + ensure_supported_type(); + + // Ensure row count consistency with the other columns in the table + for (const auto &other : data) { + if (other->get_name() != name) { + if (col->get_shape()[0] != other->get_shape()[0]) { + throw std::runtime_error( + "Row count mismatch when updating column: " + name); + } + break; + } + } + + // Replace in place, preserving the column's position + *it = std::move(col); + } + } + + /** + * @brief Replaces an existing 1D column. + */ + template + void update_column(const std::string &name, + const std::vector &column_data) { + update_column(name, std::vector{column_data.size()}, column_data); + } + + /** + * @brief Replaces an existing 2D column. + */ + template + void update_column(const std::string &name, const size_t rows, + const size_t cols, const std::vector &column_data) { + update_column(name, std::vector{rows, cols}, column_data); + } #pragma endregion #pragma region Write diff --git a/tests/test_reflection_table.cxx b/tests/test_reflection_table.cxx index 2599110..8ec7900 100644 --- a/tests/test_reflection_table.cxx +++ b/tests/test_reflection_table.cxx @@ -173,6 +173,30 @@ TEST_F(ReflectionTableTest, AccessNonExistentColumn) { } #pragma endregion +#pragma region Size +TEST_F(ReflectionTableTest, SizeOfEmptyTableIsZero) { + ReflectionTable table; + EXPECT_EQ(table.size(), 0u); +} + +TEST_F(ReflectionTableTest, SizeMatchesAddedColumnRowCount) { + ReflectionTable table; + std::vector data{1.0, 2.0, 3.0, 4.0}; + table.add_column("col", data); + EXPECT_EQ(table.size(), 4u); +} + +TEST_F(ReflectionTableTest, SizeMatchesLoadedColumnExtent) { + ReflectionTable table(test_file_path.string()); + + auto col = table.column("xyzobs.px.value"); + ASSERT_TRUE(col.has_value()); + + EXPECT_EQ(table.size(), col->extent(0)); + EXPECT_GT(table.size(), 0u); +} +#pragma endregion + #pragma region Adding TEST_F(ReflectionTableTest, AddColumn1D) { ReflectionTable table; @@ -284,6 +308,69 @@ TEST_F(ReflectionTableTest, AddDuplicateColumnThrows) { } } +TEST_F(ReflectionTableTest, UpdateColumnReplacesData) { + ReflectionTable table; + table.add_column("col", std::vector{1.0, 2.0, 3.0}); + + table.update_column("col", std::vector{4.0, 5.0, 6.0}); + + auto col = table.column("col"); + ASSERT_TRUE(col.has_value()); + ASSERT_EQ(col->extent(0), 3); + EXPECT_DOUBLE_EQ((*col)(0, 0), 4.0); + EXPECT_DOUBLE_EQ((*col)(1, 0), 5.0); + EXPECT_DOUBLE_EQ((*col)(2, 0), 6.0); +} + +TEST_F(ReflectionTableTest, UpdateColumnThrowsIfMissing) { + ReflectionTable table; + table.add_column("col", std::vector{1.0, 2.0, 3.0}); + + try { + table.update_column("missing", std::vector{4.0, 5.0, 6.0}); + FAIL() << "Expected std::runtime_error for missing column"; + } catch (const std::runtime_error &e) { + std::cout << "[UpdateColumnThrowsIfMissing] Caught: " << e.what() << "\n"; + EXPECT_TRUE(std::string(e.what()).find("not found") != std::string::npos); + } catch (...) { + FAIL() << "Expected std::runtime_error: column not found"; + } +} + +TEST_F(ReflectionTableTest, UpdateColumnRejectsMismatchedRowCount) { + ReflectionTable table; + table.add_column("a", std::vector{1.0, 2.0, 3.0}); + table.add_column("b", std::vector{4.0, 5.0, 6.0}); + + try { + // 4 rows where the rest of the table has 3 - should throw + table.update_column("b", std::vector{7.0, 8.0, 9.0, 10.0}); + FAIL() << "Expected std::runtime_error due to row count mismatch"; + } catch (const std::runtime_error &e) { + std::cout << "[UpdateColumnRejectsMismatchedRowCount] Caught: " << e.what() + << "\n"; + EXPECT_TRUE(std::string(e.what()).find("Row count mismatch") != + std::string::npos); + } catch (...) { + FAIL() << "Expected std::runtime_error: row count mismatch"; + } +} + +TEST_F(ReflectionTableTest, UpdateColumnPreservesOrder) { + ReflectionTable table; + table.add_column("a", std::vector{1.0, 2.0}); + table.add_column("b", std::vector{3.0, 4.0}); + table.add_column("c", std::vector{5.0, 6.0}); + + table.update_column("b", std::vector{30.0, 40.0}); + + auto names = table.get_column_names(); + ASSERT_EQ(names.size(), 3u); + EXPECT_EQ(names[0], "a"); + EXPECT_EQ(names[1], "b"); + EXPECT_EQ(names[2], "c"); +} + TEST_F(ReflectionTableTest, AddUnsupportedColumnTypeThrows) { ReflectionTable table; @@ -353,6 +440,42 @@ TEST_F(ReflectionTableTest, WriteTableFromScratchAndReload) { std::filesystem::remove(temp_file); } +TEST_F(ReflectionTableTest, WriteOverwritesExistingFile) { + std::filesystem::path temp_file = + std::filesystem::current_path() / "reflection_test_overwrite.h5"; + + // First write: a table with two columns ("id" and "extra") + { + ReflectionTable first; + first.add_column("id", std::vector{1, 2, 3}); + first.add_column("extra", std::vector{10, 20, 30}); + first.write(temp_file.string()); + } + + // Second write to the SAME path: a different table with only "id" + { + ReflectionTable second; + second.add_column("id", std::vector{7, 8}); + second.write(temp_file.string()); + } + + // Reload and confirm the file reflects only the second table (overwrite, + // not update): "extra" should be gone and "id" should match the second write. + ReflectionTable loaded(temp_file.string()); + + auto extra = loaded.column("extra"); + EXPECT_FALSE(extra.has_value()) + << "Stale 'extra' column survived - write did not overwrite the file"; + + auto id = loaded.column("id"); + ASSERT_TRUE(id.has_value()); + ASSERT_EQ(id->extent(0), 2); + EXPECT_EQ((*id)(0, 0), 7); + EXPECT_EQ((*id)(1, 0), 8); + + std::filesystem::remove(temp_file); +} + TEST_F(ReflectionTableTest, AddBooleanColumnFromVectorBool) { ReflectionTable table;