diff --git a/velox/dwio/parquet/writer/arrow/Metadata.cpp b/velox/dwio/parquet/writer/arrow/Metadata.cpp index a56c40e5fdf..74b545ec312 100644 --- a/velox/dwio/parquet/writer/arrow/Metadata.cpp +++ b/velox/dwio/parquet/writer/arrow/Metadata.cpp @@ -101,10 +101,12 @@ static std::shared_ptr MakeTypedColumnStats( metadata.num_values - metadata.statistics.null_count, metadata.statistics.null_count, metadata.statistics.distinct_count, + /*nan_count=*/0, metadata.statistics.__isset.max_value || metadata.statistics.__isset.min_value, metadata.statistics.__isset.null_count, - metadata.statistics.__isset.distinct_count); + metadata.statistics.__isset.distinct_count, + /*has_nan_count=*/false); } // Default behavior return MakeStatistics( @@ -114,9 +116,11 @@ static std::shared_ptr MakeTypedColumnStats( metadata.num_values - metadata.statistics.null_count, metadata.statistics.null_count, metadata.statistics.distinct_count, + /*nan_count=*/0, metadata.statistics.__isset.max || metadata.statistics.__isset.min, metadata.statistics.__isset.null_count, - metadata.statistics.__isset.distinct_count); + metadata.statistics.__isset.distinct_count, + /*has_nan_count=*/false); } std::shared_ptr MakeColumnStats( @@ -1015,6 +1019,22 @@ class FileMetaData::FileMetaDataImpl { file_decryptor_ = file_decryptor; } + // Set NaN counts from the builder (called during Finish) + // This stores total NaN counts per field ID across all row groups. + void setNaNCounts( + std::unordered_map> nan_counts) { + field_nan_counts_ = std::move(nan_counts); + } + + // Get total NaN count for a specific field ID across all row groups. + std::pair getNaNCount(int32_t fieldId) const { + auto it = field_nan_counts_.find(fieldId); + if (it != field_nan_counts_.end()) { + return it->second; + } + return {0, false}; + } + private: friend FileMetaDataBuilder; uint32_t metadata_len_ = 0; @@ -1024,6 +1044,9 @@ class FileMetaData::FileMetaDataImpl { std::shared_ptr key_value_metadata_; const ReaderProperties properties_; std::shared_ptr file_decryptor_; + // Total NaN counts per field ID across all row groups: field_id -> + // (nan_count, has_nan_count). + std::unordered_map> field_nan_counts_; void InitSchema() { if (metadata_->schema.empty()) { @@ -1200,6 +1223,10 @@ std::shared_ptr FileMetaData::Subset( return impl_->Subset(row_groups); } +std::pair FileMetaData::getNaNCount(int32_t fieldId) const { + return impl_->getNaNCount(fieldId); +} + void FileMetaData::WriteTo( ::arrow::io::OutputStream* dst, const std::shared_ptr& encryptor) const { @@ -1715,6 +1742,19 @@ class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl { // column metadata void SetStatistics(const EncodedStatistics& val) { column_chunk_->meta_data.__set_statistics(ToThrift(val)); + // Store NaN count separately since it's not written to the parquet file. + if (val.has_nan_count) { + nan_count_ = val.nan_count; + has_nan_count_ = true; + } + } + + int64_t nan_count() const { + return nan_count_; + } + + bool has_nan_count() const { + return has_nan_count_; } void Finish( @@ -1883,6 +1923,9 @@ class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl { owned_column_chunk_; const std::shared_ptr properties_; const ColumnDescriptor* column_; + // NaN count is stored separately since it's not written to the parquet file. + int64_t nan_count_ = 0; + bool has_nan_count_ = false; }; std::unique_ptr ColumnChunkMetaDataBuilder::Make( @@ -1970,6 +2013,14 @@ int64_t ColumnChunkMetaDataBuilder::total_compressed_size() const { return impl_->total_compressed_size(); } +int64_t ColumnChunkMetaDataBuilder::nan_count() const { + return impl_->nan_count(); +} + +bool ColumnChunkMetaDataBuilder::has_nan_count() const { + return impl_->has_nan_count(); +} + class RowGroupMetaDataBuilder::RowGroupMetaDataBuilderImpl { public: explicit RowGroupMetaDataBuilderImpl( @@ -2062,6 +2113,16 @@ class RowGroupMetaDataBuilder::RowGroupMetaDataBuilderImpl { return row_group_->num_rows; } + // Returns a map of field_id -> (nan_count, has_nan_count). + std::unordered_map> nan_counts() const { + std::unordered_map> result; + for (const auto& builder : column_builders_) { + int32_t field_id = builder->descr()->schema_node()->field_id(); + result[field_id] = {builder->nan_count(), builder->has_nan_count()}; + } + return result; + } + private: void InitializeColumns(int ncols) { row_group_->columns.resize(ncols); @@ -2119,6 +2180,11 @@ void RowGroupMetaDataBuilder::Finish( impl_->Finish(total_bytes_written, row_group_ordinal); } +std::unordered_map> +RowGroupMetaDataBuilder::nan_counts() const { + return impl_->nan_counts(); +} + // file metadata class FileMetaDataBuilder::FileMetaDataBuilderImpl { public: @@ -2138,6 +2204,9 @@ class FileMetaDataBuilder::FileMetaDataBuilderImpl { } RowGroupMetaDataBuilder* AppendRowGroup() { + // Accumulate NaN counts from the previous row group before creating a new + // one. + accumulateNaNCountsFromCurrentRowGroup(); row_groups_.emplace_back(); current_row_group_builder_ = RowGroupMetaDataBuilder::Make( properties_, schema_, &row_groups_.back()); @@ -2182,6 +2251,9 @@ class FileMetaDataBuilder::FileMetaDataBuilderImpl { std::unique_ptr Finish( const std::shared_ptr& key_value_metadata) { + // Accumulate NaN counts from the last row group. + accumulateNaNCountsFromCurrentRowGroup(); + int64_t total_rows = 0; for (auto row_group : row_groups_) { total_rows += row_group.num_rows; @@ -2259,6 +2331,8 @@ class FileMetaDataBuilder::FileMetaDataBuilderImpl { file_meta_data->impl_->metadata_ = std::move(metadata_); file_meta_data->impl_->InitSchema(); file_meta_data->impl_->InitKeyValueMetadata(); + // Pass total NaN counts per field ID to FileMetaData. + file_meta_data->impl_->setNaNCounts(std::move(field_nan_counts_)); return file_meta_data; } @@ -2290,12 +2364,31 @@ class FileMetaDataBuilder::FileMetaDataBuilderImpl { crypto_metadata_; private: + // Helper to accumulate NaN counts from the current row group builder. + void accumulateNaNCountsFromCurrentRowGroup() { + if (!current_row_group_builder_) { + return; + } + auto rg_nan_counts = current_row_group_builder_->nan_counts(); + // Accumulate NaN counts from this row group (keyed by field ID). + for (const auto& [fieldId, countPair] : rg_nan_counts) { + const auto& [count, has_count] = countPair; + if (has_count) { + field_nan_counts_[fieldId].first += count; + field_nan_counts_[fieldId].second = true; + } + } + } + const std::shared_ptr properties_; std::vector row_groups_; std::unique_ptr current_row_group_builder_; const SchemaDescriptor* schema_; std::shared_ptr key_value_metadata_; + // Total NaN counts per field ID across all row groups: field_id -> + // (nan_count, has_nan_count). + std::unordered_map> field_nan_counts_; }; std::unique_ptr FileMetaDataBuilder::Make( diff --git a/velox/dwio/parquet/writer/arrow/Metadata.h b/velox/dwio/parquet/writer/arrow/Metadata.h index 7cb7670a038..38a7f40da89 100644 --- a/velox/dwio/parquet/writer/arrow/Metadata.h +++ b/velox/dwio/parquet/writer/arrow/Metadata.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -416,6 +417,12 @@ class PARQUET_EXPORT FileMetaData { std::shared_ptr Subset( const std::vector& row_groups) const; + /// \brief Get total NaN count for a specific field ID across all row groups. + /// Returns a pair of (nan_count, has_nan_count). + /// NaN counts are collected during writing but not written to the parquet + /// file. + std::pair getNaNCount(int32_t fieldId) const; + private: friend FileMetaDataBuilder; friend class SerializedFile; @@ -486,6 +493,13 @@ class PARQUET_EXPORT ColumnChunkMetaDataBuilder { const ColumnDescriptor* descr() const; int64_t total_compressed_size() const; + + // NaN count accessors - NaN counts are collected during writing but not + // written to the parquet file. + int64_t nan_count() const; + + bool has_nan_count() const; + // commit the metadata void Finish( @@ -537,6 +551,10 @@ class PARQUET_EXPORT RowGroupMetaDataBuilder { void set_num_rows(int64_t num_rows); + // Get NaN counts for all columns in current row group. + // Returns a map of field_id -> (nan_count, has_nan_count). + std::unordered_map> nan_counts() const; + // commit the metadata void Finish(int64_t total_bytes_written, int16_t row_group_ordinal = -1); diff --git a/velox/dwio/parquet/writer/arrow/Statistics.cpp b/velox/dwio/parquet/writer/arrow/Statistics.cpp index 143f9bd101b..dc23727ce94 100644 --- a/velox/dwio/parquet/writer/arrow/Statistics.cpp +++ b/velox/dwio/parquet/writer/arrow/Statistics.cpp @@ -18,8 +18,6 @@ #include "velox/dwio/parquet/writer/arrow/Statistics.h" -#include "velox/functions/lib/string/StringImpl.h" - #include #include #include @@ -42,6 +40,7 @@ #include "velox/dwio/parquet/writer/arrow/Platform.h" #include "velox/dwio/parquet/writer/arrow/Schema.h" #include "velox/functions/lib/Utf8Utils.h" +#include "velox/functions/lib/string/StringImpl.h" #include "velox/type/DecimalUtil.h" #include "velox/type/HugeInt.h" @@ -613,9 +612,11 @@ class TypedStatisticsImpl : public TypedStatistics { int64_t num_values, int64_t null_count, int64_t distinct_count, + int64_t nan_count, bool has_min_max, bool has_null_count, bool has_distinct_count, + bool has_nan_count, MemoryPool* pool) : TypedStatisticsImpl(descr, pool) { TypedStatisticsImpl::IncrementNumValues(num_values); @@ -630,6 +631,12 @@ class TypedStatisticsImpl : public TypedStatistics { has_distinct_count_ = false; } + if (has_nan_count) { + IncrementNaNValues(nan_count); + } else { + has_nan_count_ = false; + } + if (!encoded_min.empty()) { PlainDecode(encoded_min, &min_); } @@ -649,6 +656,10 @@ class TypedStatisticsImpl : public TypedStatistics { return has_null_count_; }; + bool HasNaNCount() const override { + return has_nan_count_; + }; + void IncrementNullCount(int64_t n) override { statistics_.null_count += n; has_null_count_ = true; @@ -658,6 +669,13 @@ class TypedStatisticsImpl : public TypedStatistics { num_values_ += n; } + void IncrementNaNValues(int64_t n) override { + if (n > 0) { + nan_count_ += n; + has_nan_count_ = true; + } + } + bool Equals(const Statistics& raw_other) const override { if (physical_type() != raw_other.physical_type()) return false; @@ -696,6 +714,10 @@ class TypedStatisticsImpl : public TypedStatistics { } else { this->has_null_count_ = false; } + if (other.HasNaNCount()) { + this->nan_count_ += other.nan_count(); + this->has_nan_count_ = true; + } if (has_distinct_count_ && other.HasDistinctCount() && (distinct_count() == 0 || other.distinct_count() == 0)) { // We can merge distinct counts if either side is zero. @@ -829,6 +851,9 @@ class TypedStatisticsImpl : public TypedStatistics { if (HasDistinctCount()) { s.set_distinct_count(this->distinct_count()); } + if (has_nan_count_) { + s.set_nan_count(nan_count_); + } return s; } @@ -842,6 +867,10 @@ class TypedStatisticsImpl : public TypedStatistics { return num_values_; } + int64_t nan_count() const override { + return nan_count_; + } + bool MaxGreaterThan(const Statistics& other) const override { const auto* typedOther = dynamic_cast*>(&other); @@ -859,6 +888,7 @@ class TypedStatisticsImpl : public TypedStatistics { bool has_min_max_ = false; bool has_null_count_ = false; bool has_distinct_count_ = false; + bool has_nan_count_ = false; T min_; T max_; ::arrow::MemoryPool* pool_; @@ -868,6 +898,8 @@ class TypedStatisticsImpl : public TypedStatistics { // a statistics thrift message which doesn't have the optional null_count, // `num_values_` may include null values. int64_t num_values_ = 0; + // NaN count is tracked separately since it's not written to the parquet file. + int64_t nan_count_ = 0; EncodedStatistics statistics_; std::shared_ptr> comparator_; std::shared_ptr min_buffer_, max_buffer_; @@ -888,6 +920,7 @@ class TypedStatisticsImpl : public TypedStatistics { void ResetCounts() { this->statistics_.null_count = 0; this->statistics_.distinct_count = 0; + this->nan_count_ = 0; this->num_values_ = 0; } @@ -900,6 +933,7 @@ class TypedStatisticsImpl : public TypedStatistics { this->has_distinct_count_ = false; // Null count calculation is cheap and enabled by default. this->has_null_count_ = true; + this->has_nan_count_ = false; } void SetMinMaxPair(std::pair min_max) { @@ -926,6 +960,46 @@ class TypedStatisticsImpl : public TypedStatistics { max_buffer_.get()); } } + + int64_t CountNaN(const T* values, int64_t length) { + if constexpr (!std::is_floating_point_v) { + return 0; + } else { + int64_t count = 0; + for (auto i = 0; i < length; i++) { + const auto val = SafeLoad(values + i); + if (std::isnan(val)) { + count++; + } + } + return count; + } + } + + int64_t CountNaNSpaced( + const T* values, + int64_t length, + const uint8_t* valid_bits, + int64_t valid_bits_offset) { + if constexpr (!std::is_floating_point_v) { + return 0; + } else { + int64_t count = 0; + ::arrow::internal::VisitSetBitRunsVoid( + valid_bits, + valid_bits_offset, + length, + [&](int64_t position, int64_t run_length) { + for (auto i = 0; i < run_length; i++) { + const auto val = SafeLoad(values + i + position); + if (std::isnan(val)) { + count++; + } + } + }); + return count; + } + } }; template <> @@ -981,6 +1055,7 @@ void TypedStatisticsImpl::Update( if (num_values == 0) return; SetMinMaxPair(comparator_->GetMinMax(values, num_values)); + IncrementNaNValues(CountNaN(values, num_values)); } template @@ -1001,6 +1076,8 @@ void TypedStatisticsImpl::UpdateSpaced( return; SetMinMaxPair(comparator_->GetMinMaxSpaced( values, num_spaced_values, valid_bits, valid_bits_offset)); + IncrementNaNValues( + CountNaNSpaced(values, num_spaced_values, valid_bits, valid_bits_offset)); } template @@ -1165,9 +1242,11 @@ std::shared_ptr Statistics::Make( num_values, encoded_stats->null_count, encoded_stats->distinct_count, + encoded_stats->nan_count, encoded_stats->has_min && encoded_stats->has_max, encoded_stats->has_null_count, encoded_stats->has_distinct_count, + encoded_stats->has_nan_count, pool); } @@ -1178,9 +1257,11 @@ std::shared_ptr Statistics::Make( int64_t num_values, int64_t null_count, int64_t distinct_count, + int64_t nan_count, bool has_min_max, bool has_null_count, bool has_distinct_count, + bool has_nan_count, ::arrow::MemoryPool* pool) { #define MAKE_STATS(CAP_TYPE, KLASS) \ case Type::CAP_TYPE: \ @@ -1191,9 +1272,11 @@ std::shared_ptr Statistics::Make( num_values, \ null_count, \ distinct_count, \ + nan_count, \ has_min_max, \ has_null_count, \ has_distinct_count, \ + has_nan_count, \ pool) switch (descr->physical_type()) { diff --git a/velox/dwio/parquet/writer/arrow/Statistics.h b/velox/dwio/parquet/writer/arrow/Statistics.h index a0f21d88840..afe5a46316d 100644 --- a/velox/dwio/parquet/writer/arrow/Statistics.h +++ b/velox/dwio/parquet/writer/arrow/Statistics.h @@ -141,11 +141,13 @@ class PARQUET_EXPORT EncodedStatistics { int64_t null_count = 0; int64_t distinct_count = 0; + int64_t nan_count = 0; bool has_min = false; bool has_max = false; bool has_null_count = false; bool has_distinct_count = false; + bool has_nan_count = false; // When all values in the statistics are null, it is set to true. // Otherwise, at least one value is not null, or we are not sure at all. @@ -204,6 +206,12 @@ class PARQUET_EXPORT EncodedStatistics { has_distinct_count = true; return *this; } + + EncodedStatistics& set_nan_count(int64_t value) { + nan_count = value; + has_nan_count = true; + return *this; + } }; /// \brief Base type for computing column statistics while writing a file @@ -227,10 +235,12 @@ class PARQUET_EXPORT Statistics { /// \param[in] num_values total number of values /// \param[in] null_count number of null values /// \param[in] distinct_count number of distinct values + /// \param[in] nan_count number of nan values /// \param[in] has_min_max whether the min/max statistics are set /// \param[in] has_null_count whether the null_count statistics are set /// \param[in] has_distinct_count whether the distinct_count statistics are - /// set \param[in] pool a memory pool to use for any memory allocations, + /// set \param[in] has_nan_count whether the nan_count statistics are set + /// \param[in] pool a memory pool to use for any memory allocations, /// optional static std::shared_ptr Make( const ColumnDescriptor* descr, @@ -239,9 +249,11 @@ class PARQUET_EXPORT Statistics { int64_t num_values, int64_t null_count, int64_t distinct_count, + int64_t nan_count, bool has_min_max, bool has_null_count, bool has_distinct_count, + bool has_nan_count, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); // Helper function to convert EncodedStatistics to Statistics. @@ -268,6 +280,12 @@ class PARQUET_EXPORT Statistics { /// \brief The number of non-null values in the column virtual int64_t num_values() const = 0; + /// \brief Return true if the count of nan values is set + virtual bool HasNaNCount() const = 0; + + /// \brief The number of NaN values, may not be set + virtual int64_t nan_count() const = 0; + /// \brief Return true if the min and max statistics are set. Obtain /// with TypedStatistics::min and max virtual bool HasMinMax() const = 0; @@ -413,6 +431,9 @@ class TypedStatistics : public Statistics { /// \brief Increments the number of values directly /// The same note on IncrementNullCount applies here virtual void IncrementNumValues(int64_t n) = 0; + + /// \brief Increments the NaN count directly + virtual void IncrementNaNValues(int64_t n) = 0; }; using BoolStatistics = TypedStatistics; @@ -458,9 +479,11 @@ std::shared_ptr> MakeStatistics( int64_t num_values, int64_t null_count, int64_t distinct_count, + int64_t nan_count, bool has_min_max, bool has_null_count, bool has_distinct_count, + bool has_nan_count, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) { return std::static_pointer_cast>(Statistics::Make( descr, @@ -469,9 +492,11 @@ std::shared_ptr> MakeStatistics( num_values, null_count, distinct_count, + nan_count, has_min_max, has_null_count, has_distinct_count, + has_nan_count, pool)); } diff --git a/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp b/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp index 90d356a2e6b..89975c548fd 100644 --- a/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp @@ -344,9 +344,11 @@ class TestStatistics : public PrimitiveTypedTest { this->values_.size(), 0, 0, + 0, + true, true, true, - true); + false); auto statistics3 = MakeStatistics(this->schema_.Column(0)); std::vector valid_bits( @@ -610,9 +612,11 @@ void TestStatistics::TestMinMaxEncode() { this->values_.size(), 0, 0, + 0, + true, true, true, - true); + false); ASSERT_EQ(encoded_min, statistics2->EncodeMin()); ASSERT_EQ(encoded_max, statistics2->EncodeMax()); @@ -1533,6 +1537,7 @@ void CheckNaNs() { auto some_nan_stats = MakeStatistics(&descr); // Ingesting only nans should not yield valid min max AssertUnsetMinMax(some_nan_stats, all_nans); + EXPECT_EQ(some_nan_stats->nan_count(), all_nans.size()); // Ingesting a mix of NaNs and non-NaNs should not yield valid min max. AssertMinMaxAre(some_nan_stats, some_nans, min, max); // Ingesting only nans after a valid min/max, should have not effect @@ -1550,6 +1555,7 @@ void CheckNaNs() { 1.5f, max, -3.0f, -1.0f, nan, 2.0f, min, nan}; auto other_stats = MakeStatistics(&descr); AssertMinMaxAre(other_stats, other_nans, min, max); + EXPECT_EQ(other_stats->nan_count(), 2); } TEST(TestStatistic, NaNFloatValues) { @@ -1574,6 +1580,7 @@ TEST(TestStatisticsSortOrderFloatNaN, NaNAndNullsInfiniteLoop) { uint8_t all_but_last_valid = 0x7F; // 0b01111111 auto stats = MakeStatistics(&descr); AssertUnsetMinMax(stats, nans_but_last, &all_but_last_valid); + EXPECT_EQ(stats->nan_count(), kNumValues - 1); } template <