Release commit

Signed-off-by: Stephen Augustus <saugustus2@bloomberg.net>
This commit is contained in:
j-min
2025-01-30 17:04:56 -05:00
committed by oir
parent e04aeadfb0
commit 27aac8d521
50 changed files with 5692 additions and 0 deletions

View File

@ -0,0 +1,429 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import datetime
import json
import time
from pathlib import Path
import accelerate
import pytz
import torch
import transformers
from accelerate import Accelerator
from loguru import logger
from torch.utils.data import DataLoader
from tqdm import tqdm
from m3docrag.datasets.m3_docvqa import M3DocVQADataset, evaluate_prediction_file
from m3docrag.rag import MultimodalRAGModel
from m3docrag.retrieval import ColPaliRetrievalModel
from m3docrag.utils.args import parse_args
from m3docrag.utils.distributed import (
barrier,
global_rank,
is_distributed,
local_rank,
log_runtime_info,
print_gpu_stats,
supports_flash_attention,
)
from m3docrag.utils.paths import (
LOCAL_DATA_DIR,
LOCAL_EMBEDDINGS_DIR,
LOCAL_MODEL_DIR,
)
from m3docrag.utils.prompts import short_answer_template
from m3docrag.utils.tar import extract_tarfile
from m3docrag.vqa import VQAModel
def run_model(
rag_model: MultimodalRAGModel,
datum,
dataset: M3DocVQADataset,
docid2embs: dict[str, torch.Tensor],
docid2lens=None,
index=None,
token2pageuid=None,
all_token_embeddings=None,
n_return_pages=1,
args=None,
):
# if type(datum['num_pages']) == list:
batch = datum
datum = {}
for k, v in batch.items():
datum[k] = v[0]
query = datum["question"]
out_dict = {}
start = time.perf_counter()
# Stage 1: Page retrieval
# [(doc_id, page_idx, scores)...]
top_n_page_retrieval_results = rag_model.retrieve_pages_from_docs(
query=query,
docid2embs=docid2embs,
docid2lens=docid2lens,
index=index,
token2pageuid=token2pageuid,
all_token_embeddings=all_token_embeddings,
n_return_pages=n_return_pages,
show_progress=True,
)
logger.info(top_n_page_retrieval_results)
out_dict["page_retrieval_results"] = top_n_page_retrieval_results
end = time.perf_counter()
time_retrieval = end - start
logger.info(f"time_retrieval: {time_retrieval}")
start = time.perf_counter()
if args.retrieval_only:
pred_answer = ""
out_dict["pred_answer"] = pred_answer
else:
# Stage 2: QA on the retrived page
# Obtain images from the page retrieval results
images = []
for doc_id, page_idx, scores in top_n_page_retrieval_results:
page_images = dataset.get_images_from_doc_id(doc_id)
page_image = page_images[page_idx]
images += [page_image]
logger.info(len(images))
# Run VQA
if "florence" in args.model_name_or_path.lower():
text_input = query
else:
text_input = short_answer_template.substitute({"question": query})
pred_answer = rag_model.run_vqa(images=images, question=text_input)
assert isinstance(pred_answer, str)
out_dict["pred_answer"] = pred_answer
end = time.perf_counter()
time_qa = end - start
logger.info(f"time_qa: {time_qa}")
out_dict["time_retrieval"] = time_retrieval
out_dict["time_qa"] = time_qa
logger.info(query)
logger.info(pred_answer)
logger.info(datum["answers"])
return out_dict
def evaluate(data_loader, rag_model, index=None, data_len=None, args=None, **kwargs):
if data_len is not None:
logger.info(f"eval on the first {data_len} items")
# docid2embs = data_loader.dataset.load_all_embeddings()
logger.info("Preparing doc indices")
if args.retrieval_model_type == "colpali":
docid2embs = data_loader.dataset.load_all_embeddings()
# reduce_embeddings(docid2embs=docid2embs)
# docid2embs_page_reudced = reduce_embeddings(docid2embs, dim='page')
# docid2embs_token_reudced = reduce_embeddings(docid2embs, dim='token')
# docid2embs_page_token_reudced = reduce_embeddings(docid2embs, dim='page_token')
all_token_embeddings = []
token2pageuid = []
if args.retrieval_model_type == "colpali":
for doc_id, doc_emb in tqdm(docid2embs.items(), total=len(docid2embs)):
# e.g., doc_emb - torch.Size([9, 1030, 128])
for page_id in range(len(doc_emb)):
page_emb = doc_emb[page_id].view(-1, 128)
all_token_embeddings.append(page_emb)
page_uid = f"{doc_id}_page{page_id}"
token2pageuid.extend([page_uid] * page_emb.shape[0])
logger.info(len(all_token_embeddings))
all_token_embeddings = torch.cat(all_token_embeddings, dim=0)
all_token_embeddings = all_token_embeddings.float().numpy()
logger.info("Created flattened token embeddings / token2pageuid")
qid2result = {}
total_time_retrieval = 0
total_time_qa = 0
for batch_idx, batch in enumerate(tqdm(data_loader)):
bs = len(batch["question"])
# Single batch
assert bs == 1
qid = batch["qid"][0]
with torch.no_grad():
outputs = run_model(
rag_model,
batch,
dataset=data_loader.dataset,
docid2embs=docid2embs,
# docid2embs=docid2embs_page_reudced,
# docid2embs=docid2embs_token_reudced,
# docid2embs=docid2embs_page_token_reudced,
index=index,
token2pageuid=token2pageuid,
all_token_embeddings=all_token_embeddings,
n_return_pages=args.n_retrieval_pages,
args=args,
)
pred_answer = outputs["pred_answer"]
assert isinstance(pred_answer, str), type(pred_answer)
total_time_qa += outputs["time_qa"]
total_time_retrieval += outputs["time_retrieval"]
qid2result[qid] = outputs
logger.info(total_time_qa)
logger.info(total_time_retrieval)
avg_time_qa = total_time_qa / len(data_loader)
avg_time_retrieval = total_time_retrieval / len(data_loader)
logger.info(avg_time_qa)
logger.info(avg_time_retrieval)
return qid2result
def main():
args = parse_args()
logger.info(torch.__version__)
logger.info(transformers.__version__)
logger.info(accelerate.__version__)
log_runtime_info()
print_gpu_stats()
accelerator = Accelerator()
if not is_distributed() or global_rank() == 0:
logger.info(f"Process {global_rank()}:{local_rank()} - args {args}")
local_data_dir = Path(LOCAL_DATA_DIR) / args.data_name
local_embedding_dir = Path(LOCAL_EMBEDDINGS_DIR) / args.embedding_name
local_model_dir = Path(LOCAL_MODEL_DIR) / args.model_name_or_path
local_retrieval_model_dir = (
Path(LOCAL_MODEL_DIR) / args.retrieval_model_name_or_path
)
local_retrieval_adapter_model_dir = (
Path(LOCAL_MODEL_DIR) / args.retrieval_adapter_model_name_or_path
)
# local_ft_model_dir = Path(LOCAL_MODEL_DIR) / args.ft_model_name_or_path
local_index_dir = (
Path(LOCAL_EMBEDDINGS_DIR)
/ f"{args.embedding_name}_pageindex_{args.faiss_index_type}"
)
# local_answer_extraction_model_dir = Path(LOCAL_MODEL_DIR) / args.answer_extraction_model_name_or_path
if is_distributed():
barrier()
if not is_distributed() or global_rank() == 0:
if not local_data_dir.exists():
raise ValueError(f"Data directory {local_data_dir} does not exist")
if not local_embedding_dir.exists():
raise ValueError(
f"Embedding directory {local_embedding_dir} does not exist"
)
if local_model_dir.exists() or args.retrieval_only:
logger.info("Model exists - pass")
else:
raise ValueError(
f"Model directory {local_model_dir} does not exist"
)
if args.use_retrieval:
if not local_retrieval_model_dir.exists():
raise ValueError(
f"Retrieval model directory {local_retrieval_model_dir} does not exist"
)
if not local_retrieval_adapter_model_dir.exists():
raise ValueError(
f"Retrieval adapter model directory {local_retrieval_adapter_model_dir} does not exist"
)
if not local_index_dir.exists():
raise ValueError(
f"Index directory {local_index_dir} does not exist"
)
if is_distributed():
barrier()
# Create Retrieval Model (Step 1)
assert args.use_retrieval
if args.retrieval_model_type == "colpali":
colpali_model = ColPaliRetrievalModel(
backbone_name_or_path=local_retrieval_model_dir,
adapter_name_or_path=local_retrieval_adapter_model_dir,
)
retrieval_model = colpali_model
logger.info(f"loaded Retrieval model -: {local_retrieval_model_dir}")
# Create QA / VQA Model (Step 2)
if args.retrieval_only:
rag_model = MultimodalRAGModel(retrieval_model=retrieval_model, vqa_model=None)
logger.info("skipping QA model")
else:
if "florence" in args.model_name_or_path.lower():
model_type = "florence2"
elif "idefics2" in args.model_name_or_path.lower():
model_type = "idefics2"
elif "idefics3" in args.model_name_or_path.lower():
model_type = "idefics3"
elif "internvl2" in args.model_name_or_path.lower():
model_type = "internvl2"
elif "qwen2" in args.model_name_or_path.lower():
model_type = "qwen2"
else:
raise KeyError(f"model type unknown for: {args.model_name_or_path}")
use_flash_attn = True
attn_implementation = "flash_attention_2"
if not supports_flash_attention():
use_flash_attn = False
attn_implementation = "eager"
vqa_model = VQAModel(
model_name_or_path=local_model_dir,
model_type=model_type,
bits=args.bits,
use_flash_attn=use_flash_attn,
attn_implementation=attn_implementation,
)
logger.info(f"loaded VQA model - {model_type}: {local_model_dir}")
rag_model = MultimodalRAGModel(
retrieval_model=retrieval_model, vqa_model=vqa_model
)
logger.info("Created RAG model")
dataset = M3DocVQADataset(args=args)
logger.info("loaded dataset")
index = None
if local_index_dir.exists():
logger.info("Loading faiss index")
import faiss
index = faiss.read_index(str(local_index_dir / "index.bin"))
logger.info("Loading faiss index -- done")
def list_collate_fn(batch):
batch = {
k: [dic[k] for dic in batch] for k in batch[0]
} # List of dictionaries to dict of lists.
return batch
data_loader = DataLoader(
dataset, batch_size=1, shuffle=False, collate_fn=list_collate_fn
)
if args.retrieval_only:
retrieval_model.model, data_loader = accelerator.prepare(
retrieval_model.model, data_loader
)
else:
retrieval_model.model, data_loader, vqa_model.model = accelerator.prepare(
retrieval_model.model, data_loader, vqa_model.model
)
eval_out = evaluate(
data_loader=data_loader,
rag_model=rag_model,
index=index,
data_len=args.data_len,
args=args,
)
samples = eval_out
EST = pytz.timezone("US/Eastern")
experiment_date = (
datetime.datetime.now().astimezone(EST).strftime("%Y-%m-%d_%H-%M-%S")
)
save_dir = Path(args.output_dir)
save_dir.mkdir(exist_ok=True, parents=True)
logger.info("Results will be saved at:", save_dir)
if args.retrieval_model_type == "colpali":
ret_name = args.retrieval_adapter_model_name_or_path
else:
ret_name = args.retrieval_model_name_or_path
if args.retrieval_only:
pred_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{experiment_date}.json"
else:
pred_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{args.model_name_or_path}_{experiment_date}.json"
results_file = save_dir / pred_save_fname
with open(results_file, "w") as f:
json.dump(samples, f, indent=4)
logger.info(f"Prediction results saved at: {results_file}")
# Evaluation
all_eval_scores = evaluate_prediction_file(
samples,
dataset.mmqa_data_path, # '/job/datasets/m3-docvqa/MMQA_dev.jsonl'
)
if args.retrieval_only:
eval_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{experiment_date}_eval_results.json"
else:
eval_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{args.model_name_or_path}_{experiment_date}_eval_results.json"
results_file = save_dir / eval_save_fname
with open(results_file, "w") as f:
json.dump(all_eval_scores, f, indent=4)
logger.info(f"Evaluation results saved at: {results_file}")
if is_distributed():
barrier()
if __name__ == "__main__":
main()