Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions dx2/reflection.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ std::vector<std::string> ReflectionTable::get_column_names() const {
}
return names;
}

size_t ReflectionTable::size() const { return get_row_count(); }
#pragma endregion

#pragma region Private Helper Methods
Expand Down Expand Up @@ -184,14 +186,11 @@ void ReflectionTable::write(std::string_view filename,
// Suppress errors when opening non-existent files, groups, datasets..
H5ErrorSilencer silencer;

// 🗂️ Ensure the file exists or create it before writing
h5utils::H5File file(H5Fopen(fname.c_str(), H5F_ACC_RDWR, H5P_DEFAULT));
// 🗂️ Create (or truncate) the file
h5utils::H5File file(
H5Fcreate(fname.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT));
if (!file) {
file = h5utils::H5File(
H5Fcreate(fname.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT));
if (!file) {
throw std::runtime_error("Failed to create or open file: " + fname);
}
throw std::runtime_error("Failed to create or open file: " + fname);
}

// Open or create group
Expand Down
110 changes: 99 additions & 11 deletions include/dx2/reflection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,20 @@ class ReflectionTable {
return col.get_name() == name && col.get_type() == typeid(T);
}

/// Throw if type T is not supported by the HDF5 backend.
template <typename T> void ensure_supported_type() const {
const auto &registry = h5dispatch::get_supported_types();
bool supported = std::any_of(registry.begin(), registry.end(),
[](const h5dispatch::H5TypeInfo &info) {
return info.cpp_type == typeid(T);
});
if (!supported) {
throw std::runtime_error(
"Attempted to add column with unsupported type: " +
std::string(typeid(T).name()));
}
}

/**
* @brief Merges a vector of row indices into a set.
*
Expand Down Expand Up @@ -378,6 +392,14 @@ class ReflectionTable {
* @brief Get a list of all column names in the table.
*/
std::vector<std::string> get_column_names() const;

/**
* @brief Returns the number of rows (reflections) in the table.
*
* All columns are required to share the same row count, so this is the
* length of any column. Returns 0 if the table has no columns.
*/
size_t size() const;
#pragma endregion

#pragma region Column Access
Expand Down Expand Up @@ -559,17 +581,7 @@ class ReflectionTable {
}

// Check if type T is supported
const auto &registry = h5dispatch::get_supported_types();
bool supported = std::any_of(registry.begin(), registry.end(),
[](const h5dispatch::H5TypeInfo &info) {
return info.cpp_type == typeid(T);
});

if (!supported) {
throw std::runtime_error(
"Attempted to add column with unsupported type: " +
std::string(typeid(T).name()));
}
ensure_supported_type<T>();

// Ensure row count consistency
if (!data.empty() && col->get_shape()[0] != get_row_count()) {
Expand Down Expand Up @@ -616,6 +628,82 @@ class ReflectionTable {
const std::vector<T> &column_data) {
add_column(name, std::vector<size_t>{rows, cols}, column_data);
}

/**
* @brief Replaces the data of an existing column.
*
* Unlike `add_column`, this requires a column of the given name to
* already exist and replaces it in place (preserving column order).
* The replacement is wholesale, so the element type and shape may
* change, but the new row count must match the rest of the table.
*
* @tparam T The data type of the column.
* @param name The name of the (existing) column to update.
* @param shape A vector describing the shape of the new column data.
* @param column_data A flat vector of `T` values for the column.
*
* @throws std::runtime_error if no column with `name` exists, the type
* is unsupported, or the row count does not match the other columns.
*/
template <typename T>
void update_column(const std::string &name, const std::vector<size_t> &shape,
const std::vector<T> &column_data) {
// Check if the type T is a bool. If so, convert to BoolEnum and update.
if constexpr (std::is_same_v<T, bool>) {
std::vector<h5dispatch::BoolEnum> converted(column_data.size());
for (size_t i = 0; i < column_data.size(); ++i) {
converted[i] = column_data[i] ? h5dispatch::BoolEnum::TRUE
: h5dispatch::BoolEnum::FALSE;
}
update_column<h5dispatch::BoolEnum>(name, shape, converted);
} else {
// Locate the existing column (strict: it must already exist)
auto it = std::find_if(data.begin(), data.end(),
[&](const std::unique_ptr<ColumnBase> &c) {
return c->get_name() == name;
});
if (it == data.end()) {
throw std::runtime_error("Column not found for update: " + name);
}

auto col = std::make_unique<TypedColumn<T>>(name, shape, column_data);

// Check if type T is supported
ensure_supported_type<T>();

// Ensure row count consistency with the other columns in the table
for (const auto &other : data) {
if (other->get_name() != name) {
if (col->get_shape()[0] != other->get_shape()[0]) {
throw std::runtime_error(
"Row count mismatch when updating column: " + name);
}
break;
}
}

// Replace in place, preserving the column's position
*it = std::move(col);
}
}

/**
* @brief Replaces an existing 1D column.
*/
template <typename T>
void update_column(const std::string &name,
const std::vector<T> &column_data) {
update_column(name, std::vector<size_t>{column_data.size()}, column_data);
}

/**
* @brief Replaces an existing 2D column.
*/
template <typename T>
void update_column(const std::string &name, const size_t rows,
const size_t cols, const std::vector<T> &column_data) {
update_column(name, std::vector<size_t>{rows, cols}, column_data);
}
#pragma endregion

#pragma region Write
Expand Down
123 changes: 123 additions & 0 deletions tests/test_reflection_table.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,30 @@ TEST_F(ReflectionTableTest, AccessNonExistentColumn) {
}
#pragma endregion

#pragma region Size
TEST_F(ReflectionTableTest, SizeOfEmptyTableIsZero) {
ReflectionTable table;
EXPECT_EQ(table.size(), 0u);
}

TEST_F(ReflectionTableTest, SizeMatchesAddedColumnRowCount) {
ReflectionTable table;
std::vector<double> data{1.0, 2.0, 3.0, 4.0};
table.add_column("col", data);
EXPECT_EQ(table.size(), 4u);
}

TEST_F(ReflectionTableTest, SizeMatchesLoadedColumnExtent) {
ReflectionTable table(test_file_path.string());

auto col = table.column<double>("xyzobs.px.value");
ASSERT_TRUE(col.has_value());

EXPECT_EQ(table.size(), col->extent(0));
EXPECT_GT(table.size(), 0u);
}
#pragma endregion

#pragma region Adding
TEST_F(ReflectionTableTest, AddColumn1D) {
ReflectionTable table;
Expand Down Expand Up @@ -284,6 +308,69 @@ TEST_F(ReflectionTableTest, AddDuplicateColumnThrows) {
}
}

TEST_F(ReflectionTableTest, UpdateColumnReplacesData) {
ReflectionTable table;
table.add_column("col", std::vector<double>{1.0, 2.0, 3.0});

table.update_column("col", std::vector<double>{4.0, 5.0, 6.0});

auto col = table.column<double>("col");
ASSERT_TRUE(col.has_value());
ASSERT_EQ(col->extent(0), 3);
EXPECT_DOUBLE_EQ((*col)(0, 0), 4.0);
EXPECT_DOUBLE_EQ((*col)(1, 0), 5.0);
EXPECT_DOUBLE_EQ((*col)(2, 0), 6.0);
}

TEST_F(ReflectionTableTest, UpdateColumnThrowsIfMissing) {
ReflectionTable table;
table.add_column("col", std::vector<double>{1.0, 2.0, 3.0});

try {
table.update_column("missing", std::vector<double>{4.0, 5.0, 6.0});
FAIL() << "Expected std::runtime_error for missing column";
} catch (const std::runtime_error &e) {
std::cout << "[UpdateColumnThrowsIfMissing] Caught: " << e.what() << "\n";
EXPECT_TRUE(std::string(e.what()).find("not found") != std::string::npos);
} catch (...) {
FAIL() << "Expected std::runtime_error: column not found";
}
}

TEST_F(ReflectionTableTest, UpdateColumnRejectsMismatchedRowCount) {
ReflectionTable table;
table.add_column("a", std::vector<double>{1.0, 2.0, 3.0});
table.add_column("b", std::vector<double>{4.0, 5.0, 6.0});

try {
// 4 rows where the rest of the table has 3 - should throw
table.update_column("b", std::vector<double>{7.0, 8.0, 9.0, 10.0});
FAIL() << "Expected std::runtime_error due to row count mismatch";
} catch (const std::runtime_error &e) {
std::cout << "[UpdateColumnRejectsMismatchedRowCount] Caught: " << e.what()
<< "\n";
EXPECT_TRUE(std::string(e.what()).find("Row count mismatch") !=
std::string::npos);
} catch (...) {
FAIL() << "Expected std::runtime_error: row count mismatch";
}
}

TEST_F(ReflectionTableTest, UpdateColumnPreservesOrder) {
ReflectionTable table;
table.add_column("a", std::vector<double>{1.0, 2.0});
table.add_column("b", std::vector<double>{3.0, 4.0});
table.add_column("c", std::vector<double>{5.0, 6.0});

table.update_column("b", std::vector<double>{30.0, 40.0});

auto names = table.get_column_names();
ASSERT_EQ(names.size(), 3u);
EXPECT_EQ(names[0], "a");
EXPECT_EQ(names[1], "b");
EXPECT_EQ(names[2], "c");
}

TEST_F(ReflectionTableTest, AddUnsupportedColumnTypeThrows) {
ReflectionTable table;

Expand Down Expand Up @@ -353,6 +440,42 @@ TEST_F(ReflectionTableTest, WriteTableFromScratchAndReload) {
std::filesystem::remove(temp_file);
}

TEST_F(ReflectionTableTest, WriteOverwritesExistingFile) {
std::filesystem::path temp_file =
std::filesystem::current_path() / "reflection_test_overwrite.h5";

// First write: a table with two columns ("id" and "extra")
{
ReflectionTable first;
first.add_column<int>("id", std::vector<int>{1, 2, 3});
first.add_column<int>("extra", std::vector<int>{10, 20, 30});
first.write(temp_file.string());
}

// Second write to the SAME path: a different table with only "id"
{
ReflectionTable second;
second.add_column<int>("id", std::vector<int>{7, 8});
second.write(temp_file.string());
}

// Reload and confirm the file reflects only the second table (overwrite,
// not update): "extra" should be gone and "id" should match the second write.
ReflectionTable loaded(temp_file.string());

auto extra = loaded.column<int>("extra");
EXPECT_FALSE(extra.has_value())
<< "Stale 'extra' column survived - write did not overwrite the file";

auto id = loaded.column<int>("id");
ASSERT_TRUE(id.has_value());
ASSERT_EQ(id->extent(0), 2);
EXPECT_EQ((*id)(0, 0), 7);
EXPECT_EQ((*id)(1, 0), 8);

std::filesystem::remove(temp_file);
}

TEST_F(ReflectionTableTest, AddBooleanColumnFromVectorBool) {
ReflectionTable table;

Expand Down
Loading