172 lines
5.6 KiB
Python
172 lines
5.6 KiB
Python
# Copyright 2024 Bloomberg Finance L.P.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from pathlib import Path
|
|
|
|
import faiss
|
|
import numpy as np
|
|
import torch
|
|
from loguru import logger
|
|
from tqdm.auto import tqdm
|
|
|
|
from m3docrag.datasets.m3_docvqa.dataset import M3DocVQADataset
|
|
from m3docrag.utils.args import parse_args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
logger.info("Loading M3DocVQA")
|
|
|
|
dataset = M3DocVQADataset(args)
|
|
|
|
logger.info(f"Loading M3DocVQA -- all {args.retrieval_model_type} embeddings")
|
|
|
|
if args.retrieval_model_type == "colpali":
|
|
docid2embs = dataset.load_all_embeddings()
|
|
elif args.retrieval_model_type == "colbert":
|
|
docid2embs, docid2lens = dataset.load_all_embeddings()
|
|
|
|
# len(docid2embs)
|
|
# docid2embs_page_reduced = reduce_embeddings(docid2embs, dim='page')
|
|
# docid2embs_token_reduced = reduce_embeddings(docid2embs, dim='token')
|
|
# docid2embs_page_token_reduced = reduce_embeddings(docid2embs, dim='page_token')
|
|
|
|
# flat_doc_embs = []
|
|
# for doc_id, doc_emb in docid2embs.items():
|
|
# flat_doc_embs += [doc_emb]
|
|
|
|
# flat_doc_embs = torch.cat(flat_doc_embs, dim=0)
|
|
|
|
# logger.info(flat_doc_embs.shape)
|
|
|
|
d = 128
|
|
quantizer = faiss.IndexFlatIP(d)
|
|
|
|
if args.faiss_index_type == "flatip":
|
|
index = quantizer
|
|
|
|
elif args.faiss_index_type == "ivfflat":
|
|
ncentroids = 1024
|
|
index = faiss.IndexIVFFlat(quantizer, d, ncentroids)
|
|
else:
|
|
nlist = 100
|
|
m = 8
|
|
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
|
|
|
|
logger.info("Flattening all PDF pages")
|
|
|
|
all_token_embeddings = []
|
|
token2pageuid = []
|
|
|
|
if args.retrieval_model_type == "colpali":
|
|
for doc_id, doc_emb in tqdm(docid2embs.items(), total=len(docid2embs)):
|
|
# e.g., doc_emb - torch.Size([9, 1030, 128])
|
|
for page_id in range(len(doc_emb)):
|
|
page_emb = doc_emb[page_id].view(-1, d)
|
|
all_token_embeddings.append(page_emb)
|
|
|
|
page_uid = f"{doc_id}_page{page_id}"
|
|
token2pageuid.extend([page_uid] * page_emb.shape[0])
|
|
|
|
elif args.retrieval_model_type == "colbert":
|
|
for doc_id, doc_emb in tqdm(docid2embs.items(), total=len(docid2embs)):
|
|
doc_lens = docid2lens[doc_id]
|
|
|
|
# e.g., doc_emb - torch.Size([2089, 128])
|
|
# e.g., doc_lens - tensor([258, 240, 251, 231, 229, 268, 235, 211, 166])
|
|
|
|
all_token_embeddings.append(doc_emb)
|
|
|
|
for page_id, page_len in enumerate(doc_lens):
|
|
page_uid = f"{doc_id}_page{page_id}"
|
|
token2pageuid.extend([page_uid] * page_len.item())
|
|
|
|
logger.info(len(all_token_embeddings))
|
|
|
|
all_token_embeddings = torch.cat(all_token_embeddings, dim=0)
|
|
all_token_embeddings = all_token_embeddings.float().numpy()
|
|
logger.info(all_token_embeddings.shape)
|
|
logger.info(len(token2pageuid))
|
|
|
|
logger.info("Creating index")
|
|
|
|
index.train(all_token_embeddings)
|
|
index.add(all_token_embeddings)
|
|
|
|
Path(args.output_dir).mkdir(exist_ok=True)
|
|
index_output_path = str(Path(args.output_dir) / "index.bin")
|
|
logger.info(f"Saving index at {index_output_path}")
|
|
|
|
faiss.write_index(index, index_output_path)
|
|
|
|
logger.info("Running an example query")
|
|
|
|
# Example query (should be np.float32)
|
|
example_text_query_emb = np.random.randn(20, 128).astype(np.float32)
|
|
|
|
# NN search
|
|
k = 10
|
|
D, I = index.search(example_text_query_emb, k) # noqa E741
|
|
|
|
# Sum the MaxSim scores across all query tokens for each document
|
|
final_page2scores = {}
|
|
|
|
# Iterate over query tokens
|
|
for q_idx, query_emb in enumerate(example_text_query_emb):
|
|
# Initialize a dictionary to hold document relevance scores
|
|
curent_q_page2scores = {}
|
|
|
|
for nn_idx in range(k):
|
|
found_nearest_doc_token_idx = I[q_idx, nn_idx]
|
|
|
|
page_uid = token2pageuid[
|
|
found_nearest_doc_token_idx
|
|
] # Get the document ID for this token
|
|
|
|
# reconstruct the original score
|
|
doc_token_emb = all_token_embeddings[found_nearest_doc_token_idx]
|
|
score = (query_emb * doc_token_emb).sum()
|
|
|
|
# MaxSim: aggregate the highest similarity score for each query token per document
|
|
if page_uid not in curent_q_page2scores:
|
|
curent_q_page2scores[page_uid] = score
|
|
else:
|
|
curent_q_page2scores[page_uid] = max(
|
|
curent_q_page2scores[page_uid], score
|
|
)
|
|
|
|
for page_uid, score in curent_q_page2scores.items():
|
|
if page_uid in final_page2scores:
|
|
final_page2scores[page_uid] += score
|
|
else:
|
|
final_page2scores[page_uid] = score
|
|
|
|
# Sort documents by their final relevance score
|
|
sorted_pages = sorted(final_page2scores.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
# Get the top-k document candidates
|
|
top_k_pages = sorted_pages[:k]
|
|
|
|
# Output the top-k document IDs and their scores
|
|
logger.info("Top-k page candidates with scores:")
|
|
for page_uid, score in top_k_pages:
|
|
logger.info(f"{page_uid} with score {score}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|