본문 바로가기

컴린이 일기장/Today I Learned

[TIL] Pytorch Dataloader - (batch) sampler, collate_fn

반응형

[주절주절]

오랜만에 쓰는 글... 
개강도 하고 부스트캠프 업무도 피크였다보니 그동안 글을 많이 못썼다.

오늘 부스트캠프 슬랙에 한 캠퍼님이 collate_fn의 역할이 무엇인지, 꼭 필요한지 모르겠다는 식의 질문을 남겨주셨는데, 나도 우리 베이스라인 코드를 작성하면서 비슷한 생각을 했었다. 그래서 오늘 조금 여유로운 김에 다른 마스터, 멘토님들이 달아주신 좋은 코멘트, 레퍼런스 참고해서 정리해보고자 한다.

 

[Today I Learned]

# Overview

https://hulk89.github.io/pytorch/2019/09/30/pytorch_dataset/

 

# sampler

- Dataset은 idx로 데이터를 가져오도록 설계 되었다. 이 때 Sampler는 이 idx 값을 컨트롤하는 방법이다.
- 따라서 sampler를 사용할 때는 shuffle 파라미터는 False가 되어야한다.
- __len__과 __iter__ 를 구현해 커스텀할 수 있고 미리 선언된 아래와 같은 Sampler들도 있다.

  • RandomSampler
  • SequentialSampler : 항상 동일한 순서로 샘플링
  • WeightRandomSampler

 

# batch_sampler

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

 

# collate_fn

- batch sampler로 묶인 이후에는, collate_fn을 호출해 batch로 묶는다.

collate_fn([dataset[i] for i in indices])

- dataset이 variable length이면 collate_fn을 꼭 사용해주어야 한다.


- 다음 사이트에 예시가 잘 나와있다. (https://hulk89.github.io/pytorch/2019/09/30/pytorch_dataset/)
요약해보자면, Dataset의 __getitem__이 매번 다른 길이(shape)의 텐서를 리턴하는 경우, batch_size를 2 이상으로 주기 위해서는 collate_fn 함수를 짜서 넣어주어야 한다. 위의 예시에서는 아래와 같은 collate_fn 함수를 사용했다.

def make_batch(samples):
    inputs = [sample['input'] for sample in samples]
    labels = [sample['label'] for sample in samples]
    padded_inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True)
    return {'input': padded_inputs.contiguous(),
            'label': torch.stack(labels).contiguous()}

 

+ 마스터, 멘토님의 첨언

- shuffle, batch 외의 다양한 feeding 방법을 사용해야할 때가 있음. 다이나믹하게 학습 데이터가 변경될 때도.
- 당연히 dataset 안에 합칠 수 있지만 적절히 분리하여 만들어 둔 것!

 

[질문 노트]

-

 

 

반응형