• 问答
  • 技术
  • 实践
  • 资源
在 Pytorch 中构建流数据集
技术讨论
来源:原创 P**nHub兄弟网站 DeepHub IMBA


===

在处理监督机器学习任务时,最重要的东西是数据——而且是大量的数据。当面对少量数据时,特别是需要深度神经网络的任务时,该怎么办?如何创建一个快速高效的数据管道来生成更多的数据,从而在不花费数百美元在昂贵的云GPU单元上的情况下进行深度神经网络的训练?

这是我们在MAFAT雷达分类竞赛中遇到的一些问题。我的队友hezi hershkovitz为生成更多训练数据而进行的增强,以及我们首次尝试使用数据加载器在飞行中生成这些数据。


要解决的问题

我们在比赛中使用数据管道也遇到了一些问题,主要涉及速度和效率:

它没有利用Numpy和Pandas在Python中提供的快速矢量化操作的优势

每个批次所需的信息都首先编写并存储为字典,然后使用Python for循环在getitem方法中进行访问,从而导致迭代和处理速度缓慢。

从音轨生成“移位的”片段会导致每次检索新片段时都重新构建相同的音轨,这也会减缓管道的速度。

管道无法处理2D或3D输入,因为我们同时使用了scalograms和spectrograms但是无法处理。

如果我们简单地按照批处理的方式进行所有的移位和翻转,那么批处理中就会充斥着与其他示例过于相似的示例,从而使模型不能很好地泛化。

这些低效率的核心原因是,管道是以分段作为基本单元运行,而不是在音轨上运行。


数据格式概述

在制作我们的流数据之前,先再次介绍一下数据集,MAFAT数据由多普勒雷达信号的固定长度段组成,表示为128x32 I / Q矩阵;但是,在数据集中,有许多段属于同一磁道,即,雷达信号持续时间较长,一条磁道中有1到43个段。
file

上面的图像来自hezi hershkovitz 的文章,并显示了一个完整的跟踪训练数据集时,结合所有的片段。红色的矩形是包含在这条轨迹中的单独的部分。白点是“多普勒脉冲”,代表被跟踪物体的质心。

借助“多普勒脉冲”白点,我们可以很容易地看到,航迹是由相邻的段组成的,即段id 1942之后是1943,然后是1944,等等。

片段相邻的情况下允许我们使用移位来创建“新的”样本。
file

但是,由于每个音轨由不同数量的片段组成,因此从任何给定音轨生成的增补数目都会不同,这使我们无法使用常规的Pytorch Dataset 类。这里就需要依靠Pytorch中的IterableDataset 类从每个音轨生成数据流。


数据流管道设计

这三个对象的高级目标是创建一个_Segment对象流,它能够足够灵活地处理音轨和段,并且在代码中提供一致的语义:
class _Segment(Dict, ABC):
segment_id: Union[int, str]
output_array: np.ndarray
doppler_burst: np.ndarray
target_type: np.ndarray
segment_count: int

为此,我们创建了:

一个配置类,它将为一个特定的实验保存所有必要的超参数和环境变量——这实际上只是一个具有预定义键的简单字典。

一个DataDict类,它处理原始片段的加载,验证每一条轨迹,创建子轨迹以防止数据泄漏,并将数据转换为正确的格式,例如2D或3D,并为扩展做好准备

StreamingDataset类,是Pytorch IterableDataset的子类,处理模型的扩充和流段。


 config = Config(file_path=PATH_DATA,
  num_tracks=3,
  valratio=6,
  get_shifts=True,
  output_data_type='spectrogram',
  get_horizontal_flip=True,
  get_vertical_flip=True,
  mother_wavelet='cgau1',
  wavelet_scale=3,
  batch_size=50,
  tracks_in_memory=25,
  include_doppler=True,
  shift_segment=2)dataset = DataDict(config=config)
 
 train_dataset = StreamingDataset(dataset.train_data, config, shuffle=True)
 
 train_loader = DataLoader(train_dataset,batch_size=config['batch_size'])


DataDict实现

在DataDict中将片段处理为音轨,然后再处理为片段,为加速代码提供了很好的机会,特别是在数据验证、重新分割和轨创建都可以向量化的情况下。

我们使用了Numpy和Pandas中的一堆技巧和简洁的特性,大量使用了布尔矩阵来进行验证,并将scalogram/spectrogram 图转换应用到音轨中连接的片段上。代码太长,但你可以去最后的源代码地址中查看一下DataDict create_track_objects方法。

生成细分流

