From ec98bfbfe6b058a4ab1b1709364d3020a1e0f225 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Wed, 21 Feb 2024 16:30:14 +0100 Subject: [PATCH 1/5] added parquet writer --- src/datatrove/io.py | 2 +- src/datatrove/pipeline/writers/__init__.py | 1 + src/datatrove/pipeline/writers/disk_base.py | 11 ++-- src/datatrove/pipeline/writers/jsonl.py | 4 +- src/datatrove/pipeline/writers/parquet.py | 56 +++++++++++++++++++++ tests/pipeline/test_parquet_writer.py | 38 ++++++++++++++ 6 files changed, 104 insertions(+), 8 deletions(-) create mode 100644 src/datatrove/pipeline/writers/parquet.py create mode 100644 tests/pipeline/test_parquet_writer.py diff --git a/src/datatrove/io.py b/src/datatrove/io.py index dd2ced344..f0a4c5443 100644 --- a/src/datatrove/io.py +++ b/src/datatrove/io.py @@ -21,7 +21,7 @@ def __init__(self, fs, mode: str = "wt", compression: str | None = "infer"): def get_file(self, filename): """ - Opens file `filename` if it hasn't been opened yet. Otherwise just returns it from the file cache + Opens file `filename` if it hasn't been opened yet. Otherwise, just returns it from the file cache Args: filename: name of the file to open/get if previously opened diff --git a/src/datatrove/pipeline/writers/__init__.py b/src/datatrove/pipeline/writers/__init__.py index ba5515ce9..bf664183b 100644 --- a/src/datatrove/pipeline/writers/__init__.py +++ b/src/datatrove/pipeline/writers/__init__.py @@ -1 +1,2 @@ from .jsonl import JsonlWriter +from .parquet import ParquetWriter diff --git a/src/datatrove/pipeline/writers/disk_base.py b/src/datatrove/pipeline/writers/disk_base.py index 9245a924f..b7d1f8345 100644 --- a/src/datatrove/pipeline/writers/disk_base.py +++ b/src/datatrove/pipeline/writers/disk_base.py @@ -1,7 +1,7 @@ import dataclasses from abc import ABC, abstractmethod from string import Template -from typing import Callable +from typing import IO, Callable from datatrove.data import Document, DocumentsPipeline from datatrove.io import DataFolderLike, get_datafolder @@ -31,6 +31,7 @@ def __init__( output_filename: str = None, compression: str | None = "infer", adapter: Callable = None, + mode: str = "wt", ): """ Base writer block to save data to disk. @@ -47,7 +48,7 @@ def __init__( if self.compression == "gzip" and not output_filename.endswith(".gz"): output_filename += ".gz" self.output_filename = Template(output_filename) - self.output_mg = self.output_folder.get_output_file_manager(mode="wt", compression=compression) + self.output_mg = self.output_folder.get_output_file_manager(mode=mode, compression=compression) self.adapter = adapter if adapter else _default_adapter def __enter__(self): @@ -81,13 +82,13 @@ def _get_output_filename(self, document: Document, rank: int | str = 0, **kwargs ) @abstractmethod - def _write(self, document: dict, file_handler): + def _write(self, document: dict, file_handler: IO, filename: str): """ Main method that subclasses should implement. Receives an adapted (after applying self.adapter) dictionary with data to save to `file_handler` Args: document: dictionary with the data to save file_handler: file_handler where it should be saved - + filename: to use as a key for writer helpers and other data Returns: """ @@ -105,7 +106,7 @@ def write(self, document: Document, rank: int = 0, **kwargs): """ output_filename = self._get_output_filename(document, rank, **kwargs) - self._write(self.adapter(document), self.output_mg.get_file(output_filename)) + self._write(self.adapter(document), self.output_mg.get_file(output_filename), output_filename) self.stat_update(self._get_output_filename(document, "XXXXX", **kwargs)) self.stat_update(StatHints.total) self.update_doc_stats(document) diff --git a/src/datatrove/pipeline/writers/jsonl.py b/src/datatrove/pipeline/writers/jsonl.py index 10fff3bfa..0fd5b099f 100644 --- a/src/datatrove/pipeline/writers/jsonl.py +++ b/src/datatrove/pipeline/writers/jsonl.py @@ -18,5 +18,5 @@ def __init__( ): super().__init__(output_folder, output_filename=output_filename, compression=compression, adapter=adapter) - def _write(self, document: dict, file: IO): - file.write(json.dumps(document, ensure_ascii=False) + "\n") + def _write(self, document: dict, file_handler: IO, _filename: str): + file_handler.write(json.dumps(document, ensure_ascii=False) + "\n") diff --git a/src/datatrove/pipeline/writers/parquet.py b/src/datatrove/pipeline/writers/parquet.py new file mode 100644 index 000000000..9423d2818 --- /dev/null +++ b/src/datatrove/pipeline/writers/parquet.py @@ -0,0 +1,56 @@ +from collections import defaultdict +from typing import IO, Callable + +from datatrove.io import DataFolderLike +from datatrove.pipeline.writers.disk_base import DiskWriter + + +class ParquetWriter(DiskWriter): + default_output_filename: str = "${rank}.parquet" + name = "📒 Parquet" + _requires_dependencies = ["pyarrow"] + + def __init__( + self, + output_folder: DataFolderLike, + output_filename: str = None, + compression: str | None = None, + adapter: Callable = None, + batch_size: int = 1000, + ): + super().__init__(output_folder, output_filename, compression, adapter, mode="wb") + self._writers = {} + self._batches = defaultdict(list) + self.batch_size = batch_size + + def _write_batch(self, filename): + if not self._batches[filename]: + return + import pyarrow as pa + + names = list(self._batches[filename][0].keys()) + # prepare batch + batch = pa.record_batch(list(zip(*[d.values() for d in self._batches.pop(filename)])), names=names) + # write batch + self._writers[filename].write_batch(batch) + + def _write(self, document: dict, file_handler: IO, filename: str): + import pyarrow as pa + import pyarrow.parquet as pq + + if filename not in self._writers: + self._writers[filename] = pq.ParquetWriter( + file_handler, schema=pa.table({name: [val] for name, val in document.items()}).schema + ) + self._batches[filename].append(document) + if len(self._batches[filename]) == self.batch_size: + self._write_batch(filename) + + def close(self): + for filename in list(self._batches.keys()): + self._write_batch(filename) + for writer in self._writers.values(): + writer.close() + self._batches.clear() + self._writers.clear() + super().close() diff --git a/tests/pipeline/test_parquet_writer.py b/tests/pipeline/test_parquet_writer.py new file mode 100644 index 000000000..2bff3e569 --- /dev/null +++ b/tests/pipeline/test_parquet_writer.py @@ -0,0 +1,38 @@ +import shutil +import tempfile +import unittest + +from datatrove.data import Document +from datatrove.pipeline.readers.parquet import ParquetReader +from datatrove.pipeline.writers.parquet import ParquetWriter +from datatrove.utils._import_utils import is_pyarrow_available + +from ..utils import require_pyarrow + + +if is_pyarrow_available(): + pass # noqa: F811 + + +@require_pyarrow +class TestParquetWriter(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.tmp_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp_dir) + + def test_write(self): + data = [ + Document(text=text, id=str(i), metadata={"somedata": 2 * i}) + for i, text in enumerate(["hello", "text2", "more text"]) + ] + with ParquetWriter(output_folder=self.tmp_dir, batch_size=2) as w: + for doc in data: + w.write(doc) + reader = ParquetReader(self.tmp_dir) + c = 0 + for read_doc, original in zip(reader(), data): + read_doc.metadata.pop("file_path", None) + assert read_doc == original + c += 1 + assert c == len(data) From 9e00d75e9fab1a3ef3725e769b4be22e233436ca Mon Sep 17 00:00:00 2001 From: guipenedo Date: Wed, 21 Feb 2024 16:36:46 +0100 Subject: [PATCH 2/5] nit --- tests/pipeline/test_parquet_writer.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/pipeline/test_parquet_writer.py b/tests/pipeline/test_parquet_writer.py index 2bff3e569..a279f5452 100644 --- a/tests/pipeline/test_parquet_writer.py +++ b/tests/pipeline/test_parquet_writer.py @@ -5,15 +5,10 @@ from datatrove.data import Document from datatrove.pipeline.readers.parquet import ParquetReader from datatrove.pipeline.writers.parquet import ParquetWriter -from datatrove.utils._import_utils import is_pyarrow_available from ..utils import require_pyarrow -if is_pyarrow_available(): - pass # noqa: F811 - - @require_pyarrow class TestParquetWriter(unittest.TestCase): def setUp(self): From 7cd4051259ac59c24e34fe0ab002f64b52440a6a Mon Sep 17 00:00:00 2001 From: Guilherme Penedo Date: Wed, 21 Feb 2024 17:45:30 +0100 Subject: [PATCH 3/5] Update src/datatrove/pipeline/writers/parquet.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mario Šaško --- src/datatrove/pipeline/writers/parquet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/datatrove/pipeline/writers/parquet.py b/src/datatrove/pipeline/writers/parquet.py index 9423d2818..38cf32fe7 100644 --- a/src/datatrove/pipeline/writers/parquet.py +++ b/src/datatrove/pipeline/writers/parquet.py @@ -28,9 +28,8 @@ def _write_batch(self, filename): return import pyarrow as pa - names = list(self._batches[filename][0].keys()) # prepare batch - batch = pa.record_batch(list(zip(*[d.values() for d in self._batches.pop(filename)])), names=names) + batch = pa.RecordBatch.from_pylist(self._batches.pop(filename)) # write batch self._writers[filename].write_batch(batch) From 9459e5031a509f453171de3601758dc50440a275 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Thu, 22 Feb 2024 11:40:31 +0100 Subject: [PATCH 4/5] updated test --- tests/pipeline/test_parquet_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipeline/test_parquet_writer.py b/tests/pipeline/test_parquet_writer.py index a279f5452..dfd4a1d4b 100644 --- a/tests/pipeline/test_parquet_writer.py +++ b/tests/pipeline/test_parquet_writer.py @@ -18,7 +18,7 @@ def setUp(self): def test_write(self): data = [ - Document(text=text, id=str(i), metadata={"somedata": 2 * i}) + Document(text=text, id=str(i), metadata={"somedata": 2 * i, "somefloat": i * 0.4, "somestring": "hello"}) for i, text in enumerate(["hello", "text2", "more text"]) ] with ParquetWriter(output_folder=self.tmp_dir, batch_size=2) as w: From 19be55f879dd555f439cc2a46cd2a32a2c59f8f0 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Thu, 22 Feb 2024 11:42:32 +0100 Subject: [PATCH 5/5] nit --- src/datatrove/pipeline/writers/parquet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datatrove/pipeline/writers/parquet.py b/src/datatrove/pipeline/writers/parquet.py index 38cf32fe7..ca25ae4e8 100644 --- a/src/datatrove/pipeline/writers/parquet.py +++ b/src/datatrove/pipeline/writers/parquet.py @@ -39,7 +39,7 @@ def _write(self, document: dict, file_handler: IO, filename: str): if filename not in self._writers: self._writers[filename] = pq.ParquetWriter( - file_handler, schema=pa.table({name: [val] for name, val in document.items()}).schema + file_handler, schema=pa.RecordBatch.from_pylist([document]).schema ) self._batches[filename].append(document) if len(self._batches[filename]) == self.batch_size: