• 问答
  • 技术
  • 实践
  • 资源
PyTorch DataLoader 初探
技术讨论
文章来源:nero@知乎



1.相关问题背景

我最先接触的深度学习框架是Torch7,感觉是一个文档写的特别好的框架,各种神经网络模块的使用方式一目了然,由于是动态图,所以随调随用,debug非常方便。同时也有facebook提供的官方resnet repo,一般的分类任务只用简单改动几个地方就能很快跑起来。但由于lua语言自身不像python有那么多系统的库,一般我训练模型都是预先做好相关的训练/验证文件清单,对于一些需要on-the-fly之类的预处理操作显得稍有麻烦。当PyTorch发布的时候,感觉一切不愉快都解决了!

但是,在使用PyTorch进行模型训练的时候,当任务是IO稠密或者网络本身较浅,使用较大的BatchSize的时候往往会遇到某个迭代数据装载耗时特别久的现象,如下图所示,在iter=12时发生了这个情况。

数据装载”阻塞“

继续观察一段时间后,发现”阻塞“现象是规律发生的,即每12次迭代就会发生一次较长时间的数据装载,而随后的11次迭代基本都是无耗时装载(0.001s)。而我的程序设定的num_workers正好是12,一开始我以为是workers数量不够大,随后我将其调整到14个,但仍然会出现这个问题,并且阻塞的周期变成了每14个。之后尝试了不同的num_workers,都会规律的在num_workers倍数次迭代(包括第0次)阻塞住,导致了GPU大部分时间都处在空闲状态。

其实去年我在PyTorch论坛就提了一下这个问题,但没有得到什么回应。当时我考虑可能是数据读写太慢,用iotop看了一下的确IO的量比较大。因此做的优化就是尝试不同的数据读取方式(PIL, OpenCV,制作HDF5格式文件等),一定程度上对上面这个问题有所缓解,但其实还是存在周期阻塞现象。之后由于没有做相关的IO密集性任务就搁置了一段时间,最近又遇到了这个问题,看各种论坛上的解决方案,众说纷纭,有的认为是速度不够快,直接换SSD;有的说是PyTorch DataLoader实现的时候只有当线程取空才会去读下一批数据;有的说是PyTorch不像TensorFlow使用C++的op进行数据装载,受到到GIL的限制;也有的直接改用Nvidia Dali,通过Pipeline来提升效率

但是我的任务需要需要保持随机采样,有的操作需要one-the-fly处理,没办法那么灵活的直接改用Dali。所以,我就对PyTorch自身的DataLoader实现原理做了一下分析,想看看具体造成这个问题的原因是什么。


2. 代码初探

我主要基于PyTorch v0.4.1源码进行的分析,它和master(v1.3)中的其实是一样的主逻辑,但是1.3里面包括了一些多进程安全退出的处理,代码也进行了一定程度重构解耦,代码相对分散一些。

torch/utils/data下面一共含有4个主文件

|---- dataloader.py
|---- dataset.py
|---- distributed.py
|---- sample.py
我们先聚焦在dataloader.py上

DataLoader类在L450-491定义了接口API,也就是我们平时指定的dataset,num_workers,pin_memory等参数,并对用户输入合法性进行校验,例如 [公式]

DataLoader类实际上在实现__iter__ 迭代器方法是调用的 _DataLoader类(我猜是为了防止用户修改类内成员做的隐藏?)。

为了快速理解,我们先看不开启多进程(只含主进程)时的实现逻辑。

2.1 num_workers=0(主进程阻塞式读取数据)

num_workers=0时,并不会出现周期性的长时间阻塞,但是由于所有操作都在一个主进程中进行,训练速度受I/O效率影响也比较明显

def __next__(self):
    if self.num_workers == 0:  # same-process loading
        indices = next(self.sample_iter)  # may raise StopIteration
        batch = self.collate_fn([self.dataset[i] for i in indices])
        if self.pin_memory:
            batch = pin_memory_batch(batch)
        return batch

可以看到其每次取数据时从sample_iter拿一个batch的dataset的索引,随后通过for循环[self.dataset[i] for i in indices]的方式去读数据,self.dataset[i]处的操作,其实就是对应torch.utils.data.Dataset类中的__getitem__函数,对于一般图像分类任务我们通常我们会在其中进行一些resize,crop,flip等预处理的操作,并返回image和相应label。

由于这里是通过for循环进行处理,因此基本上数据装载时间和batch size呈线性正比关系,而这块是造成训练阻塞的一个原因,此处也关联导致了多进程每num_workers次迭代阻塞的问题。

其中函数collate_fn是用来将list中的batch size个tensor(也可能是numpy或者纯python数值),默认的实现逻辑就是将其在0维叠起来,这块不会直接对数据读取性能造成时延,不再赘述,有兴趣的同学可以去直接看相关源码

torch.stack(batch, 0, out=out)

pin_memory是用来加速数据从cpu->gpu的函数,详细定义可看此处(然而提示403 Forbidden,用google上存的cache看了一下)

