티스토리 뷰

AI/NLP

FAISS 설명 및 사용법

체봄 2022. 4. 28. 21:03

GitHub 링크: https://github.com/facebookresearch/faiss

 

GitHub - facebookresearch/faiss: A library for efficient similarity search and clustering of dense vectors.

A library for efficient similarity search and clustering of dense vectors. - GitHub - facebookresearch/faiss: A library for efficient similarity search and clustering of dense vectors.

github.com

 

 

 

FAISS 설명

: 유사한 벡터를 검색해서 가져오는 facebook의 라이브러리

  • L2 거리 (가장 작은값) 또는 내적 연산 (가장 큰값) 을 기반으로 유사한 벡터를 계산한다.
  • C++로 구현되었지만 python에서도 편하게 쓸 수 있다.
  • 벡터 차원 d는 보통 10~100으로 사용
  • 2개의 행렬(리스트) 필요
    1. database 행렬: (인덱싱되어야 하는) 모든 벡터가 저장됨, 검색을 여기서 수행 (크기: database_size x d)
    2. query 행렬: query 벡터들의 집합 (크기: query_num x d)
  • index (객체) 구축하기
    • 다양한 타입의 index가 있지만, 가장 간단한 것은 brute-force L2 거리로 검색을 수행하는 IndexFlatL2
    • 벡터마다 정수형 id를 저장할 수도 있다 (IndexFlatL2 제외)
    • 대부분의 index에서는 벡터 분포를 분석하기 위해 학습 단계가 필요 (IndexFlatL2에서는 스킵 가능)
    • 학습 단계까지 마치면, index에 대해서 add / search 연산이 가능
  • add 연산
    • index에 vector들(database 행렬)을 추가함
  • search 연산
    • 가장 기본 검색 방법은 k-nearest-neighbor search
    • index에서 query 행렬과 유사한 벡터를 검색해서 k개 반환

 

  • Index 비교
    • IndexFlat
      • encoding 안함
      • IndexFlatL2: L2 거리
      • IndexFlatIP: Inner Product
      • IndexFlatL2가 IndexFlatIP보다 정확한 듯 하다
    • IndexHNSW*
      • 인덱싱된 벡터들로 구축된 그래프를 기반으로 함
      • 그래프는 가능한 빨리 가장 가까운 이웃으로 수렴하는 방식으로 탐색
      • 벡터 제거는 그래프 구조를 파괴할 수 있으므로 불가
      • M: 그래프에 사용되는 이웃 수. 클수록 더 정확하고 더 많은 메모리 사용
      • efConstruction: add 시에 탐색하는 깊이
      • efSearch: search 시에 탐색하는 depth
      • IndexHNSWFlat, IndexHNSWSQ, IndexHNSWPQ, IndexHNSW2Level
    • IndexIVF*
      • Cell-probe 방법
      • 처리 속도를 높이지만 가장 가까운 이웃을 찾는 보장이 줄어들음 
      • IndexIVFPQ: 대규모 검색에 가장 유용한 구조
    • IndexLSH
      • LSH는 가장 널리 사용되는 Cell-probe 방법
      • 많은 해쉬 함수가 필요해서 많은 메모리가 필요
      • 바이너리 코드를 사용하는 Flat Index
      • database 벡터와 query 벡터는 바이너리 코드로 해싱되고 Hamming 거리로 비교된다
    • IndexScalarQuantizer
      • 16-bit float 인코딩 이용 시 정확도에 손실이 생길 수 있음
    • IndexPQ
      • 벡터는 몇 비트(8/12/16)의 하위 벡터로 분할됨
      • 차원 d는 m의 배수여야함
    • 결론: IndexFlatL2 - IndexHNSWFlat - IndexFlatIP 순으로 추천

 

 


 

설치

참고: https://github.com/facebookresearch/faiss/blob/main/INSTALL.md

 

conda로 설치하는 걸 추천한다고해서, Anaconda도 설치하였다. (Windows 환경에서 진행)

들어가자마자 나오는 가장 최신 버전을 설치했더니 네트워크 관련 오류가 발생해서, 지우고 archive에서 가장 최신 버전인 2018년 말에 나온 5.3.1 버전을 설치했다.

 

