Release commit

Signed-off-by: Stephen Augustus <saugustus2@bloomberg.net>
This commit is contained in:
j-min
2025-01-30 17:04:56 -05:00
committed by oir
parent e04aeadfb0
commit 27aac8d521
50 changed files with 5692 additions and 0 deletions

View File

@ -0,0 +1,171 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
import faiss
import numpy as np
import torch
from loguru import logger
from tqdm.auto import tqdm
from m3docrag.datasets.m3_docvqa.dataset import M3DocVQADataset
from m3docrag.utils.args import parse_args
def main():
args = parse_args()
logger.info("Loading M3DocVQA")
dataset = M3DocVQADataset(args)
logger.info(f"Loading M3DocVQA -- all {args.retrieval_model_type} embeddings")
if args.retrieval_model_type == "colpali":
docid2embs = dataset.load_all_embeddings()
elif args.retrieval_model_type == "colbert":
docid2embs, docid2lens = dataset.load_all_embeddings()
# len(docid2embs)
# docid2embs_page_reduced = reduce_embeddings(docid2embs, dim='page')
# docid2embs_token_reduced = reduce_embeddings(docid2embs, dim='token')
# docid2embs_page_token_reduced = reduce_embeddings(docid2embs, dim='page_token')
# flat_doc_embs = []
# for doc_id, doc_emb in docid2embs.items():
# flat_doc_embs += [doc_emb]
# flat_doc_embs = torch.cat(flat_doc_embs, dim=0)
# logger.info(flat_doc_embs.shape)
d = 128
quantizer = faiss.IndexFlatIP(d)
if args.faiss_index_type == "flatip":
index = quantizer
elif args.faiss_index_type == "ivfflat":
ncentroids = 1024
index = faiss.IndexIVFFlat(quantizer, d, ncentroids)
else:
nlist = 100
m = 8
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
logger.info("Flattening all PDF pages")
all_token_embeddings = []
token2pageuid = []
if args.retrieval_model_type == "colpali":
for doc_id, doc_emb in tqdm(docid2embs.items(), total=len(docid2embs)):
# e.g., doc_emb - torch.Size([9, 1030, 128])
for page_id in range(len(doc_emb)):
page_emb = doc_emb[page_id].view(-1, d)
all_token_embeddings.append(page_emb)
page_uid = f"{doc_id}_page{page_id}"
token2pageuid.extend([page_uid] * page_emb.shape[0])
elif args.retrieval_model_type == "colbert":
for doc_id, doc_emb in tqdm(docid2embs.items(), total=len(docid2embs)):
doc_lens = docid2lens[doc_id]
# e.g., doc_emb - torch.Size([2089, 128])
# e.g., doc_lens - tensor([258, 240, 251, 231, 229, 268, 235, 211, 166])
all_token_embeddings.append(doc_emb)
for page_id, page_len in enumerate(doc_lens):
page_uid = f"{doc_id}_page{page_id}"
token2pageuid.extend([page_uid] * page_len.item())
logger.info(len(all_token_embeddings))
all_token_embeddings = torch.cat(all_token_embeddings, dim=0)
all_token_embeddings = all_token_embeddings.float().numpy()
logger.info(all_token_embeddings.shape)
logger.info(len(token2pageuid))
logger.info("Creating index")
index.train(all_token_embeddings)
index.add(all_token_embeddings)
Path(args.output_dir).mkdir(exist_ok=True)
index_output_path = str(Path(args.output_dir) / "index.bin")
logger.info(f"Saving index at {index_output_path}")
faiss.write_index(index, index_output_path)
logger.info("Running an example query")
# Example query (should be np.float32)
example_text_query_emb = np.random.randn(20, 128).astype(np.float32)
# NN search
k = 10
D, I = index.search(example_text_query_emb, k) # noqa E741
# Sum the MaxSim scores across all query tokens for each document
final_page2scores = {}
# Iterate over query tokens
for q_idx, query_emb in enumerate(example_text_query_emb):
# Initialize a dictionary to hold document relevance scores
curent_q_page2scores = {}
for nn_idx in range(k):
found_nearest_doc_token_idx = I[q_idx, nn_idx]
page_uid = token2pageuid[
found_nearest_doc_token_idx
] # Get the document ID for this token
# reconstruct the original score
doc_token_emb = all_token_embeddings[found_nearest_doc_token_idx]
score = (query_emb * doc_token_emb).sum()
# MaxSim: aggregate the highest similarity score for each query token per document
if page_uid not in curent_q_page2scores:
curent_q_page2scores[page_uid] = score
else:
curent_q_page2scores[page_uid] = max(
curent_q_page2scores[page_uid], score
)
for page_uid, score in curent_q_page2scores.items():
if page_uid in final_page2scores:
final_page2scores[page_uid] += score
else:
final_page2scores[page_uid] = score
# Sort documents by their final relevance score
sorted_pages = sorted(final_page2scores.items(), key=lambda x: x[1], reverse=True)
# Get the top-k document candidates
top_k_pages = sorted_pages[:k]
# Output the top-k document IDs and their scores
logger.info("Top-k page candidates with scores:")
for page_uid, score in top_k_pages:
logger.info(f"{page_uid} with score {score}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,180 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
import accelerate
import safetensors
import torch
import transformers
from accelerate import Accelerator
from loguru import logger
from tqdm import tqdm
from m3docrag.datasets.m3_docvqa import M3DocVQADataset
from m3docrag.retrieval import ColPaliRetrievalModel
from m3docrag.utils.args import parse_args
from m3docrag.utils.distributed import (
barrier,
global_rank,
is_distributed,
local_rank,
log_runtime_info,
print_gpu_stats,
)
from m3docrag.utils.paths import (
LOCAL_DATA_DIR,
LOCAL_MODEL_DIR,
)
logger.info(torch.__version__)
logger.info(transformers.__version__)
logger.info(accelerate.__version__)
def main():
args = parse_args()
log_runtime_info()
print_gpu_stats()
accelerator = Accelerator()
if not is_distributed() or global_rank() == 0:
logger.info(f"Process {global_rank()}:{local_rank()} - args {args}")
if is_distributed():
barrier()
local_data_dir = Path(LOCAL_DATA_DIR) / args.data_name
local_retrieval_model_dir = (
Path(LOCAL_MODEL_DIR) / args.retrieval_model_name_or_path
)
local_retrieval_adapter_model_dir = (
Path(LOCAL_MODEL_DIR) / args.retrieval_adapter_model_name_or_path
)
# Download datasets / model checkpoints
if not is_distributed() or global_rank() == 0:
if not local_data_dir.exists():
raise ValueError(f"Data directory {local_data_dir} does not exist")
assert args.use_retrieval, args.use_retrieval
if not local_retrieval_model_dir.exists():
raise ValueError(
f"Retrieval model directory {local_retrieval_model_dir} does not exist"
)
if args.retrieval_model_type == "colpali":
if not local_retrieval_adapter_model_dir.exists():
raise ValueError(
f"Retrieval adapter model directory {local_retrieval_adapter_model_dir} does not exist"
)
if is_distributed():
barrier()
if args.retrieval_model_type == "colpali":
colpali_model = ColPaliRetrievalModel(
backbone_name_or_path=local_retrieval_model_dir,
adapter_name_or_path=local_retrieval_adapter_model_dir,
)
retrieval_model = colpali_model
if args.data_name == "m3-docvqa":
dataset = M3DocVQADataset(args=args)
def collate_fn(examples):
out = {}
if args.retrieval_model_type == "colpali":
for k in ["doc_id", "images"]:
out[k] = [ex[k] for ex in examples]
return out
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
collate_fn=collate_fn,
batch_size=1,
shuffle=False,
batch_sampler=None,
sampler=None,
drop_last=False,
num_workers=args.dataloader_num_workers,
)
retrieval_model.model, data_loader = accelerator.prepare(
retrieval_model.model, data_loader
)
all_results = []
save_dir = Path(args.output_dir)
save_dir.mkdir(exist_ok=True, parents=True)
logger.info(f"Results will be saved at: {save_dir}")
for i, datum in enumerate(tqdm(data_loader)):
print(f"{i} / {len(data_loader)}")
if args.data_name == "mp-docvqa":
page_name = datum["page_name"][0]
logger.info(page_name)
else:
doc_id = datum["doc_id"][0]
logger.info(doc_id)
if args.retrieval_model_type == "colpali":
images = datum["images"][0]
doc_embs = colpali_model.encode_images(
images=images,
batch_size=args.per_device_eval_batch_size,
to_cpu=True,
use_tqdm=False,
)
# [n_pages, n_tokens, emb_dim]
doc_embs = torch.stack(doc_embs, dim=0)
# Store embedding as BF16 by default
doc_embs = doc_embs.to(torch.bfloat16)
logger.info(doc_embs.shape)
if args.retrieval_model_type == "colpali":
logger.info(doc_embs[0, 0, :5])
# Save the embedding
if args.data_name == "mp-docvqa":
local_save_fname = f"{page_name}.safetensors"
else:
local_save_fname = f"{doc_id}.safetensors"
local_save_path = save_dir / local_save_fname
if args.retrieval_model_type == "colpali":
safetensors.torch.save_file({"embeddings": doc_embs}, local_save_path)
all_results.append({"save_path": local_save_path})
logger.info(
f"Process {global_rank()}:{local_rank()} Results correctly saved at {save_dir}"
)
if is_distributed():
barrier()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,429 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import datetime
import json
import time
from pathlib import Path
import accelerate
import pytz
import torch
import transformers
from accelerate import Accelerator
from loguru import logger
from torch.utils.data import DataLoader
from tqdm import tqdm
from m3docrag.datasets.m3_docvqa import M3DocVQADataset, evaluate_prediction_file
from m3docrag.rag import MultimodalRAGModel
from m3docrag.retrieval import ColPaliRetrievalModel
from m3docrag.utils.args import parse_args
from m3docrag.utils.distributed import (
barrier,
global_rank,
is_distributed,
local_rank,
log_runtime_info,
print_gpu_stats,
supports_flash_attention,
)
from m3docrag.utils.paths import (
LOCAL_DATA_DIR,
LOCAL_EMBEDDINGS_DIR,
LOCAL_MODEL_DIR,
)
from m3docrag.utils.prompts import short_answer_template
from m3docrag.utils.tar import extract_tarfile
from m3docrag.vqa import VQAModel
def run_model(
rag_model: MultimodalRAGModel,
datum,
dataset: M3DocVQADataset,
docid2embs: dict[str, torch.Tensor],
docid2lens=None,
index=None,
token2pageuid=None,
all_token_embeddings=None,
n_return_pages=1,
args=None,
):
# if type(datum['num_pages']) == list:
batch = datum
datum = {}
for k, v in batch.items():
datum[k] = v[0]
query = datum["question"]
out_dict = {}
start = time.perf_counter()
# Stage 1: Page retrieval
# [(doc_id, page_idx, scores)...]
top_n_page_retrieval_results = rag_model.retrieve_pages_from_docs(
query=query,
docid2embs=docid2embs,
docid2lens=docid2lens,
index=index,
token2pageuid=token2pageuid,
all_token_embeddings=all_token_embeddings,
n_return_pages=n_return_pages,
show_progress=True,
)
logger.info(top_n_page_retrieval_results)
out_dict["page_retrieval_results"] = top_n_page_retrieval_results
end = time.perf_counter()
time_retrieval = end - start
logger.info(f"time_retrieval: {time_retrieval}")
start = time.perf_counter()
if args.retrieval_only:
pred_answer = ""
out_dict["pred_answer"] = pred_answer
else:
# Stage 2: QA on the retrived page
# Obtain images from the page retrieval results
images = []
for doc_id, page_idx, scores in top_n_page_retrieval_results:
page_images = dataset.get_images_from_doc_id(doc_id)
page_image = page_images[page_idx]
images += [page_image]
logger.info(len(images))
# Run VQA
if "florence" in args.model_name_or_path.lower():
text_input = query
else:
text_input = short_answer_template.substitute({"question": query})
pred_answer = rag_model.run_vqa(images=images, question=text_input)
assert isinstance(pred_answer, str)
out_dict["pred_answer"] = pred_answer
end = time.perf_counter()
time_qa = end - start
logger.info(f"time_qa: {time_qa}")
out_dict["time_retrieval"] = time_retrieval
out_dict["time_qa"] = time_qa
logger.info(query)
logger.info(pred_answer)
logger.info(datum["answers"])
return out_dict
def evaluate(data_loader, rag_model, index=None, data_len=None, args=None, **kwargs):
if data_len is not None:
logger.info(f"eval on the first {data_len} items")
# docid2embs = data_loader.dataset.load_all_embeddings()
logger.info("Preparing doc indices")
if args.retrieval_model_type == "colpali":
docid2embs = data_loader.dataset.load_all_embeddings()
# reduce_embeddings(docid2embs=docid2embs)
# docid2embs_page_reudced = reduce_embeddings(docid2embs, dim='page')
# docid2embs_token_reudced = reduce_embeddings(docid2embs, dim='token')
# docid2embs_page_token_reudced = reduce_embeddings(docid2embs, dim='page_token')
all_token_embeddings = []
token2pageuid = []
if args.retrieval_model_type == "colpali":
for doc_id, doc_emb in tqdm(docid2embs.items(), total=len(docid2embs)):
# e.g., doc_emb - torch.Size([9, 1030, 128])
for page_id in range(len(doc_emb)):
page_emb = doc_emb[page_id].view(-1, 128)
all_token_embeddings.append(page_emb)
page_uid = f"{doc_id}_page{page_id}"
token2pageuid.extend([page_uid] * page_emb.shape[0])
logger.info(len(all_token_embeddings))
all_token_embeddings = torch.cat(all_token_embeddings, dim=0)
all_token_embeddings = all_token_embeddings.float().numpy()
logger.info("Created flattened token embeddings / token2pageuid")
qid2result = {}
total_time_retrieval = 0
total_time_qa = 0
for batch_idx, batch in enumerate(tqdm(data_loader)):
bs = len(batch["question"])
# Single batch
assert bs == 1
qid = batch["qid"][0]
with torch.no_grad():
outputs = run_model(
rag_model,
batch,
dataset=data_loader.dataset,
docid2embs=docid2embs,
# docid2embs=docid2embs_page_reudced,
# docid2embs=docid2embs_token_reudced,
# docid2embs=docid2embs_page_token_reudced,
index=index,
token2pageuid=token2pageuid,
all_token_embeddings=all_token_embeddings,
n_return_pages=args.n_retrieval_pages,
args=args,
)
pred_answer = outputs["pred_answer"]
assert isinstance(pred_answer, str), type(pred_answer)
total_time_qa += outputs["time_qa"]
total_time_retrieval += outputs["time_retrieval"]
qid2result[qid] = outputs
logger.info(total_time_qa)
logger.info(total_time_retrieval)
avg_time_qa = total_time_qa / len(data_loader)
avg_time_retrieval = total_time_retrieval / len(data_loader)
logger.info(avg_time_qa)
logger.info(avg_time_retrieval)
return qid2result
def main():
args = parse_args()
logger.info(torch.__version__)
logger.info(transformers.__version__)
logger.info(accelerate.__version__)
log_runtime_info()
print_gpu_stats()
accelerator = Accelerator()
if not is_distributed() or global_rank() == 0:
logger.info(f"Process {global_rank()}:{local_rank()} - args {args}")
local_data_dir = Path(LOCAL_DATA_DIR) / args.data_name
local_embedding_dir = Path(LOCAL_EMBEDDINGS_DIR) / args.embedding_name
local_model_dir = Path(LOCAL_MODEL_DIR) / args.model_name_or_path
local_retrieval_model_dir = (
Path(LOCAL_MODEL_DIR) / args.retrieval_model_name_or_path
)
local_retrieval_adapter_model_dir = (
Path(LOCAL_MODEL_DIR) / args.retrieval_adapter_model_name_or_path
)
# local_ft_model_dir = Path(LOCAL_MODEL_DIR) / args.ft_model_name_or_path
local_index_dir = (
Path(LOCAL_EMBEDDINGS_DIR)
/ f"{args.embedding_name}_pageindex_{args.faiss_index_type}"
)
# local_answer_extraction_model_dir = Path(LOCAL_MODEL_DIR) / args.answer_extraction_model_name_or_path
if is_distributed():
barrier()
if not is_distributed() or global_rank() == 0:
if not local_data_dir.exists():
raise ValueError(f"Data directory {local_data_dir} does not exist")
if not local_embedding_dir.exists():
raise ValueError(
f"Embedding directory {local_embedding_dir} does not exist"
)
if local_model_dir.exists() or args.retrieval_only:
logger.info("Model exists - pass")
else:
raise ValueError(
f"Model directory {local_model_dir} does not exist"
)
if args.use_retrieval:
if not local_retrieval_model_dir.exists():
raise ValueError(
f"Retrieval model directory {local_retrieval_model_dir} does not exist"
)
if not local_retrieval_adapter_model_dir.exists():
raise ValueError(
f"Retrieval adapter model directory {local_retrieval_adapter_model_dir} does not exist"
)
if not local_index_dir.exists():
raise ValueError(
f"Index directory {local_index_dir} does not exist"
)
if is_distributed():
barrier()
# Create Retrieval Model (Step 1)
assert args.use_retrieval
if args.retrieval_model_type == "colpali":
colpali_model = ColPaliRetrievalModel(
backbone_name_or_path=local_retrieval_model_dir,
adapter_name_or_path=local_retrieval_adapter_model_dir,
)
retrieval_model = colpali_model
logger.info(f"loaded Retrieval model -: {local_retrieval_model_dir}")
# Create QA / VQA Model (Step 2)
if args.retrieval_only:
rag_model = MultimodalRAGModel(retrieval_model=retrieval_model, vqa_model=None)
logger.info("skipping QA model")
else:
if "florence" in args.model_name_or_path.lower():
model_type = "florence2"
elif "idefics2" in args.model_name_or_path.lower():
model_type = "idefics2"
elif "idefics3" in args.model_name_or_path.lower():
model_type = "idefics3"
elif "internvl2" in args.model_name_or_path.lower():
model_type = "internvl2"
elif "qwen2" in args.model_name_or_path.lower():
model_type = "qwen2"
else:
raise KeyError(f"model type unknown for: {args.model_name_or_path}")
use_flash_attn = True
attn_implementation = "flash_attention_2"
if not supports_flash_attention():
use_flash_attn = False
attn_implementation = "eager"
vqa_model = VQAModel(
model_name_or_path=local_model_dir,
model_type=model_type,
bits=args.bits,
use_flash_attn=use_flash_attn,
attn_implementation=attn_implementation,
)
logger.info(f"loaded VQA model - {model_type}: {local_model_dir}")
rag_model = MultimodalRAGModel(
retrieval_model=retrieval_model, vqa_model=vqa_model
)
logger.info("Created RAG model")
dataset = M3DocVQADataset(args=args)
logger.info("loaded dataset")
index = None
if local_index_dir.exists():
logger.info("Loading faiss index")
import faiss
index = faiss.read_index(str(local_index_dir / "index.bin"))
logger.info("Loading faiss index -- done")
def list_collate_fn(batch):
batch = {
k: [dic[k] for dic in batch] for k in batch[0]
} # List of dictionaries to dict of lists.
return batch
data_loader = DataLoader(
dataset, batch_size=1, shuffle=False, collate_fn=list_collate_fn
)
if args.retrieval_only:
retrieval_model.model, data_loader = accelerator.prepare(
retrieval_model.model, data_loader
)
else:
retrieval_model.model, data_loader, vqa_model.model = accelerator.prepare(
retrieval_model.model, data_loader, vqa_model.model
)
eval_out = evaluate(
data_loader=data_loader,
rag_model=rag_model,
index=index,
data_len=args.data_len,
args=args,
)
samples = eval_out
EST = pytz.timezone("US/Eastern")
experiment_date = (
datetime.datetime.now().astimezone(EST).strftime("%Y-%m-%d_%H-%M-%S")
)
save_dir = Path(args.output_dir)
save_dir.mkdir(exist_ok=True, parents=True)
logger.info("Results will be saved at:", save_dir)
if args.retrieval_model_type == "colpali":
ret_name = args.retrieval_adapter_model_name_or_path
else:
ret_name = args.retrieval_model_name_or_path
if args.retrieval_only:
pred_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{experiment_date}.json"
else:
pred_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{args.model_name_or_path}_{experiment_date}.json"
results_file = save_dir / pred_save_fname
with open(results_file, "w") as f:
json.dump(samples, f, indent=4)
logger.info(f"Prediction results saved at: {results_file}")
# Evaluation
all_eval_scores = evaluate_prediction_file(
samples,
dataset.mmqa_data_path, # '/job/datasets/m3-docvqa/MMQA_dev.jsonl'
)
if args.retrieval_only:
eval_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{experiment_date}_eval_results.json"
else:
eval_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{args.model_name_or_path}_{experiment_date}_eval_results.json"
results_file = save_dir / eval_save_fname
with open(results_file, "w") as f:
json.dump(all_eval_scores, f, indent=4)
logger.info(f"Evaluation results saved at: {results_file}")
if is_distributed():
barrier()
if __name__ == "__main__":
main()