除了已经提到的之外,cycle()
and zip()
可能会造成内存泄漏问题- 特别是在使用图像数据集时!为了解决这个问题,不要像这样迭代:
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10
for epoch in range(num_epochs):
for i, (data1, data2) in enumerate(zip(cycle(dataloaders1), dataloaders2)):
do_cool_things()
你可以使用:
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10
for epoch in range(num_epochs):
dataloader_iterator = iter(dataloaders1)
for i, data1 in enumerate(dataloaders2)):
try:
data2 = next(dataloader_iterator)
except StopIteration:
dataloader_iterator = iter(dataloaders1)
data2 = next(dataloader_iterator)
do_cool_things()
请记住,如果您也使用标签,则应在此示例中替换data1
with (inputs1,targets1)
and data2
with inputs2,targets2
正如@Sajad Norouzi 所说。
对此表示敬意:https://github.com/pytorch/pytorch/issues/1917#issuecomment-433698337 https://github.com/pytorch/pytorch/issues/1917#issuecomment-433698337