AI/Computer Vision

[StarGAN] 코드 파악하기 - solver.py (1)

체봄 2021. 1. 9. 23:55

 

solver.py는 코드가 길어서 글을 두 개로 나눠서 적을 것이다.

이번 글에서는 180번째 줄까지, 다시 말해 train() 함수 전까지 파악해본다.

 

 

Solver 클래스는 nn.Module을 상속받지 않는다.

Solver 객체를 호출할 때에는 매개변수에 celeba_loader, rafd_loader, config를 넘겨준다.

 

line 19 ~ 70은 거의 다 매개변수로 전달받은 값을 그대로 self에 저장하는 과정이므로 생략한다. 

 

 

build_model 함수에서는 주석에 써있듯 Generator와 Discriminator를 만든다.

Generator 객체는 self.G에, Discriminator 객체는 self.D에 할당한다.

 

 

StarGAN에서의 모든 모델은 학습 과정에 Adam optimizer를 사용한다. 인자에는 각각 Generator/Discriminator의 parameter들, learning rate, beta 값이 들어간다.

main.py 일부

main.py에서 learning rate와 beta 값에 대한 기본 정보를 확인할 수 있었다.

 

 

'모델명.to(장치)'는 사용중인 장치에 최적화된 형태로 모델을 변환하는 작업이라고 한다.

 

 

print_network() 함수는 인자로 모델과 모델의 이름을 전달받아 모델의 네트워크 정보를 출력하는 역할을 한다.

for문에서는 model의 모든 파라미터의 원소 수를 numel() 함수로 구해 num_params에 더한다.

print(model)을 하면 다음과 같이 출력된다.

                    ...

print(name)을 하면 다음과 같이 출력된다.

파라미터 수를 출력하면 다음과 같다.

 

 

restore_model() 함수는 이전에 학습하여 저장된 모델을 불러오는 역할을 한다.

정해준 iteration(=>'num_iters')만큼 학습이 완료되지 못하고 중간에 종료된 경우, 'resume_iters' 인자에 학습을 이어서 시작할 iteration 수를 지정해주면 해당 iteration부터 학습을 이어서 할 수 있다. 이 때, 그 iteration부터 무조건 시작할 수 있는 것이 아니라 그 iteration에 해당하는 저장된 모델이 있어야 한다. (예를 들어, 10000번마다 모델이 저장되는 상황에서 학습이 20300번째 iteration에서 종료되었다면, 'resume_iters'에 20300이 아니라 20000을 넘겨줘야 한다)

main.py 캡처

main.py에서 해당 인자를 보면 default값은 None이다. 'resume_iters' 인자를 따로 지정해주지 않으면 기본적으로 학습을 아예 처음부터 시작하게 된다.

main.py에서 'model_save_dir' 인자에 넘겨준 경로에 resume_iters에 해당하는 모델명을 붙여 G_path, D_path로 할당한다.

그리고 G_path, D_path의 모델에서 state_dict를 불러와 각각 self.G와 self.D에 저장한다.

 

 

build_tensorboard() 함수에서는 logger.py에 정의된 Logger 클래스 객체를 생성한다.

 

update_lr() 함수에서는 g_optimizer와 d_optimizer에서의 learning rate를 업데이트한다.

 

g_optimizer, d_optimizer의 gradient를 0으로 리셋한다.

 

out의 모든 원소들을 [0,1] 범위로 만들어 반환한다.

 

논문에서 빨간색 밑줄에 해당하는 부분이 gradient penalty에 해당한다.

l2 norm

l2 norm은 각 원소를 제곱하여 모두 더한 것에 루트를 씌워 구한다.

 

 

 

이제부터는 초보자의 입장에서 구조 파악이 어려울 수 있으니 직접 코드를 실습해보는 것을 추천한다. (내가 그랬다ㅎㅎ)

label2onehot() 함수와 create_labels() 함수를 함께 설명하겠다. create_labels() 함수는 dataset이 'RaFD'일 때의 코드만 표시하였다. 

먼저 create_labels() 함수에서는 모든 타겟 도메인 레이블을 생성하는 함수이다.