原文:When allocating CPU memory that will be used to transfer data to the GPU, there are two types of memory to choose from: pinned and non-pinned memory. Pinned memory is memory allocated using thecudaMallocHost function, which prevents the memory from being swapped out and provides improved transfer speeds. Non-pinned memory is memory allocated using themalloc function. As described in Memory Management Overhead and Memory Transfer Overhead, pinned memory is much more expensive to allocate and deallocate but provides higher transfer throughput for large memory transfers.

我理解其大致意思是,当所分配用于传输给GPU的CPU内存中时,有pinned和non-pinned两种方式,pinned方式分配和释放所用开销较大,但是对于较大容量的数据传输,可以提供更高的吞吐。

在这里的实现方式其实是直接调用torch.Tensor下面的pin_memory方法。

最后直接将相应的数据传递回训练/验证函数中,进行模型的前向传递。

2.2 num_workers>0

上面对单进程数据装载实现原理进行了简单的分析,下面我们看看在多进程(多workers)情况下,PyTorch是怎么管理的。

我们先回到__init__函数的L251-297里,看一下其初始化的相关变量。

if self.num_workers > 0:
    self.worker_init_fn = loader.worker_init_fn
    # 定义了workers相同数量个Queue并放置在index_queues这个list中,
    # 这些Queue与worker一一对应,用来给worker传递“工作内容”
    self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
    # worker_queue_idx用于下一个工作的workre序号,主进程轮询使用不同workers,以模拟负载均衡
    self.worker_queue_idx = 0
    # 各个workre将自己所取得的数据传递给wokrker_result_queue,供主进程fetch
    self.worker_result_queue = multiprocessing.SimpleQueue()

    # 记录当前时刻分配了多少个任务(可能有处于等待状态的任务)
    self.batches_outstanding = 0
    self.worker_pids_set = False
    self.shutdown = False
    # 发送出去数据的编号
    self.send_idx = 0
    # 接受到数据的编号
    self.rcvd_idx = 0

    # 缓存区
    self.reorder_dict = {}
    self.workers = [
        multiprocessing.Process(
            target=_worker_loop,
            args=(self.dataset, self.index_queues[i],
                  self.worker_result_queue, self.collate_fn, base_seed + i,
                  self.worker_init_fn, i))
        for i in range(self.num_workers)]
    # 初始化相应的进程,目标函数为_worker_loop
    # 参数:dataset(用于数据读取),index_queues[i]为worker对应的index_queue
    # 以及用于输出的queue

    # 此处主要用于数据读取后的pin_memory操作,不影响多进程主逻辑,暂不展开
    if self.pin_memory or self.timeout > 0:
        ...
    else:
        self.data_queue = self.worker_result_queue
    for w in self.workers:
        w.daemon = True  # ensure that the worker exits on process exit
        # 将父进程设置为守护进程,保证父进程结束后,worker进程也结束,必须设置在start之前
        w.start()

    # 下面是一些系统信号处理逻辑
    _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
    _set_SIGCHLD_handler()
    self.worker_pids_set = True

    # 初始化后生成2*num_workers数量个prefetch的数据,使dataloader提前工作,提升整体效率。
    # prime the prefetch loop
    for _ in range(2 * self.num_workers):
        self._put_indices()

再来看一下_put_indices的相关定义。

def _put_indices(self):
    assert self.batches_outstanding < 2 * self.num_workers
    # 默认设定是只允许分配2*num_workers个任务,保证内存等资源不被耗尽
    indices = next(self.sample_iter, None)
    # 从sample_iter中拿到dataset中下一轮次的索引,用于fetch数据
    if indices is None:
        return
    self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
    # 轮询选择worker,找到其对应的队列,向其中发送工作内容(数据编号,数据索引)
    self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
    # worker_queue_idx自增
    self.batches_outstanding += 1
    # 任务分配数+1
    self.send_idx += 1
    # 已发送任务总数+1(下批数据编号)

看到这里我其实是有疑惑的,既然生成了2*num_workers数量个工作任务,为什么整体流程还会出现阻塞的情况呢?

我们接着看各worker进程的工作流程_worker_loop

def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
    global _use_shared_memory
    _use_shared_memory = True

    # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
    # module's handlers are executed after Python returns from C low-level
    # handlers, likely when the same fatal signal happened again already.
    # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
    _set_worker_signal_handlers()

    torch.set_num_threads(1)
    random.seed(seed)
    torch.manual_seed(seed)

    if init_fn is not None:
        init_fn(worker_id)

    # 父进程状态监测
    watchdog = ManagerWatchdog()

    # 死循环查询是否有任务传进来
    while True:
        try:
            # 从index_queue获取相应数据
            r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
        except queue.Empty:
            if watchdog.is_alive():
                continue
            else:
                break
        if r is None:
            break
        idx, batch_indices = r
        try:
            # 获得以后for循环进行读取数据读取,此处和单进程的工作原理是一样的
            # 因此时间花费和batchsize数量呈线性关系
            samples = collate_fn([dataset[i] for i in batch_indices])
            # 经过collate_fn后变成torch.Tensor
        except Exception:
            # 异常处理
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            # 通过data_queue传回处理好的batch数据
            data_queue.put((idx, samples))
            # 显示删除中间变量,降低内存消耗
            del samples

