StarGAN version1이 2018년에 나오고 지금은 시간이 꽤 지나 version2까지 나온 상황이다. 그 사이에 tensorflow가 기존 1 버전의 문법과는 다른 부분이 많은 2 버전이 출시되었다. tensorflow 2 버전 사용법에 관련해서는 www.tensorflow.org/guide/effective_tf2?hl=ko를 참조하면 도움이 될 것이다. 2 버전 사용 시 오류가 발생하는 곳이 logger.py이다. (1 버전을 사용하고 있다면 에러가 발생하지 않을 것이다) tf.summary.FileWriter() 함수는 tensorflow 1 버전에서만 사용 가능한 함수다. 따라서 2 버전에서 지원하는 tf.summary.create_file_writer() 함수로 수정해줘야 한다. 앞에 ..
-> 이전 글 [StarGAN] 코드 파악하기 - solver.py (1) 초보자라 틀린 부분이 있을 수 있습니다. 피드백 주시면 감사히 반영하겠습니다 :) solver.py는 코드가 길어서 글을 두 개로 나눠서 적을 것이다. Solver 클래스는 nn.Module을 상속받지 않는다. Solver 객 beausty23.tistory.com 이번 포스팅에서는 solver.py의 182번째 줄부터 파악해본다. 초보자의 입장에서 구조 파악이 어려운 부분이 꽤 있으니 해당 코드를 직접 실습해보면서 구조를 직접 확인해보는 것을 추천한다. 단일 데이터셋을 사용할 때 학습을 진행하는 train() 함수이다. (여러 데이터셋을 사용할 때의 train_multi() 함수는 다루지 않을 것이다) dataset 값에 따라 ..
먼저 CelebA 데이터셋을 사용하는 경우에 필요한 CelebA 클래스가 정의되어 있다. 하지만 나는 커스텀 데이터셋을 사용할 것이고 이 경우 RaFD 데이터셋을 사용하는 방식과 동일하기 때문에 CelebA 클래스는 건너뛰고, get_loader() 함수만 알아본다. get_loader() 함수에서는 data loader를 만들어 반환한다. torchvision.transforms에 정의된 함수들을 집어넣기 위한 transform 리스트를 만든다. 논문에서 training 과정에서 데이터 증가를 위해 horizontally flip한다고 쓰여있다. 따라서 'train' 모드이면 torchvision.transforms의 RandomHorizontalFlip() 함수를 리스트에 집어넣는다. CenterCr..
solver.py는 코드가 길어서 글을 두 개로 나눠서 적을 것이다. 이번 글에서는 180번째 줄까지, 다시 말해 train() 함수 전까지 파악해본다. Solver 클래스는 nn.Module을 상속받지 않는다. Solver 객체를 호출할 때에는 매개변수에 celeba_loader, rafd_loader, config를 넘겨준다. line 19 ~ 70은 거의 다 매개변수로 전달받은 값을 그대로 self에 저장하는 과정이므로 생략한다. build_model 함수에서는 주석에 써있듯 Generator와 Discriminator를 만든다. Generator 객체는 self.G에, Discriminator 객체는 self.D에 할당한다. StarGAN에서의 모든 모델은 학습 과정에 Adam optimizer를..
model.py의 코드를 파악해볼 것이다. 살펴보기에 앞서, 아래 블로그에 Pytorch를 이용한 딥러닝의 기본적인 틀이 소개되어 있으니 참고해보면 도움이 될 것이다. PyTorch로 딥러닝하기 — Intro 거창하게 “딥러닝하기”라는 제목을 달았지만, 알다시피 우리에게 딥러닝을 한다는 것은 딥러닝 framework를 잘 사용하기와 같은 의미입니다. medium.com ResidualBlock 먼저 Generator의 Bottleneck 부분에 사용되는 Residual Block 클래스가 선언되어 있다. ResidualBlock 인스턴스를 생성하면 __init__() 함수가 실행된다. 이 ResidualBlock 클래스는 nn.Module 클래스를 상속받는다. super( ~ ).__init__() 을 ..
StarGAN version1의 main.py 코드를 파악해볼 것이다. github.com/yunjey/stargan yunjey/stargan StarGAN - Official PyTorch Implementation (CVPR 2018) - yunjey/stargan github.com train을 수행할 때 main.py를 실행하되, --로 인자를 넘겨주는데 뒤에 나오는 모든 인자들이 config가 된다. mode에 'train'를 적었으면 solver.py에 정의된 train() 함수를 실행하고, mode에 'test'를 적었으면 solver.py에 정의된 test() 함수를 실행한다. c_dim은 데이터셋에서 사용할 특성(attribute)의 수를 의미한다. (StarGAN에서 기본적으로 Cel..
StarGAN 모델을 이용해 의류에 패턴을 합성하는 프로젝트를 진행하고 있다. 그래서 StarGAN에서의 기본 데이터셋인 사람 얼굴에 관련된 CelebA나 RaFD가 아닌, 의류에 관련된 커스텀 데이터셋을 사용하여 train 및 test하려고 한다. !apt-get install -y -qq software-properties-common python-software-properties module-init-tools !add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null !apt-get update -qq 2>&1 > /dev/null !apt-get -y install -qq google-drive-ocamlfuse fuse from go..
github.com/yunjey/stargan/issues/46 를 참고하였다. 구글 Colaboratory에서 진행한다. GPU는 필요 없다. !apt-get install -y -qq software-properties-common python-software-properties module-init-tools !add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null !apt-get update -qq 2>&1 > /dev/null !apt-get -y install -qq google-drive-ocamlfuse fuse from google.colab import auth auth.authenticate_user() from oauth2..