아래 명령어들은 cmd창이 아니라 Anaconda prompt에서 실행하는걸 강력히 추천한다. (cmd창에서 하다가 끝내 안됐었음)

 

conda로 가상환경을 만들어준다.

conda create -n 가상환경이름 python=3.x	# 예시: conda create -n myvenv python=3.8

 

가상환경을 활성화한다.

conda activate 가상환경이름	# 예시: conda activate myvenv

 

faiss 라이브러리를 설치한다. 나는 cpu 버전으로 설치했다.

conda install -c pytorch faiss-cpu # CPU 버전
# conda install -c pytorch faiss-gpu # GPU 버전

 

jupyter notebook에서 사용 시 더보기 클릭

더보기

Anaconda 설치시 생긴 Anaconda3 폴더에 있는 Jupyter Notebook으로 바로 켜지 말고 Anaconda prompt에서 커맨드를 입력해서 실행하는게 좋다. 

jupyter notebook

 

그런 다음 jupyter notebook에서 위에서 생성한 가상환경을 연결해준다.

https://needjarvis.tistory.com/626 블로그에 잘 설명되어 있다 :)

 

import faiss

에러가 안 나면 설치 성공!

 


 

사용법

참고: https://github.com/facebookresearch/faiss/wiki/Getting-started

 

1. 벡터 생성

모든 벡터들의 집합 그리고 검색의 대상인 query 벡터들의 집합을 만든다.

이 때 벡터의 형태는 기본 리스트나 torch.Tensor이면 안되고, 무조건 np.array여야한다.

d = 64					# dimension of vector
num_total = 10000		# number of total vectors
num_query = 5			# number of query vectors

np.random.seed(1234)             # make reproducible

total_vectors = np.random.random((num_total, d)).astype('float32')
query_vectors = np.random.random((num_queries, d)).astype('float32')

 

2. Index (객체) 구축

Index를 구축하고, 모든 벡터들을 Index에 집어넣는다. 

Index = faiss.IndexFlatL2(d)
print(Index.is_trained)		# True
Index.add(total_vectors)	# add 연산
print(Index.ntotal)		# 10000

Index에는 여러 유형이 있는데, 가장 간단한 IndexFlatL2 Index를 사용하였다.

IndexFlatL2는 학습을 하지 않아도 되기 때문에, 학습을 하지 않았어도 is_trained 값이 True로 나온다.

 

3. 검색 수행

IndexFlatL2에서는 k-nearest-neighbor 검색 방법을 사용하므로 몇개의 유사한 벡터를 가져올 것인지에 대한 값인 k를 정해준다.

각 query_vector에 대해 Index로부터 유사한 벡터 k개를 검색한다.

이 때 반환 값은 L2 거리 값과 검색된 벡터의 Index에서의 정수 인덱스 값이다. (주의: Index는 faiss Index 객체를 의미하고, indexes는 우리가 흔히 말하는 정수 인덱스를 의미한다.)

k = 3
distances, indexes = Index.search(query_vectors, k)

print(distances)	# num_query x k
# [[7.011895  7.390568  7.478056 ]
#  [8.265501  8.38073   9.07618  ]
#  [7.980177  8.24932   8.45091  ]
#  [8.033938  8.254479  8.498654 ]
#  [7.9609275 8.029879  8.259771 ]]
print(indexes)
# [[ 72  84 160]
#  [ 13  26   2]
#  [101  68  76]
#  [ 49 152  18]
#  [ 41  51 199]]

 

검증을 위해 query 벡터들을 넣는 자리에 total_vectors의 앞단 5개 벡터를 넣어보았다.

distances, indexes = Index.search(total_vectors[:5], k)

print(distances)
# [[0.        7.8548703 8.563642 ]
#  [0.        7.848359  7.9348636]
#  [0.        7.3042192 7.663117 ]
#  [0.        7.7451077 8.46536  ]
#  [0.        7.7310977 8.006843 ]]
print(indexes)
# [[  0  78  39]
#  [  1  24  88]
#  [  2  13 101]
#  [  3  18   8]
#  [  4  18  52]]

total_vectors 전체가 Index에 들어가 있기 때문에, 검색 결과를 보면 각 Top1 인덱스는 자기 자신의 인덱스다.

자기 자신과의 거리는 0이므로, 각 Top1 거리 값도 0이다.

반응형

댓글