-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathdataset.py
More file actions
752 lines (666 loc) · 27.1 KB
/
dataset.py
File metadata and controls
752 lines (666 loc) · 27.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
from collections import OrderedDict
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
from torch import distributed as dist
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
from tzrec.constant import Mode
from tzrec.datasets.data_parser import DataParser
from tzrec.datasets.sampler import BaseSampler, TDMSampler
from tzrec.datasets.utils import (
C_NEG_SAMPLE_MASK,
C_SAMPLE_MASK,
CKPT_ROW_IDX,
CKPT_SOURCE_ID,
Batch,
RecordBatchTensor,
combine_neg_as_candidate_sequence,
remove_nullable,
)
from tzrec.features.feature import BaseFeature
from tzrec.protos import data_pb2
from tzrec.utils.load_class import get_register_class_meta
from tzrec.utils.logging_util import logger
_DATASET_CLASS_MAP = {}
_READER_CLASS_MAP = {}
_WRITER_CLASS_MAP = {}
_dataset_meta_cls = get_register_class_meta(_DATASET_CLASS_MAP)
_reader_meta_cls = get_register_class_meta(_READER_CLASS_MAP)
_writer_meta_cls = get_register_class_meta(_WRITER_CLASS_MAP)
AVAILABLE_PA_TYPES = {
pa.int64(),
pa.float64(),
pa.float32(),
pa.string(),
pa.int32(),
pa.list_(pa.int64()),
pa.list_(pa.float64()),
pa.list_(pa.float32()),
pa.list_(pa.string()),
pa.list_(pa.int32()),
pa.list_(pa.list_(pa.int64())),
pa.list_(pa.list_(pa.float64())),
pa.list_(pa.list_(pa.float32())),
pa.list_(pa.list_(pa.string())),
pa.list_(pa.list_(pa.int32())),
pa.map_(pa.string(), pa.int64()),
pa.map_(pa.string(), pa.float64()),
pa.map_(pa.string(), pa.float32()),
pa.map_(pa.string(), pa.string()),
pa.map_(pa.string(), pa.int32()),
pa.map_(pa.int64(), pa.int64()),
pa.map_(pa.int64(), pa.float64()),
pa.map_(pa.int64(), pa.float32()),
pa.map_(pa.int64(), pa.string()),
pa.map_(pa.int64(), pa.int32()),
pa.map_(pa.int32(), pa.int64()),
pa.map_(pa.int32(), pa.float64()),
pa.map_(pa.int32(), pa.float32()),
pa.map_(pa.int32(), pa.string()),
pa.map_(pa.int32(), pa.int32()),
}
def _expand_tdm_sample(
input_data: Dict[str, pa.Array],
pos_sampled: Dict[str, pa.Array],
neg_sampled: Dict[str, pa.Array],
data_config: data_pb2.DataConfig,
) -> Dict[str, pa.Array]:
"""Expand input data with sampled data for tdm.
Combine the sampled positive and negative samples with the item
features, then expand the user features based on the original user-item
relationships, and supplement the corresponding labels according to the
positive and negative samples. Note that in the sampling results, the
sampled outcomes for each item are contiguous.
for example:
user_fea:[1, 2], item_fea:[0.1, 0.2], labels:[1,1],
pos_sample:[0.11, 0.12, 0.21, 0.22], neg_sample:[-0.11, -0.12, -0.21, -0.22]
concat item_fea:[0.1, 0.2, 0.11, 0.12, 0.21, 0.22, -0.11, -0.12, -0.21, -0.22]
duplicate user_fea and keep origin user-item
relationship: [1, 2, 1, 1, 2, 2, 1, 1, 2, 2]
expand label: [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]
"""
item_fea_names = pos_sampled.keys()
all_fea_names = input_data.keys()
label_fields = set(data_config.label_fields)
user_fea_names = all_fea_names - item_fea_names - label_fields
for item_fea_name in item_fea_names:
input_data[item_fea_name] = pa.concat_arrays(
[
input_data[item_fea_name],
pos_sampled[item_fea_name],
neg_sampled[item_fea_name],
]
)
# In the sampling results, the sampled outcomes for each item are contiguous.
batch_size = len(input_data[list(label_fields)[0]])
num_pos_sampled = len(pos_sampled[list(item_fea_names)[0]])
num_neg_sampled = len(neg_sampled[list(item_fea_names)[0]])
user_pos_index = np.repeat(np.arange(batch_size), num_pos_sampled // batch_size)
user_neg_index = np.repeat(np.arange(batch_size), num_neg_sampled // batch_size)
for user_fea_name in user_fea_names:
user_fea = input_data[user_fea_name]
pos_expand_user_fea = user_fea.take(user_pos_index)
neg_expand_user_fea = user_fea.take(user_neg_index)
input_data[user_fea_name] = pa.concat_arrays(
[
input_data[user_fea_name],
pos_expand_user_fea,
neg_expand_user_fea,
]
)
for label_field in label_fields:
input_data[label_field] = pa.concat_arrays(
[
input_data[label_field].cast(pa.int64()),
pa.array([1] * num_pos_sampled, type=pa.int64()),
pa.array([0] * num_neg_sampled, type=pa.int64()),
]
)
return input_data
class BaseDataset(IterableDataset, metaclass=_dataset_meta_cls):
"""Dataset base class.
Args:
data_config (DataConfig): an instance of DataConfig.
features (list): list of features.
input_path (str): data input path.
reserved_columns (list): reserved columns in predict mode.
mode (Mode): train or eval or predict.
debug_level (int): dataset debug level, when mode=predict and
debug_level > 0, will dump fg encoded data to debug_str
"""
def __init__(
self,
data_config: data_pb2.DataConfig,
features: List[BaseFeature],
input_path: str,
reserved_columns: Optional[List[str]] = None,
mode: Mode = Mode.EVAL,
debug_level: int = 0,
) -> None:
super(BaseDataset, self).__init__()
self._data_config = data_config
self._features = features
self._input_path = input_path
self._reserved_columns = reserved_columns or []
self._mode = mode
self._debug_level = debug_level
self.sampler_type = (
self._data_config.WhichOneof("sampler")
if self._data_config.HasField("sampler")
else None
)
self._data_parser = DataParser(
features=features,
labels=list(data_config.label_fields)
if self._mode != Mode.PREDICT
else None,
sample_weights=list(data_config.sample_weight_fields)
if self._mode != Mode.PREDICT
else None,
mode=self._mode,
fg_threads=data_config.fg_threads,
force_base_data_group=data_config.force_base_data_group,
sampler_type=self.sampler_type,
)
self._input_fields = None
self._selected_input_names = set()
self._selected_input_names |= self._data_parser.feature_input_names
if self._mode == Mode.PREDICT:
self._selected_input_names |= set(self._reserved_columns)
else:
self._selected_input_names |= set(data_config.label_fields)
self._selected_input_names |= set(data_config.sample_weight_fields)
if data_config.HasField("sample_cost_field"):
self._selected_input_names.add(data_config.sample_cost_field)
if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT:
sampler_type = self._data_config.WhichOneof("sampler")
sampler_config = getattr(self._data_config, sampler_type)
if hasattr(sampler_config, "item_id_field") and sampler_config.HasField(
"item_id_field"
):
self._selected_input_names.add(sampler_config.item_id_field)
if hasattr(sampler_config, "user_id_field") and sampler_config.HasField(
"user_id_field"
):
self._selected_input_names.add(sampler_config.user_id_field)
# if set selected_input_names to None,
# all columns will be reserved.
if (
len(self._reserved_columns) > 0
and self._reserved_columns[0] == "ALL_COLUMNS"
):
self._selected_input_names = None
self._fg_mode = data_config.fg_mode
self._fg_encoded_multival_sep = data_config.fg_encoded_multival_sep
if mode != Mode.TRAIN and data_config.HasField("eval_batch_size"):
self._batch_size = data_config.eval_batch_size
else:
self._batch_size = data_config.batch_size
self._sampler = None
self._sampler_inited = False
# Build mapping of field_name → sequence_delim for candidate sequence
# auto-detection during negative sampling.
self._seq_field_delims: Dict[str, str] = {}
for feature in features:
if hasattr(feature, "sequence_delim") and feature.sequence_delim:
for input_name in feature.inputs:
self._seq_field_delims[input_name] = feature.sequence_delim
self._reader = None
def launch_sampler_cluster(
self,
num_client_per_rank: int = 1,
client_id_bias: int = 0,
cluster: Optional[Dict[str, Union[int, str]]] = None,
) -> None:
"""Launch sampler cluster and server."""
if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT:
sampler_type = self._data_config.WhichOneof("sampler")
sampler_config = getattr(self._data_config, sampler_type)
# pyre-ignore [16]
self._sampler = BaseSampler.create_class(sampler_config.__class__.__name__)(
sampler_config,
self.input_fields,
self._batch_size,
is_training=self._mode == Mode.TRAIN,
multival_sep=self._fg_encoded_multival_sep
if self._fg_mode == data_pb2.FgMode.FG_NONE
else chr(29),
)
self._sampler.init_cluster(num_client_per_rank, client_id_bias, cluster)
if cluster is None:
self._sampler.launch_server()
def get_sampler_cluster(self) -> Optional[Dict[str, Union[int, str]]]:
"""Get sampler cluster."""
if self._sampler:
return self._sampler._cluster
def _init_input_fields(self) -> None:
"""Init input fields info."""
self._input_fields = []
for field in self._reader.schema:
field_type = remove_nullable(field.type)
if any(map(lambda x: x == field_type, AVAILABLE_PA_TYPES)):
self._input_fields.append(field)
else:
raise ValueError(
f"column [{field.name}] with dtype {field.type} "
"is not supported now."
)
@property
def input_fields(self) -> List[pa.Field]:
"""Input fields info, overwrote by subclass for auto infer the info."""
if self._input_fields is None:
self._init_input_fields()
return self._input_fields
def get_worker_info(self) -> Tuple[int, int]:
"""Get multiprocessing dataloader worker id and worker number."""
worker_info = get_worker_info()
if worker_info is None:
worker_id = 0
num_workers = 1
else:
worker_id = worker_info.id
num_workers = worker_info.num_workers
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank * num_workers + worker_id, num_workers * world_size
def load_state_dict(self, state: Optional[Dict[str, int]]) -> None:
"""Set checkpoint state for resume.
Args:
state: dict mapping source_key to max consumed row index.
"""
assert self._reader is not None
self._reader.load_state_dict(state)
def __iter__(self) -> Iterator[Batch]:
if self._sampler is not None and not self._sampler_inited:
self._sampler.init()
self._sampler_inited = True
worker_id, num_workers = self.get_worker_info()
for input_data in self._reader.to_batches(worker_id, num_workers):
yield self._build_batch(input_data)
def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch:
"""Process input data and build batch.
Args:
input_data (dict): raw input data.
Returns:
an instance of Batch.
"""
# Extract checkpoint info if present
checkpoint_info = None
if CKPT_SOURCE_ID in input_data and CKPT_ROW_IDX in input_data:
source_ids = input_data.pop(CKPT_SOURCE_ID)
row_idxs = input_data.pop(CKPT_ROW_IDX)
# Use PyArrow group_by + max aggregation
table = pa.table({CKPT_SOURCE_ID: source_ids, CKPT_ROW_IDX: row_idxs})
grouped = table.group_by(CKPT_SOURCE_ID).aggregate([(CKPT_ROW_IDX, "max")])
# Convert to dict: {source_key: max_absolute_position}
checkpoint_info = dict(
zip(
grouped[CKPT_SOURCE_ID].to_pylist(),
grouped[f"{CKPT_ROW_IDX}_max"].to_pylist(),
)
)
use_sample_mask = self._mode == Mode.TRAIN and (
self._data_config.negative_sample_mask_prob > 0
or self._data_config.sample_mask_prob > 0
)
if use_sample_mask:
input_data[C_SAMPLE_MASK] = pa.array(
np.random.random(len(list(input_data.values())[0]))
< self._data_config.sample_mask_prob
)
if self._sampler is not None:
if isinstance(self._sampler, TDMSampler):
pos_sampled, neg_sampled = self._sampler.get(input_data)
input_data = _expand_tdm_sample(
input_data, pos_sampled, neg_sampled, self._data_config
)
else:
sampled = self._sampler.get(input_data)
for k, v in sampled.items():
if k in input_data:
seq_delim = self._seq_field_delims.get(k)
if seq_delim is not None:
input_data[k] = combine_neg_as_candidate_sequence(
input_data[k],
v,
self._sampler._num_sample,
seq_delim,
)
else:
input_data[k] = pa.concat_arrays([input_data[k], v])
else:
input_data[k] = v
if use_sample_mask:
input_data[C_NEG_SAMPLE_MASK] = pa.concat_arrays(
[
input_data[C_SAMPLE_MASK],
pa.array(
np.random.random(len(list(sampled.values())[0]))
< self._data_config.negative_sample_mask_prob
),
]
)
# TODO(hongsheng.jhs): add additional field like hard_negative
output_data = self._data_parser.parse(input_data)
if self._mode == Mode.PREDICT:
batch = self._data_parser.to_batch(output_data, force_no_tile=True)
reserved_data = {}
if (
len(self._reserved_columns) > 0
and self._reserved_columns[0] == "ALL_COLUMNS"
):
reserved_data = input_data
else:
for k in self._reserved_columns:
reserved_data[k] = input_data[k]
if self._debug_level > 0:
reserved_data["__features__"] = self._data_parser.dump_parsed_inputs(
output_data
)
if len(reserved_data) > 0:
batch.reserves = RecordBatchTensor(pa.record_batch(reserved_data))
else:
batch = self._data_parser.to_batch(output_data)
# Set checkpoint info on batch
batch.checkpoint_info = checkpoint_info
return batch
@property
def sampled_batch_size(self) -> int:
"""Batch size with sampler."""
if self._sampler:
return self._batch_size + self._sampler.estimated_sample_num
else:
return self._batch_size
class BaseReader(metaclass=_reader_meta_cls):
"""Reader base class.
Args:
input_path (str): data input path.
batch_size (int): batch size.
selected_cols (list): selection column names.
drop_remainder (bool): drop last batch.
shuffle (bool): shuffle data or not.
shuffle_buffer_size (int): buffer size for shuffle.
sample_cost_field (str): sample cost field name.
batch_cost_size (int): batch cost limit size.
"""
def __init__(
self,
input_path: str,
batch_size: int,
selected_cols: Optional[List[str]] = None,
drop_remainder: bool = False,
shuffle: bool = False,
shuffle_buffer_size: int = 32,
sample_cost_field: Optional[str] = None,
batch_cost_size: Optional[int] = None,
**kwargs: Any,
) -> None:
self._input_path = input_path
self._batch_size = batch_size
self._selected_cols = selected_cols
self._drop_remainder = drop_remainder
self._shuffle = shuffle
self._shuffle_buffer_size = shuffle_buffer_size
self._sample_cost_field = sample_cost_field
self._batch_cost_size = batch_cost_size
self._use_sample_cost = False
self._checkpoint_state: Optional[Dict[str, int]] = None
if self._batch_cost_size is not None and self._batch_cost_size > 0:
assert (
self._sample_cost_field is not None and len(self._sample_cost_field) > 0
), "Should set data_config.sample_cost_field when use batch_cost_size"
self._use_sample_cost = True
def load_state_dict(self, state: Optional[Dict[str, int]]) -> None:
"""Set checkpoint state for resume.
Args:
state: dict mapping source_key to max consumed row index.
"""
self._checkpoint_state = state
@property
def schema(self) -> pa.Schema:
"""Table schema."""
raise NotImplementedError
def to_batches(
self, worker_id: int = 0, num_workers: int = 1
) -> Iterator[Dict[str, pa.Array]]:
"""Get batch iterator."""
raise NotImplementedError
def _slice_buff_data(
self, buff_data: pa.RecordBatch
) -> Tuple[pa.RecordBatch, Optional[pa.RecordBatch]]:
if self._use_sample_cost:
# calculate slice point by cost
sample_cost = buff_data[self._sample_cost_field]
cumsum_sample_cost = pc.cumulative_sum(sample_cost)
slice_size = pc.sum(
pc.less_equal(cumsum_sample_cost, self._batch_cost_size)
).as_py()
else:
slice_size = self._batch_size
if len(buff_data) <= slice_size:
data = buff_data
buff_data = None
else:
data = buff_data.slice(0, slice_size)
buff_data = buff_data.slice(slice_size)
return data, buff_data
def _arrow_reader_iter(
self, reader: Iterator[pa.RecordBatch]
) -> Iterator[Dict[str, pa.Array]]:
shuffle_buffer = []
buff_data = None
while True:
data = None
if buff_data is None or len(buff_data) < self._batch_size:
try:
read_data = next(reader)
if buff_data is None:
buff_data = pa.Table.from_batches([read_data])
else:
buff_data = pa.concat_tables(
[buff_data, pa.Table.from_batches([read_data])]
)
except StopIteration:
if self._drop_remainder or buff_data is None:
data = buff_data = None
else:
data, buff_data = self._slice_buff_data(buff_data)
else:
data, buff_data = self._slice_buff_data(buff_data)
if data is not None:
data_dict = {}
for name, column in zip(data.column_names, data.columns):
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
data_dict[name] = column
if self._shuffle:
shuffle_buffer.append(data_dict)
if len(shuffle_buffer) < self._shuffle_buffer_size:
continue
else:
idx = random.randrange(len(shuffle_buffer))
data_dict = shuffle_buffer.pop(idx)
yield data_dict
if data is None and buff_data is None:
break
if len(shuffle_buffer) > 0:
random.shuffle(shuffle_buffer)
for data_dict in shuffle_buffer:
yield data_dict
def num_files(self) -> Optional[int]:
"""Get number of files in the dataset."""
return None
class BaseWriter(metaclass=_writer_meta_cls):
"""Writer base class.
Args:
output_path (str): data output path.
"""
def __init__(self, output_path: str, **kwargs: Any) -> None:
self._lazy_inited = False
self._output_path = output_path
def write(self, output_dict: OrderedDict[str, pa.Array]) -> None:
"""Write a batch of data."""
raise NotImplementedError
def close(self) -> None:
"""Close and commit data."""
self._lazy_inited = False
def __del__(self) -> None:
if self._lazy_inited:
# pyre-ignore [16]
logger.warning(f"You should close {self.__class__.__name__} explicitly.")
def create_reader(
input_path: str,
batch_size: int,
selected_cols: Optional[List[str]] = None,
reader_type: Optional[str] = None,
**kwargs: Any,
) -> BaseReader:
"""Create data reader.
Args:
input_path (str): data input path.
batch_size (int): batch size.
selected_cols (list): selection column names.
reader_type (str, optional): specify the input path reader type, if we cannot
infer from input_path.
**kwargs: additional params.
Returns:
reader: a data reader.
"""
if input_path.startswith("odps://"):
reader_cls_name = "OdpsReader"
elif input_path.startswith("kafka://"):
reader_cls_name = "KafkaReader"
elif input_path.endswith(".csv"):
reader_cls_name = "CsvReader"
elif input_path.endswith(".parquet"):
reader_cls_name = "ParquetReader"
else:
assert reader_type is not None, "You should set reader_type."
reader_cls_name = reader_type
# pyre-ignore [16]
reader = BaseReader.create_class(reader_cls_name)(
input_path=input_path,
batch_size=batch_size,
selected_cols=selected_cols,
**kwargs,
)
return reader
def create_writer(
output_path: str, writer_type: Optional[str] = None, **kwargs: Any
) -> BaseWriter:
"""Create data writer.
Args:
output_path (str): data output path.
writer_type (str, optional): specify the input path writer type, if we cannot
infer from input_path.
**kwargs: additional params.
Returns:
writer: a data writer.
"""
if output_path.startswith("odps://"):
writer_cls_name = "OdpsWriter"
else:
assert writer_type is not None, "You should set writer_type."
writer_cls_name = writer_type
# pyre-ignore [16]
writer = BaseWriter.create_class(writer_cls_name)(output_path=output_path, **kwargs)
return writer
def create_dataloader(
data_config: data_pb2.DataConfig,
features: List[BaseFeature],
input_path: str,
reserved_columns: Optional[List[str]] = None,
mode: Mode = Mode.TRAIN,
gl_cluster: Optional[Dict[str, Union[int, str]]] = None,
debug_level: int = 0,
) -> DataLoader:
"""Build dataloader.
Args:
data_config (DataConfig): dataloader config.
features (list): list of feature.
input_path (str): input data path.
reserved_columns (list): reserved columns in predict mode.
mode (Mode): train or eval or predict.
gl_cluster (dict, bool): if set, reuse the graphlearn cluster.
debug_level (int): dataset debug level, when mode=predict and
debug_level > 0, will dump fg encoded data to debug_str
Return:
dataloader (dataloader): a DataLoader.
"""
dataset_name = data_pb2.DatasetType.Name(data_config.dataset_type)
# pyre-ignore [16]
dataset_cls = BaseDataset.create_class(dataset_name)
dataset = dataset_cls(
data_config=data_config,
features=features,
input_path=input_path,
reserved_columns=reserved_columns,
mode=mode,
debug_level=debug_level,
)
kwargs = {}
if data_config.num_workers < 1:
num_workers = 1
else:
num_workers = data_config.num_workers
# check number of files is valid or not for file based dataset.
num_files = dataset._reader.num_files()
if num_files is not None:
world_size = int(os.environ.get("WORLD_SIZE", 1))
if num_files >= world_size:
num_files_per_worker = num_files // world_size
if num_files_per_worker < num_workers:
logger.info(
f"data_config.num_workers reset to {num_files_per_worker}"
)
num_workers = num_files_per_worker
else:
raise ValueError(
f"Number of files in the dataset[{input_path}] must greater "
f"than world_size: {world_size}, but got {num_files}"
)
kwargs["num_workers"] = num_workers
kwargs["persistent_workers"] = True
if mode == Mode.TRAIN:
# When in train_and_eval mode, use 2x worker in gl cluster
# for train_dataloader and eval_dataloader
dataset.launch_sampler_cluster(num_client_per_rank=num_workers * 2)
else:
if gl_cluster:
# Reuse the gl cluster for eval_dataloader
dataset.launch_sampler_cluster(
num_client_per_rank=num_workers * 2,
client_id_bias=num_workers,
cluster=gl_cluster,
)
else:
dataset.launch_sampler_cluster(num_client_per_rank=num_workers)
dataloader = DataLoader(
dataset=dataset,
batch_size=None,
pin_memory=data_config.pin_memory if mode != Mode.PREDICT else False,
collate_fn=lambda x: x,
**kwargs,
)
# For PyTorch versions 2.6 and above, we initialize the data iterator before
# beginning the training process to avoid potential CUDA-related issues following
# model saving.
iter(dataloader)
return dataloader