티스토리 뷰

RuntimeError: Error(s) in loading state_dict :
        Unexpected key(s) in state_dict: "bert.embeddings.position_ids"

 

GPU에서 train한 모델을 CPU에서 test하려고 하니 위와 같은 에러 메시지가 발생했다.

에러가 난 코드는 이 부분이었다.

model.load_state_dict(torch.load(state_save_path, map_location='cpu'))

state_dict를 불러오는 과정에서 서버 환경이 달라져서 key 값이 매칭이 되지 않아 발생한 에러다.

 

해결 방법은 load_state_dict()에 strict=False를 추가해주면 된다.

이를 추가해주면 state_dict를 불러올때 모든 key를 엄격(?)하게 불러오지 않고, 불러올 수 있는 것만 유동적으로 불러올 수 있다.

model.load_state_dict(torch.load(state_save_path, map_location='cpu'), strict=False)

 

 

 

참고 링크

반응형

댓글