一旦将数据集转换为轨迹,下一个问题就是以更快的方式进行拆分和移动。在这里,Numpy提供了执行快速的,基于矩阵的操作和从一条轨迹快速生成一组新的片段所需的所有工具。

 def split_Nd_array(array: np.ndarray, nsplits: int) -> List[np.ndarray]:
  if array.ndim == 1:
  indices = range(0, len(array) - 31, nsplits)
  segments = [np.take(array, np.arange(i, i + 32), axis=0).copy() for i in indices]
  else:
  indices = range(0, array.shape[1] - 31, nsplits)
  segments = [np.take(array, np.arange(i, i + 32), axis=1).copy() for i in indices]
  return segments
 
 def create_new_segments_from_splits(segment: _Segment, nsplits: int) -> List[_Segment]:
  new_segments = []
  if segment['output_array'].shape[1] > 32:
  output_array = split_Nd_array(array=segment['output_array'], nsplits=nsplits)
  bursts = split_Nd_array(array=segment['doppler_burst'], nsplits=nsplits)
  new_segments.extend([_Segment(segment_id=f'{segment["segment_id"]}_{j}',
  output_array=array,
  doppler_burst=bursts[j],
  target_type=segment['target_type'],
  segment_count=1)
  for j, array in enumerate(output_array)])
 
  else:
  new_segments.append(segment)
  return new_segments


Pytorch IterableDataset

注:torch.utils.data.IterableDataset 是 PyTorch 1.2中新的数据集类

一旦音轨再次被分割成段,我们需要编写一个函数,每次增加一个音轨,并将新生成的段发送到流中,从流中从多个音轨生成成批的段。最后一点对于确保每个批的数据分布合理是至关重要的。

生成流数据集正是IterableDataset类的工作。它与Pytorch中的经典(Map)Dataset类的区别在于,对于IterableDataset,DataLoader调用next(iterable_Dataset),直到它构建了一个完整的批处理,而不是实现一个接收映射到数据集中某个项的索引的方法。


创建批次

在这个例子的基础上,我们创建了一个实现,它的核心进程是“process_tracks_shuffle”,以确保DataLoader提供的每个批处理都包含来自多个音轨的段的良好混合。我们通过设置tracks_in_memory超参数来实现这一点,该参数允许我们调整在生成新的流之前将处理多少条音轨并将其保存到工作内存中。

 def segments_generator(self, segment_list: _Segment) -> None:
  """
  Generates original and augmented segments from a track.
  """
  if self.config.get('get_shifts'):
  segment_list = create_new_segments_from_splits(segment_list, nsplits=self.config['shift_segment'])
  else:
  segment_list = create_new_segments_from_splits(segment_list, nsplits=32)
 
  if self.config.get('get_vertical_flip'):
  flips = create_flipped_segments(segment_list, flip_type='vertical')
  segment_list.extend(flips)
  if self.config.get('get_horizontal_flip'):
  flips = create_flipped_segments(segment_list, flip_type='horizontal')
  segment_list.extend(flips)
 
  for segment in segment_list:
  if self.config['output_data_type'] == 'scalogram':
  segment.assert_valid_scalogram()
  else:
  segment.assert_valid_spectrogram()
 
  self.segment_blocks.extend(segment_list)
  random.shuffle(self.segment_blocks)
 
  def process_tracks_shuffle(self):
  for i, track in enumerate(self.data):
  self.segments_generator(track)
  if i % self.config.get('tracks_in_memory', 100) == self.config.get('tracks_in_memory', 100):
  segment_blocks = self.segment_blocks
  self.segment_blocks = []
  random.shuffle(segment_blocks)
  yield segment_blocks
  segment_blocks = self.segment_blocks
  self.segment_blocks = []
  random.shuffle(segment_blocks)
  yield segment_blocks
 
  def shuffle_stream(self):
  return chain(self.process_tracks_shuffle())
 
 #     def linear_stream(self):
 #         return chain(self.segments_generator(track) for track in self.data)
 
  def __iter__(self):
  for segments in chain(self.shuffle_stream()):
  yield from segments


并行化

在进一步加速数据处理方面,我们没有利用通过在多个GPU并行化的处理来生成多个流。不过需要记住的一件事是,IterableDataset的并行化并不像标准Dataset类那样简单,因为仅仅用IterableDataset添加workers会导致每个worker获得数据的底层完整副本。


结论

在Pytorch中学习使用流数据是一次很好的学习经历,也是一次很好的编程挑战。这里通过改变我们对pytorch传统的dataset的组织的概念的理解,开启一种更有效地处理数据的方式。

众所周知,我们80%的时间都花在了数据清理和管道建立上。然而,我们不应将数据处理视为必须处理而又经常被忽略的工作,而去深入研究20%建模的“乐趣”。我们而应将管道和处理视为一个同样具有乐趣和关键性的工作。因为这是必要的,因为管道速度越快,运行的实验就越多,数据处理的越好,得到的结果就会越好。

作者:Adam Cohn
代码地址:https://github.com/ShaulSolomon/sota-mafat-radar/

本文地址:https://medium.com/definitely-not-sota-but-we-do-our-best/building-a-streaming-dataset-in-pytorch-8112760b028


相关推荐:

PyTorch 版 CenterNet 训练自己的数据集
【轻松学 Pytorch】自定义数据集制作与使用

  • 0
  • 0
  • 1512
收藏
暂无评论
Find me
大咖

一个大的公司

  • 18,119

    关注
  • 267

    获赞
  • 54

    精选文章
近期动态
  • 哈工大深圳研究生院CV汪,请原谅我这一生放纵不羁爱CV~
文章专栏
  • Awsome-Github 资源列表