AI/Computer Vision

[StarGAN] 코드 파악하기 - main.py

체봄 2021. 1. 4. 23:19

 

StarGAN version1의 main.py 코드를 파악해볼 것이다.

github.com/yunjey/stargan

 

yunjey/stargan

StarGAN - Official PyTorch Implementation (CVPR 2018) - yunjey/stargan

github.com

 

 

 

train 수행 코드

train을 수행할 때 main.py를 실행하되, --로 인자를 넘겨주는데 뒤에 나오는 모든 인자들이 config가 된다.

 

mode에 'train'를 적었으면 solver.py에 정의된 train() 함수를 실행하고,

mode에 'test'를 적었으면 solver.py에 정의된 test() 함수를 실행한다.

 

c_dim은 데이터셋에서 사용할 특성(attribute)의 수를 의미한다.

(StarGAN에서 기본적으로 CelebA 데이터셋으로부터 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young' 특성을 사용하기 때문에 default값이 5이다)

 

image_size는 모델에 들어갈 이미지의 크기를 의미한다. default는 128x128이다.

 

논문 발췌

g_conv_dim은 Generator 구조에서 첫번째 layer의 filter 수를 의미한다. 논문을 따라 default는 64이다.

g_repeat_num은 Generator 구조에서 Residual Block의 수를 의미한다. 논문을 따르면 default는 6이다.

 

논문 발췌

 d_conv_dim은 Discriminator 구조에서 첫번째 layer의 filter 수를 의미한다. 논문에 따르면 default는 64이다.

d_repeat_num은 Discriminator 구조에서 Output layer를 제외한 convolution layer의 수이다. 논문에 따르면 default는 6이다.

 

논문 발췌

lambda_gp는 adversarial loss를 구하는 데에 사용되는 gradient penalty 값을 의미하며 default는 10이다.

 

num_iters는 학습 과정에서 몇 번의 iteration을 돌 것인지를 나타내는 값이다. default로 200000번의 iteration을 수행한다.

 

논문 발췌

n_critic은 Discriminator가 몇 번 update되었을 때 Generator를 한번 update시킬 것인지를 의미하는 값이다. 

 

selected_attrs는 CelebA 데이터셋에서 사용할 특성들이다.

default로는 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young' 특성이 사용된다. 

 

test_iters는 모델 테스트를 위해 학습된 모델을 몇 번째 step에서 가져올 것인지를 의미한다.

다시 말해, 모델 학습 시에 model_save_step 인자의 default 값인 10000번째 iteration마다 학습 모델이 저장되는데, 몇 번째 iteration에서 저장된 학습 모델을 가져와 테스트를 할 것인지를 나타내는 값이다.

(num_iters의 default가 200000이기 때문에 test_iters의 default도 200000이다)

 

num_workers는 몇 개의 CPU 코어를 할당할 것인지를 나타내는 값이라고 한다.

좀 개념이 어려워서 jybaek.tistory.com/799에 잘 설명되어 있으니 참고하면 좋을 것 같다.

default는 1이지만 github.com/yunjey/stargan/issues/43와 같은 에러가 발생하는 경우 값을 0으로 설정하는게 도움이 될 수도 있다.

 

mode는 train을 할 것인지 test를 할 것인지를 결정하는 값이다. 

 

CelebA 데이터셋을 사용하는 경우 데이터셋이 저장되는 default 경로는 'data/celeba/images'이다.

CelebA를 사용하면 attribute 정보를 담고 있는 list_attr_celeba.txt 파일도 만들어줘야 하는데, 이 파일이 위치하는 경로를 attr_path에 써준다.

 

RaFD 데이터셋을 사용하는 경우 학습용 데이터셋이 저장되는 default 경로는 'data/RaFD/train'이다.

'train' 폴더 하위에 특성별로 폴더를 만들어 그 안에 이미지를 저장해야한다.

 

 

model_save_dir은 모델 학습 과정에서 model_save_step 인자 값에 해당하는 iteration 수마다 학습 모델이 저장되는 경로이다.

10000-D.ckpt, 10000-G.ckpt 형태의 모델들이 저장된다.

 

result_dir은 모델 테스트 결과가 저장되는 경로이다.

테스트 데이터셋을 각각의 특성으로 합성된 이미지들이 저장된다.

 

model_save_step은 모델 학습 과정에서 몇 번째 iteration마다 학습 모델을 저장할 것인지를 나타낸다.

default가 10000이므로 10000번째, 20000번째, ... 학습 모델들이 저장된다.

 

 

반응형