可以看到,worker的工作函数比较简单,就是不断从输入队列里面获取dataset索引,随后通过相应函数转换成Tensor数据再传回到输出队列中,所以说每个进程是独立的获取到一整个batch的数据。但是由于batch个索引是通过for循环处理的,总体耗时几乎是单张图像处理的batchsize倍。

但上述逻辑也没有反应出规律阻塞的原因,我们重新看回__next__函数里面多workers处理逻辑,主进程是怎么利用data_queue中各worker返回的批数据的。

def __next__(self):
    if self.num_workers == 0:  # same-process loading
        ...
    # check if the next sample has already been generated
    # 先查看数据是否在缓存dict中
    if self.rcvd_idx in self.reorder_dict:
        batch = self.reorder_dict.pop(self.rcvd_idx)
        return self._process_next_batch(batch)
    # 异常处理
    if self.batches_outstanding == 0:
        self._shutdown_workers()
        raise StopIteration
    while True:
        assert (not self.shutdown and self.batches_outstanding > 0)
        # 阻塞式的从data_queue里面获取处理好的批数据
        idx, batch = self._get_batch() 
        # 任务数减一
        self.batches_outstanding -= 1
        # 这一步就是造成的周期阻塞现象的原因
        # 因为该DataLoader设计要保证模型复现性(个人猜测),因此数据读取的顺序也是需要保证可复现的
        # 因此每次获取data以后,要校验和rcvd_idx是否一致
        # 若不一致,则先把获取到的数据放到reorder_dict这个缓存dict中,继续死循环
        # 直到获取到相应的idx编号于rcvd_idx可以对应上,并将数据返回
        if idx != self.rcvd_idx:
            # store out-of-order samples
            self.reorder_dict[idx] = batch
            continue
        return self._process_next_batch(batch)

由于上述实验需要确定从data_queue里面读到的数据idx和rcvd_idx一致才将数据返回。因此可能会存在如下这种情况:

假设num_workers=8,现在发送了8个数据给相应的worker,此时send_idx=8,rcvd_idx=0。过了一段时间以后,{1,2,3,5,6,7}进程数据准备完毕,此时主进程从data_queue读取到相关的数据,但由于和rcvd_idx不匹配,只能将其放在缓存里。直到send_idx=0数据准备齐以后,才能将数据返回出去,随后从缓存中弹出2,3的数据,之后又阻塞等待idx=4的数据。即输出的数据必须保持顺序性!因此在worker变多,出现这种逆序现象可能性会更大,这种现象也会出现在非num_workrers次迭代,只要相应的rcvd_idx没有得到相关数据,则主进程就会一直等待。

不过上述设计加入了缓存机制,之后从中拿数据也几乎无时延,从训练日志上的0.001s数据也反映了这一点。

再附上一段_process_next_batch源码

def _process_next_batch(self, batch):
    # 序号对上以后,rcvd_idx自加1
    self.rcvd_idx += 1
    # 添加一个fetchdata任务给worker
    self._put_indices()
    if isinstance(batch, ExceptionWrapper):
        raise batch.exc_type(batch.exc_msg)
    # 返回数据以供前向传递
    return batch

上述代码没有什么特别地方,但值得注意的是,只有rcvd_idx等于后,才会添加一个任务给worker。但当io效率不稳定时,可能已经有很多worker把自己的任务完成了,但由于序号不是rcvd_idx,只得将数据保存至缓存区。此处无法保证worker进程长期处于活跃的状态(可能PyTorch原生设计认为各个worker的数据读取时间差距会很大?)

最后来考虑一下,为什么会周期性的阻塞住?

其实就是IO耗时和模型前/后传耗时之间的GAP太大,下面简单画一个图说明一下。

为了方便理解,每次我们只给worker传入相应数量的任务(原实现中是2倍)
file

可以看到当io效率较低的时候,数据读取时间并不能被模型传递时间给遮掩掉,所以导致了阻塞。之前有同学回答的说dataloader只有所有worker数据取完才会进行下一批次的数据读取,是不正确的理解。


所以换SSD通过减少时间读取可以一定程度上解决这个问题。在内存容量允许的情况下,通过增加worker数来增加非阻塞的迭代次数(缓存数据量变多),从而给更多dataloader更多缓冲时间来减弱阻塞现象的发生。但像我自己的实验中,前后向耗时小于1s,但数据读取通常是几十秒(batchsize=128)这个级别,由于机器只有28核,所以也无法完全遮盖掉数据读取时延。

第一次写知乎文章,就先写这么多吧,之后有空再补上一些实验的数据和整体流程图。

相关推荐:
pytorch 实用工具总结

PyTorch trick 集锦

pytorch 排坑指南 | 常见的坑汇总 + 解决方案

【框架】PyTorch 图像检索框架

一分钟学 Trick: PyTorch 动态更新 DataLoader

  • 3
  • 0
  • 2820
收藏
暂无评论
xiaoxiaohui
大咖

北京大学

  • 19

    关注
  • 33

    获赞
  • 3

    精选文章
近期动态
  • 我是CV界的搬运工
文章专栏
  • 专注图像处理的CV小白菜