faiss是Facebook开源相似性搜索库是目前最成熟的近邻搜索库,为密集向量提供高效的相似性搜索和聚类,支持10亿级向量搜索 faiss通过余弦距离公式,不直接提供余弦距离计算,而是提供欧式距离和点积L2正则后的向量点积结果是余弦距离,因此使用faiss计算余弦距离需要首先将输入归一化 找到定义的距离faiss.IndexFlatIP是内积 ;faiss.indexFlatL2是欧式距离
class FaissKNN: def __init__( self, reset_before=True, reset_after=True, index_init_fn=None, gpus=None ): self.reset() self.reset_before = reset_before self.reset_after = reset_after self.index_init_fn = ( **faiss.IndexFlatIP** if index_init_fn is None else index_init_fn ) if gpus is not None: if not isinstance(gpus, (list, tuple)): raise TypeError("gpus must be a list") if len(gpus) < 1: raise ValueError("gpus must have length greater than 0") self.gpus = gpus
改为内积距离 将输入向量除以其模具,内积距离为余弦距离
import torch.nn.functional as F train_embeddings = F.normalize(train_embeddings) test_embeddings = F.normalize(test_embeddings)