[에러 해결] KeyError: 'src_texts'
KeyError: 'src_texts'
BART를 이용한 대화 생성 실험을 하려는데, 위 에러가 발생했다.
디버깅해보니 에러가 발생하는 원인은 Seq2SeqDataCollator 객체가 호출될 때 __call__() 함수에 넘어오는 batch 변수가 비어있기 때문에 'src_texts'라는 key를 찾을 수 없기 때문이었다.
batch 변수가 비어있는 이유는 다음과 같다.
data_collator를 선언하는 과정에서, RemoveColumnsCollator 객체가 호출될 때 _remove_columns() 함수가 실행된다. 이 함수는 현재 데이터에서 key가 self.signature_columns라는 리스트에 포함되지 않으면 제거하는 기능을 수행한다.
transformers에서 제공하는 코드를 그대로 사용했을 시 self.signature_columns 리스트에 들어있는 값은 'input_ids', 'attention_mask', 'token_type_ids', ..., 'labels', 'label', 'label_ids'이다.
그런데 Seq2SeqDataset 클래스를 이용하면 사용할 데이터의 key가 "src_texts", "tgt_texts", "id"로 설정된다.
이 key들은 self.signature_columns 리스트에 없는 값이기 때문에 _remove_columns() 함수 내에서 모두 제거된다. 그래서 batch 변수가 비어있게 되는 것이다.
따라서 해결 방법은 다음과 같다.
trainer_utils.py 파일을 보면 Trainer 클래스 내에 _set_signature_columns_if_needed() 함수가 있다.
여기서 self.signature_columns 리스트를 지정해주기 때문에, 사용할 데이터의 key를 넣어주면 된다.
함수 맨 아래에 다음 코드를 추가해주면 된다.
self._signature_columns += ["src_texts", "tgt_texts", "id"]