这似乎无需定期复制数据即可工作:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
BATCH_SIZE = 2
class Infinite(Dataset):
def __len__(self):
return BATCH_SIZE
def __getitem__(self, idx):
return torch.randint(0, 10, (3,))
data_loader = DataLoader(Infinite(), batch_size=BATCH_SIZE, num_workers=16)
batch_count = 0
while True:
batch_count += 1
print(f'Batch {batch_count}:')
data = next(iter(data_loader))
print(data)
# forward + backward on "data"
if batch_count == 5:
break
Result:
Batch 1:
tensor([[4, 7, 7],
[0, 8, 0]])
Batch 2:
tensor([[6, 8, 6],
[2, 6, 7]])
Batch 3:
tensor([[6, 6, 2],
[8, 7, 0]])
Batch 4:
tensor([[9, 4, 8],
[2, 4, 1]])
Batch 5:
tensor([[9, 6, 1],
[2, 7, 5]])
所以我认为问题出在你的功能上sample_func_to_be_parallelized()
.
Edit: 如果代替torch.randint(0, 10, (3,))
I use np.random.randint(10, size=3)
in __getitem__
(作为一个例子sample_func_to_be_parallelized()
),那么数据确实在每个批次中都是重复的。看到这个issue https://github.com/pytorch/pytorch/issues/5059.
所以如果你在你的某个地方使用 numpy 的 RGNsample_func_to_be_parallelized()
,那么解决方法是使用
worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id)
并通过以下方式重置种子np.random.seed()
每次调用之前data = next(iter(data_loader))
.