Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions DatasetLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,33 @@ def __iter__(self):
flattened_list.append([data[i] for i in indices])

## Mix data in random order
mixid = torch.randperm(len(flattened_label), generator=g).tolist()
mixid = torch.randperm(len(flattened_label), generator=g).tolist() # numpy.arange(len(flattened_label)).tolist()
mixlabel = []
mixmap = []

## Prevent two pairs of the same speaker in the same batch
for ii in mixid:
startbatch = round_down(len(mixlabel), self.batch_size)
if flattened_label[ii] not in mixlabel[startbatch:]:
mixlabel.append(flattened_label[ii])
mixmap.append(ii)
resmixid = []
mixlabel_ins = 1 # for start while

# ## Prevent two pairs of the same speaker in the same batch
# for ii in mixid:
# startbatch = round_down(len(mixlabel), self.batch_size)
# if flattened_label[ii] not in mixlabel[startbatch:]:
# mixlabel.append(flattened_label[ii])
# mixmap.append(ii)
# mixlabel_ins += 1

## Prevent two pairs of the same speaker in the same batch (Reduce data waste with "resmixid")
while len(mixid)>0 and mixlabel_ins>0:
mixlabel_ins = 0
for ii in mixid:
startbatch = round_down(len(mixlabel), self.batch_size)
if flattened_label[ii] not in mixlabel[startbatch:]:
mixlabel.append(flattened_label[ii])
mixmap.append(ii)
mixlabel_ins += 1
else:
resmixid.append(ii)
mixid = resmixid
resmixid = []

mixed_list = [flattened_list[i] for i in mixmap]

Expand Down