Release commit
Signed-off-by: Stephen Augustus <saugustus2@bloomberg.net>
This commit is contained in:
171
examples/run_indexing_m3docvqa.py
Normal file
171
examples/run_indexing_m3docvqa.py
Normal file
@ -0,0 +1,171 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from m3docrag.datasets.m3_docvqa.dataset import M3DocVQADataset
|
||||
from m3docrag.utils.args import parse_args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
logger.info("Loading M3DocVQA")
|
||||
|
||||
dataset = M3DocVQADataset(args)
|
||||
|
||||
logger.info(f"Loading M3DocVQA -- all {args.retrieval_model_type} embeddings")
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
docid2embs = dataset.load_all_embeddings()
|
||||
elif args.retrieval_model_type == "colbert":
|
||||
docid2embs, docid2lens = dataset.load_all_embeddings()
|
||||
|
||||
# len(docid2embs)
|
||||
# docid2embs_page_reduced = reduce_embeddings(docid2embs, dim='page')
|
||||
# docid2embs_token_reduced = reduce_embeddings(docid2embs, dim='token')
|
||||
# docid2embs_page_token_reduced = reduce_embeddings(docid2embs, dim='page_token')
|
||||
|
||||
# flat_doc_embs = []
|
||||
# for doc_id, doc_emb in docid2embs.items():
|
||||
# flat_doc_embs += [doc_emb]
|
||||
|
||||
# flat_doc_embs = torch.cat(flat_doc_embs, dim=0)
|
||||
|
||||
# logger.info(flat_doc_embs.shape)
|
||||
|
||||
d = 128
|
||||
quantizer = faiss.IndexFlatIP(d)
|
||||
|
||||
if args.faiss_index_type == "flatip":
|
||||
index = quantizer
|
||||
|
||||
elif args.faiss_index_type == "ivfflat":
|
||||
ncentroids = 1024
|
||||
index = faiss.IndexIVFFlat(quantizer, d, ncentroids)
|
||||
else:
|
||||
nlist = 100
|
||||
m = 8
|
||||
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
|
||||
|
||||
logger.info("Flattening all PDF pages")
|
||||
|
||||
all_token_embeddings = []
|
||||
token2pageuid = []
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
for doc_id, doc_emb in tqdm(docid2embs.items(), total=len(docid2embs)):
|
||||
# e.g., doc_emb - torch.Size([9, 1030, 128])
|
||||
for page_id in range(len(doc_emb)):
|
||||
page_emb = doc_emb[page_id].view(-1, d)
|
||||
all_token_embeddings.append(page_emb)
|
||||
|
||||
page_uid = f"{doc_id}_page{page_id}"
|
||||
token2pageuid.extend([page_uid] * page_emb.shape[0])
|
||||
|
||||
elif args.retrieval_model_type == "colbert":
|
||||
for doc_id, doc_emb in tqdm(docid2embs.items(), total=len(docid2embs)):
|
||||
doc_lens = docid2lens[doc_id]
|
||||
|
||||
# e.g., doc_emb - torch.Size([2089, 128])
|
||||
# e.g., doc_lens - tensor([258, 240, 251, 231, 229, 268, 235, 211, 166])
|
||||
|
||||
all_token_embeddings.append(doc_emb)
|
||||
|
||||
for page_id, page_len in enumerate(doc_lens):
|
||||
page_uid = f"{doc_id}_page{page_id}"
|
||||
token2pageuid.extend([page_uid] * page_len.item())
|
||||
|
||||
logger.info(len(all_token_embeddings))
|
||||
|
||||
all_token_embeddings = torch.cat(all_token_embeddings, dim=0)
|
||||
all_token_embeddings = all_token_embeddings.float().numpy()
|
||||
logger.info(all_token_embeddings.shape)
|
||||
logger.info(len(token2pageuid))
|
||||
|
||||
logger.info("Creating index")
|
||||
|
||||
index.train(all_token_embeddings)
|
||||
index.add(all_token_embeddings)
|
||||
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
index_output_path = str(Path(args.output_dir) / "index.bin")
|
||||
logger.info(f"Saving index at {index_output_path}")
|
||||
|
||||
faiss.write_index(index, index_output_path)
|
||||
|
||||
logger.info("Running an example query")
|
||||
|
||||
# Example query (should be np.float32)
|
||||
example_text_query_emb = np.random.randn(20, 128).astype(np.float32)
|
||||
|
||||
# NN search
|
||||
k = 10
|
||||
D, I = index.search(example_text_query_emb, k) # noqa E741
|
||||
|
||||
# Sum the MaxSim scores across all query tokens for each document
|
||||
final_page2scores = {}
|
||||
|
||||
# Iterate over query tokens
|
||||
for q_idx, query_emb in enumerate(example_text_query_emb):
|
||||
# Initialize a dictionary to hold document relevance scores
|
||||
curent_q_page2scores = {}
|
||||
|
||||
for nn_idx in range(k):
|
||||
found_nearest_doc_token_idx = I[q_idx, nn_idx]
|
||||
|
||||
page_uid = token2pageuid[
|
||||
found_nearest_doc_token_idx
|
||||
] # Get the document ID for this token
|
||||
|
||||
# reconstruct the original score
|
||||
doc_token_emb = all_token_embeddings[found_nearest_doc_token_idx]
|
||||
score = (query_emb * doc_token_emb).sum()
|
||||
|
||||
# MaxSim: aggregate the highest similarity score for each query token per document
|
||||
if page_uid not in curent_q_page2scores:
|
||||
curent_q_page2scores[page_uid] = score
|
||||
else:
|
||||
curent_q_page2scores[page_uid] = max(
|
||||
curent_q_page2scores[page_uid], score
|
||||
)
|
||||
|
||||
for page_uid, score in curent_q_page2scores.items():
|
||||
if page_uid in final_page2scores:
|
||||
final_page2scores[page_uid] += score
|
||||
else:
|
||||
final_page2scores[page_uid] = score
|
||||
|
||||
# Sort documents by their final relevance score
|
||||
sorted_pages = sorted(final_page2scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Get the top-k document candidates
|
||||
top_k_pages = sorted_pages[:k]
|
||||
|
||||
# Output the top-k document IDs and their scores
|
||||
logger.info("Top-k page candidates with scores:")
|
||||
for page_uid, score in top_k_pages:
|
||||
logger.info(f"{page_uid} with score {score}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
180
examples/run_page_embedding.py
Normal file
180
examples/run_page_embedding.py
Normal file
@ -0,0 +1,180 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import safetensors
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
from m3docrag.datasets.m3_docvqa import M3DocVQADataset
|
||||
from m3docrag.retrieval import ColPaliRetrievalModel
|
||||
from m3docrag.utils.args import parse_args
|
||||
from m3docrag.utils.distributed import (
|
||||
barrier,
|
||||
global_rank,
|
||||
is_distributed,
|
||||
local_rank,
|
||||
log_runtime_info,
|
||||
print_gpu_stats,
|
||||
)
|
||||
from m3docrag.utils.paths import (
|
||||
LOCAL_DATA_DIR,
|
||||
LOCAL_MODEL_DIR,
|
||||
)
|
||||
|
||||
logger.info(torch.__version__)
|
||||
logger.info(transformers.__version__)
|
||||
logger.info(accelerate.__version__)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
log_runtime_info()
|
||||
print_gpu_stats()
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
if not is_distributed() or global_rank() == 0:
|
||||
logger.info(f"Process {global_rank()}:{local_rank()} - args {args}")
|
||||
|
||||
if is_distributed():
|
||||
barrier()
|
||||
|
||||
local_data_dir = Path(LOCAL_DATA_DIR) / args.data_name
|
||||
local_retrieval_model_dir = (
|
||||
Path(LOCAL_MODEL_DIR) / args.retrieval_model_name_or_path
|
||||
)
|
||||
local_retrieval_adapter_model_dir = (
|
||||
Path(LOCAL_MODEL_DIR) / args.retrieval_adapter_model_name_or_path
|
||||
)
|
||||
|
||||
# Download datasets / model checkpoints
|
||||
if not is_distributed() or global_rank() == 0:
|
||||
if not local_data_dir.exists():
|
||||
raise ValueError(f"Data directory {local_data_dir} does not exist")
|
||||
|
||||
assert args.use_retrieval, args.use_retrieval
|
||||
|
||||
if not local_retrieval_model_dir.exists():
|
||||
raise ValueError(
|
||||
f"Retrieval model directory {local_retrieval_model_dir} does not exist"
|
||||
)
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
if not local_retrieval_adapter_model_dir.exists():
|
||||
raise ValueError(
|
||||
f"Retrieval adapter model directory {local_retrieval_adapter_model_dir} does not exist"
|
||||
)
|
||||
|
||||
if is_distributed():
|
||||
barrier()
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
colpali_model = ColPaliRetrievalModel(
|
||||
backbone_name_or_path=local_retrieval_model_dir,
|
||||
adapter_name_or_path=local_retrieval_adapter_model_dir,
|
||||
)
|
||||
retrieval_model = colpali_model
|
||||
|
||||
if args.data_name == "m3-docvqa":
|
||||
dataset = M3DocVQADataset(args=args)
|
||||
|
||||
def collate_fn(examples):
|
||||
out = {}
|
||||
if args.retrieval_model_type == "colpali":
|
||||
for k in ["doc_id", "images"]:
|
||||
out[k] = [ex[k] for ex in examples]
|
||||
return out
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
batch_sampler=None,
|
||||
sampler=None,
|
||||
drop_last=False,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
retrieval_model.model, data_loader = accelerator.prepare(
|
||||
retrieval_model.model, data_loader
|
||||
)
|
||||
|
||||
all_results = []
|
||||
|
||||
save_dir = Path(args.output_dir)
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
logger.info(f"Results will be saved at: {save_dir}")
|
||||
|
||||
for i, datum in enumerate(tqdm(data_loader)):
|
||||
print(f"{i} / {len(data_loader)}")
|
||||
|
||||
if args.data_name == "mp-docvqa":
|
||||
page_name = datum["page_name"][0]
|
||||
logger.info(page_name)
|
||||
else:
|
||||
doc_id = datum["doc_id"][0]
|
||||
logger.info(doc_id)
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
images = datum["images"][0]
|
||||
|
||||
doc_embs = colpali_model.encode_images(
|
||||
images=images,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
to_cpu=True,
|
||||
use_tqdm=False,
|
||||
)
|
||||
|
||||
# [n_pages, n_tokens, emb_dim]
|
||||
doc_embs = torch.stack(doc_embs, dim=0)
|
||||
|
||||
# Store embedding as BF16 by default
|
||||
doc_embs = doc_embs.to(torch.bfloat16)
|
||||
|
||||
logger.info(doc_embs.shape)
|
||||
if args.retrieval_model_type == "colpali":
|
||||
logger.info(doc_embs[0, 0, :5])
|
||||
|
||||
# Save the embedding
|
||||
if args.data_name == "mp-docvqa":
|
||||
local_save_fname = f"{page_name}.safetensors"
|
||||
else:
|
||||
local_save_fname = f"{doc_id}.safetensors"
|
||||
local_save_path = save_dir / local_save_fname
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
safetensors.torch.save_file({"embeddings": doc_embs}, local_save_path)
|
||||
|
||||
all_results.append({"save_path": local_save_path})
|
||||
|
||||
logger.info(
|
||||
f"Process {global_rank()}:{local_rank()} Results correctly saved at {save_dir}"
|
||||
)
|
||||
|
||||
if is_distributed():
|
||||
barrier()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
429
examples/run_rag_m3docvqa.py
Normal file
429
examples/run_rag_m3docvqa.py
Normal file
@ -0,0 +1,429 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import pytz
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from loguru import logger
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from m3docrag.datasets.m3_docvqa import M3DocVQADataset, evaluate_prediction_file
|
||||
from m3docrag.rag import MultimodalRAGModel
|
||||
from m3docrag.retrieval import ColPaliRetrievalModel
|
||||
from m3docrag.utils.args import parse_args
|
||||
from m3docrag.utils.distributed import (
|
||||
barrier,
|
||||
global_rank,
|
||||
is_distributed,
|
||||
local_rank,
|
||||
log_runtime_info,
|
||||
print_gpu_stats,
|
||||
supports_flash_attention,
|
||||
)
|
||||
from m3docrag.utils.paths import (
|
||||
LOCAL_DATA_DIR,
|
||||
LOCAL_EMBEDDINGS_DIR,
|
||||
LOCAL_MODEL_DIR,
|
||||
)
|
||||
from m3docrag.utils.prompts import short_answer_template
|
||||
from m3docrag.utils.tar import extract_tarfile
|
||||
from m3docrag.vqa import VQAModel
|
||||
|
||||
|
||||
def run_model(
|
||||
rag_model: MultimodalRAGModel,
|
||||
datum,
|
||||
dataset: M3DocVQADataset,
|
||||
docid2embs: dict[str, torch.Tensor],
|
||||
docid2lens=None,
|
||||
index=None,
|
||||
token2pageuid=None,
|
||||
all_token_embeddings=None,
|
||||
n_return_pages=1,
|
||||
args=None,
|
||||
):
|
||||
# if type(datum['num_pages']) == list:
|
||||
batch = datum
|
||||
datum = {}
|
||||
for k, v in batch.items():
|
||||
datum[k] = v[0]
|
||||
|
||||
query = datum["question"]
|
||||
|
||||
out_dict = {}
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Stage 1: Page retrieval
|
||||
# [(doc_id, page_idx, scores)...]
|
||||
top_n_page_retrieval_results = rag_model.retrieve_pages_from_docs(
|
||||
query=query,
|
||||
docid2embs=docid2embs,
|
||||
docid2lens=docid2lens,
|
||||
index=index,
|
||||
token2pageuid=token2pageuid,
|
||||
all_token_embeddings=all_token_embeddings,
|
||||
n_return_pages=n_return_pages,
|
||||
show_progress=True,
|
||||
)
|
||||
logger.info(top_n_page_retrieval_results)
|
||||
out_dict["page_retrieval_results"] = top_n_page_retrieval_results
|
||||
|
||||
end = time.perf_counter()
|
||||
|
||||
time_retrieval = end - start
|
||||
logger.info(f"time_retrieval: {time_retrieval}")
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
if args.retrieval_only:
|
||||
pred_answer = ""
|
||||
out_dict["pred_answer"] = pred_answer
|
||||
else:
|
||||
# Stage 2: QA on the retrived page
|
||||
# Obtain images from the page retrieval results
|
||||
images = []
|
||||
for doc_id, page_idx, scores in top_n_page_retrieval_results:
|
||||
page_images = dataset.get_images_from_doc_id(doc_id)
|
||||
page_image = page_images[page_idx]
|
||||
images += [page_image]
|
||||
|
||||
logger.info(len(images))
|
||||
|
||||
# Run VQA
|
||||
if "florence" in args.model_name_or_path.lower():
|
||||
text_input = query
|
||||
else:
|
||||
text_input = short_answer_template.substitute({"question": query})
|
||||
pred_answer = rag_model.run_vqa(images=images, question=text_input)
|
||||
|
||||
assert isinstance(pred_answer, str)
|
||||
out_dict["pred_answer"] = pred_answer
|
||||
|
||||
end = time.perf_counter()
|
||||
|
||||
time_qa = end - start
|
||||
logger.info(f"time_qa: {time_qa}")
|
||||
|
||||
out_dict["time_retrieval"] = time_retrieval
|
||||
out_dict["time_qa"] = time_qa
|
||||
|
||||
logger.info(query)
|
||||
logger.info(pred_answer)
|
||||
logger.info(datum["answers"])
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def evaluate(data_loader, rag_model, index=None, data_len=None, args=None, **kwargs):
|
||||
if data_len is not None:
|
||||
logger.info(f"eval on the first {data_len} items")
|
||||
|
||||
# docid2embs = data_loader.dataset.load_all_embeddings()
|
||||
|
||||
logger.info("Preparing doc indices")
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
docid2embs = data_loader.dataset.load_all_embeddings()
|
||||
|
||||
# reduce_embeddings(docid2embs=docid2embs)
|
||||
# docid2embs_page_reudced = reduce_embeddings(docid2embs, dim='page')
|
||||
# docid2embs_token_reudced = reduce_embeddings(docid2embs, dim='token')
|
||||
# docid2embs_page_token_reudced = reduce_embeddings(docid2embs, dim='page_token')
|
||||
|
||||
all_token_embeddings = []
|
||||
token2pageuid = []
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
for doc_id, doc_emb in tqdm(docid2embs.items(), total=len(docid2embs)):
|
||||
# e.g., doc_emb - torch.Size([9, 1030, 128])
|
||||
for page_id in range(len(doc_emb)):
|
||||
page_emb = doc_emb[page_id].view(-1, 128)
|
||||
all_token_embeddings.append(page_emb)
|
||||
|
||||
page_uid = f"{doc_id}_page{page_id}"
|
||||
token2pageuid.extend([page_uid] * page_emb.shape[0])
|
||||
|
||||
logger.info(len(all_token_embeddings))
|
||||
|
||||
all_token_embeddings = torch.cat(all_token_embeddings, dim=0)
|
||||
all_token_embeddings = all_token_embeddings.float().numpy()
|
||||
|
||||
logger.info("Created flattened token embeddings / token2pageuid")
|
||||
|
||||
qid2result = {}
|
||||
|
||||
total_time_retrieval = 0
|
||||
total_time_qa = 0
|
||||
|
||||
for batch_idx, batch in enumerate(tqdm(data_loader)):
|
||||
bs = len(batch["question"])
|
||||
|
||||
# Single batch
|
||||
assert bs == 1
|
||||
qid = batch["qid"][0]
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = run_model(
|
||||
rag_model,
|
||||
batch,
|
||||
dataset=data_loader.dataset,
|
||||
docid2embs=docid2embs,
|
||||
# docid2embs=docid2embs_page_reudced,
|
||||
# docid2embs=docid2embs_token_reudced,
|
||||
# docid2embs=docid2embs_page_token_reudced,
|
||||
index=index,
|
||||
token2pageuid=token2pageuid,
|
||||
all_token_embeddings=all_token_embeddings,
|
||||
n_return_pages=args.n_retrieval_pages,
|
||||
args=args,
|
||||
)
|
||||
|
||||
pred_answer = outputs["pred_answer"]
|
||||
assert isinstance(pred_answer, str), type(pred_answer)
|
||||
|
||||
total_time_qa += outputs["time_qa"]
|
||||
total_time_retrieval += outputs["time_retrieval"]
|
||||
|
||||
qid2result[qid] = outputs
|
||||
|
||||
logger.info(total_time_qa)
|
||||
logger.info(total_time_retrieval)
|
||||
|
||||
avg_time_qa = total_time_qa / len(data_loader)
|
||||
avg_time_retrieval = total_time_retrieval / len(data_loader)
|
||||
logger.info(avg_time_qa)
|
||||
logger.info(avg_time_retrieval)
|
||||
|
||||
return qid2result
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
logger.info(torch.__version__)
|
||||
logger.info(transformers.__version__)
|
||||
logger.info(accelerate.__version__)
|
||||
|
||||
log_runtime_info()
|
||||
print_gpu_stats()
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
if not is_distributed() or global_rank() == 0:
|
||||
logger.info(f"Process {global_rank()}:{local_rank()} - args {args}")
|
||||
|
||||
local_data_dir = Path(LOCAL_DATA_DIR) / args.data_name
|
||||
local_embedding_dir = Path(LOCAL_EMBEDDINGS_DIR) / args.embedding_name
|
||||
local_model_dir = Path(LOCAL_MODEL_DIR) / args.model_name_or_path
|
||||
local_retrieval_model_dir = (
|
||||
Path(LOCAL_MODEL_DIR) / args.retrieval_model_name_or_path
|
||||
)
|
||||
local_retrieval_adapter_model_dir = (
|
||||
Path(LOCAL_MODEL_DIR) / args.retrieval_adapter_model_name_or_path
|
||||
)
|
||||
# local_ft_model_dir = Path(LOCAL_MODEL_DIR) / args.ft_model_name_or_path
|
||||
|
||||
local_index_dir = (
|
||||
Path(LOCAL_EMBEDDINGS_DIR)
|
||||
/ f"{args.embedding_name}_pageindex_{args.faiss_index_type}"
|
||||
)
|
||||
# local_answer_extraction_model_dir = Path(LOCAL_MODEL_DIR) / args.answer_extraction_model_name_or_path
|
||||
|
||||
if is_distributed():
|
||||
barrier()
|
||||
|
||||
if not is_distributed() or global_rank() == 0:
|
||||
if not local_data_dir.exists():
|
||||
raise ValueError(f"Data directory {local_data_dir} does not exist")
|
||||
|
||||
if not local_embedding_dir.exists():
|
||||
raise ValueError(
|
||||
f"Embedding directory {local_embedding_dir} does not exist"
|
||||
)
|
||||
|
||||
if local_model_dir.exists() or args.retrieval_only:
|
||||
logger.info("Model exists - pass")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model directory {local_model_dir} does not exist"
|
||||
)
|
||||
|
||||
if args.use_retrieval:
|
||||
if not local_retrieval_model_dir.exists():
|
||||
raise ValueError(
|
||||
f"Retrieval model directory {local_retrieval_model_dir} does not exist"
|
||||
)
|
||||
|
||||
if not local_retrieval_adapter_model_dir.exists():
|
||||
raise ValueError(
|
||||
f"Retrieval adapter model directory {local_retrieval_adapter_model_dir} does not exist"
|
||||
)
|
||||
|
||||
if not local_index_dir.exists():
|
||||
raise ValueError(
|
||||
f"Index directory {local_index_dir} does not exist"
|
||||
)
|
||||
|
||||
if is_distributed():
|
||||
barrier()
|
||||
|
||||
# Create Retrieval Model (Step 1)
|
||||
assert args.use_retrieval
|
||||
if args.retrieval_model_type == "colpali":
|
||||
colpali_model = ColPaliRetrievalModel(
|
||||
backbone_name_or_path=local_retrieval_model_dir,
|
||||
adapter_name_or_path=local_retrieval_adapter_model_dir,
|
||||
)
|
||||
retrieval_model = colpali_model
|
||||
|
||||
logger.info(f"loaded Retrieval model -: {local_retrieval_model_dir}")
|
||||
|
||||
# Create QA / VQA Model (Step 2)
|
||||
|
||||
if args.retrieval_only:
|
||||
rag_model = MultimodalRAGModel(retrieval_model=retrieval_model, vqa_model=None)
|
||||
logger.info("skipping QA model")
|
||||
|
||||
else:
|
||||
if "florence" in args.model_name_or_path.lower():
|
||||
model_type = "florence2"
|
||||
elif "idefics2" in args.model_name_or_path.lower():
|
||||
model_type = "idefics2"
|
||||
elif "idefics3" in args.model_name_or_path.lower():
|
||||
model_type = "idefics3"
|
||||
elif "internvl2" in args.model_name_or_path.lower():
|
||||
model_type = "internvl2"
|
||||
elif "qwen2" in args.model_name_or_path.lower():
|
||||
model_type = "qwen2"
|
||||
else:
|
||||
raise KeyError(f"model type unknown for: {args.model_name_or_path}")
|
||||
|
||||
use_flash_attn = True
|
||||
attn_implementation = "flash_attention_2"
|
||||
if not supports_flash_attention():
|
||||
use_flash_attn = False
|
||||
attn_implementation = "eager"
|
||||
|
||||
vqa_model = VQAModel(
|
||||
model_name_or_path=local_model_dir,
|
||||
model_type=model_type,
|
||||
bits=args.bits,
|
||||
use_flash_attn=use_flash_attn,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
|
||||
logger.info(f"loaded VQA model - {model_type}: {local_model_dir}")
|
||||
|
||||
rag_model = MultimodalRAGModel(
|
||||
retrieval_model=retrieval_model, vqa_model=vqa_model
|
||||
)
|
||||
|
||||
logger.info("Created RAG model")
|
||||
|
||||
dataset = M3DocVQADataset(args=args)
|
||||
logger.info("loaded dataset")
|
||||
|
||||
index = None
|
||||
if local_index_dir.exists():
|
||||
logger.info("Loading faiss index")
|
||||
import faiss
|
||||
|
||||
index = faiss.read_index(str(local_index_dir / "index.bin"))
|
||||
logger.info("Loading faiss index -- done")
|
||||
|
||||
def list_collate_fn(batch):
|
||||
batch = {
|
||||
k: [dic[k] for dic in batch] for k in batch[0]
|
||||
} # List of dictionaries to dict of lists.
|
||||
return batch
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataset, batch_size=1, shuffle=False, collate_fn=list_collate_fn
|
||||
)
|
||||
|
||||
if args.retrieval_only:
|
||||
retrieval_model.model, data_loader = accelerator.prepare(
|
||||
retrieval_model.model, data_loader
|
||||
)
|
||||
else:
|
||||
retrieval_model.model, data_loader, vqa_model.model = accelerator.prepare(
|
||||
retrieval_model.model, data_loader, vqa_model.model
|
||||
)
|
||||
|
||||
eval_out = evaluate(
|
||||
data_loader=data_loader,
|
||||
rag_model=rag_model,
|
||||
index=index,
|
||||
data_len=args.data_len,
|
||||
args=args,
|
||||
)
|
||||
|
||||
samples = eval_out
|
||||
|
||||
EST = pytz.timezone("US/Eastern")
|
||||
experiment_date = (
|
||||
datetime.datetime.now().astimezone(EST).strftime("%Y-%m-%d_%H-%M-%S")
|
||||
)
|
||||
|
||||
save_dir = Path(args.output_dir)
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
logger.info("Results will be saved at:", save_dir)
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
ret_name = args.retrieval_adapter_model_name_or_path
|
||||
else:
|
||||
ret_name = args.retrieval_model_name_or_path
|
||||
|
||||
if args.retrieval_only:
|
||||
pred_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{experiment_date}.json"
|
||||
else:
|
||||
pred_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{args.model_name_or_path}_{experiment_date}.json"
|
||||
results_file = save_dir / pred_save_fname
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(samples, f, indent=4)
|
||||
|
||||
logger.info(f"Prediction results saved at: {results_file}")
|
||||
|
||||
# Evaluation
|
||||
all_eval_scores = evaluate_prediction_file(
|
||||
samples,
|
||||
dataset.mmqa_data_path, # '/job/datasets/m3-docvqa/MMQA_dev.jsonl'
|
||||
)
|
||||
if args.retrieval_only:
|
||||
eval_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{experiment_date}_eval_results.json"
|
||||
else:
|
||||
eval_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{args.model_name_or_path}_{experiment_date}_eval_results.json"
|
||||
results_file = save_dir / eval_save_fname
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(all_eval_scores, f, indent=4)
|
||||
|
||||
logger.info(f"Evaluation results saved at: {results_file}")
|
||||
|
||||
if is_distributed():
|
||||
barrier()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user