AI/Computer Vision

[StarGAN] 코드 파악하기 - data_loader.py

체봄 2021. 1. 10. 16:47

 

먼저 CelebA 데이터셋을 사용하는 경우에 필요한 CelebA 클래스가 정의되어 있다.

하지만 나는 커스텀 데이터셋을 사용할 것이고 이 경우 RaFD 데이터셋을 사용하는 방식과 동일하기 때문에 CelebA 클래스는 건너뛰고, get_loader() 함수만 알아본다.

 

 

get_loader() 함수에서는 data loader를 만들어 반환한다.

torchvision.transforms에 정의된 함수들을 집어넣기 위한 transform 리스트를 만든다.

논문 발췌

논문에서 training 과정에서 데이터 증가를 위해 horizontally flip한다고 쓰여있다. 따라서 'train' 모드이면 torchvision.transforms의 RandomHorizontalFlip() 함수를 리스트에 집어넣는다.

CenterCrop() 함수는 이미지의 중앙에서 crop_size x crop_size 크기로 crop한다. 그리고 Resize() 함수를 통해 image_size x image_size 크기로 resize한다. ToTensor() 함수로 텐서로 변형한 다음, Normalize() 함수에서 평균(mean)과 표준편차(std) 값을 지정하여 정규화를 수행한다. 마지막으로 Compose() 함수를 통해 transform 리스트에 집어넣은 여러 transform들을 하나로 묶어준다.

 

데이터셋이 CelebA이면 CelebA 객체를 만들고, 데이터셋이 RaFD(커스텀 데이터셋의 경우도 RaFD로 전달함)이면 ImageFolder 객체를 만든다.

github.com/yunjey/StarGAN/blob/master/jpg/RaFD.md

StarGAN에서 커스텀 데이터셋으로 학습을 할 때에는 위 링크에 나온 구조대로 각 특성 폴더 하위에 이미지를 저장해야 한다. 이러한 구조를 가진 데이터셋을 불러올 때 ImageFolder 라이브러리를 사용한다고 한다.

 

main.py 캡처
main.py 캡처

이 get_loader() 함수는 main.py에서 위와 같이 호출되며, rafd_image_dir에는 'train' 폴더의 경로가 들어간다. 

즉 get_loader() 함수의 첫번째 인자 image_dir에 이 'train' 폴더의 경로가 들어온다. 

ImageFolder 객체를 만들 때 인자에 'train' 폴더의 경로와 사용할 transform을 넘겨주는 것이다.

 

이제 마지막으로 DataLoader를 이용해 데이터를 실제로 불러오고 이를 반환한다. (인자에 대해 자세히 알고 싶으면 클릭)

 

 

 

반응형