티스토리 뷰
model.py의 코드를 파악해볼 것이다.
살펴보기에 앞서, 아래 블로그에 Pytorch를 이용한 딥러닝의 기본적인 틀이 소개되어 있으니 참고해보면 도움이 될 것이다.
ResidualBlock
먼저 Generator의 Bottleneck 부분에 사용되는 Residual Block 클래스가 선언되어 있다.
ResidualBlock 인스턴스를 생성하면 __init__() 함수가 실행된다. 이 ResidualBlock 클래스는 nn.Module 클래스를 상속받는다.
super( ~ ).__init__() 을 통해 상위 클래스인 nn.Module을 초기화한다.
논문에는 Convolution layer, Instance Normalization, ReLU만 나와 있으나 실제 코드에서는 그 뒤에 Convolution layer와 Instance Normalization이 추가된다.
여러 개의 layer를 순차적으로 거치기 때문에 torch.nn 모듈의 Sequential() 함수를 통해 모든 레이어들을 묶어주고, 이를 self.main에 할당한다.
ResidualBlock 인스턴스를 생성할 때 dim_in, dim_out을 매개변수로 받는데, 이는 각각 입력/출력 dimension을 의미한다. 논문에 따르면 dim_in과 dim_out이 256으로 설정된다. (N이 output dimension을 의미)
첫번째 Convolution layer를 거쳐 dimension이 dim_out이 되었으므로 다음으로 실행될 InstanceNorm2d 함수에는 dim_out을 입력으로 넣어주고 계속 진행한다.
Residual block의 핵심 기능은 여러 동작을 진행하면서 처음의 정보를 잃어갈 수 있기 때문에 처음의 정보를 더해주는 것이다. 이를 forward() 함수에서 정의한다.
forward() 함수를 사용 시에는 forward() 함수를 직접 호출하는 것이 아니라, 해당 클래스 객체에 대해 객체명(forward 함수의 매개변수) 형태로 사용하면 forward() 함수가 자동 호출이 된다.
예를 들어, model = ResidualBlock(dim_in, dim_out)으로 model 객체를 만들었을 때, ResidualBlock 클래스의 forward() 함수를 호출하려면 model(x)와 같이 사용하면 된다.
Generator
Generator 인스턴스를 생성할 때에는 첫번째 Convolution layer의 output dimension인 conv_dim, domain label의 수인 c_dim, Residual Block의 수인 repeat_num을 매개변수에 넘겨준다.
layers 리스트에 모든 레이어들을 append할 것이다. (왜 여기서는 Sequential() 함수를 안 이용하는지 의문이 든다면, 이용할 것이긴 하지만 안에 들어가는 layer가 너무 많아서 일단 리스트에 다 집어넣고 마지막에 Sequential() 함수에 리스트를 넘겨줄 것이다)
먼저 Down-sampling 과정에서는 Convolution을 위해 Conv2d() 함수를 사용하고, dimension은 커진다.
첫번째 Convolution layer의 입력 dimension을 3+c_dim으로 하는데, 이유는 아래에 나올 forward() 함수에서 설명한다.
Instance Normalization과 ReLU를 거친 다음, 이제 conv_dim 변수 대신 curr_dim 변수를 사용한다.
CONV, IN, ReLU 를 두 번 더 거친다. 이 때, output dimension은 input dimension의 2배가 된다.
다음으로 Bottleneck 부분에서는 맨 처음에 정의한 ResidualBlock 인스턴스를 repeat_num 수만큼 만들어 layers 리스트에 append하기만 하면 된다.
dim_in과 dim_out이 같다. 논문에 따르면 6개의 Residual Block이 append된다.
다음으로 Up-sampling 과정에서는 Deconvolution을 위해 ConvTranspose2d() 함수를 사용하고, dimension이 작아진다.
DECONV, IN, ReLU를 거치는 것을 두번 반복한다.
for문을 빠져나와 CONV, Tanh를 거친다.
이제 모든 레이어를 layers 리스트에 추가했으니 Sequential() 함수에 리스트를 전달하고 이를 self.main에 할당한다.
Generator 클래스의 forward() 함수의 매개변수에는 real image x와 target domain c가 들어온다. (solver.py에서 들어옴)
x에는 128x128 크기의 이미지가 16개 들어있으며, dimension은 3이다.
c는 아래와 같은 형태를 갖는다. (값은 달라짐)
c = c.view(c.size(0), c.size(1), 1, 1)에서 view() 함수는 size를 바꾸는 함수로, c의 형태가 다음과 같이 바뀐다.
c = c.repeat(1, 1, x.size(2), x.size(3))를 하면 다음과 같이 형태가 바뀐다.
repeat() 함수는 텐서 c의 데이터를 매개변수로 넘겨준 차원만큼 반복한다.
매개변수에 1을 쓴 부분의 차원은 원래 c와 동일하고, x.size(2)=x.size(3)=128이기 때문에 128번씩 반복된 것을 알 수 있다.
x = torch.cat([x, c], dim=1)를 하면 x는 다음과 같다.
torch.cat() 함수에서는 x에 [x, c]를 이어붙여 x에 저장한다.
x의 size가 [16, 3, 128, 128], c의 size가 [16, 7, 128, 128]이었으므로 이어붙인 결과는 [16, 10, 128, 128]이 된다.
그리고 self.main(x)를 호출하는데, 이는 Generator의 레이어들의 집합에 x를 입력으로 주는 것이다. 그래서 논문에서 Generator의 첫번째 layer의 입력 dimension이 3+c_dim인 이유가 이와 같다. forward() 함수에 들어온 초기 x의 dimension은 3이었지만, dimension이 7인 c와 torch.cat() 함수로 이어붙였기 때문에 3+c_dim이 된다.
앞에서 self.main에 Generator의 모든 layer들을 Sequential() 함수에 집어넣어 할당했었다. self.main(x)를 하면 self.main에 할당된 Sequential에 입력 값으로 x를 전달함으로써 Generator의 모든 layer들을 거치며 동작한다. 이 때, Sequential 안에 들어있는 ResidualBlock 객체에도 입력 값 x가 전달되어 ResidualBlock 클래스의 forward 함수가 호출되는 것 같다. 최종적으로 마지막 layer를 거친 값이 forward() 함수에서 반환된다.
self.main(x)에서 반환된 값(합성 이미지)은 초기 x(원본 이미지)와 동일한 size를 갖는다. 논문에 나와 있는 Generator의 마지막 layer의 shape와 동일하다.
Discriminator
Input Layer 부분이다.
Discriminator 클래스에서는 먼저 Convolution layer를 거치는데, 입력으로 평범한 RGB 이미지가 들어오므로 input dimension이 3이다.
Hidden Layer 부분이다.
repeat_num이 default 6으로 설정되어 있으므로 CONV, Leaky ReLU를 거치는 동작을 5번 반복한다. (i=1~5)
dimension이 점점 두배가 된다.
Output Layer 부분이다. Dsrc는 입력 이미지가 real 이미지인지 아니면 합성된 fake 이미지인지를 구별하여 나타내는 값이고, Dcls는 입력 이미지의 도메인 label을 나타낸다.
kernel_size를 (이미지 한변의 길이 / 2^repeat_num) 으로 할당하는데, 이 값은 두번째 Convolution layer의 kernel_size로 들어간다. 논문에 h/64로 표기된 값이 이 값에 해당한다.
그리고 Discriminator에서 지금까지 나왔던 레이어들을 Sequential() 함수에 넣고 self.main에 할당한다.
첫번째 Convolution layer에서 real/fake 여부만을 출력해야하므로 output dimension이 1이 된다.
두번째 Convolution layer에서는 도메인 label을 출력해야하므로 output dimension이 c_dim이 된다.
마지막으로 Discriminator의 forward() 함수이다.
x에는 진짜인지 가짜인지 판별할 이미지가 전달된다. self.main(x)를 하면 self.main에 할당한 Sequential()에 들어있는 레이어들이 x를 입력값으로 하여 순차적으로 동작을 진행한다. Hidden layer까지 모두 거치고 나면 그 값이 반환되어 h에 할당된다.
그런 다음 이 h를 self.conv1와 self.conv2의 인자로 각각 전달해줌으로써 h값에 대해 Convolution 레이어를 거쳐 Dsrc값과 Dcls값을 반환하여 각각 out_src, out_cls에 저장한다.
이 out_cls의 크기를 조정하여 out_src와 함께 반환한다.
참고
forward() : wikidocs.net/60036
view(), size() : hichoe95.tistory.com/26
repeat() : seducinghyeok.tistory.com/9
'AI > Computer Vision' 카테고리의 다른 글
[StarGAN] 코드 파악하기 - data_loader.py (2) | 2021.01.10 |
---|---|
[StarGAN] 코드 파악하기 - solver.py (1) (0) | 2021.01.09 |
[StarGAN] 코드 파악하기 - main.py (0) | 2021.01.04 |
[StarGAN] 커스텀 데이터셋으로 train/test하기 (0) | 2020.12.30 |
[StarGAN] pre-trained 모델로 커스텀 데이터셋 test하기 (0) | 2020.12.28 |