diff --git a/src/datatrove/io.py b/src/datatrove/io.py index dd2ced344..f0a4c5443 100644 --- a/src/datatrove/io.py +++ b/src/datatrove/io.py @@ -21,7 +21,7 @@ def __init__(self, fs, mode: str = "wt", compression: str | None = "infer"): def get_file(self, filename): """ - Opens file `filename` if it hasn't been opened yet. Otherwise just returns it from the file cache + Opens file `filename` if it hasn't been opened yet. Otherwise, just returns it from the file cache Args: filename: name of the file to open/get if previously opened diff --git a/src/datatrove/pipeline/writers/__init__.py b/src/datatrove/pipeline/writers/__init__.py index ba5515ce9..bf664183b 100644 --- a/src/datatrove/pipeline/writers/__init__.py +++ b/src/datatrove/pipeline/writers/__init__.py @@ -1 +1,2 @@ from .jsonl import JsonlWriter +from .parquet import ParquetWriter diff --git a/src/datatrove/pipeline/writers/disk_base.py b/src/datatrove/pipeline/writers/disk_base.py index 9245a924f..b7d1f8345 100644 --- a/src/datatrove/pipeline/writers/disk_base.py +++ b/src/datatrove/pipeline/writers/disk_base.py @@ -1,7 +1,7 @@ import dataclasses from abc import ABC, abstractmethod from string import Template -from typing import Callable +from typing import IO, Callable from datatrove.data import Document, DocumentsPipeline from datatrove.io import DataFolderLike, get_datafolder @@ -31,6 +31,7 @@ def __init__( output_filename: str = None, compression: str | None = "infer", adapter: Callable = None, + mode: str = "wt", ): """ Base writer block to save data to disk. @@ -47,7 +48,7 @@ def __init__( if self.compression == "gzip" and not output_filename.endswith(".gz"): output_filename += ".gz" self.output_filename = Template(output_filename) - self.output_mg = self.output_folder.get_output_file_manager(mode="wt", compression=compression) + self.output_mg = self.output_folder.get_output_file_manager(mode=mode, compression=compression) self.adapter = adapter if adapter else _default_adapter def __enter__(self): @@ -81,13 +82,13 @@ def _get_output_filename(self, document: Document, rank: int | str = 0, **kwargs ) @abstractmethod - def _write(self, document: dict, file_handler): + def _write(self, document: dict, file_handler: IO, filename: str): """ Main method that subclasses should implement. Receives an adapted (after applying self.adapter) dictionary with data to save to `file_handler` Args: document: dictionary with the data to save file_handler: file_handler where it should be saved - + filename: to use as a key for writer helpers and other data Returns: """ @@ -105,7 +106,7 @@ def write(self, document: Document, rank: int = 0, **kwargs): """ output_filename = self._get_output_filename(document, rank, **kwargs) - self._write(self.adapter(document), self.output_mg.get_file(output_filename)) + self._write(self.adapter(document), self.output_mg.get_file(output_filename), output_filename) self.stat_update(self._get_output_filename(document, "XXXXX", **kwargs)) self.stat_update(StatHints.total) self.update_doc_stats(document) diff --git a/src/datatrove/pipeline/writers/jsonl.py b/src/datatrove/pipeline/writers/jsonl.py index 10fff3bfa..0fd5b099f 100644 --- a/src/datatrove/pipeline/writers/jsonl.py +++ b/src/datatrove/pipeline/writers/jsonl.py @@ -18,5 +18,5 @@ def __init__( ): super().__init__(output_folder, output_filename=output_filename, compression=compression, adapter=adapter) - def _write(self, document: dict, file: IO): - file.write(json.dumps(document, ensure_ascii=False) + "\n") + def _write(self, document: dict, file_handler: IO, _filename: str): + file_handler.write(json.dumps(document, ensure_ascii=False) + "\n") diff --git a/src/datatrove/pipeline/writers/parquet.py b/src/datatrove/pipeline/writers/parquet.py new file mode 100644 index 000000000..ca25ae4e8 --- /dev/null +++ b/src/datatrove/pipeline/writers/parquet.py @@ -0,0 +1,55 @@ +from collections import defaultdict +from typing import IO, Callable + +from datatrove.io import DataFolderLike +from datatrove.pipeline.writers.disk_base import DiskWriter + + +class ParquetWriter(DiskWriter): + default_output_filename: str = "${rank}.parquet" + name = "📒 Parquet" + _requires_dependencies = ["pyarrow"] + + def __init__( + self, + output_folder: DataFolderLike, + output_filename: str = None, + compression: str | None = None, + adapter: Callable = None, + batch_size: int = 1000, + ): + super().__init__(output_folder, output_filename, compression, adapter, mode="wb") + self._writers = {} + self._batches = defaultdict(list) + self.batch_size = batch_size + + def _write_batch(self, filename): + if not self._batches[filename]: + return + import pyarrow as pa + + # prepare batch + batch = pa.RecordBatch.from_pylist(self._batches.pop(filename)) + # write batch + self._writers[filename].write_batch(batch) + + def _write(self, document: dict, file_handler: IO, filename: str): + import pyarrow as pa + import pyarrow.parquet as pq + + if filename not in self._writers: + self._writers[filename] = pq.ParquetWriter( + file_handler, schema=pa.RecordBatch.from_pylist([document]).schema + ) + self._batches[filename].append(document) + if len(self._batches[filename]) == self.batch_size: + self._write_batch(filename) + + def close(self): + for filename in list(self._batches.keys()): + self._write_batch(filename) + for writer in self._writers.values(): + writer.close() + self._batches.clear() + self._writers.clear() + super().close() diff --git a/tests/pipeline/test_parquet_writer.py b/tests/pipeline/test_parquet_writer.py new file mode 100644 index 000000000..dfd4a1d4b --- /dev/null +++ b/tests/pipeline/test_parquet_writer.py @@ -0,0 +1,33 @@ +import shutil +import tempfile +import unittest + +from datatrove.data import Document +from datatrove.pipeline.readers.parquet import ParquetReader +from datatrove.pipeline.writers.parquet import ParquetWriter + +from ..utils import require_pyarrow + + +@require_pyarrow +class TestParquetWriter(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.tmp_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp_dir) + + def test_write(self): + data = [ + Document(text=text, id=str(i), metadata={"somedata": 2 * i, "somefloat": i * 0.4, "somestring": "hello"}) + for i, text in enumerate(["hello", "text2", "more text"]) + ] + with ParquetWriter(output_folder=self.tmp_dir, batch_size=2) as w: + for doc in data: + w.write(doc) + reader = ParquetReader(self.tmp_dir) + c = 0 + for read_doc, original in zip(reader(), data): + read_doc.metadata.pop("file_path", None) + assert read_doc == original + c += 1 + assert c == len(data)