• 问答
  • 技术
  • 实践
  • 资源
一分钟学 Trick: PyTorch 动态更新 DataLoader
技术讨论

文章来源:董鑫 哈佛大学·人工智能博士@知乎
file

PyTorch 动态更新 DataLoader

我们知道 PyTorch 里面的 DataLoader 是多线程的,用起来非常方便。但是多线程带来问题就是,会有额外的内存消耗。

最近遇到了一个需求,需要不断的更新 DataLoader 。当然,最简单的方法还是每次需要更新 DataLoader 的时候,重新新建一个 DataLoader,但是这种方法的代价是很大的,尤其是当我们需要频繁的更新 DataLoader 的时候。

举个例子,我们只想取 Dataset 中的一部分,所以可以使用 SubsetRandomSampler

from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler, BatchSampler
import torch

candidate = [1]
dataset = TensorDataset(torch.tensor(list(range(10))))
dataloader = DataLoader(dataset, batch_size=1, sampler=SubsetRandomSampler(candidate))

for idx, data in enumerate(dataloader):
    print(data)

输出的结果肯定是只取第二个数据,也就是

[tensor([1])]

但是假如我们在不重新新建一个 DataLoader 的情况下,更新这个 SubsetRandomSampler ,怎么办?

一个简单的想法是,我们能否直接更新 dataloader.sampler 以及 dataloader.batch_sampler

candidate = [2]
dataloader.sampler = SubsetRandomSampler(candidate)
dataloader.batch_sampler = BatchSampler(SubsetRandomSampler(candidate), batch_size=1, drop_last=True)

遗憾的是,PyTorch 并不支持这个操作。
file


本文给出一个非常简单的解决方法。

In-place 地改变 candidate 这个变量就行了。

from torch.utils.data import DataLoader, TensorDataset
import torch

candidate = [1,2]
dataset = TensorDataset(torch.tensor(list(range(10))))
dataloader = DataLoader(dataset, batch_size=1, sampler=SubsetRandomSampler(candidate))

for idx, data in enumerate(dataloader):
    print(data)

[candidate.pop() for i in range(len(candidate))]
candidate.extend([3,4])

for idx, data in enumerate(dataloader):
    print(data)

我们来看下效果:

[tensor([1])]
[tensor([2])]
[tensor([4])]
[tensor([3])]

可以看到这个方法是成功的。

中间两行:

[candidate.pop() for i in range(len(candidate))]
candidate.extend([3,4])

是先 in-place 地把这个 list 清空,然后再 in-place 地更新这个 list。

注意,这里的关键是 in-place!

如果你直接对 candidate 进行赋值,是无法达到修改效果的。

from torch.utils.data import DataLoader, TensorDataset
import torch

candidate = [1,2]
dataset = TensorDataset(torch.tensor(list(range(10))))
dataloader = DataLoader(dataset, batch_size=1, sampler=SubsetRandomSampler(candidate))

for idx, data in enumerate(dataloader):
    print(data)

candidate = [3,4]

for idx, data in enumerate(dataloader):
    print(data)

无法达到效果:

[tensor([1])]
[tensor([2])]
[tensor([2])]
[tensor([1])]
  • 0
  • 0
  • 2814
收藏
暂无评论