매개변수 c_org는 한 batch를 가져왔을 때 batch_size개의 이미지들의 실제 도메인 레이블을 담고 있는 tensor이다. (batch_size의 default 값은 16)

값 확인

(내 커스텀 데이터셋에 대해서 실제 값을 확인해보았는데, 내 데이터셋의 경우 도메인의 수가 7이기 때문에 도메인 레이블이 0~6의 값을 갖고 있다)

for문에서 i는 0부터 c_dim-1의 값을 가지고, label2onehot() 함수를 호출한다. 이 때, c_org.size(0)은 batch_size의 값과 같다. torch.ones(c_org.size(0))는 크기가 16이며 값이 모두 1.인 tensor를 생성한다. 따라서 torch.ones(c_org.size(0))*i는 크기가 16이며 값이 모두 i인 tensor를 생성한다는 말이다.

값 확인

(i가 2라면 값이 모두 2.으로 초기화된 tensor를 생성하는 것을 확인할 수 있다)

이제 label2onehot() 함수에 생성된 텐서와 특성의 수 c_dim을 넘겨준다.

label2onehot() 함수를 살펴본다. 

매개변수 labels에는 크기가 batch_size이고 값이 모두 i인 텐서가 넘어온다. 크기가 batch_size x dim(c_dim)이고 값이 모두 0인 tensor를 만들어 out에 할당한다. 

np.arange(batch_size)는 0부터 batch_size-1까지의 값이 순차적으로 저장된 배열이고, labels.long()은 labels에서 값이 2.와 같았던 것을 모두 2와 같이 변환한 텐서이다.

값 확인

out[np.arange(batch_size), labels.long()] = 1를 하면 다음과 같이 된다.

값 확인

(모두 값이 0이었던 텐서 out에서 [0,2], [1,2], [2,2], ..., [15,2]번째 값을 1로 변경한 것이다)

이 텐서 out을 반환하여 create_labels() 함수에서 c_trg에 저장된다. c_trg는 device에 최적화된 형태로 변환되어 c_trg_list 리스트에 저장된다. 

이 과정들이 for문에서 c_dim번 반복되며 최종적인 c_trg_list를 출력해보면 다음과 같다.

...

StarGAN의 결과 이미지를 보면 각 테스트 이미지에 대해 가능한 모든 도메인으로(자기 자신의 도메인 포함) 모두 변환한 결과가 출력된다. 따라서 각 이미지마다 변환할 모든 타겟 도메인을 create_labels() 함수를 통해 생성하는 것이다.

 

 

classification_loss() 함수는 Domain classification loss를 구하는 함수이다. 

파라미터 logit에는 Discriminator의 출력 값 중 하나인 Dcls, 즉 Discriminator의 인풋 이미지의 domain classification 값이 들어오고, 파라미터 target에는 원본 도메인의 레이블 또는 랜덤으로 생성된 타겟 도메인의 레이블이 들어온다.

즉, logit에 원본 이미지에 대한 domain classification 값이 들어오면 target에는 원본 도메인 레이블이 들어오고, logit에 합성 이미지(가짜 이미지)에 대한 domain classification 값이 들어오면 target에는 합성 이미지의 도메인 레이블이 들어온다.

한 마디로 Discriminator를 통해 예측한 입력 이미지의 domain과 입력 이미지의 실제 domain 사이의 loss를 구하는 과정이다. cross_entropy() 함수를 사용한다.

 

논문 발췌

논문에 따르면, 원본 이미지에 대한 domain classification loss는 Discriminator를 최적화하기 위함이며, 합성 이미지에 대한 domain classification loss는 Generator를 최적화하기 위해 사용된다.

 

 

 

-> 다음 글

 

[StarGAN] 코드 파악하기 - solver.py (2)

초보자라 틀린 부분이 있을 수 있습니다. 피드백 주시면 감사히 반영하겠습니다 :) 이번 글에서는 solver.py의 182번째 줄부터 파악해본다. 초보자의 입장에서 구조 파악이 어려운 부분이 꽤 있으니

beausty23.tistory.com

 

 

 

반응형