AI/NLP

Transformer의 generate() 함수 파라미터 이해하기

체봄 2021. 12. 2. 18:08

 

Huggingface의 transformers에서 generate() 함수에 쓰이는 파라미터들이 어떤 역할을 하는지 알아본다.

전체 코드는 해당 링크로부터 확인할 수 있다.

Huggingface 관련 문서

 

generate() 함수 파라미터

 

  • do_sample
    • True -> 랜덤하게 샘플링 (ex: Top-k Sampling, Top-p Sampling)
    • False -> 높은 확률의 토큰을 선택 (ex: Greedy Decoding, Beam Search)
  • early_stopping
    • Batch당 최소 num_beams개의 문장이 완료되면 Beam Search를 종료
  • num_beams
    • Beam Search 사이즈. Beam Search에서 확률이 높은 토큰 k개씩 선택하는데, 이 k 값을 의미함
    • 보통 5~10으로 설정
  • temperature
    • Temperature Scaling에 쓰이는 값으로 예측 토큰들의 확률 분포를 조정함.
    • 0에 가까울수록 => 확률 분포가 뾰족해짐 (확률이 큰 값은 더 크게, 작은 값은 더 작게) => 가장 확률이 높은 토큰을 선택하게 됨
    • 1 => 확률 분포 그대로
    • ∞로 커질수록 => 확률 분포가 평평해짐 (확률이 큰 값과 작은 값 사이의 차이가 줄어듦) => uniform 분포에 가까워지므로 랜덤하게 선택하게 됨
  • no_repeat_ngram_size
    • n-gram 단위로 반복되는 토큰을 무시하기 위함. 해당하는 n-gram의 등장 확률을 0으로 만들어 생성 시 제외시킴.
    • 보통 3으로 설정
    • (Beam Search에서 사용했을 때 요상한 글자가 자주 생성되었음. repetition_penalty를 쓰는게 나음)
  • repetition_penalty
    • 같은 토큰이 반복적으로 나오면 penalty를 적용해서 해당 토큰의 logit을 낮춰줌
    • 값은 1.0 이상이며, 값이 클 수록 penalty를 세게 적용
  • length_penalty
    • 누적 확률 사용 시, beam의 길이가 길수록 누적 확률이 작아지는 문제를 해소하기 위함 (확률 값은 0~1 범위이므로 이를 여러 번 곱할수록 누적 확률이 작아짐)
    • 보통 1.2로 설정
  • diversity_penalty 
    • Beam Search 사용 시, 각 beam group에서 생성된 토큰이 다른 beam에서도 생성되었으면 penalty를 적용하여 해당 토큰의 확률을 감소시킴
  • num_beam_groups
    • Beam Search로 나온 그룹들 중 다양성을 위해 가장 안 유사한걸 뽑음
    • num_beams(=Beam Search 사이즈) 값에서 나눌 수 있는 값이어야 함
  • num_return_sequences
    • 최종적으로 가장 높은 확률로 생성된 문장을 몇 개 반환할지 설정
    • num_beams 값보다 작거나 같아야 함
  • use_cache
    • decoding을 좀 더 효율적이고 빠르게하기 위해 cache를 사용
    • cache를 쓰지 않을 때 동작 
      1. 입력 토큰들 각각에 대해 hidden state를 계산
      2. 첫번째 출력 토큰 생성
      3. 첫번째 출력 토큰의 hidden state를 계산해 두번째 출력 토큰 생성
      4. 첫번째,두번째 출력 토큰의 hidden state를 계산해 세번째 출력 토큰 생성 (반복)
    • 출력 토큰의 hidden state는 항상 동일한데 매 생성마다 이 hidden state를 계산하는 과정이 중복됨
    • 따라서, cache를 사용하여 한번 계산된 hidden state는 cache에 저장해두고, 가장 최근에 생성된 출력 토큰에 대해서만 hidden state를 계산함
    • default 값은 True지만 학습 시에는 False로 설정

참고 : https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958

 

 


 

참고 링크 (읽어보길 추천 :D)

  1. https://blog.naver.com/PostView.nhn?blogId=sooftware&logNo=221809101199&from=search&redirect=Log&widgetTypeCall=true&directAccess=false
  2. https://ratsgo.github.io/nlpbook/docs/generation
  3. https://littlefoxdiary.tistory.com/46
반응형