diff --git a/velox/connectors/hive/iceberg/CMakeLists.txt b/velox/connectors/hive/iceberg/CMakeLists.txt index d71d964d2c6..f6369cb6d9e 100644 --- a/velox/connectors/hive/iceberg/CMakeLists.txt +++ b/velox/connectors/hive/iceberg/CMakeLists.txt @@ -19,6 +19,7 @@ set( IcebergColumnHandle.cpp IcebergConfig.cpp IcebergConnector.cpp + IcebergDataFileStatistics.cpp IcebergDataSink.cpp IcebergDataSource.cpp IcebergPartitionName.cpp @@ -31,6 +32,10 @@ set( WriterOptionsAdapter.cpp ) +if(VELOX_ENABLE_PARQUET) + list(APPEND ICEBERG_SOURCES IcebergParquetStatsCollector.cpp) +endif() + velox_add_library( velox_hive_iceberg_splitreader ${ICEBERG_SOURCES} @@ -41,10 +46,12 @@ velox_add_library( IcebergColumnHandle.h IcebergConfig.h IcebergConnector.h + IcebergDataFileStatistics.h IcebergDataSink.h IcebergDataSource.h IcebergDeleteFile.h IcebergMetadataColumns.h + IcebergParquetStatsCollector.h IcebergPartitionName.h IcebergSplit.h IcebergSplitReader.h @@ -58,13 +65,14 @@ velox_add_library( velox_link_libraries( velox_hive_iceberg_splitreader velox_connector + velox_dwio_parquet_field_id velox_functions_iceberg velox_dwio_dwrf_writer Folly::folly ) if(VELOX_ENABLE_PARQUET) - velox_link_libraries(velox_hive_iceberg_splitreader velox_dwio_parquet_field_id) + velox_link_libraries(velox_hive_iceberg_splitreader velox_dwio_arrow_parquet_writer) endif() if(${VELOX_BUILD_TESTING}) diff --git a/velox/connectors/hive/iceberg/IcebergDataFileStatistics.cpp b/velox/connectors/hive/iceberg/IcebergDataFileStatistics.cpp new file mode 100644 index 00000000000..018f3f8f68c --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergDataFileStatistics.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergDataFileStatistics.h" + +namespace facebook::velox::connector::hive::iceberg { + +folly::dynamic IcebergDataFileStatistics::toJson() const { + folly::dynamic json = folly::dynamic::object; + json["recordCount"] = numRecords; + + folly::dynamic columnSizes = folly::dynamic::object; + folly::dynamic valueCounts = folly::dynamic::object; + folly::dynamic nullValueCounts = folly::dynamic::object; + folly::dynamic nanValueCounts = folly::dynamic::object; + folly::dynamic lowerBounds = folly::dynamic::object; + folly::dynamic upperBounds = folly::dynamic::object; + + for (const auto& [fieldId, stats] : columnStats) { + auto fieldIdStr = folly::to(fieldId); + columnSizes[fieldIdStr] = stats.columnSize; + valueCounts[fieldIdStr] = stats.valueCount; + nullValueCounts[fieldIdStr] = stats.nullValueCount; + if (stats.nanValueCount.has_value()) { + nanValueCounts[fieldIdStr] = stats.nanValueCount.value(); + } + if (stats.lowerBound.has_value()) { + lowerBounds[fieldIdStr] = stats.lowerBound.value(); + } + if (stats.upperBound.has_value()) { + upperBounds[fieldIdStr] = stats.upperBound.value(); + } + } + + json["columnSizes"] = std::move(columnSizes); + json["valueCounts"] = std::move(valueCounts); + json["nullValueCounts"] = std::move(nullValueCounts); + json["nanValueCounts"] = std::move(nanValueCounts); + json["lowerBounds"] = std::move(lowerBounds); + json["upperBounds"] = std::move(upperBounds); + + return json; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDataFileStatistics.h b/velox/connectors/hive/iceberg/IcebergDataFileStatistics.h new file mode 100644 index 00000000000..5bcfb84b83f --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergDataFileStatistics.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace facebook::velox::connector::hive::iceberg { + +/// Statistics for an Iceberg data file, corresponding to the `data_file` +/// structure defined in the Iceberg specification: +/// https://iceberg.apache.org/spec/#data-file-fields. +/// +/// All column-level statistics maps are keyed by Iceberg field IDs (`int32_t`), +/// which uniquely identify columns in the Iceberg schema independent of column +/// names or physical column positions. +struct IcebergDataFileStatistics { + struct ColumnStats { + int64_t columnSize{0}; + + /// Total number of values for this field ID in the file, including null and + /// NaN values. + /// + /// For primitive (flat) columns, this is equal to the number of rows in the + /// file: numRows = valueCount = (nonNullValues + numNulls + numNaNs). + /// + /// For nested columns (e.g. elements inside an array), this represents the + /// total occurrences of the field across all rows, which is not necessarily + /// related to the top-level record count. + int64_t valueCount{0}; + int64_t nullValueCount{0}; + std::optional nanValueCount; + /// Base64 encoded lower bound. + std::optional lowerBound; + /// Base64 encoded upper bound. + std::optional upperBound; + }; + + int64_t numRecords{0}; + folly::F14FastMap columnStats; + + /// Returns a IcebergDataFileStatistics with all values set to zero/empty. + /// Useful for empty data files that have no actual data. + static IcebergDataFileStatistics empty() { + return IcebergDataFileStatistics{.numRecords = 0, .columnStats = {}}; + } + + folly::dynamic toJson() const; +}; + +using IcebergDataFileStatisticsPtr = std::shared_ptr; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDataSink.cpp b/velox/connectors/hive/iceberg/IcebergDataSink.cpp index bf5bd9f377c..dce118b5b8c 100644 --- a/velox/connectors/hive/iceberg/IcebergDataSink.cpp +++ b/velox/connectors/hive/iceberg/IcebergDataSink.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -26,11 +27,23 @@ #include "velox/common/base/Fs.h" #include "velox/common/encode/Base64.h" +#include "velox/common/memory/MemoryArbitrator.h" +#include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/PartitionIdGenerator.h" #include "velox/connectors/hive/iceberg/IcebergColumnHandle.h" + +#ifdef VELOX_ENABLE_PARQUET +#include "velox/connectors/hive/iceberg/IcebergParquetStatsCollector.h" +#include "velox/dwio/parquet/writer/Writer.h" +#endif + #include "velox/connectors/hive/iceberg/TransformExprBuilder.h" #include "velox/connectors/hive/iceberg/WriterOptionsAdapter.h" +#include "velox/dwio/dwrf/writer/Writer.h" #include "velox/exec/OperatorUtils.h" +#include "velox/type/Type.h" + +using facebook::velox::common::testutil::TestValue; namespace facebook::velox::connector::hive::iceberg { @@ -303,39 +316,51 @@ IcebergDataSink::IcebergDataSink( partitionSpec_ != nullptr ? std::make_unique(partitionSpec_) : nullptr), - partitionRowType_(std::move(partitionRowType)) { + partitionRowType_(std::move(partitionRowType)), + icebergInsertTableHandle_(insertTableHandle) { commitPartitionValue_.resize(maxOpenWriters_); + +#ifdef VELOX_ENABLE_PARQUET + // Only initialize Parquet stats collector for Parquet format tables + if (insertTableHandle->storageFormat() == dwio::common::FileFormat::PARQUET) { + std::vector columnHandles; + columnHandles.reserve(insertTableHandle->inputColumns().size()); + for (auto& column : insertTableHandle->inputColumns()) { + columnHandles.emplace_back( + checkedPointerCast(column)); + } + parquetStatsCollector_ = std::make_shared( + std::move(columnHandles)); + } +#endif } std::vector IcebergDataSink::commitMessage() const { std::vector commitTasks; commitTasks.reserve(writerInfo_.size()); - auto icebergInsertTableHandle = - std::dynamic_pointer_cast( - insertTableHandle_); - for (auto i = 0; i < writerInfo_.size(); ++i) { const auto& writerInfo = writerInfo_.at(i); VELOX_CHECK_NOT_NULL(writerInfo); // Following metadata (json format) is consumed by Presto CommitTaskData. // It contains the minimal subset of metadata. - // TODO: Complete metrics is missing now and this could lead to suboptimal - // query plan, will collect full iceberg metrics in following PR. - for (const auto& fileInfo : writerInfo->writtenFiles) { + VELOX_CHECK_EQ(writerInfo->writtenFiles.size(), dataFileStats_[i].size()); + for (auto fileIdx = 0; fileIdx < writerInfo->writtenFiles.size(); + ++fileIdx) { + const auto& fileInfo = writerInfo->writtenFiles[fileIdx]; // clang-format off folly::dynamic commitData = folly::dynamic::object( "path", (fs::path(writerInfo->writerParameters.targetDirectory()) / fileInfo.targetFileName).string()) ("fileSizeInBytes", fileInfo.fileSize) - ("metrics", - folly::dynamic::object("recordCount", fileInfo.numRows)) + ("metrics", dataFileStats_[i][fileIdx]->toJson()) ("partitionSpecJson", - icebergInsertTableHandle->partitionSpec() ? - icebergInsertTableHandle->partitionSpec()->specId : 0) - ("fileFormat", - toManifestFormatString(icebergInsertTableHandle->storageFormat())) + icebergInsertTableHandle_->partitionSpec() ? + icebergInsertTableHandle_->partitionSpec()->specId : 0) + // Sort order evolution is not supported. Set default id to 0 ( unsorted order). + ("sortOrderId", 0) + ("fileFormat", toManifestFormatString(icebergInsertTableHandle_->storageFormat())) ("content", "DATA"); // clang-format on if (!commitPartitionValue_.empty() && @@ -386,27 +411,37 @@ uint32_t IcebergDataSink::ensureWriter(const WriterId& id) { } std::shared_ptr -IcebergDataSink::createWriterOptions() const { - auto options = HiveDataSink::createWriterOptions(); - auto icebergInsertTableHandle = - std::dynamic_pointer_cast( - insertTableHandle_); - const auto storageFormat = icebergInsertTableHandle - ? icebergInsertTableHandle->storageFormat() - : dwio::common::FileFormat::UNKNOWN; - - // Format support is enforced in the constructor; the adapter must not - // be null here. - auto adapter = createWriterOptionsAdapter(storageFormat); - VELOX_CHECK_NOT_NULL( - adapter, - "Unsupported file format for Iceberg writer: {}", - dwio::common::toString(storageFormat)); - - adapter->applyPreConfigs(*options); +IcebergDataSink::createWriterOptions(size_t writerIndex) const { + auto options = HiveDataSink::createWriterOptions(writerIndex); + +#ifdef VELOX_ENABLE_PARQUET + if (auto parquetOptions = + std::dynamic_pointer_cast(options)) { + // Per Iceberg specification (https://iceberg.apache.org/spec/#parquet): + // - Timestamps must be stored with microsecond precision. + // - Timestamps must NOT be adjusted to UTC timezone; they should be written + // as-is without timezone conversion (empty string disables conversion). + // + // These settings are passed via serdeParameters. The keys must match + // kParquetSerdeTimestampUnit and kParquetSerdeTimestampTimezone defined + // in velox/dwio/parquet/writer/Writer.h. The value "6" represents + // microseconds (TimestampPrecision::kMicroseconds). + parquetOptions + ->serdeParameters[parquet::WriterConfig::kParquetSerdeTimestampUnit] = + "6"; + parquetOptions->serdeParameters + [parquet::WriterConfig::kParquetSerdeTimestampTimezone] = ""; + + if (parquetStatsCollector_) { + parquetOptions->parquetFieldIds = + parquetStatsCollector_->parquetFieldIds().children; + } + } +#endif + options->processConfigs( *hiveConfig_->config(), *connectorQueryCtx_->sessionProperties()); - adapter->applyPostConfigs(*options); + return options; } @@ -426,4 +461,100 @@ folly::dynamic IcebergDataSink::makeCommitPartitionValue( return partitionValues; } +void IcebergDataSink::rotateWriter(size_t index) { + VELOX_CHECK_LT(index, writers_.size()); + VELOX_CHECK_NOT_NULL(writers_[index]); + + // Ensure dataFileStats_ has an entry for this writer index. + if (dataFileStats_.size() <= index) { + dataFileStats_.resize(index + 1); + } + + // Close the writer to flush the footer and obtain file metadata, then + // aggregate Iceberg stats from the metadata. The base rotateWriter() would + // also call writers_[index]->close() but discards the returned metadata. + // We close the writer ourselves to capture the metadata, then reset the + // writer to prevent double close. + { + const memory::NonReclaimableSectionGuard nonReclaimableGuard( + writerInfo_[index]->nonReclaimableSectionHolder.get()); + auto metadata = writers_[index]->close(); + const bool fileAdded = getCurrentFileBytes(index) > 0; + + // Finalize file info (capture file size, add to writtenFiles). + finalizeWriterFile(index); + + if (fileAdded) { +#ifdef VELOX_ENABLE_PARQUET + if (parquetStatsCollector_) { + dataFileStats_[index].emplace_back( + parquetStatsCollector_->aggregate(std::move(metadata))); + } else +#endif + { + dataFileStats_[index].emplace_back( + std::make_shared( + IcebergDataFileStatistics::empty())); + } + } + } + + // Release old writer. The new writer will be created lazily on the next + // write call. + writers_[index].reset(); + + ++writerInfo_[index]->fileSequenceNumber; +} + +void IcebergDataSink::closeInternal() { + VELOX_CHECK_NE(state_, State::kRunning); + VELOX_CHECK_NE(state_, State::kFinishing); + + TestValue::adjust( + "facebook::velox::connector::hive::FileDataSink::closeInternal", this); + + if (state_ == State::kClosed) { + // Ensure dataFileStats_ has entries for all writers. + dataFileStats_.resize(writers_.size()); + + for (auto i = 0; i < writers_.size(); ++i) { + if (writers_[i] == nullptr) { + // Writer was rotated and is null. Stats for rotated files were already + // collected in rotateWriter(). No final file to close. + continue; + } + const memory::NonReclaimableSectionGuard nonReclaimableGuard( + writerInfo_[i]->nonReclaimableSectionHolder.get()); + + auto metadata = writers_[i]->close(); + const bool fileAdded = getCurrentFileBytes(i) > 0; + + finalizeWriterFile(i); + + if (fileAdded) { +#ifdef VELOX_ENABLE_PARQUET + if (parquetStatsCollector_) { + dataFileStats_[i].emplace_back( + parquetStatsCollector_->aggregate(std::move(metadata))); + } else +#endif + { + dataFileStats_[i].emplace_back( + std::make_shared( + IcebergDataFileStatistics::empty())); + } + } + } + } else { + for (auto i = 0; i < writers_.size(); ++i) { + if (writers_[i] == nullptr) { + continue; + } + memory::NonReclaimableSectionGuard nonReclaimableGuard( + writerInfo_[i]->nonReclaimableSectionHolder.get()); + writers_[i]->abort(); + } + } +} + } // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDataSink.h b/velox/connectors/hive/iceberg/IcebergDataSink.h index f7056761633..6c6155045f3 100644 --- a/velox/connectors/hive/iceberg/IcebergDataSink.h +++ b/velox/connectors/hive/iceberg/IcebergDataSink.h @@ -24,6 +24,12 @@ #include "velox/connectors/hive/HiveDataSink.h" #include "velox/connectors/hive/TableHandle.h" #include "velox/connectors/hive/iceberg/IcebergColumnHandle.h" +#include "velox/connectors/hive/iceberg/IcebergDataFileStatistics.h" + +#ifdef VELOX_ENABLE_PARQUET +#include "velox/connectors/hive/iceberg/IcebergParquetStatsCollector.h" +#endif + #include "velox/connectors/hive/iceberg/IcebergConfig.h" #include "velox/connectors/hive/iceberg/IcebergPartitionName.h" #include "velox/connectors/hive/iceberg/PartitionSpec.h" @@ -156,8 +162,8 @@ class IcebergDataSink : public HiveDataSink { // base HiveDataSink writer options with Iceberg-specific settings: // - Sets timestamp timezone to nullopt (UTC) for Iceberg compliance. // - Sets timestamp precision to microseconds. - std::shared_ptr createWriterOptions() - const override; + std::shared_ptr createWriterOptions( + size_t writerIndex) const override; // Extracts partition values for a specific writer to be included in the // commit message. Converts the transformed partition values from columnar @@ -167,6 +173,10 @@ class IcebergDataSink : public HiveDataSink { // Returns nullptr for null partition values. folly::dynamic makeCommitPartitionValue(uint32_t writerIndex) const; + void rotateWriter(size_t index) override; + + void closeInternal() override; + // Iceberg partition specification defining how the table is partitioned. // Contains partition fields with source column names, transform types // (e.g., identity, year, month, day, hour, bucket, truncate), transform @@ -213,6 +223,23 @@ class IcebergDataSink : public HiveDataSink { // folly::dynamic array of values across all partition fields), ready for JSON // serialization. std::vector commitPartitionValue_; + + // Statistics for all data files written by this sink, organized by writer + // index and file index within each writer. These statistics are populated + // during rotateWriter() (for rotated files) and during closeInternal() + // (for the final file of each writer). These metrics are subsequently used + // to construct Iceberg commit messages. + // Outer vector: indexed by writer index (same as writerInfo_). + // Inner vector: one entry per file written by that writer (including + // rotated files and the final file). Each entry corresponds to one + // individual data file. + std::vector> dataFileStats_; + + const IcebergInsertTableHandlePtr icebergInsertTableHandle_; + +#ifdef VELOX_ENABLE_PARQUET + std::shared_ptr parquetStatsCollector_; +#endif }; } // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergParquetStatsCollector.cpp b/velox/connectors/hive/iceberg/IcebergParquetStatsCollector.cpp new file mode 100644 index 00000000000..10b04574c0a --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergParquetStatsCollector.cpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergParquetStatsCollector.h" + +#include "velox/common/Casts.h" +#include "velox/common/base/Exceptions.h" +#include "velox/common/encode/Base64.h" +#include "velox/connectors/hive/iceberg/IcebergColumnHandle.h" +#include "velox/connectors/hive/iceberg/IcebergDataFileStatistics.h" +#include "velox/dwio/common/FileMetadata.h" +#include "velox/dwio/parquet/writer/Writer.h" +#include "velox/dwio/parquet/writer/arrow/Metadata.h" +#include "velox/dwio/parquet/writer/arrow/Statistics.h" +#include "velox/type/Type.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +void addAllRecursive( + const parquet::ParquetFieldId& field, + const TypePtr& type, + std::unordered_set& fieldIds) { + fieldIds.insert(field.fieldId); + + VELOX_CHECK_EQ(field.children.size(), type->size()); + for (auto i = 0; i < type->size(); ++i) { + addAllRecursive(field.children[i], type->childAt(i), fieldIds); + } +} + +// Recursively collects field IDs that should skip bounds collection. +// Repeated fields (e.g. MAP and ARRAY) are not currently supported by Iceberg. +// These fields, along with all their descendants, should skip bounds +// collection. +// @param field The Parquet field ID structure to process. +// @param type The Velox type corresponding to this field. +// @param fieldIds Output set to populate with field IDs to skip. +void collectSkipBoundsFieldIds( + const parquet::ParquetFieldId& field, + const TypePtr& type, + std::unordered_set& fieldIds) { + VELOX_CHECK_NOT_NULL(type, "Input column type cannot be null."); + + if (type->isMap() || type->isArray()) { + addAllRecursive(field, type, fieldIds); + return; + } + + VELOX_CHECK_EQ(field.children.size(), type->size()); + for (auto i = 0; i < type->size(); ++i) { + collectSkipBoundsFieldIds(field.children[i], type->childAt(i), fieldIds); + } +} + +} // namespace + +IcebergParquetStatsCollector::IcebergParquetStatsCollector( + const std::vector& inputColumns) { + parquetFieldIds_.children.reserve(inputColumns.size()); + for (const auto& columnHandle : inputColumns) { + parquetFieldIds_.children.emplace_back(columnHandle->field()); + collectSkipBoundsFieldIds( + columnHandle->field(), columnHandle->dataType(), skipBoundsFieldIds_); + } +} + +IcebergDataFileStatisticsPtr IcebergParquetStatsCollector::aggregate( + std::unique_ptr fileMetadata) { + // Empty data file. + if (!fileMetadata) { + return std::make_shared( + IcebergDataFileStatistics::empty()); + } + + auto parquetMetadata = + checkedPointerCast(std::move(fileMetadata)); + auto metadata = parquetMetadata->arrowMetadata(); + auto dataFileStats = std::make_shared(); + dataFileStats->numRecords = metadata->numRows(); + const auto numRowGroups = metadata->numRowGroups(); + + // Track global min/max statistics for each column across all row groups. + // Key: Iceberg field ID. + // Value: A pair of Statistics objects where: + // - first: The statistics from the row group containing the global minimum + // value. + // - second: The statistics from the row group containing the global maximum + // value. Two separate objects are stored because the global minimum and + // global maximum for a single column may originate from different row groups. + folly::F14FastMap< + int32_t, + std::pair< + std::shared_ptr, + std::shared_ptr>> + globalMinMaxStats; + + std::unordered_set fieldIds; + for (auto i = 0; i < numRowGroups; ++i) { + const auto& rowGroup = metadata->rowGroup(i); + + for (auto j = 0; j < rowGroup->numColumns(); ++j) { + const auto& columnChunk = rowGroup->columnChunk(j); + const auto fieldId = columnChunk->fieldId(); + fieldIds.insert(fieldId); + + auto& stats = dataFileStats->columnStats[fieldId]; + stats.valueCount += columnChunk->numValues(); + stats.columnSize += columnChunk->totalCompressedSize(); + + const auto& columnChunkStats = columnChunk->statistics(); + if (columnChunkStats) { + stats.nullValueCount += columnChunkStats->nullCount(); + + if (columnChunkStats->hasMinMax() && shouldStoreBounds(fieldId)) { + auto [it, inserted] = globalMinMaxStats.emplace( + fieldId, std::pair{columnChunkStats, columnChunkStats}); + + if (!inserted) { + auto& [minStats, maxStats] = it->second; + + if (columnChunkStats->maxGreaterThan(*maxStats)) { + maxStats = columnChunkStats; + } + if (columnChunkStats->minLessThan(*minStats)) { + minStats = columnChunkStats; + } + } + } + } + } + } + + for (const auto fieldId : fieldIds) { + const auto& [nanCount, hasNanCount] = metadata->getNaNCount(fieldId); + if (hasNanCount) { + dataFileStats->columnStats[fieldId].nanValueCount = nanCount; + } + } + + for (const auto& [fieldId, stats] : globalMinMaxStats) { + const auto& [minStats, maxStats] = stats; + + auto& columnStats = dataFileStats->columnStats[fieldId]; + const auto& lowerBound = + minStats->icebergLowerBoundInclusive(kDefaultTruncateLength); + columnStats.lowerBound = + encoding::Base64::encode(lowerBound.data(), lowerBound.size()); + + const auto upperBound = + maxStats->icebergUpperBoundExclusive(kDefaultTruncateLength); + if (upperBound.has_value()) { + columnStats.upperBound = + encoding::Base64::encode(upperBound->data(), upperBound->size()); + } + } + return dataFileStats; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergParquetStatsCollector.h b/velox/connectors/hive/iceberg/IcebergParquetStatsCollector.h new file mode 100644 index 00000000000..cc442a48722 --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergParquetStatsCollector.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "velox/connectors/hive/iceberg/IcebergColumnHandle.h" +#include "velox/connectors/hive/iceberg/IcebergDataFileStatistics.h" +#include "velox/dwio/common/FileMetadata.h" +#include "velox/dwio/parquet/ParquetFieldId.h" + +namespace facebook::velox::connector::hive::iceberg { + +class IcebergParquetStatsCollector { + public: + explicit IcebergParquetStatsCollector( + const std::vector& inputColumns); + + /// Returns the Parquet field IDs for all input columns. + /// The field IDs are written to the Parquet data file's column metadata. + /// The return object describes a multi-column input. + const parquet::ParquetFieldId& parquetFieldIds() const { + return parquetFieldIds_; + } + + /// Aggregates Parquet file metadata into Iceberg data file statistics. + /// Iterates through all row groups and columns to collect: + /// - Record count, split offsets, value counts, column sizes, null counts. + /// - Min/max bounds (base64-encoded). Currently not collected for MAP and + /// ARRAY types and all their descendants. + /// @param fileMetadata The Parquet file metadata to aggregate. + IcebergDataFileStatisticsPtr aggregate( + std::unique_ptr fileMetadata); + + /// TODO: Need to support this config property. + /// 16 is default value. See DEFAULT_WRITE_METRICS_MODE_DEFAULT in + /// org.apache.iceberg.TableProperties. + constexpr static int32_t kDefaultTruncateLength{16}; + + private: + bool shouldStoreBounds(int32_t fieldId) const { + return !skipBoundsFieldIds_.contains(fieldId); + } + + // Hierarchical Parquet field IDs for all input columns. A single + // ParquetFieldId can describe all the columns including their nested + // children. + parquet::ParquetFieldId parquetFieldIds_; + + // Set of field IDs for which bounds collection should be skipped. + // This includes MAP and ARRAY types and all their descendants. + std::unordered_set skipBoundsFieldIds_; +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/CMakeLists.txt b/velox/connectors/hive/iceberg/tests/CMakeLists.txt index 5168ae01408..8931bb6d0e9 100644 --- a/velox/connectors/hive/iceberg/tests/CMakeLists.txt +++ b/velox/connectors/hive/iceberg/tests/CMakeLists.txt @@ -63,6 +63,7 @@ if(NOT VELOX_DISABLE_GOOGLETEST) velox_hive_iceberg_insert_test IcebergConnectorTest.cpp IcebergInsertTest.cpp + IcebergParquetStatsTest.cpp IcebergTestBase.cpp Main.cpp PartitionNameTest.cpp diff --git a/velox/connectors/hive/iceberg/tests/IcebergParquetStatsTest.cpp b/velox/connectors/hive/iceberg/tests/IcebergParquetStatsTest.cpp new file mode 100644 index 00000000000..ebab93a2510 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergParquetStatsTest.cpp @@ -0,0 +1,881 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "velox/common/encode/Base64.h" +#include "velox/connectors/hive/iceberg/IcebergDataFileStatistics.h" +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" + +using namespace facebook::velox::common::testutil; + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +#ifdef VELOX_ENABLE_PARQUET + +class IcebergParquetStatsTest : public test::IcebergTestBase { + protected: + static IcebergDataFileStatisticsPtr statsFromMetrics( + const folly::dynamic& metrics) { + VELOX_CHECK(metrics.isObject()); + VELOX_CHECK(metrics.count("recordCount") > 0); + auto stats = std::make_shared(); + stats->numRecords = metrics["recordCount"].asInt(); + + auto setIntField = [&](const folly::dynamic& map, auto setter) { + if (!map.isObject()) { + return; + } + for (const auto& item : map.items()) { + const auto fieldId = folly::to(item.first.asString()); + auto& column = stats->columnStats[fieldId]; + setter(column, item.second); + } + }; + + setIntField(metrics["columnSizes"], [](auto& column, const auto& value) { + column.columnSize = value.asInt(); + }); + setIntField(metrics["valueCounts"], [](auto& column, const auto& value) { + column.valueCount = value.asInt(); + }); + setIntField( + metrics["nullValueCounts"], [](auto& column, const auto& value) { + column.nullValueCount = value.asInt(); + }); + setIntField(metrics["nanValueCounts"], [](auto& column, const auto& value) { + column.nanValueCount = value.asInt(); + }); + + const auto& lowerBounds = metrics["lowerBounds"]; + if (lowerBounds.isObject()) { + for (const auto& item : lowerBounds.items()) { + const auto fieldId = folly::to(item.first.asString()); + stats->columnStats[fieldId].lowerBound = item.second.asString(); + } + } + const auto& upperBounds = metrics["upperBounds"]; + if (upperBounds.isObject()) { + for (const auto& item : upperBounds.items()) { + const auto fieldId = folly::to(item.first.asString()); + stats->columnStats[fieldId].upperBound = item.second.asString(); + } + } + + return stats; + } + + static std::vector statsFromCommitTasks( + const std::vector& commitTasks) { + std::vector stats; + stats.reserve(commitTasks.size()); + for (const auto& task : commitTasks) { + auto taskJson = folly::parseJson(task); + VELOX_CHECK(taskJson.isObject()); + VELOX_CHECK(taskJson.count("metrics") > 0); + stats.emplace_back(statsFromMetrics(taskJson["metrics"])); + } + return stats; + } + + // Write data and get all stats (for partitioned tables). + std::vector> + writeDataAndGetAllStats( + const RowVectorPtr& data, + const std::vector& partitionFields = {}) { + const auto outputDir = TempDirectoryPath::create(); + auto dataSink = createDataSinkAndAppendData( + {data}, outputDir->getPath(), partitionFields); + auto commitTasks = dataSink->close(); + EXPECT_FALSE(commitTasks.empty()); + return statsFromCommitTasks(commitTasks); + } + + // Decode and extract typed value from base64 encoded bounds. + template + static std::pair decodeBounds( + const std::shared_ptr& stats, + int32_t fieldId) { + auto decode = [](const std::string& base64Encoded) { + const std::string decoded = encoding::Base64::decode(base64Encoded); + T value; + std::memcpy(&value, decoded.data(), sizeof(T)); + return value; + }; + + const auto& columnStats = stats->columnStats.at(fieldId); + VELOX_CHECK(columnStats.lowerBound.has_value()); + VELOX_CHECK(columnStats.upperBound.has_value()); + return { + decode(columnStats.lowerBound.value()), + decode(columnStats.upperBound.value()), + }; + } + + // Verify basic statistics (record count, value counts, null counts). + static void verifyBasicStats( + const std::shared_ptr& stats, + int64_t expectedRecords, + const std::unordered_map& expectedValueCounts, + const std::unordered_map& expectedNullCounts) { + EXPECT_EQ(stats->numRecords, expectedRecords); + + for (const auto& [fieldId, count] : expectedValueCounts) { + ASSERT_TRUE(stats->columnStats.contains(fieldId)); + EXPECT_EQ(stats->columnStats.at(fieldId).valueCount, count); + } + + if (!expectedNullCounts.empty()) { + for (const auto& [fieldId, count] : expectedNullCounts) { + ASSERT_TRUE(stats->columnStats.contains(fieldId)); + EXPECT_EQ(stats->columnStats.at(fieldId).nullValueCount, count); + } + } + } + + // Verify bounds exist for given field IDs. + static void verifyBoundsExist( + const std::shared_ptr& stats, + const std::vector& fieldIds) { + for (const int32_t fieldId : fieldIds) { + ASSERT_TRUE(stats->columnStats.contains(fieldId)); + const auto& columnStats = stats->columnStats.at(fieldId); + ASSERT_TRUE(columnStats.lowerBound.has_value()); + ASSERT_TRUE(columnStats.upperBound.has_value()); + EXPECT_FALSE(columnStats.lowerBound.value().empty()); + EXPECT_FALSE(columnStats.upperBound.value().empty()); + } + } + + // Verify bounds do not exist for given field IDs. + static void verifyBoundsNotExist( + const std::shared_ptr& stats, + const std::vector& fieldIds) { + for (const int32_t fieldId : fieldIds) { + if (stats->columnStats.contains(fieldId)) { + const auto& columnStats = stats->columnStats.at(fieldId); + ASSERT_FALSE(columnStats.lowerBound.has_value()); + ASSERT_FALSE(columnStats.upperBound.has_value()); + } + } + } +}; + +TEST_F(IcebergParquetStatsTest, mixedNull) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedIntNulls = 34; + constexpr int32_t intColId = 1; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, [](vector_size_t row) { return row * 10; }, nullEvery(3))})); + verifyBasicStats( + stats[0], size, {{intColId, size}}, {{intColId, expectedIntNulls}}); + verifyBoundsExist(stats[0], {intColId}); + + const auto& [minVal, maxVal] = decodeBounds(stats[0], intColId); + EXPECT_EQ(minVal, 10); + EXPECT_EQ(maxVal, 980); +} + +TEST_F(IcebergParquetStatsTest, bigint) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedNulls = 25; + constexpr int32_t bigintColId = 1; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [](vector_size_t row) { return row * 1'000'000'000LL; }, + nullEvery(4))})); + verifyBasicStats( + stats[0], size, {{bigintColId, size}}, {{bigintColId, expectedNulls}}); + verifyBoundsExist(stats[0], {bigintColId}); + + const auto& [minVal, maxVal] = decodeBounds(stats[0], bigintColId); + EXPECT_EQ(minVal, 1'000'000'000LL); + EXPECT_EQ(maxVal, 99'000'000'000LL); +} + +TEST_F(IcebergParquetStatsTest, decimal) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedNulls = 20; + constexpr int32_t decimalColId = 1; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [](vector_size_t row) { return HugeInt::build(row, row * 123); }, + nullEvery(5), + DECIMAL(38, 3))})); + verifyBasicStats( + stats[0], size, {{decimalColId, size}}, {{decimalColId, expectedNulls}}); + verifyBoundsExist(stats[0], {decimalColId}); +} + +TEST_F(IcebergParquetStatsTest, varchar) { + constexpr vector_size_t size = 100; + constexpr int32_t varcharColId = 1; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [](vector_size_t row) { + return "Customer#00000" + std::to_string(row) + "_" + + std::string(row % 10, 'a'); + }, + nullEvery(6))})); + + constexpr int32_t expectedNulls = 17; + verifyBasicStats( + stats[0], size, {{varcharColId, size}}, {{varcharColId, expectedNulls}}); + verifyBoundsExist(stats[0], {varcharColId}); + + EXPECT_EQ( + encoding::Base64::decode( + stats[0]->columnStats.at(varcharColId).lowerBound.value()), + "Customer#0000010"); + EXPECT_EQ( + encoding::Base64::decode( + stats[0]->columnStats.at(varcharColId).upperBound.value()), + "Customer#000009`"); +} + +TEST_F(IcebergParquetStatsTest, varbinary) { + constexpr vector_size_t size = 100; + constexpr int32_t varbinaryColId = 1; + + auto rowVector = makeRowVector({makeFlatVector( + size, + [](vector_size_t row) { + std::string value(17, 11); + value[0] = static_cast(row % 256); + value[1] = static_cast((row * 3) % 256); + value[2] = static_cast((row * 7) % 256); + value[3] = static_cast((row * 11) % 256); + return value; + }, + nullEvery(5), + VARBINARY())}); + + const auto& stats = writeDataAndGetAllStats(rowVector); + constexpr int32_t expectedNulls = 20; + verifyBasicStats( + stats[0], + size, + {{varbinaryColId, size}}, + {{varbinaryColId, expectedNulls}}); + verifyBoundsExist(stats[0], {varbinaryColId}); +} + +TEST_F(IcebergParquetStatsTest, varbinaryWithTransform) { + const auto& fileStats = writeDataAndGetAllStats( + makeRowVector({makeFlatVector( + {"01020304", + "05060708", + "090A0B0C", + "0D0E0F10", + "11121314", + "15161718", + "191A1B1C", + "1D1E1F20", + "21222324", + "25262728"}, + VARBINARY())}), + {{0, TransformType::kBucket, 4}}); + ASSERT_EQ(fileStats.size(), 3); + const auto& stats = fileStats[0]; + EXPECT_EQ(stats->numRecords, 5); + constexpr int32_t varbinaryColId = 1; + EXPECT_EQ(stats->columnStats.at(varbinaryColId).valueCount, 5); +} + +TEST_F(IcebergParquetStatsTest, multipleDataTypes) { + constexpr vector_size_t size = 100; + constexpr int32_t intColId = 1; + constexpr int32_t bigintColId = 2; + constexpr int32_t decimalColId = 3; + constexpr int32_t varcharColId = 4; + constexpr int32_t varbinaryColId = 5; + + constexpr int32_t expectedIntNulls = 34; + constexpr int32_t expectedBigintNulls = 25; + constexpr int32_t expectedDecimalNulls = 20; + constexpr int32_t expectedVarcharNulls = 17; + constexpr int32_t expectedVarbinaryNulls = 15; + + auto rowVector = makeRowVector( + {makeFlatVector( + size, [](vector_size_t row) { return row * 10; }, nullEvery(3)), + makeFlatVector( + size, + [](vector_size_t row) { return row * 1'000'000'000LL; }, + nullEvery(4)), + makeFlatVector( + size, + [](vector_size_t row) { return HugeInt::build(row, row * 12'345); }, + nullEvery(5), + DECIMAL(38, 3)), + makeFlatVector( + size, + [](vector_size_t row) { return "str_" + std::to_string(row); }, + nullEvery(6)), + makeFlatVector( + size, + [](vector_size_t row) { + std::string value(4, 0); + value[0] = static_cast(row % 256); + value[1] = static_cast((row * 3) % 256); + value[2] = static_cast((row * 7) % 256); + value[3] = static_cast((row * 11) % 256); + return value; + }, + nullEvery(7), + VARBINARY())}); + const auto& stats = writeDataAndGetAllStats(rowVector); + + verifyBasicStats( + stats[0], + size, + { + {intColId, size}, + {bigintColId, size}, + {decimalColId, size}, + {varcharColId, size}, + {varbinaryColId, size}, + }, + { + {intColId, expectedIntNulls}, + {bigintColId, expectedBigintNulls}, + {decimalColId, expectedDecimalNulls}, + {varcharColId, expectedVarcharNulls}, + {varbinaryColId, expectedVarbinaryNulls}, + }); + + verifyBoundsExist( + stats[0], + {intColId, bigintColId, decimalColId, varcharColId, varbinaryColId}); +} + +TEST_F(IcebergParquetStatsTest, date) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedNulls = 20; + constexpr int32_t dateColId = 1; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [](vector_size_t row) { return 18262 + row; }, + nullEvery(5), + DATE())})); + verifyBasicStats( + stats[0], size, {{dateColId, size}}, {{dateColId, expectedNulls}}); + verifyBoundsExist(stats[0], {dateColId}); + + const auto& [minVal, maxVal] = decodeBounds(stats[0], dateColId); + EXPECT_EQ(minVal, 18263); + EXPECT_EQ(maxVal, 18262 + 99); +} + +TEST_F(IcebergParquetStatsTest, boolean) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedNulls = 10; + constexpr int32_t boolColId = 1; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [](vector_size_t row) { return row % 2 == 1; }, + nullEvery(10), + BOOLEAN())})); + verifyBasicStats( + stats[0], size, {{boolColId, size}}, {{boolColId, expectedNulls}}); + verifyBoundsExist(stats[0], {boolColId}); + + // For boolean, the lower bound should be false (0) and upper bound should be + // true (1) if both values are present. + const auto& [minVal, maxVal] = decodeBounds(stats[0], boolColId); + EXPECT_FALSE(minVal); + EXPECT_TRUE(maxVal); +} + +TEST_F(IcebergParquetStatsTest, empty) { + const auto outputDir = TempDirectoryPath::create(); + auto dataSink = createDataSinkAndAppendData( + {makeRowVector( + {makeFlatVector(0), makeFlatVector(0)})}, + outputDir->getPath()); + auto commitTasks = dataSink->close(); + EXPECT_TRUE(commitTasks.empty()); +} + +TEST_F(IcebergParquetStatsTest, nullValues) { + constexpr vector_size_t size = 100; + + const auto& stats = writeDataAndGetAllStats(makeRowVector( + {makeNullConstant(TypeKind::INTEGER, size), + makeNullConstant(TypeKind::VARCHAR, size)})); + EXPECT_EQ(stats[0]->numRecords, size); + ASSERT_EQ(stats[0]->columnStats.at(1).nullValueCount, size); + // Do not collect lower and upper bounds for NULLs. + for (const auto& [fieldId, columnStats] : stats[0]->columnStats) { + ASSERT_FALSE(columnStats.lowerBound.has_value()); + ASSERT_FALSE(columnStats.upperBound.has_value()); + } +} + +TEST_F(IcebergParquetStatsTest, real) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedNulls = 20; + constexpr int32_t realColId = 1; + int32_t expectedNaNs = 0; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [&expectedNaNs](vector_size_t row) { + if (row % 6 == 0) { + expectedNaNs++; + return std::numeric_limits::quiet_NaN(); + } + return row * 1.5f; + }, + nullEvery(5), + REAL())})); + verifyBasicStats( + stats[0], size, {{realColId, size}}, {{realColId, expectedNulls}}); + + EXPECT_EQ( + stats[0]->columnStats.at(realColId).nanValueCount.value_or(0), + expectedNaNs); + verifyBoundsExist(stats[0], {realColId}); + const auto& [minVal, maxVal] = decodeBounds(stats[0], realColId); + EXPECT_FLOAT_EQ(minVal, 1.5f); + EXPECT_FLOAT_EQ(maxVal, 148.5f); +} + +TEST_F(IcebergParquetStatsTest, double) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedNulls = 15; + constexpr int32_t doubleColId = 1; + int32_t expectedNaNs = 0; + + auto rowVector = makeRowVector({makeFlatVector( + size, + [&expectedNaNs](vector_size_t row) { + if (row % 3 == 0) { + expectedNaNs++; + return std::numeric_limits::quiet_NaN(); + } + if (row % 4 == 0) { + return std::numeric_limits::infinity(); + } + if (row % 5 == 0) { + return -std::numeric_limits::infinity(); + } + return row * 2.5; + }, + nullEvery(7), + DOUBLE())}); + + const auto& stats = writeDataAndGetAllStats(rowVector); + verifyBasicStats( + stats[0], size, {{doubleColId, size}}, {{doubleColId, expectedNulls}}); + + EXPECT_EQ( + stats[0]->columnStats.at(doubleColId).nanValueCount.value_or(0), + expectedNaNs); + + verifyBoundsExist(stats[0], {doubleColId}); + + // Verify bounds are set correctly and NaN/infinity values don't affect + // min/max incorrectly. + const auto& [minVal, maxVal] = decodeBounds(stats[0], doubleColId); + EXPECT_DOUBLE_EQ(minVal, -std::numeric_limits::infinity()) + << "Lower bound should be -infinity"; + EXPECT_DOUBLE_EQ(maxVal, std::numeric_limits::infinity()) + << "Upper bound should be infinity"; +} + +TEST_F(IcebergParquetStatsTest, mixedDoubleFloat) { + constexpr vector_size_t size = 6; + + auto rowVector = makeRowVector( + {makeFlatVector(size, [](vector_size_t row) { return 1; }), + makeFlatVector( + size, + [](vector_size_t row) { + return -std::numeric_limits::infinity(); + }), + makeFlatVector( + size, + [](vector_size_t row) { + return std::numeric_limits::infinity(); + }), + makeFlatVector(size, [](vector_size_t row) { + switch (row) { + case 0: + return 1.23; + case 1: + return -1.23; + case 2: + return std::numeric_limits::infinity(); + case 3: + return 2.23; + case 4: + return -std::numeric_limits::infinity(); + default: + return -2.23; + } + })}); + + const auto& stats = writeDataAndGetAllStats(rowVector); + constexpr int32_t doubleColId = 4; + verifyBasicStats(stats[0], size, {{doubleColId, size}}, {{doubleColId, 0}}); + const auto& [minVal, maxVal] = decodeBounds(stats[0], doubleColId); + EXPECT_DOUBLE_EQ(minVal, -std::numeric_limits::infinity()); + EXPECT_DOUBLE_EQ(maxVal, std::numeric_limits::infinity()); + + constexpr int32_t floatColId = 2; + const auto& [minFloatVal, maxFloatVal] = + decodeBounds(stats[0], floatColId); + EXPECT_FLOAT_EQ(minFloatVal, -std::numeric_limits::infinity()); + EXPECT_FLOAT_EQ(maxFloatVal, -std::numeric_limits::infinity()); +} + +TEST_F(IcebergParquetStatsTest, NaN) { + constexpr vector_size_t size = 1'000; + constexpr int32_t expectedNulls = 500; + constexpr int32_t doubleColId = 1; + int32_t expectedNaNs = 0; + + const auto& stats = + writeDataAndGetAllStats(makeRowVector({makeFlatVector( + size, + [&expectedNaNs](vector_size_t /*row*/) { + expectedNaNs++; + return std::numeric_limits::quiet_NaN(); + }, + nullEvery(2), + DOUBLE())})); + verifyBasicStats( + stats[0], size, {{doubleColId, size}}, {{doubleColId, expectedNulls}}); + + EXPECT_EQ( + stats[0]->columnStats.at(doubleColId).nanValueCount.value_or(0), + expectedNaNs); + // Do not collect bounds for NULLs and NaNs. + for (const auto& [fieldId, columnStats] : stats[0]->columnStats) { + ASSERT_FALSE(columnStats.lowerBound.has_value()); + ASSERT_FALSE(columnStats.upperBound.has_value()); + } +} + +TEST_F(IcebergParquetStatsTest, partitionedTable) { + std::vector partitionTransforms = { + {0, TransformType::kBucket, 4}, + {1, TransformType::kDay, std::nullopt}, + {2, TransformType::kTruncate, 2}, + }; + + constexpr vector_size_t size = 100; + + auto rowVector = makeRowVector( + {makeFlatVector(size, [](vector_size_t row) { return row; }), + makeFlatVector( + size, + [](vector_size_t row) { return 18262 + (row % 5); }, + nullptr, + DATE()), + makeFlatVector(size, [](vector_size_t row) { + return fmt::format("str{}", row % 10); + })}); + + const auto& fileStats = + writeDataAndGetAllStats(rowVector, partitionTransforms); + + EXPECT_GT(fileStats.size(), 1) + << "Expected multiple files due to partitioning"; + + for (const auto& stats : fileStats) { + EXPECT_GT(stats->numRecords, 0); + ASSERT_FALSE(stats->columnStats.empty()); + + constexpr int32_t intColId = 1; + constexpr int32_t dateColId = 2; + constexpr int32_t varcharColId = 3; + EXPECT_EQ(stats->columnStats.at(intColId).valueCount, stats->numRecords); + EXPECT_EQ(stats->columnStats.at(dateColId).valueCount, stats->numRecords); + EXPECT_EQ( + stats->columnStats.at(varcharColId).valueCount, stats->numRecords); + + for (const auto fieldId : {intColId, dateColId, varcharColId}) { + const auto& columnStats = stats->columnStats.at(fieldId); + ASSERT_TRUE(columnStats.lowerBound.has_value()); + ASSERT_TRUE(columnStats.upperBound.has_value()); + EXPECT_FALSE(columnStats.lowerBound.value().empty()); + EXPECT_FALSE(columnStats.upperBound.value().empty()); + } + } + + // Verify total record count across all partitions. + int64_t totalRecords = 0; + for (const auto& stats : fileStats) { + totalRecords += stats->numRecords; + } + EXPECT_EQ(totalRecords, size); +} + +TEST_F(IcebergParquetStatsTest, multiplePartitionTransforms) { + std::vector partitionTransforms = { + {0, TransformType::kBucket, 2}, + {1, TransformType::kYear, std::nullopt}, + {2, TransformType::kTruncate, 3}, + {3, TransformType::kIdentity, std::nullopt}}; + + constexpr vector_size_t size = 100; + + auto rowVector = makeRowVector( + {makeFlatVector( + size, [](vector_size_t row) { return row * 10; }), + makeFlatVector( + size, + [](vector_size_t row) { return 18262 + (row * 100); }, + nullptr, + DATE()), + makeFlatVector( + size, + [](vector_size_t row) { + return fmt::format("prefix{}_value", row % 5); + }), + makeFlatVector( + size, [](vector_size_t row) { return (row % 3) * 1'000; })}); + + const auto& fileStats = + writeDataAndGetAllStats(rowVector, partitionTransforms); + EXPECT_GT(fileStats.size(), 1) + << "Expected multiple files due to partitioning"; + // Check each file's stats. + for (const auto& stats : fileStats) { + EXPECT_GT(stats->numRecords, 0); + constexpr int32_t intColId = 1; + constexpr int32_t dateColId = 2; + constexpr int32_t bigintColId = 4; + + if (stats->columnStats.contains(intColId)) { + const auto& [minVal, maxVal] = decodeBounds(stats, intColId); + EXPECT_LE(minVal, maxVal) + << "Lower bound should be <= upper bound for int column"; + } + + if (stats->columnStats.contains(dateColId)) { + const auto& [minVal, maxVal] = decodeBounds(stats, dateColId); + EXPECT_LE(minVal, maxVal) + << "Lower bound should be <= upper bound for date column"; + } + + if (stats->columnStats.contains(bigintColId)) { + const auto& [minVal, maxVal] = decodeBounds(stats, bigintColId); + EXPECT_LE(minVal, maxVal) + << "Lower bound should be <= upper bound for bigint column"; + } + } + int64_t totalRecords = 0; + for (const auto& stats : fileStats) { + totalRecords += stats->numRecords; + } + EXPECT_EQ(totalRecords, size); +} + +TEST_F(IcebergParquetStatsTest, partitionedTableWithNulls) { + constexpr vector_size_t size = 100; + constexpr int32_t expectedIntNulls = 20; + constexpr int32_t expectedDateNulls = 15; + constexpr int32_t expectedVarcharNulls = 10; + + std::vector partitionTransforms = { + {0, TransformType::kIdentity, std::nullopt}, + {1, TransformType::kMonth, std::nullopt}, + {2, TransformType::kTruncate, 2}}; + auto rowVector = makeRowVector( + {makeFlatVector( + size, + [](vector_size_t row) { return row % 10; }, + nullEvery(5), + INTEGER()), + makeFlatVector( + size, + [](vector_size_t row) { return 18262 + (row % 3) * 30; }, + nullEvery(7), + DATE()), + makeFlatVector( + size, + [](vector_size_t row) { return fmt::format("val{}", row % 5); }, + nullEvery(11))}); + const auto& fileStats = + writeDataAndGetAllStats(rowVector, partitionTransforms); + int32_t totalIntNulls = 0; + int32_t totalDateNulls = 0; + int32_t totalVarcharNulls = 0; + int32_t totalRecords = 0; + + constexpr int32_t intColId = 1; + constexpr int32_t dateColId = 2; + constexpr int32_t varcharColId = 3; + + for (const auto& stats : fileStats) { + totalRecords += stats->numRecords; + // Add null counts if present. + if (stats->columnStats.contains(intColId)) { + totalIntNulls += stats->columnStats.at(intColId).nullValueCount; + } + + if (stats->columnStats.contains(dateColId)) { + totalDateNulls += stats->columnStats.at(dateColId).nullValueCount; + } + + if (stats->columnStats.contains(varcharColId)) { + totalVarcharNulls += stats->columnStats.at(varcharColId).nullValueCount; + } + } + + // Verify total counts match expected. + EXPECT_EQ(totalRecords, size); + EXPECT_EQ(totalIntNulls, expectedIntNulls); + EXPECT_EQ(totalDateNulls, expectedDateNulls); + EXPECT_EQ(totalVarcharNulls, expectedVarcharNulls); +} + +TEST_F(IcebergParquetStatsTest, mapType) { + constexpr vector_size_t size = 100; + constexpr int32_t intColId = 1; + constexpr int32_t mapValueColId = 3; // Map value field ID. + + std::vector>>>> + mapData; + for (auto i = 0; i < size; ++i) { + std::vector>> mapRow; + for (auto j = 0; j < 5; ++j) { + mapRow.emplace_back(j, fmt::format("value_{}", i * 5 + j)); + } + mapData.push_back(std::move(mapRow)); + } + + const auto& stats = writeDataAndGetAllStats(makeRowVector({ + makeFlatVector(size, [](auto row) { return row * 10; }), + makeNullableMapVector(mapData), + })); + verifyBasicStats(stats[0], size, {{intColId, size}}, {{intColId, 0}}); + + EXPECT_EQ(stats[0]->columnStats.at(mapValueColId).valueCount, size * 5); + // Map values have stats but no bounds (skipBounds=true for maps). + verifyBoundsNotExist(stats[0], {mapValueColId}); +} + +TEST_F(IcebergParquetStatsTest, arrayType) { + constexpr vector_size_t size = 100; + constexpr int32_t intColId = 1; + constexpr int32_t arrayElementColId = 3; // Array element field ID. + + std::vector>> arrayData; + for (auto i = 0; i < size; ++i) { + std::vector> arrayRow; + for (auto j = 0; j < 3; ++j) { + arrayRow.emplace_back(fmt::format("item_{}", i * 3 + j)); + } + arrayData.push_back(std::move(arrayRow)); + } + + const auto& stats = writeDataAndGetAllStats(makeRowVector( + {makeFlatVector(size, [](auto row) { return row * 10; }), + makeNullableArrayVector(arrayData)})); + verifyBasicStats(stats[0], size, {{intColId, size}}, {{intColId, 0}}); + + EXPECT_EQ(stats[0]->columnStats.at(arrayElementColId).valueCount, size * 3); + // Array elements have stats but no bounds (skipBounds=true for arrays). + verifyBoundsNotExist(stats[0], {arrayElementColId}); +} + +// Test statistics collection for nested struct fields. +// Field ID assignment: +// int_col: 1 +// struct_col: 2 (parent, no stats) +// first_level_id: 3 +// first_level_name: 4 +// nested_struct: 5 (parent, no stats) +// second_level_id: 6 +// second_level_name: 7 +// Statistics collected for leaf fields: [1, 3, 4, 6, 7] +TEST_F(IcebergParquetStatsTest, structType) { + constexpr vector_size_t size = 100; + constexpr int32_t intColId = 1; + constexpr int32_t firstLevelIdColId = 3; + constexpr int32_t secondLevelIdColId = 6; + constexpr int32_t secondLevelNameColId = 7; + + auto firstLevelId = makeFlatVector( + size, [](vector_size_t row) { return row % size; }, nullEvery(5)); + + auto firstLevelName = makeFlatVector( + size, + [](vector_size_t row) { return fmt::format("name_{}", row * 10); }, + nullEvery(7)); + + auto secondLevelId = makeFlatVector( + size, [](vector_size_t row) { return row * size; }, nullEvery(6)); + + auto secondLevelName = makeFlatVector( + size, + [](vector_size_t row) { return fmt::format("nested_{}", row * 100); }, + nullEvery(8)); + + auto nestedStruct = makeRowVector({secondLevelId, secondLevelName}); + auto structVector = + makeRowVector({firstLevelId, firstLevelName, nestedStruct}); + auto rowVector = makeRowVector( + {makeFlatVector(size, [](auto row) { return row * 10; }), + structVector}); + + const auto& stats = writeDataAndGetAllStats(rowVector); + EXPECT_EQ(stats[0]->numRecords, size); + EXPECT_EQ(stats[0]->columnStats.size(), 5); + + verifyBasicStats( + stats[0], + size, + {{intColId, size}, {firstLevelIdColId, size}, {secondLevelIdColId, size}}, + {{intColId, 0}, {firstLevelIdColId, 20}}); + + EXPECT_EQ( + encoding::Base64::decode( + stats[0]->columnStats.at(secondLevelNameColId).lowerBound.value()), + "nested_100"); + EXPECT_EQ( + encoding::Base64::decode( + stats[0]->columnStats.at(secondLevelNameColId).upperBound.value()), + "nested_9900"); +} + +#endif + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/dwio/parquet/writer/arrow/CMakeLists.txt b/velox/dwio/parquet/writer/arrow/CMakeLists.txt index d1685bc9d45..b00b2f8a662 100644 --- a/velox/dwio/parquet/writer/arrow/CMakeLists.txt +++ b/velox/dwio/parquet/writer/arrow/CMakeLists.txt @@ -37,6 +37,7 @@ velox_add_library( Properties.cpp Schema.cpp Statistics.cpp + StringTruncation.cpp Types.cpp Writer.cpp HEADERS @@ -59,6 +60,7 @@ velox_add_library( Schema.h SchemaInternal.h Statistics.h + StringTruncation.h ThriftInternal.h Types.h Writer.h diff --git a/velox/dwio/parquet/writer/arrow/Statistics.cpp b/velox/dwio/parquet/writer/arrow/Statistics.cpp index 019157b8a6c..13e1c304120 100644 --- a/velox/dwio/parquet/writer/arrow/Statistics.cpp +++ b/velox/dwio/parquet/writer/arrow/Statistics.cpp @@ -39,8 +39,8 @@ #include "velox/dwio/parquet/writer/arrow/Exception.h" #include "velox/dwio/parquet/writer/arrow/Platform.h" #include "velox/dwio/parquet/writer/arrow/Schema.h" +#include "velox/dwio/parquet/writer/arrow/StringTruncation.h" -#include "velox/functions/lib/string/StringImpl.h" #include "velox/type/DecimalUtil.h" #include "velox/type/HugeInt.h" @@ -792,8 +792,8 @@ class TypedStatisticsImpl : public TypedStatistics { return encodeDecimalToBigEndian(min_); } if constexpr (std::is_same_v) { - const auto truncatedMin = functions::stringImpl::truncateUtf8( - std::string_view(min_), truncateTo); + const auto truncatedMin = + truncateUtf8(std::string_view(min_), truncateTo); std::string s; this->plainEncode( ByteArray( @@ -816,8 +816,23 @@ class TypedStatisticsImpl : public TypedStatistics { return encodeDecimalToBigEndian(max_); } if constexpr (std::is_same_v) { - const auto truncatedMax = functions::stringImpl::roundUpUtf8( - std::string_view(max_), truncateTo); + // For ByteArray, we need to determine if this is UTF-8 text (STRING) + // or raw binary data (BINARY/VARBINARY). The Parquet logical type tells + // us this. + const bool isUtf8String = descr_->logicalType()->isString(); + + std::optional truncatedMax; + + if (isUtf8String) { + // Use UTF-8 string logic for STRING type + truncatedMax = roundUpUtf8(std::string_view(max_), truncateTo); + } else { + // Use binary byte logic for BINARY type (VARBINARY) + // Implementation follows Apache Iceberg's + // BinaryUtil.truncateBinaryMax() + truncatedMax = roundUpBinary(std::string_view(max_), truncateTo); + } + if (!truncatedMax.has_value()) { return std::nullopt; } diff --git a/velox/dwio/parquet/writer/arrow/StringTruncation.cpp b/velox/dwio/parquet/writer/arrow/StringTruncation.cpp new file mode 100644 index 00000000000..a7b4bad6bcb --- /dev/null +++ b/velox/dwio/parquet/writer/arrow/StringTruncation.cpp @@ -0,0 +1,180 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/parquet/writer/arrow/StringTruncation.h" + +#include +#include +#include +#include +#include +#include + +#include "velox/functions/lib/string/StringCore.h" +#include "velox/functions/lib/string/StringImpl.h" + +namespace facebook::velox::parquet::arrow { + +// Import necessary functions from stringImpl namespace +using facebook::velox::functions::stringCore::isAscii; +using facebook::velox::functions::stringImpl::cappedByteLength; + +namespace { + +// Increments a Unicode code point to the next valid Unicode scalar value. +// Returns 0 if overflow (input is max code point). +FOLLY_ALWAYS_INLINE int32_t incrementCodePoint(int32_t codePoint) { + static constexpr int32_t kMaxCodePoint = 0x10FFFF; + static constexpr int32_t kMinSurrogate = 0xD800; + static constexpr int32_t kMaxSurrogate = 0xDFFF; + if (codePoint == (kMinSurrogate - 1)) { + // Skip the surrogate range. + return kMaxSurrogate + 1; + } else if (codePoint == kMaxCodePoint) { + return 0; + } + return codePoint + 1; +} + +// ASCII fast-path for roundUp. +FOLLY_ALWAYS_INLINE std::optional roundUpAscii( + std::string_view input, + int32_t numCodePoints) { + const size_t truncatedLength = + std::min(input.size(), static_cast(numCodePoints)); + + if (truncatedLength == input.size()) { + return std::string(input); + } + + if (truncatedLength == 0) { + return std::nullopt; + } + + for (int32_t i = truncatedLength - 1; i >= 0; --i) { + const auto byte = static_cast(input[i]); + if (byte < 0x7F) { + std::string result(input.data(), i); + result.push_back(static_cast(byte + 1)); + return result; + } + } + + // All bytes are 0x7F (DEL character), no valid upper bound. + return std::nullopt; +} + +// Unicode path for roundUp. +FOLLY_ALWAYS_INLINE std::optional roundUpUnicode( + std::string_view input, + int32_t numCodePoints) { + const auto truncatedLength = cappedByteLength(input, numCodePoints); + + if (truncatedLength == input.size()) { + return std::string(input); + } + + if (truncatedLength == 0) { + return std::nullopt; + } + + const char* data = input.data(); + const char* truncatedEnd = data + truncatedLength; + + // Collect the byte offset of each code point. + std::vector codePointOffsets; + codePointOffsets.reserve(numCodePoints); + const char* current = data; + while (current < truncatedEnd) { + codePointOffsets.push_back(current - data); + int32_t charLength; + utf8proc_codepoint(current, truncatedEnd, charLength); + current += charLength; + } + + // Try incrementing from the last code point backwards. + for (int32_t i = codePointOffsets.size() - 1; i >= 0; --i) { + const char* pos = data + codePointOffsets[i]; + int32_t charLength; + const auto codePoint = utf8proc_codepoint(pos, truncatedEnd, charLength); + const auto nextCodePoint = incrementCodePoint(codePoint); + if (nextCodePoint != 0) { + std::string result(data, codePointOffsets[i]); + char buffer[4]; + const auto bytesWritten = utf8proc_encode_char( + nextCodePoint, reinterpret_cast(buffer)); + result.append(buffer, bytesWritten); + return result; + } + } + + // No valid upper bound can be found. + return std::nullopt; +} + +} // namespace + +std::string_view truncateUtf8(std::string_view input, int32_t numCodePoints) { + if (isAscii(input.data(), input.size())) { + return std::string_view( + input.data(), std::min(input.size(), (size_t)numCodePoints)); + } + const auto truncatedLength = cappedByteLength(input, numCodePoints); + return std::string_view(input.data(), truncatedLength); +} + +std::optional roundUpUtf8( + std::string_view input, + int32_t numCodePoints) { + if (isAscii(input.data(), input.size())) { + return roundUpAscii(input, numCodePoints); + } + return roundUpUnicode(input, numCodePoints); +} + +std::optional roundUpBinary( + std::string_view input, + int32_t truncateLength) { + if (truncateLength <= 0) { + return std::nullopt; + } + + const size_t length = static_cast(truncateLength); + if (input.size() <= length) { + return std::string(input); + } + + // Create a mutable copy of the truncated input. + std::string result(input.data(), length); + + // Try incrementing bytes from the end. + for (size_t i = length; i-- > 0;) { + unsigned char byte = static_cast(result[i]); + + if (byte != 0xFF) { // Can increment without overflow. + result[i] = static_cast(byte + 1); + // Truncate to i + 1 bytes (remove trailing bytes after increment point). + result.resize(i + 1); + return result; + } + // If byte == 0xFF, it will overflow, continue to previous byte. + } + + // All bytes were 0xFF and overflowed - no valid upper bound. + return std::nullopt; +} + +} // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/StringTruncation.h b/velox/dwio/parquet/writer/arrow/StringTruncation.h new file mode 100644 index 00000000000..2a0fcb35ec4 --- /dev/null +++ b/velox/dwio/parquet/writer/arrow/StringTruncation.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace facebook::velox::parquet::arrow { + +/// Truncates a UTF-8 encoded string to at most 'numCodePoints' Unicode code +/// points. Returns a string_view pointing to the truncated portion of the +/// input string. This is used for computing lower bound statistics, +/// as the truncated string is guaranteed to be less than or equal to the +/// original string in lexicographic order. +/// +/// @param input The UTF-8 encoded input string. +/// @param numCodePoints Maximum number of Unicode code points to retain. +/// @return A string_view of the truncated string. +std::string_view truncateUtf8(std::string_view input, int32_t numCodePoints); + +/// Rounds up a UTF-8 encoded string to produce an exclusive upper bound. +/// The result is guaranteed to be greater than any string that shares the +/// same prefix up to 'numCodePoints' code points. This is used for computing +/// upper bound statistics. +/// +/// The function behaves as follows: +/// - If the string has fewer than or equal to 'numCodePoints' code points, +/// returns the original string unchanged. +/// - Otherwise, truncates to 'numCodePoints' code points and increments +/// code points from the last to the first, returning immediately on the +/// first successful increment. +/// - If no code point can be incremented (e.g., all are at max value +/// U+10FFFF), returns std::nullopt. +/// +/// @param input The UTF-8 encoded input string. +/// @param numCodePoints Maximum number of Unicode code points to retain. +/// @return A new string containing the rounded-up result, or std::nullopt if +/// no valid upper bound can be computed. +std::optional roundUpUtf8( + std::string_view input, + int32_t numCodePoints); + +/// Computes an upper bound for binary data by truncating to a specified length +/// and incrementing the last byte that is not 0xFF. +/// +/// This function is used for computing upper bounds on binary statistics +/// (e.g., for Parquet file metadata). It follows the algorithm described in +/// Apache Iceberg's BinaryUtil.truncateBinaryMax(). +/// +/// The algorithm: +/// 1. If the input is shorter than or equal to truncateLength, return it as-is. +/// 2. Otherwise, truncate to truncateLength bytes. +/// 3. Starting from the last byte, find the first byte that is not 0xFF. +/// 4. Increment that byte and truncate everything after it. +/// 5. If all bytes are 0xFF, return std::nullopt (no valid upper bound). +/// +/// @param input The binary data as a string_view. +/// @param truncateLength Maximum number of bytes to retain before incrementing. +/// @return An optional string containing the upper bound, or std::nullopt if +/// no valid upper bound exists (e.g., all bytes are 0xFF). +std::optional roundUpBinary( + std::string_view input, + int32_t truncateLength); + +} // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/tests/CMakeLists.txt b/velox/dwio/parquet/writer/arrow/tests/CMakeLists.txt index cd1d4be684f..19e3d0e3b92 100644 --- a/velox/dwio/parquet/writer/arrow/tests/CMakeLists.txt +++ b/velox/dwio/parquet/writer/arrow/tests/CMakeLists.txt @@ -25,6 +25,7 @@ add_executable( PropertiesTest.cpp SchemaTest.cpp StatisticsTest.cpp + StringTruncationTest.cpp TypesTest.cpp ) diff --git a/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp b/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp index 091d50e2a9c..f4be7ed69e9 100644 --- a/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp @@ -24,6 +24,7 @@ #include "velox/common/testutil/TempFilePath.h" #include "velox/dwio/parquet/reader/ParquetReader.h" #include "velox/dwio/parquet/writer/arrow/FileWriter.h" +#include "velox/dwio/parquet/writer/arrow/StringTruncation.h" #include "velox/dwio/parquet/writer/arrow/tests/TestUtil.h" using arrow::default_memory_pool; diff --git a/velox/dwio/parquet/writer/arrow/tests/StringTruncationTest.cpp b/velox/dwio/parquet/writer/arrow/tests/StringTruncationTest.cpp new file mode 100644 index 00000000000..a0486b0c1d0 --- /dev/null +++ b/velox/dwio/parquet/writer/arrow/tests/StringTruncationTest.cpp @@ -0,0 +1,271 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "velox/dwio/parquet/writer/arrow/StringTruncation.h" + +namespace facebook::velox::parquet::arrow { + +// Tests for string utility functions used in statistics +TEST(StringTruncation, truncateUtf8) { + auto testTruncate = [](const std::string& input, + int32_t numCodePoints, + const std::string& expected) { + EXPECT_EQ(truncateUtf8(input, numCodePoints), expected); + }; + + // ASCII string. + std::string ascii = "Hello, world!"; + testTruncate(ascii, 0, ""); + testTruncate(ascii, 1, "H"); + testTruncate(ascii, 5, "Hello"); + testTruncate(ascii, 13, ascii); + testTruncate(ascii, 20, ascii); + + // String with multi-bytes characters. + std::string unicode = "Hello, 世界!"; + testTruncate(unicode, 7, "Hello, "); + testTruncate(unicode, 8, "Hello, 世"); + testTruncate(unicode, 9, "Hello, 世界"); + testTruncate(unicode, 10, unicode); + testTruncate(unicode, 20, unicode); + + // String with emoji (surrogate pairs). + std::string emoji = "Hello 🌍!"; + testTruncate(emoji, 6, "Hello "); + testTruncate(emoji, 7, "Hello 🌍"); + testTruncate(emoji, 8, emoji); + testTruncate(emoji, 10, emoji); + + std::string empty = ""; + testTruncate(empty, 0, ""); + testTruncate(empty, 5, ""); + + std::string mixed = "café世界🌍"; + testTruncate(mixed, 3, "caf"); + testTruncate(mixed, 4, "café"); + testTruncate(mixed, 5, "café世"); + testTruncate(mixed, 6, "café世界"); + testTruncate(mixed, 7, mixed); +} + +TEST(StringTruncation, roundUpUtf8) { + auto testRoundUp = [](const std::string& input, + int32_t numCodePoints, + const std::optional& expected) { + EXPECT_EQ(roundUpUtf8(input, numCodePoints), expected); + }; + + std::string ascii = "Hello, world!"; + // Empty truncation returns nullopt. + testRoundUp(ascii, 0, std::nullopt); + // 'o' -> 'p'. + testRoundUp(ascii, 5, "Hellp"); + testRoundUp(ascii, ascii.length(), ascii); + + ascii = "Customer#000001500"; + // '5' -> '6'. + testRoundUp(ascii, 16, "Customer#0000016"); + + std::string unicode = "Hello, 世界!"; + testRoundUp(unicode, 8, "Hello, 丗"); + + // No truncation needed. + std::string shortString = "Hi"; + testRoundUp(shortString, 2, shortString); + testRoundUp(shortString, 20, shortString); + + // Last character is already at maximum value, returns nullopt. + std::string maxChar = "Hello\U0010FFFF"; + testRoundUp(maxChar, 6, maxChar); + + std::string empty = ""; + testRoundUp(empty, 0, ""); + testRoundUp(empty, 5, ""); + + std::string single = "a"; + // No truncation needed. + testRoundUp(single, 1, "a"); + + std::string zChar = "zz"; + // 'z' -> '{'. + testRoundUp(zChar, 1, "{"); + + std::string emojiTest = "🌍!!"; + // U1F30D (🌍) -> U1F30E. + testRoundUp(emojiTest, 1, "\U0001F30E"); + + std::string multiByteTest = "café+"; + // 'f' -> 'g'. + testRoundUp(multiByteTest, 3, "cag"); + // 'é' -> 'ê'. + testRoundUp(multiByteTest, 4, "cafê"); + + // Test surrogate boundary: U+D7FF should increment to U+E000 (skipping + // surrogate range U+D800-U+DFFF). + // U+D7FF followed by "!!" + std::string surrogateTest = "\xED\x9F\xBF!!"; + // U+E000 + testRoundUp(surrogateTest, 1, "\xEE\x80\x80"); + + // Test all max code points - should return nullopt. + std::string allMax = "\U0010FFFF\U0010FFFF"; + testRoundUp(allMax, 1, std::nullopt); +} + +TEST(StringTruncation, roundUpBinary) { + auto testRoundUpBinary = [](const std::string& input, + int32_t truncateLength, + const std::optional& expected) { + EXPECT_EQ(roundUpBinary(input, truncateLength), expected); + }; + + // Basic binary data with truncation. + std::string binary = "Hello, world!"; + // Empty truncation returns nullopt. + testRoundUpBinary(binary, 0, std::nullopt); + // 'o' (0x6F) -> 'p' (0x70). + testRoundUpBinary(binary, 5, "Hellp"); + // No truncation needed - returns input unchanged. + testRoundUpBinary(binary, binary.length(), binary); + testRoundUpBinary(binary, binary.length() + 10, binary); + + // Test with numeric data. + std::string numeric = "Customer#000001500"; + // '5' (0x35) -> '6' (0x36). + testRoundUpBinary(numeric, 16, "Customer#0000016"); + + // Test with binary data containing high bytes. + std::string highBytes = "data\xFE\xFD"; + // No truncation needed - returns input unchanged. + testRoundUpBinary(highBytes, 6, highBytes); + // Truncate to 5 bytes "data\xFE", 0xFE -> 0xFF. + testRoundUpBinary(highBytes, 5, "data\xFF"); + + // Test with all 0xFF bytes - should return nullopt. + std::string allFF = "\xFF\xFF\xFF"; + testRoundUpBinary(allFF, 1, std::nullopt); + testRoundUpBinary(allFF, 2, std::nullopt); + // No truncation needed - returns input unchanged. + testRoundUpBinary(allFF, 3, allFF); + + // Test with trailing 0xFF bytes. + std::string trailingFF = "abc\xFF\xFF"; + // No truncation needed - returns input unchanged. + testRoundUpBinary(trailingFF, 5, trailingFF); + // Truncate to 4 bytes "abc\xFF", 0xFF overflows, 'c' (0x63) -> 'd' (0x64). + testRoundUpBinary(trailingFF, 4, "abd"); + // Truncate to 3 bytes "abc", 'c' (0x63) -> 'd' (0x64). + testRoundUpBinary(trailingFF, 3, "abd"); + + // Test empty string. + std::string empty = ""; + testRoundUpBinary(empty, 0, std::nullopt); + testRoundUpBinary(empty, 5, ""); + + // Test single byte. + std::string single = "a"; + // No truncation needed - returns input unchanged. + testRoundUpBinary(single, 1, "a"); + testRoundUpBinary(single, 10, "a"); + + // Test incrementing single byte with truncation. + std::string singleZ = "zz"; + // Truncate to 1 byte "z", 'z' (0x7A) -> '{' (0x7B). + testRoundUpBinary(singleZ, 1, "{"); + + // Test with null bytes. + std::string withNull = std::string("ab\0cd", 5); + // No truncation needed - returns input unchanged. + testRoundUpBinary(withNull, 5, withNull); + // Truncate to 4 bytes "ab\0c", 'c' (0x63) -> 'd' (0x64). + testRoundUpBinary(withNull, 4, std::string("ab\0d", 4)); + + // Test boundary case: 0xFE -> 0xFF. + std::string boundaryFE = "test\xFE"; + // No truncation needed - returns input unchanged. + testRoundUpBinary(boundaryFE, 5, boundaryFE); + // Truncate to 5 bytes and increment would give same result. + std::string boundaryFE2 = std::string("test\xFE", 5) + "abc"; + testRoundUpBinary(boundaryFE2, 5, "test\xFF"); + + // Test mixed case with overflow in middle. + std::string mixedOverflow = "a\xFF\xFFz"; + // Truncate to 3 bytes "a\xFF\xFF", both 0xFF overflow, 'a' (0x61) -> 'b' + // (0x62). + testRoundUpBinary(mixedOverflow, 3, "b"); + + // Test truncation removes trailing bytes after increment. + std::string longString = "abcdefgh"; + // Truncate to 3 bytes "abc", 'c' (0x63) -> 'd' (0x64), result is "abd". + testRoundUpBinary(longString, 3, "abd"); + // Truncate to 5 bytes "abcde", 'e' (0x65) -> 'f' (0x66), result is "abcdf". + testRoundUpBinary(longString, 5, "abcdf"); + + // Test with UTF-8 multi-byte sequences (treated as raw bytes). + std::string utf8Bytes = "café"; + // Truncate to 3 bytes "caf", 'f' (0x66) -> 'g' (0x67). + testRoundUpBinary(utf8Bytes, 3, "cag"); + // No truncation needed - returns input unchanged. + testRoundUpBinary(utf8Bytes, 5, utf8Bytes); + // Truncate to 5 bytes and increment last byte. + std::string utf8Bytes2 = "café!"; + // Truncate to 5 bytes "café" (caf + 0xC3 0xA9), 0xA9 -> 0xAA. + testRoundUpBinary(utf8Bytes2, 5, "caf\xC3\xAA"); + + // Test with INVALID UTF-8 sequences - this is the key use case for + // roundUpBinary. These sequences would cause roundUpUtf8 to fail, but + // roundUpBinary treats them as raw bytes. + + // Invalid UTF-8: lone continuation byte 0x80. + std::string invalidUtf8_1 = std::string("test\x80", 5); + testRoundUpBinary(invalidUtf8_1, 5, invalidUtf8_1); + testRoundUpBinary(invalidUtf8_1, 4, "tesu"); + + // Invalid UTF-8: incomplete multi-byte sequence (0xC3 without continuation). + std::string invalidUtf8_2 = std::string("data\xC3", 5); + testRoundUpBinary(invalidUtf8_2, 5, invalidUtf8_2); + testRoundUpBinary(invalidUtf8_2, 4, "datb"); + + // Invalid UTF-8: overlong encoding (0xC0 0x80 for null byte). + std::string invalidUtf8_3 = std::string("ab\xC0\x80", 4); + testRoundUpBinary(invalidUtf8_3, 4, invalidUtf8_3); + testRoundUpBinary(invalidUtf8_3, 3, std::string("ab\xC1", 3)); + + // Invalid UTF-8: invalid start byte 0xFE. + std::string invalidUtf8_4 = std::string("xyz\xFE", 4); + testRoundUpBinary(invalidUtf8_4, 4, invalidUtf8_4); + testRoundUpBinary(invalidUtf8_4, 3, "xy{"); + + // Invalid UTF-8: truncated 3-byte sequence (0xE0 0x80 without third byte). + std::string invalidUtf8_5 = std::string("foo\xE0\x80", 5); + testRoundUpBinary(invalidUtf8_5, 5, invalidUtf8_5); + testRoundUpBinary(invalidUtf8_5, 4, std::string("foo\xE1", 4)); + + // Invalid UTF-8: sequence with 0xFF (which is never valid in UTF-8). + std::string invalidUtf8_6 = std::string("bar\xFF", 4); + testRoundUpBinary(invalidUtf8_6, 4, invalidUtf8_6); + // Truncate to 3 bytes "bar", 'r' (0x72) -> 's' (0x73). + testRoundUpBinary(invalidUtf8_6, 3, "bas"); + + // Test with all 0xFF in invalid UTF-8 context. + std::string invalidUtf8_7 = std::string("\xFF\xFF\xFF", 3); + testRoundUpBinary(invalidUtf8_7, 2, std::nullopt); + testRoundUpBinary(invalidUtf8_7, 3, invalidUtf8_7); +} + +} // namespace facebook::velox::parquet::arrow diff --git a/velox/functions/lib/string/StringImpl.h b/velox/functions/lib/string/StringImpl.h index 67f7728f3b6..3322ac36991 100644 --- a/velox/functions/lib/string/StringImpl.h +++ b/velox/functions/lib/string/StringImpl.h @@ -788,97 +788,6 @@ FOLLY_ALWAYS_INLINE bool initcapUtf8Impl( return true; } -// Increments a Unicode code point to the next valid Unicode scalar value. -// Returns 0 if overflow (input is max code point). -FOLLY_ALWAYS_INLINE int32_t incrementCodePoint(int32_t codePoint) { - static constexpr int32_t kMaxCodePoint = 0x10FFFF; - static constexpr int32_t kMinSurrogate = 0xD800; - static constexpr int32_t kMaxSurrogate = 0xDFFF; - if (codePoint == (kMinSurrogate - 1)) { - // Skip the surrogate range. - return kMaxSurrogate + 1; - } else if (codePoint == kMaxCodePoint) { - return 0; - } - return codePoint + 1; -} - -// ASCII fast-path for roundUp. -FOLLY_ALWAYS_INLINE std::optional roundUpAscii( - std::string_view input, - int32_t numCodePoints) { - const size_t truncatedLength = - std::min(input.size(), static_cast(numCodePoints)); - - if (truncatedLength == input.size()) { - return std::string(input); - } - - if (truncatedLength == 0) { - return std::nullopt; - } - - for (int32_t i = truncatedLength - 1; i >= 0; --i) { - const auto byte = static_cast(input[i]); - if (byte < 0x7F) { - std::string result(input.data(), i); - result.push_back(static_cast(byte + 1)); - return result; - } - } - - // All bytes are 0x7F (DEL character), no valid upper bound. - return std::nullopt; -} - -// Unicode path for roundUp. -FOLLY_ALWAYS_INLINE std::optional roundUpUnicode( - std::string_view input, - int32_t numCodePoints) { - const auto truncatedLength = cappedByteLength(input, numCodePoints); - - if (truncatedLength == input.size()) { - return std::string(input); - } - - if (truncatedLength == 0) { - return std::nullopt; - } - - const char* data = input.data(); - const char* truncatedEnd = data + truncatedLength; - - // Collect the byte offset of each code point. - std::vector codePointOffsets; - codePointOffsets.reserve(numCodePoints); - const char* current = data; - while (current < truncatedEnd) { - codePointOffsets.push_back(current - data); - int32_t charLength; - utf8proc_codepoint(current, truncatedEnd, charLength); - current += charLength; - } - - // Try incrementing from the last code point backwards. - for (int32_t i = codePointOffsets.size() - 1; i >= 0; --i) { - const char* pos = data + codePointOffsets[i]; - int32_t charLength; - const auto codePoint = utf8proc_codepoint(pos, truncatedEnd, charLength); - const auto nextCodePoint = incrementCodePoint(codePoint); - if (nextCodePoint != 0) { - std::string result(data, codePointOffsets[i]); - char buffer[4]; - const auto bytesWritten = utf8proc_encode_char( - nextCodePoint, reinterpret_cast(buffer)); - result.append(buffer, bytesWritten); - return result; - } - } - - // No valid upper bound can be found. - return std::nullopt; -} - } // namespace detail /// Converts the first character of each word to uppercase and all other @@ -905,51 +814,4 @@ FOLLY_ALWAYS_INLINE bool initcap(TOutString& output, const TInString& input) { } } -/// Truncates a UTF-8 encoded string to at most 'numCodePoints' Unicode code -/// points. Returns a string_view pointing to the truncated portion of the -/// input string. This is used for computing Iceberg lower bound statistics, -/// as the truncated string is guaranteed to be less than or equal to the -/// original string in lexicographic order. -/// -/// @param input The UTF-8 encoded input string. -/// @param numCodePoints Maximum number of Unicode code points to retain. -/// @return A string_view of the truncated string. -FOLLY_ALWAYS_INLINE std::string_view truncateUtf8( - std::string_view input, - int32_t numCodePoints) { - if (isAscii(input.data(), input.size())) { - return std::string_view( - input.data(), std::min(input.size(), (size_t)numCodePoints)); - } - const auto truncatedLength = cappedByteLength(input, numCodePoints); - return std::string_view(input.data(), truncatedLength); -} - -/// Rounds up a UTF-8 encoded string to produce an exclusive upper bound. -/// The result is guaranteed to be greater than any string that shares the -/// same prefix up to 'numCodePoints' code points. This is used for computing -/// Iceberg upper bound statistics. -/// -/// The function behaves as follows: -/// - If the string has fewer than or equal to 'numCodePoints' code points, -/// returns the original string unchanged. -/// - Otherwise, truncates to 'numCodePoints' code points and increments -/// code points from the last to the first, returning immediately on the -/// first successful increment. -/// - If no code point can be incremented (e.g., all are at max value -/// U+10FFFF), returns std::nullopt. -/// -/// @param input The UTF-8 encoded input string. -/// @param numCodePoints Maximum number of Unicode code points to retain. -/// @return A new string containing the rounded-up result, or std::nullopt if -/// no valid upper bound can be computed. -FOLLY_ALWAYS_INLINE std::optional roundUpUtf8( - std::string_view input, - int32_t numCodePoints) { - if (isAscii(input.data(), input.size())) { - return detail::roundUpAscii(input, numCodePoints); - } - return detail::roundUpUnicode(input, numCodePoints); -} - } // namespace facebook::velox::functions::stringImpl diff --git a/velox/functions/lib/string/tests/StringImplTest.cpp b/velox/functions/lib/string/tests/StringImplTest.cpp index c4716a16e5b..c9a0a671c95 100644 --- a/velox/functions/lib/string/tests/StringImplTest.cpp +++ b/velox/functions/lib/string/tests/StringImplTest.cpp @@ -214,20 +214,6 @@ class StringImplTest : public testing::Test { {"CAPS_LOCK@DOMAIN.COM", "Caps_lock@domain.com"}, {"__init__.py@example.dev", "__init__.py@example.dev"}}; } - - static void testTruncate( - const std::string& input, - int32_t numCodePoints, - const std::string& expected) { - EXPECT_EQ(truncateUtf8(input, numCodePoints), expected); - } - - static void testRoundUp( - const std::string& input, - int32_t numCodePoints, - const std::optional& expected) { - EXPECT_EQ(roundUpUtf8(input, numCodePoints), expected); - } }; TEST_F(StringImplTest, upperAscii) { @@ -1366,97 +1352,3 @@ TEST_F(StringImplTest, initcapAsciiSpark) { ASSERT_EQ(output, expected); } } - -TEST_F(StringImplTest, truncate) { - // ASCII string. - std::string ascii = "Hello, world!"; - testTruncate(ascii, 0, ""); - testTruncate(ascii, 1, "H"); - testTruncate(ascii, 5, "Hello"); - testTruncate(ascii, 13, ascii); - testTruncate(ascii, 20, ascii); - - // String with multi-bytes characters. - std::string unicode = "Hello, 世界!"; - testTruncate(unicode, 7, "Hello, "); - testTruncate(unicode, 8, "Hello, 世"); - testTruncate(unicode, 9, "Hello, 世界"); - testTruncate(unicode, 10, unicode); - testTruncate(unicode, 20, unicode); - - // String with emoji (surrogate pairs). - std::string emoji = "Hello 🌍!"; - testTruncate(emoji, 6, "Hello "); - testTruncate(emoji, 7, "Hello 🌍"); - testTruncate(emoji, 8, emoji); - testTruncate(emoji, 10, emoji); - - std::string empty = ""; - testTruncate(empty, 0, ""); - testTruncate(empty, 5, ""); - - std::string mixed = "café世界🌍"; - testTruncate(mixed, 3, "caf"); - testTruncate(mixed, 4, "café"); - testTruncate(mixed, 5, "café世"); - testTruncate(mixed, 6, "café世界"); - testTruncate(mixed, 7, mixed); -} - -TEST_F(StringImplTest, roundUp) { - std::string ascii = "Hello, world!"; - // Empty truncation returns nullopt. - testRoundUp(ascii, 0, std::nullopt); - // 'o' -> 'p'. - testRoundUp(ascii, 5, "Hellp"); - testRoundUp(ascii, ascii.length(), ascii); - - ascii = "Customer#000001500"; - // '5' -> '6'. - testRoundUp(ascii, 16, "Customer#0000016"); - - std::string unicode = "Hello, 世界!"; - testRoundUp(unicode, 8, "Hello, 丗"); - - // No truncation needed. - std::string shortString = "Hi"; - testRoundUp(shortString, 2, shortString); - testRoundUp(shortString, 20, shortString); - - // Last character is already at maximum value, returns nullopt. - std::string maxChar = "Hello\U0010FFFF"; - testRoundUp(maxChar, 6, maxChar); - - std::string empty = ""; - testRoundUp(empty, 0, ""); - testRoundUp(empty, 5, ""); - - std::string single = "a"; - // No truncation needed. - testRoundUp(single, 1, "a"); - - std::string zChar = "zz"; - // 'z' -> '{'. - testRoundUp(zChar, 1, "{"); - - std::string emojiTest = "🌍!!"; - // U1F30D (🌍) -> U1F30E. - testRoundUp(emojiTest, 1, "\U0001F30E"); - - std::string multiByteTest = "café+"; - // 'f' -> 'g'. - testRoundUp(multiByteTest, 3, "cag"); - // 'é' -> 'ê'. - testRoundUp(multiByteTest, 4, "cafê"); - - // Test surrogate boundary: U+D7FF should increment to U+E000 (skipping - // surrogate range U+D800-U+DFFF). - // U+D7FF followed by "!!" - std::string surrogateTest = "\xED\x9F\xBF!!"; - // U+E000 - testRoundUp(surrogateTest, 1, "\xEE\x80\x80"); - - // Test all max code points - should return nullopt. - std::string allMax = "\U0010FFFF\U0010FFFF"; - testRoundUp(allMax, 1, std::nullopt); -}