Published on

Pytorch Dataloader

Authors
  • avatar
    Name
    Inhwan Cho
    Twitter

Dataloader 사용하는 방법

  • 모든 Dataset 으로부터 DataLoader 를 생성할 수 있습니다. PyTorch의 DataLoader 는 배치 관리를 담당합니다.
  • DataLoader란 Dataset을 batch기반의 딥러닝모델 학습을 위해서 미니배치 형태로 만들어서 우리가 실제로 학습할 때 이용할 수 있게 형태를 만들어주는 기능을 합니다.
  • DataLoader를 통해 Dataset의 전체 데이터가 batch size로 slice되어 공급됩니다.
  • dataset을 input으로 넣어주면 여러 옵션(데이터 묶기, 섞기, 알아서 병렬처리)을 통해 batch를 만들어줍니다.
  • DataLoader는 iterator 형식으로 데이터에 접근 하도록 하며 batch_size나 shuffle 유무를 설정할 수 있다.

일반적인 사용 방법

from torch.utils.data import Dataloader

dataloader1 = Dataloader(
      dataset,
    batch_size = 1,
    shuffle = True,
)

dataloader2 = DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)
  • batch_size는 일반적으로 2의 배수를 사용합니다.(컴퓨터의 연산때문에 2의 배수로 해야 속도가 빠름)

  • shuffle 은 Epoch 마다 데이터셋을 섞어, 데이터가 학습되는 순서를 바꾸는 기능을 말합니다.

  • num_worker는 동시에 처리하는 프로세서의 수입니다. PC(특히 윈도우)에서는 default=0로 설정해야 오류가 나지 않습니다.

  • collate_fn 함수는 DataLoader 로부터 생성된 샘플 배치로 동작합니다. collate_fn 의 입력은 DataLoader 에 배치 크기(batch size)가 있는 배치(batch) 데이터이며, collate_fn 은 이를 미리 선언된 데이터 처리 파이프라인에 따라 처리합니다.(아래의 예시)

# batches가 1이 아닌 경우 이런식으로 세팅하여 DataLoader의 collate_fn에 넣어준다.
# 사용자 정의 collate_fn() 함수는 가변 길이 배치를 채우는 데 자주 사용됩니다.


def collate_batch(batch):
    data = [item[0] for item in batch]
    mask = [item[1] for item in batch]
    label = [item[2] for item in batch]
    return torch.LongTensor(data), torch.LongTensor(mask), torch.LongTensor(label)
train_dataloader = DataLoader(train_set, batch_size=32, num_workers=0, shuffle=True, collate_fn=collate_batch)