diff --git a/README.md b/README.md new file mode 100644 index 0000000..bec2ebb --- /dev/null +++ b/README.md @@ -0,0 +1,178 @@ +# M3DocRAG + +Code for [M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding](https://m3docrag.github.io/) + +by [Jaemin Cho](https://j-min.io/)¹, [Debanjan Mahata](https://sites.google.com/a/ualr.edu/debanjan-mahata/)², [Ozan İrsoy](https://wtimesx.com/)², [Yujie He](https://scholar.google.com/citations?user=FbeAZGgAAAAJ&hl=en)², and [Mohit Bansal](https://www.cs.unc.edu/~mbansal/)¹ + +¹UNC Chapel Hill
+²Bloomberg + +# Summary + +## Abstract + +Document visual question answering (DocVQA) pipelines that answer questions from documents have broad applications. Existing methods focus on handling single-page documents with multi-modal language models (MLMs), or rely on text-based retrieval-augmented generation (RAG) that uses text extraction tools such as optical character recognition (OCR). However, there are difficulties in applying these methods in real-world scenarios: (a) questions often require information across different pages or documents, where MLMs cannot handle many long documents; (b) documents often have important information in visual elements such as figures, but text extraction tools ignore them. We introduce M3DocRAG, a novel multi-modal RAG framework that flexibly accommodates various document contexts (closed-domain and open-domain), question hops (single-hop and multi-hop), and evidence modalities (text, chart, figure, etc.). M3DocRAG finds relevant documents and answers questions using a multi-modal retriever and an MLM, so that it can efficiently handle single or many documents while preserving visual information. Since previous DocVQA datasets ask questions in the context of a specific document, we also present M3DocVQA, a new benchmark for evaluating open-domain DocVQA over 3,000+ PDF documents with 40,000+ pages. In three benchmarks (M3DocVQA/MMLongBench-Doc/MP-DocVQA), empirical results show that M3DocRAG with ColPali and Qwen2-VL 7B achieves superior performance than many strong baselines, including state-of-the-art performance in MP-DocVQA. We provide comprehensive analyses of different indexing, MLMs, and retrieval models. Lastly, we qualitatively show that M3DocRAG can successfully handle various scenarios, such as when relevant information exists across multiple pages and when answer evidence only exists in images. + +## Comparison with previous approches + + + +Comparison of multi-modal document understanding pipelines: Previous works focus on (a) **Single-page DocVQA** that cannot handle many long documents or (b) **Text-based RAG** that ignores visual information. Our (c) **M3DocRAG** framework retrieves relevant documents and answers questions using multi-modal retrieval and MLM components so that it can efficiently handle many long documents, while preserving visual information. + +## M3DocRAG framework + + + +Our **M3DocRAG** framework consists of three stages: (1) document embedding, (2) page retrieval, and (3) question answering. +- In (1) document embedding, we extract visual embedding (with ColPali) to represent each page from all PDF documents. +- In (2) page retrieval, we retrieve the top-K pages of high relevance (MaxSim scores) with text queries. In an open-domain setting, we create approximate page indices for faster search. +- In (3) question answering, we conduct visual question answering with multi-modal LM (e.g., Qwen2-VL) to obtain the final answer. + + +# Setup + +## Package + +We assume Conda has been installed: + +```bash +git clone +cd m3docrag-release +pip install -e . + +# Install Poppler (for pdf2image; check [https://pdf2image.readthedocs.io/en/latest/installation.html](https://pdf2image.readthedocs.io/en/latest/installation.html) for details) +# conda install -y poppler +# or +# apt-get install poppler-utils +``` + +## Code structure + +```bash +examples/ # scripts to run PDF embedding / RAG +src/m3docrag/ + datasets/ # data loader for existing datasets + retrieval/ # retrieval model (e.g., ColPaLi) + vqa/ # vqa model (e.g., Qwen2-VL) + rag/ # RAG model that combines retrieval and VQA models + utils/ # misc utility methods +m3docvqa/ # how to set up the M3DocVQA dataset +``` +## Paths: Data, Embeddings, Model checkpoints, Outputs + +```bash +# in .env +LOCAL_DATA_DIR="/job/datasets" # where to store data +LOCAL_EMBEDDINGS_DIR="/job/embeddings" # where to store embeddings +LOCAL_MODEL_DIR="/job/model" # where to store model checkpoints +LOCAL_OUTPUT_DIR="/job/output" # where to store model outputs +``` + +You can adjust variables in [`.env`](.env) to change where to store data/embeddings/model checkpoints/outputs by default. They are loaded in [`src/m3docrag/utils/paths.py`](./src/m3docrag/utils/paths.py) via [python-dotenv](https://github.com/theskumar/python-dotenv). + +## Download the M3DocVQA dataset + +Please see [m3docvqa/README.md](m3docvqa/README.md) for the download instructions. + +## Download model checkpoints + +By default, we use ColPali-v1.2 for retrival and Qwen2-VL-7B-Instruct for question answering. + +At `$LOCAL_MODEL_DIR`, download [ColPali-v1.2](https://huggingface.co/vidore/colpali-v1.2), [colpaligemma-3b-mix-448-base](https://huggingface.co/vidore/colpaligemma-3b-mix-448-base), and [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) checkpoints. + +```bash +cd $LOCAL_MODEL_DIR + +git clone https://huggingface.co/vidore/colpaligemma-3b-pt-448-base # ColPali backbone +git clone https://huggingface.co/vidore/colpali-v1.2 # ColPali adapter +git clone https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct # VQA +``` + +# Example usage + +Below, we describe examples of the usage of M3DocRAG on the M3DocVQA dataset. + +## 1. Extract PDF embeddings + +```bash +DATASET_NAME="m3-docvqa" +RETRIEVAL_MODEL_TYPE="colpali" +RETRIEVAL_MODEL_NAME="colpaligemma-3b-pt-448-base" +RETRIEVAL_ADAPTER_MODEL_NAME="colpali-v1.2" +SPLIT="dev" +EMBEDDING_NAME=$RETRIEVAL_ADAPTER_MODEL_NAME"_"$DATASET_NAME"_"$SPLIT # where to save embeddings +accelerate launch --num_processes=1 --mixed_precision=bf16 examples/run_page_embedding.py \ + --use_retrieval \ + --retrieval_model_type=$RETRIEVAL_MODEL_TYPE \ + --data_name=$DATASET_NAME \ + --split=$SPLIT \ + --loop_unique_doc_ids=True \ + --output_dir=/job/embeddings/$EMBEDDING_NAME \ + --retrieval_model_name_or_path=$RETRIEVAL_MODEL_NAME \ + --retrieval_adapter_model_name_or_path=$RETRIEVAL_ADAPTER_MODEL_NAME +``` + +## 2. Indexing + +```bash +DATASET_NAME="m3-docvqa" +RETRIEVAL_MODEL_TYPE="colpali" +RETRIEVAL_ADAPTER_MODEL_NAME="colpali-v1.2" +SPLIT="dev" +FAISS_INDEX_TYPE='ivfflat' +EMBEDDING_NAME=$RETRIEVAL_ADAPTER_MODEL_NAME"_"$DATASET_NAME"_"$SPLIT +INDEX_NAME=$EMBEDDING_NAME"_pageindex_"$FAISS_INDEX_TYPE # where to save resulting index +echo $EMBEDDING_NAME +echo $FAISS_INDEX_TYPE +python examples/run_indexing_m3docvqa.py \ + --use_retrieval \ + --retrieval_model_type=$RETRIEVAL_MODEL_TYPE \ + --data_name=$DATASET_NAME \ + --split=$SPLIT \ + --loop_unique_doc_ids=False \ + --embedding_name=$EMBEDDING_NAME \ + --faiss_index_type=$FAISS_INDEX_TYPE \ + --output_dir=/job/embeddings/$INDEX_NAME +``` + +## 3. RAG + +```bash +BACKBONE_MODEL_NAME="Qwen2-VL-7B-Instruct" +RETRIEVAL_MODEL_TYPE="colpali" +RETRIEVAL_MODEL_NAME="colpaligemma-3b-pt-448-base" +RETRIEVAL_ADAPTER_MODEL_NAME="colpali-v1.2" +EMBEDDING_NAME="colpali-v1.2_m3-docvqa_dev" # from Step 1 Embedding +SPLIT="dev" +DATASET_NAME="m3-docvqa" +FAISS_INDEX_TYPE='ivfflat' +N_RETRIEVAL_PAGES=1 +INDEX_NAME="${EMBEDDING_NAME}_pageindex_$FAISS_INDEX_TYPE" # from Step 2 Indexing +OUTPUT_SAVE_NAME="${RETRIEVAL_ADAPTER_MODEL_NAME}_${BACKBONE_MODEL_NAME}_${DATASET_NAME}" # where to save RAG results +BITS=16 # BITS=4 for 4-bit qunaitzation in low memory GPUs +python examples/run_rag_m3docvqa.py \ + --use_retrieval \ + --retrieval_model_type=$RETRIEVAL_MODEL_TYPE \ + --load_embedding=True \ + --split=$SPLIT \ + --bits=$BITS \ + --n_retrieval_pages=$N_RETRIEVAL_PAGES \ + --data_name=$DATASET_NAME \ + --model_name_or_path=$BACKBONE_MODEL_NAME \ + --embedding_name=$EMBEDDING_NAME \ + --retrieval_model_name_or_path=$RETRIEVAL_MODEL_NAME \ + --retrieval_adapter_model_name_or_path=$RETRIEVAL_ADAPTER_MODEL_NAME \ + --output_dir=/job/eval_outputs/$OUTPUT_SAVE_NAME +``` + +# Citation + +Please cite our paper if you use our dataset and/or method in your projects. + +```bibtex +@article{Cho2024M3DocRAG, + author = {Jaemin Cho and Debanjan Mahata and Ozan İrsoy and Yujie He and Mohit Bansal}, + title = {M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding}, + year = {2024}, +} +``` diff --git a/assets/data_collection.png b/assets/data_collection.png new file mode 100644 index 0000000..6274485 Binary files /dev/null and b/assets/data_collection.png differ diff --git a/assets/dataset.png b/assets/dataset.png new file mode 100644 index 0000000..9ec7834 Binary files /dev/null and b/assets/dataset.png differ diff --git a/assets/m3docrag_teaser.png b/assets/m3docrag_teaser.png new file mode 100644 index 0000000..e8d3439 Binary files /dev/null and b/assets/m3docrag_teaser.png differ diff --git a/assets/method.png b/assets/method.png new file mode 100644 index 0000000..8254892 Binary files /dev/null and b/assets/method.png differ diff --git a/examples/run_indexing_m3docvqa.py b/examples/run_indexing_m3docvqa.py new file mode 100644 index 0000000..6b4b50a --- /dev/null +++ b/examples/run_indexing_m3docvqa.py @@ -0,0 +1,171 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import faiss +import numpy as np +import torch +from loguru import logger +from tqdm.auto import tqdm + +from m3docrag.datasets.m3_docvqa.dataset import M3DocVQADataset +from m3docrag.utils.args import parse_args + + +def main(): + args = parse_args() + + logger.info("Loading M3DocVQA") + + dataset = M3DocVQADataset(args) + + logger.info(f"Loading M3DocVQA -- all {args.retrieval_model_type} embeddings") + + if args.retrieval_model_type == "colpali": + docid2embs = dataset.load_all_embeddings() + elif args.retrieval_model_type == "colbert": + docid2embs, docid2lens = dataset.load_all_embeddings() + + # len(docid2embs) + # docid2embs_page_reduced = reduce_embeddings(docid2embs, dim='page') + # docid2embs_token_reduced = reduce_embeddings(docid2embs, dim='token') + # docid2embs_page_token_reduced = reduce_embeddings(docid2embs, dim='page_token') + + # flat_doc_embs = [] + # for doc_id, doc_emb in docid2embs.items(): + # flat_doc_embs += [doc_emb] + + # flat_doc_embs = torch.cat(flat_doc_embs, dim=0) + + # logger.info(flat_doc_embs.shape) + + d = 128 + quantizer = faiss.IndexFlatIP(d) + + if args.faiss_index_type == "flatip": + index = quantizer + + elif args.faiss_index_type == "ivfflat": + ncentroids = 1024 + index = faiss.IndexIVFFlat(quantizer, d, ncentroids) + else: + nlist = 100 + m = 8 + index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8) + + logger.info("Flattening all PDF pages") + + all_token_embeddings = [] + token2pageuid = [] + + if args.retrieval_model_type == "colpali": + for doc_id, doc_emb in tqdm(docid2embs.items(), total=len(docid2embs)): + # e.g., doc_emb - torch.Size([9, 1030, 128]) + for page_id in range(len(doc_emb)): + page_emb = doc_emb[page_id].view(-1, d) + all_token_embeddings.append(page_emb) + + page_uid = f"{doc_id}_page{page_id}" + token2pageuid.extend([page_uid] * page_emb.shape[0]) + + elif args.retrieval_model_type == "colbert": + for doc_id, doc_emb in tqdm(docid2embs.items(), total=len(docid2embs)): + doc_lens = docid2lens[doc_id] + + # e.g., doc_emb - torch.Size([2089, 128]) + # e.g., doc_lens - tensor([258, 240, 251, 231, 229, 268, 235, 211, 166]) + + all_token_embeddings.append(doc_emb) + + for page_id, page_len in enumerate(doc_lens): + page_uid = f"{doc_id}_page{page_id}" + token2pageuid.extend([page_uid] * page_len.item()) + + logger.info(len(all_token_embeddings)) + + all_token_embeddings = torch.cat(all_token_embeddings, dim=0) + all_token_embeddings = all_token_embeddings.float().numpy() + logger.info(all_token_embeddings.shape) + logger.info(len(token2pageuid)) + + logger.info("Creating index") + + index.train(all_token_embeddings) + index.add(all_token_embeddings) + + Path(args.output_dir).mkdir(exist_ok=True) + index_output_path = str(Path(args.output_dir) / "index.bin") + logger.info(f"Saving index at {index_output_path}") + + faiss.write_index(index, index_output_path) + + logger.info("Running an example query") + + # Example query (should be np.float32) + example_text_query_emb = np.random.randn(20, 128).astype(np.float32) + + # NN search + k = 10 + D, I = index.search(example_text_query_emb, k) # noqa E741 + + # Sum the MaxSim scores across all query tokens for each document + final_page2scores = {} + + # Iterate over query tokens + for q_idx, query_emb in enumerate(example_text_query_emb): + # Initialize a dictionary to hold document relevance scores + curent_q_page2scores = {} + + for nn_idx in range(k): + found_nearest_doc_token_idx = I[q_idx, nn_idx] + + page_uid = token2pageuid[ + found_nearest_doc_token_idx + ] # Get the document ID for this token + + # reconstruct the original score + doc_token_emb = all_token_embeddings[found_nearest_doc_token_idx] + score = (query_emb * doc_token_emb).sum() + + # MaxSim: aggregate the highest similarity score for each query token per document + if page_uid not in curent_q_page2scores: + curent_q_page2scores[page_uid] = score + else: + curent_q_page2scores[page_uid] = max( + curent_q_page2scores[page_uid], score + ) + + for page_uid, score in curent_q_page2scores.items(): + if page_uid in final_page2scores: + final_page2scores[page_uid] += score + else: + final_page2scores[page_uid] = score + + # Sort documents by their final relevance score + sorted_pages = sorted(final_page2scores.items(), key=lambda x: x[1], reverse=True) + + # Get the top-k document candidates + top_k_pages = sorted_pages[:k] + + # Output the top-k document IDs and their scores + logger.info("Top-k page candidates with scores:") + for page_uid, score in top_k_pages: + logger.info(f"{page_uid} with score {score}") + + +if __name__ == "__main__": + main() diff --git a/examples/run_page_embedding.py b/examples/run_page_embedding.py new file mode 100644 index 0000000..6ad4659 --- /dev/null +++ b/examples/run_page_embedding.py @@ -0,0 +1,180 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import accelerate +import safetensors +import torch +import transformers +from accelerate import Accelerator +from loguru import logger +from tqdm import tqdm + +from m3docrag.datasets.m3_docvqa import M3DocVQADataset +from m3docrag.retrieval import ColPaliRetrievalModel +from m3docrag.utils.args import parse_args +from m3docrag.utils.distributed import ( + barrier, + global_rank, + is_distributed, + local_rank, + log_runtime_info, + print_gpu_stats, +) +from m3docrag.utils.paths import ( + LOCAL_DATA_DIR, + LOCAL_MODEL_DIR, +) + +logger.info(torch.__version__) +logger.info(transformers.__version__) +logger.info(accelerate.__version__) + + +def main(): + args = parse_args() + + log_runtime_info() + print_gpu_stats() + + accelerator = Accelerator() + + if not is_distributed() or global_rank() == 0: + logger.info(f"Process {global_rank()}:{local_rank()} - args {args}") + + if is_distributed(): + barrier() + + local_data_dir = Path(LOCAL_DATA_DIR) / args.data_name + local_retrieval_model_dir = ( + Path(LOCAL_MODEL_DIR) / args.retrieval_model_name_or_path + ) + local_retrieval_adapter_model_dir = ( + Path(LOCAL_MODEL_DIR) / args.retrieval_adapter_model_name_or_path + ) + + # Download datasets / model checkpoints + if not is_distributed() or global_rank() == 0: + if not local_data_dir.exists(): + raise ValueError(f"Data directory {local_data_dir} does not exist") + + assert args.use_retrieval, args.use_retrieval + + if not local_retrieval_model_dir.exists(): + raise ValueError( + f"Retrieval model directory {local_retrieval_model_dir} does not exist" + ) + + if args.retrieval_model_type == "colpali": + if not local_retrieval_adapter_model_dir.exists(): + raise ValueError( + f"Retrieval adapter model directory {local_retrieval_adapter_model_dir} does not exist" + ) + + if is_distributed(): + barrier() + + if args.retrieval_model_type == "colpali": + colpali_model = ColPaliRetrievalModel( + backbone_name_or_path=local_retrieval_model_dir, + adapter_name_or_path=local_retrieval_adapter_model_dir, + ) + retrieval_model = colpali_model + + if args.data_name == "m3-docvqa": + dataset = M3DocVQADataset(args=args) + + def collate_fn(examples): + out = {} + if args.retrieval_model_type == "colpali": + for k in ["doc_id", "images"]: + out[k] = [ex[k] for ex in examples] + return out + + data_loader = torch.utils.data.DataLoader( + dataset=dataset, + collate_fn=collate_fn, + batch_size=1, + shuffle=False, + batch_sampler=None, + sampler=None, + drop_last=False, + num_workers=args.dataloader_num_workers, + ) + + retrieval_model.model, data_loader = accelerator.prepare( + retrieval_model.model, data_loader + ) + + all_results = [] + + save_dir = Path(args.output_dir) + save_dir.mkdir(exist_ok=True, parents=True) + logger.info(f"Results will be saved at: {save_dir}") + + for i, datum in enumerate(tqdm(data_loader)): + print(f"{i} / {len(data_loader)}") + + if args.data_name == "mp-docvqa": + page_name = datum["page_name"][0] + logger.info(page_name) + else: + doc_id = datum["doc_id"][0] + logger.info(doc_id) + + if args.retrieval_model_type == "colpali": + images = datum["images"][0] + + doc_embs = colpali_model.encode_images( + images=images, + batch_size=args.per_device_eval_batch_size, + to_cpu=True, + use_tqdm=False, + ) + + # [n_pages, n_tokens, emb_dim] + doc_embs = torch.stack(doc_embs, dim=0) + + # Store embedding as BF16 by default + doc_embs = doc_embs.to(torch.bfloat16) + + logger.info(doc_embs.shape) + if args.retrieval_model_type == "colpali": + logger.info(doc_embs[0, 0, :5]) + + # Save the embedding + if args.data_name == "mp-docvqa": + local_save_fname = f"{page_name}.safetensors" + else: + local_save_fname = f"{doc_id}.safetensors" + local_save_path = save_dir / local_save_fname + + if args.retrieval_model_type == "colpali": + safetensors.torch.save_file({"embeddings": doc_embs}, local_save_path) + + all_results.append({"save_path": local_save_path}) + + logger.info( + f"Process {global_rank()}:{local_rank()} Results correctly saved at {save_dir}" + ) + + if is_distributed(): + barrier() + + +if __name__ == "__main__": + main() diff --git a/examples/run_rag_m3docvqa.py b/examples/run_rag_m3docvqa.py new file mode 100644 index 0000000..f1bbe6c --- /dev/null +++ b/examples/run_rag_m3docvqa.py @@ -0,0 +1,429 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import json +import time +from pathlib import Path + +import accelerate +import pytz +import torch +import transformers +from accelerate import Accelerator +from loguru import logger +from torch.utils.data import DataLoader +from tqdm import tqdm + +from m3docrag.datasets.m3_docvqa import M3DocVQADataset, evaluate_prediction_file +from m3docrag.rag import MultimodalRAGModel +from m3docrag.retrieval import ColPaliRetrievalModel +from m3docrag.utils.args import parse_args +from m3docrag.utils.distributed import ( + barrier, + global_rank, + is_distributed, + local_rank, + log_runtime_info, + print_gpu_stats, + supports_flash_attention, +) +from m3docrag.utils.paths import ( + LOCAL_DATA_DIR, + LOCAL_EMBEDDINGS_DIR, + LOCAL_MODEL_DIR, +) +from m3docrag.utils.prompts import short_answer_template +from m3docrag.utils.tar import extract_tarfile +from m3docrag.vqa import VQAModel + + +def run_model( + rag_model: MultimodalRAGModel, + datum, + dataset: M3DocVQADataset, + docid2embs: dict[str, torch.Tensor], + docid2lens=None, + index=None, + token2pageuid=None, + all_token_embeddings=None, + n_return_pages=1, + args=None, +): + # if type(datum['num_pages']) == list: + batch = datum + datum = {} + for k, v in batch.items(): + datum[k] = v[0] + + query = datum["question"] + + out_dict = {} + + start = time.perf_counter() + + # Stage 1: Page retrieval + # [(doc_id, page_idx, scores)...] + top_n_page_retrieval_results = rag_model.retrieve_pages_from_docs( + query=query, + docid2embs=docid2embs, + docid2lens=docid2lens, + index=index, + token2pageuid=token2pageuid, + all_token_embeddings=all_token_embeddings, + n_return_pages=n_return_pages, + show_progress=True, + ) + logger.info(top_n_page_retrieval_results) + out_dict["page_retrieval_results"] = top_n_page_retrieval_results + + end = time.perf_counter() + + time_retrieval = end - start + logger.info(f"time_retrieval: {time_retrieval}") + + start = time.perf_counter() + + if args.retrieval_only: + pred_answer = "" + out_dict["pred_answer"] = pred_answer + else: + # Stage 2: QA on the retrived page + # Obtain images from the page retrieval results + images = [] + for doc_id, page_idx, scores in top_n_page_retrieval_results: + page_images = dataset.get_images_from_doc_id(doc_id) + page_image = page_images[page_idx] + images += [page_image] + + logger.info(len(images)) + + # Run VQA + if "florence" in args.model_name_or_path.lower(): + text_input = query + else: + text_input = short_answer_template.substitute({"question": query}) + pred_answer = rag_model.run_vqa(images=images, question=text_input) + + assert isinstance(pred_answer, str) + out_dict["pred_answer"] = pred_answer + + end = time.perf_counter() + + time_qa = end - start + logger.info(f"time_qa: {time_qa}") + + out_dict["time_retrieval"] = time_retrieval + out_dict["time_qa"] = time_qa + + logger.info(query) + logger.info(pred_answer) + logger.info(datum["answers"]) + + return out_dict + + +def evaluate(data_loader, rag_model, index=None, data_len=None, args=None, **kwargs): + if data_len is not None: + logger.info(f"eval on the first {data_len} items") + + # docid2embs = data_loader.dataset.load_all_embeddings() + + logger.info("Preparing doc indices") + + if args.retrieval_model_type == "colpali": + docid2embs = data_loader.dataset.load_all_embeddings() + + # reduce_embeddings(docid2embs=docid2embs) + # docid2embs_page_reudced = reduce_embeddings(docid2embs, dim='page') + # docid2embs_token_reudced = reduce_embeddings(docid2embs, dim='token') + # docid2embs_page_token_reudced = reduce_embeddings(docid2embs, dim='page_token') + + all_token_embeddings = [] + token2pageuid = [] + + if args.retrieval_model_type == "colpali": + for doc_id, doc_emb in tqdm(docid2embs.items(), total=len(docid2embs)): + # e.g., doc_emb - torch.Size([9, 1030, 128]) + for page_id in range(len(doc_emb)): + page_emb = doc_emb[page_id].view(-1, 128) + all_token_embeddings.append(page_emb) + + page_uid = f"{doc_id}_page{page_id}" + token2pageuid.extend([page_uid] * page_emb.shape[0]) + + logger.info(len(all_token_embeddings)) + + all_token_embeddings = torch.cat(all_token_embeddings, dim=0) + all_token_embeddings = all_token_embeddings.float().numpy() + + logger.info("Created flattened token embeddings / token2pageuid") + + qid2result = {} + + total_time_retrieval = 0 + total_time_qa = 0 + + for batch_idx, batch in enumerate(tqdm(data_loader)): + bs = len(batch["question"]) + + # Single batch + assert bs == 1 + qid = batch["qid"][0] + + with torch.no_grad(): + outputs = run_model( + rag_model, + batch, + dataset=data_loader.dataset, + docid2embs=docid2embs, + # docid2embs=docid2embs_page_reudced, + # docid2embs=docid2embs_token_reudced, + # docid2embs=docid2embs_page_token_reudced, + index=index, + token2pageuid=token2pageuid, + all_token_embeddings=all_token_embeddings, + n_return_pages=args.n_retrieval_pages, + args=args, + ) + + pred_answer = outputs["pred_answer"] + assert isinstance(pred_answer, str), type(pred_answer) + + total_time_qa += outputs["time_qa"] + total_time_retrieval += outputs["time_retrieval"] + + qid2result[qid] = outputs + + logger.info(total_time_qa) + logger.info(total_time_retrieval) + + avg_time_qa = total_time_qa / len(data_loader) + avg_time_retrieval = total_time_retrieval / len(data_loader) + logger.info(avg_time_qa) + logger.info(avg_time_retrieval) + + return qid2result + + +def main(): + args = parse_args() + + logger.info(torch.__version__) + logger.info(transformers.__version__) + logger.info(accelerate.__version__) + + log_runtime_info() + print_gpu_stats() + + accelerator = Accelerator() + + if not is_distributed() or global_rank() == 0: + logger.info(f"Process {global_rank()}:{local_rank()} - args {args}") + + local_data_dir = Path(LOCAL_DATA_DIR) / args.data_name + local_embedding_dir = Path(LOCAL_EMBEDDINGS_DIR) / args.embedding_name + local_model_dir = Path(LOCAL_MODEL_DIR) / args.model_name_or_path + local_retrieval_model_dir = ( + Path(LOCAL_MODEL_DIR) / args.retrieval_model_name_or_path + ) + local_retrieval_adapter_model_dir = ( + Path(LOCAL_MODEL_DIR) / args.retrieval_adapter_model_name_or_path + ) + # local_ft_model_dir = Path(LOCAL_MODEL_DIR) / args.ft_model_name_or_path + + local_index_dir = ( + Path(LOCAL_EMBEDDINGS_DIR) + / f"{args.embedding_name}_pageindex_{args.faiss_index_type}" + ) + # local_answer_extraction_model_dir = Path(LOCAL_MODEL_DIR) / args.answer_extraction_model_name_or_path + + if is_distributed(): + barrier() + + if not is_distributed() or global_rank() == 0: + if not local_data_dir.exists(): + raise ValueError(f"Data directory {local_data_dir} does not exist") + + if not local_embedding_dir.exists(): + raise ValueError( + f"Embedding directory {local_embedding_dir} does not exist" + ) + + if local_model_dir.exists() or args.retrieval_only: + logger.info("Model exists - pass") + else: + raise ValueError( + f"Model directory {local_model_dir} does not exist" + ) + + if args.use_retrieval: + if not local_retrieval_model_dir.exists(): + raise ValueError( + f"Retrieval model directory {local_retrieval_model_dir} does not exist" + ) + + if not local_retrieval_adapter_model_dir.exists(): + raise ValueError( + f"Retrieval adapter model directory {local_retrieval_adapter_model_dir} does not exist" + ) + + if not local_index_dir.exists(): + raise ValueError( + f"Index directory {local_index_dir} does not exist" + ) + + if is_distributed(): + barrier() + + # Create Retrieval Model (Step 1) + assert args.use_retrieval + if args.retrieval_model_type == "colpali": + colpali_model = ColPaliRetrievalModel( + backbone_name_or_path=local_retrieval_model_dir, + adapter_name_or_path=local_retrieval_adapter_model_dir, + ) + retrieval_model = colpali_model + + logger.info(f"loaded Retrieval model -: {local_retrieval_model_dir}") + + # Create QA / VQA Model (Step 2) + + if args.retrieval_only: + rag_model = MultimodalRAGModel(retrieval_model=retrieval_model, vqa_model=None) + logger.info("skipping QA model") + + else: + if "florence" in args.model_name_or_path.lower(): + model_type = "florence2" + elif "idefics2" in args.model_name_or_path.lower(): + model_type = "idefics2" + elif "idefics3" in args.model_name_or_path.lower(): + model_type = "idefics3" + elif "internvl2" in args.model_name_or_path.lower(): + model_type = "internvl2" + elif "qwen2" in args.model_name_or_path.lower(): + model_type = "qwen2" + else: + raise KeyError(f"model type unknown for: {args.model_name_or_path}") + + use_flash_attn = True + attn_implementation = "flash_attention_2" + if not supports_flash_attention(): + use_flash_attn = False + attn_implementation = "eager" + + vqa_model = VQAModel( + model_name_or_path=local_model_dir, + model_type=model_type, + bits=args.bits, + use_flash_attn=use_flash_attn, + attn_implementation=attn_implementation, + ) + + logger.info(f"loaded VQA model - {model_type}: {local_model_dir}") + + rag_model = MultimodalRAGModel( + retrieval_model=retrieval_model, vqa_model=vqa_model + ) + + logger.info("Created RAG model") + + dataset = M3DocVQADataset(args=args) + logger.info("loaded dataset") + + index = None + if local_index_dir.exists(): + logger.info("Loading faiss index") + import faiss + + index = faiss.read_index(str(local_index_dir / "index.bin")) + logger.info("Loading faiss index -- done") + + def list_collate_fn(batch): + batch = { + k: [dic[k] for dic in batch] for k in batch[0] + } # List of dictionaries to dict of lists. + return batch + + data_loader = DataLoader( + dataset, batch_size=1, shuffle=False, collate_fn=list_collate_fn + ) + + if args.retrieval_only: + retrieval_model.model, data_loader = accelerator.prepare( + retrieval_model.model, data_loader + ) + else: + retrieval_model.model, data_loader, vqa_model.model = accelerator.prepare( + retrieval_model.model, data_loader, vqa_model.model + ) + + eval_out = evaluate( + data_loader=data_loader, + rag_model=rag_model, + index=index, + data_len=args.data_len, + args=args, + ) + + samples = eval_out + + EST = pytz.timezone("US/Eastern") + experiment_date = ( + datetime.datetime.now().astimezone(EST).strftime("%Y-%m-%d_%H-%M-%S") + ) + + save_dir = Path(args.output_dir) + save_dir.mkdir(exist_ok=True, parents=True) + logger.info("Results will be saved at:", save_dir) + + if args.retrieval_model_type == "colpali": + ret_name = args.retrieval_adapter_model_name_or_path + else: + ret_name = args.retrieval_model_name_or_path + + if args.retrieval_only: + pred_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{experiment_date}.json" + else: + pred_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{args.model_name_or_path}_{experiment_date}.json" + results_file = save_dir / pred_save_fname + with open(results_file, "w") as f: + json.dump(samples, f, indent=4) + + logger.info(f"Prediction results saved at: {results_file}") + + # Evaluation + all_eval_scores = evaluate_prediction_file( + samples, + dataset.mmqa_data_path, # '/job/datasets/m3-docvqa/MMQA_dev.jsonl' + ) + if args.retrieval_only: + eval_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{experiment_date}_eval_results.json" + else: + eval_save_fname = f"{ret_name}_{args.faiss_index_type}_ret{args.n_retrieval_pages}_{args.model_name_or_path}_{experiment_date}_eval_results.json" + results_file = save_dir / eval_save_fname + with open(results_file, "w") as f: + json.dump(all_eval_scores, f, indent=4) + + logger.info(f"Evaluation results saved at: {results_file}") + + if is_distributed(): + barrier() + + +if __name__ == "__main__": + main() diff --git a/m3docvqa/.gitignore b/m3docvqa/.gitignore new file mode 100644 index 0000000..496f5fc --- /dev/null +++ b/m3docvqa/.gitignore @@ -0,0 +1,167 @@ +.vscode/ +notebooks/node_modules/ +notebooks/package-lock.json +dataset/*.jsonl +dataset/img +dataset/cache +dataset/img_features +baselines/data/ +baselines/output/ +baselines/image_qa/training_stats/ +baselines/image_qa/checkpoints/ +baselines/image_qa/analysis/ +deps/vilbert-multi-task/data/ +deps/vilbert-multi-task/save/ +deps/vilbert-multi-task/multi_task_model* + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# IDE +.idea/ + +# DATA +data/ + +# for MacOS +*.DS_Store + +*.png +dataset +multimodalqa_screenshots +*.json +*.pdf +*.jsonl* + +setup_BCOS_README.md +*.ipynb + +.env +READDME copy.md \ No newline at end of file diff --git a/m3docvqa/README.md b/m3docvqa/README.md new file mode 100644 index 0000000..4b75289 --- /dev/null +++ b/m3docvqa/README.md @@ -0,0 +1,200 @@ +# M3DocVQA + +Dataset generation package for [M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding.](https://m3docrag.github.io/) + +## Summary + +M3DocVQA (Multi-modal Multi-page Multi-Document Visual Question Answering) is a new benchmark for evaluating open-domain DocVQA over 3,000+ PDF documents with 40,000+ pages. M3DocVQA significantly raises the challenge of DocVQA to answering questions from a large document corpus (Sec. 3). By extending the MultimodalQA dataset’s closed-domain context to an open-domain setting, M3DocVQA introduces 2,441 multi-hop questions spanning 3,368 PDF documents, which collectively contain over 41,005 pages of diverse multi-modal content, including text, images, and tables. The dataset generated presents real-world challenges by requiring models to navigate complex reasoning paths across pages and within various types of document elements, better reflecting the intricacies of document understanding. + + + +Comparison of existing DocVQA datasets (left: e.g., DocVQA) and the generated `M3DocVQA` dataset (right). In contrast to previous DocVQA datasets that have questions that are specific to a single provided PDF (e.g., `What was the gross profit in the year 2009?`), M3DocVQA contains information-seeking questions that benchmark open-domain question answering capabilities across more than `3,000 PDF documents` (i.e., `40,000+ pages`). + + + +We extend the question-answer pairs from a short-context VQA dataset to a more complex setting that includes: +1. PDF documents. +2. Open-domain contexts. + +We first collect the URLs of all supporting contexts (Wikipedia documents) of individual questions of [MultimodalQA](https://github.com/allenai/multimodalqa). This tool then creates PDF versions from their URLs by rendering them in a Chromium web browser. + +## M3DocVQA Dataset Creation Pipeline + +This part of the repository provides scripts to create the `M3DocVQA` dataset, including functionalities to download Wikipedia pages as PDFs, check and clean corrupted PDFs, extract images, and organize files into directories for training and evaluation. + +### Overview + +The scripts allows users to: +- Download Wikipedia pages in either PDF or PNG format. +- Verify and clean downloaded PDFs. +- Extract images from PDFs. +- Organize files into directories based on split information for training/evaluation. + +## Installation + +``` +git clone +cd /m3docvqa +``` + +### Install Python Package + +We used Python 3.10. + +```bash +pip install -e . +``` + +### Setup Playwright + +```bash +# e.g., download browsers, ffmpeg, etc. +playwright install +playwright install-deps +``` + +### Test the Package +```bash +pytest tests +``` + +**Note**: The tests might fail if `poppler-utils` is not installed on your system. You need to make sure you have `poppler-utils` installed for `pdf2image`. Please refer to these [detailed instructions](https://pdf2image.readthedocs.io/en/latest/installation.html). + +### Additional Setup +Ensure the required directories and metadata files are available before running the scripts. Continue as directed to get the required data. + +## Usage + +The main script (`main.py`) supports several actions, each of which targets a specific portion of the dataset creation process. + +### Command Structure +```bash +python main.py [options] +``` + +### Available Actions +- `download_pdfs`: Download PDFs from URLs provided in the metadata. +- `check_pdfs`: Verify if the downloaded PDFs are valid. +- `extract_images`: Extract images from the pages of the downloaded PDFs. +- `organize_files`: Organize downloaded PDFs into specified directory splits. +- `download_mmqa`: Download and decompress the MMQA dataset. +- `generate_wiki_mapping`: Generate a mapping of 'id' to 'url' from multiple JSONL files. + +## Steps for Generating the M3DocVQA Dataset + +### Step 1: Download the MultiModalQA Dataset +Use the `download_mmqa` action to download and decompress the MultiModalQA dataset files. + +```bash +python main.py download_mmqa --output_dir=./multimodalqa +``` + +Output: +Decompressed JSONL files +```bash +MMQA_train.jsonl +MMQA_dev.jsonl +MMQA_texts.jsonl +MMQA_images.jsonl +MMQA_tables.jsonl +``` + +These files will be stored in the `./multimodalqa/` directory. + +### Step 2: Generate Wiki Mapping +Use the `generate_wiki_mapping` action to create a mapping of `id` to `url` from the downloaded JSONL files. + +```bash +python main.py generate_wiki_mapping --text=./multimodalqa/MMQA_texts.jsonl --image=./multimodalqa/MMQA_images.jsonl --table=./multimodalqa/MMQA_tables.jsonl --output=./id_url_mapping.jsonl +``` +Output: + +A JSONL file `id_url_mapping.jsonl` containing the ID and corresponding URL mappings. + +### Step 3: Download Wikipedia Articles as PDFs +Use the `download_pdfs` action to download Wikipedia articles in a PDF format based on the generated mapping. + +```bash +python main.py download_pdfs --metadata_path=./id_url_mapping.jsonl --pdf_dir=./pdfs --result_log_path=./download_results.jsonl --first_n=10 --supporting_doc_ids_per_split=./supporting_doc_ids_per_split.json --split=dev +``` + +Options: +- `--metadata_path`: Path to the id_url_mapping.jsonl file. +- `--pdf_dir`: Directory to save the downloaded PDFs. +- `--result_log_path`: Path to log the download results. +- `--first_n`: Downloads the first N PDFs for testing. **Do not use this option for downloading all the PDFs.** +- `--supporting_doc_ids_per_split`: Path to JSON file containing document IDs for each split. `dev` is the default split, as all of the experimental results in the `M3DocRAG` paper were reported on the `dev` split. Anyone interested in downloading the PDFs in the `train` split can provide `--supporting_doc_ids_per_split=train` as the option. In case anyone is interested in downloading all the PDFs, one can also provide `--supporting_doc_ids_per_split=all` as an option. + +Output: + +- PDF files for Wikipedia articles, saved in the `./pdfs/` directory. +- A `download_results.jsonl` file logging the status of each download. + +### Step 4: Check PDF Integrity +Use the `check_pdfs` action to verify the integrity of the downloaded PDFs. + +```bash +python main.py check_pdfs --pdf_dir=./pdfs +``` +Output: + +Identifies and logs corrupted or unreadable PDFs. + +### Step 5: Organize Files into Splits +Use the `organize_files` action to organize the downloaded PDFs into specific splits (e.g., `train`, `dev`) based on a split information file. + +```bash +python main.py organize_files --all_pdf_dir=./pdfs --target_dir_base=./splits --split=dev --split_metadata_file=./multimodalqa/MMQA_dev.jsonl +``` + +If train split is needed: + +```bash +python main.py organize_files --all_pdf_dir=./pdfs --target_dir_base=./splits --split=train --split_metadata_file=./multimodalqa/MMQA_train.jsonl +``` + +Output: + +- Organized PDFs into directories in `./splits/pdfs_train/` and `./splits/pdfs_dev/`. +- Files that store document IDs of each split `./train_doc_ids.json` and `./dev_doc_ids.json`. + +**Note** - In the [M3DocRAG](https://arxiv.org/abs/2411.04952) paper, we only use the `dev` split for our experiments. + +### Step 6: Extract Images from PDFs +Use the `extract_images` action to extract images from the downloaded PDFs. A PNG image of each page of the PDFs is extracted. These images are used for both `retrieval` using `ColPali/ColQwen`, as well as `question answering` using the LLMs mentioned in the [M3DocRAG](https://arxiv.org/abs/2411.04952) paper. + +```bash +python main.py extract_images --pdf_dir=./splits/pdfs_dev/ --image_dir=./images/images_dev +``` + +Output: + +Extracted images from the PDFs in the dev split are saved in the `./images/images_dev` directory. + +After following these steps, your dataset directory structure will look like this: + +``` +./ +|-- multimodalqa/ +|   |-- MMQA_train.jsonl +|   |-- MMQA_dev.jsonl +|   |-- MMQA_texts.jsonl +|   |-- MMQA_images.jsonl +|   |-- MMQA_tables.jsonl +|-- id_url_mapping.jsonl +|-- dev_doc_ids.json +|-- train_doc_ids.json +|-- supporting_doc_ids_per_split.json +|-- download_results.jsonl +|-- pdfs/ +|   |-- .pdf +|   |-- .pdf +|-- images/ +|-- |--images_dev/ +|   | |-- .png +| | |-- .png +|-- splits/ +|   |-- pdfs_dev/ +|   |   |-- .pdf +|   |   |-- .pdf +``` diff --git a/m3docvqa/main.py b/m3docvqa/main.py new file mode 100644 index 0000000..adc5619 --- /dev/null +++ b/m3docvqa/main.py @@ -0,0 +1,201 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Main Script for M3DocVQA Dataset Creation Pipeline. + +This script orchestrates downloading PDFs or PNGs, checking for corrupted PDFs, extracting images, +organizing them into directories, downloading/decompressing MMQA data, and creating wiki links mapping. + +Usage: + python main.py [other options] + +Actions: + - download_pdfs: Download PDFs from URLs provided in metadata. + - check_pdfs: Verify if the downloaded PDFs are valid. + - extract_images: Extract images from the pages of downloaded PDFs. + - organize_files: Organize downloaded PDFs into specified directory splits. + - download_mmqa: Download and decompress the MMQA dataset. + - generate_wiki_mapping: Generate a mapping of 'id' to 'url' from multiple JSONL files. + +Example: + python main.py generate_wiki_mapping --text=MMQA_texts.jsonl --image=MMQA_images.jsonl --table=MMQA_tables.jsonl --output=id_url_mapping.jsonl +""" + +import fire +import json +import jsonlines +from pathlib import Path +from m3docvqa.downloader import download_wiki_page +from m3docvqa.pdf_utils import is_pdf_downloaded, is_pdf_clean, get_images_from_pdf +from m3docvqa.split_utils import create_split_dirs +from m3docvqa.mmqa_downloader import download_and_decompress_mmqa +from m3docvqa.wiki_mapper import generate_wiki_links_mapping +from loguru import logger +from tqdm.auto import tqdm + + +def _prepare_download( + metadata_path: Path | str, + output_dir: Path | str, + first_n: int, + doc_ids: set, + ) -> tuple[list[str], list[Path]]: + """Prepare URLs and save paths for downloading. + + Args: + metadata_path (Path): Path to the metadata JSONL file. + output_dir (str): Directory where files will be saved. + first_n (int): Maximum number of entries to process. + doc_ids (set): Set of doc ids to filter for downloading. + + Returns: + tuple[list[str], list[Path]]: URLs and save paths for downloading. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + urls, save_paths = [], [] + with jsonlines.open(metadata_path) as reader: + for line in reader: + if len(urls) == first_n: + break + + doc_id = line.get("id") + if doc_ids and doc_id not in doc_ids: + continue + + url = line.get("url") + save_path = output_dir / f"{doc_id}.pdf" + if not is_pdf_downloaded(save_path): + urls.append(url) + save_paths.append(save_path) + + return urls, save_paths + + +def download_pdfs( + metadata_path: Path | str, + pdf_dir: Path | str, + result_log_path: Path | str, + supporting_doc_ids_per_split: Path | str, + first_n: int = -1, + proc_id: int = 0, + n_proc: int = 1, + split: str = 'dev', + ): + """Download Wikipedia pages as PDFs.""" + # Load document ids for the specified split + if supporting_doc_ids_per_split: + with open(supporting_doc_ids_per_split, "r") as f: + doc_ids_per_split = json.load(f) + split_doc_ids = { + "train": set(doc_ids_per_split.get("train", [])), + "dev": set(doc_ids_per_split.get("dev", [])), + "all": set(doc_ids_per_split.get("train", []) + doc_ids_per_split.get("dev", [])) + } + if split not in split_doc_ids: + raise ValueError(f"Invalid or missing split. Expected one of {split_doc_ids.keys()}") + doc_ids = split_doc_ids.get(split, split_doc_ids.get("all")) + logger.info(f"Downloading documents for split: {split} with {len(doc_ids)} document IDs.") + + urls, save_paths = _prepare_download(metadata_path, pdf_dir, first_n, doc_ids) + logger.info(f"Starting download of {len(urls)} PDFs to {pdf_dir}") + download_results = download_wiki_page(urls, save_paths, "pdf", result_log_path, proc_id, n_proc) + logger.info(f"Download completed with {sum(download_results)} successful downloads out of {len(urls)}") + + +def check_pdfs(pdf_dir: str, proc_id: int = 0, n_proc: int = 1): + """Verifies the integrity of downloaded PDFs.""" + corrupted_paths = [] + total_checked, corrupted_count = 0, 0 + + pdf_files = list(Path(pdf_dir).glob("*.pdf")) + for pdf_path in tqdm(pdf_files, disable=(proc_id != 0), desc="Checking PDFs"): + total_checked += 1 + if not is_pdf_downloaded(pdf_path) or not is_pdf_clean(pdf_path): + corrupted_paths.append(pdf_path) + corrupted_count += 1 + + logger.info(f"Checked {total_checked} PDFs: {corrupted_count} corrupted files.") + if corrupted_paths: + logger.warning(f"Corrupted PDFs: {corrupted_paths}") + + +def extract_images(pdf_dir: str, image_dir: str, save_type='png'): + """Extracts images from downloaded PDFs.""" + Path(image_dir).mkdir(parents=True, exist_ok=True) + + pdf_files = list(Path(pdf_dir).glob("*.pdf")) + if not pdf_files: + logger.warning(f"No PDFs found in {pdf_dir} for image extraction.") + return + + logger.info(f"Starting image extraction from {len(pdf_files)} PDFs in {pdf_dir}.") + + for pdf_path in tqdm(pdf_files, desc="Extracting images", unit="PDF"): + get_images_from_pdf(pdf_path, save_dir=image_dir, save_type=save_type) + logger.info(f"Images extracted from PDFs in {pdf_dir}") + + +def organize_files(all_pdf_dir: Path | str, target_dir_base: Path | str, split_metadata_file: str | Path, split: str): + """Organizes PDFs into directory splits based on split information file.""" + create_split_dirs( + all_pdf_dir=all_pdf_dir, + target_dir_base=target_dir_base, + split_metadata_file=split_metadata_file, + split=split, + ) + logger.info(f"Files organized for {split} split: in {target_dir_base}") + + +def download_mmqa(output_dir: str): + """Downloads and decompresses the MMQA dataset. + + Args: + output_dir (str): Directory where the MMQA files will be downloaded and decompressed. + """ + logger.info(f"Starting MMQA dataset download to {output_dir}") + download_and_decompress_mmqa(output_dir) + logger.info(f"MMQA dataset downloaded and decompressed successfully in {output_dir}") + + +def generate_wiki_mapping(text: str, image: str, table: str, output: str = "id_url_mapping.jsonl"): + """Generates a mapping of 'id' to 'url' from multiple JSONL files. + + Args: + text (str): Path to the JSONL file containing text data from multimodalqa dataset with 'id' and 'url' fields. + image (str): Path to the JSONL file containing image data from multimodalqa dataset with 'id' and 'url' fields. + table (str): Path to the JSONL file containing table data from multimodalqa dataset with 'id' and 'url' fields. + output (str): Path to save the output JSONL file. Defaults to 'id_url_mapping.jsonl'. + """ + logger.info("Starting wiki mapping generation...") + generate_wiki_links_mapping(text_file=text, image_file=image, table_file=table, output_file=output) + logger.info(f"Wiki mapping successfully saved to {output}") + + +def main(): + fire.Fire({ + "download_mmqa": download_mmqa, + "generate_wiki_mapping": generate_wiki_mapping, + "download_pdfs": download_pdfs, + "check_pdfs": check_pdfs, + "extract_images": extract_images, + "organize_files": organize_files, + }) + + +if __name__ == "__main__": + main() diff --git a/m3docvqa/pyproject.toml b/m3docvqa/pyproject.toml new file mode 100644 index 0000000..dcbd760 --- /dev/null +++ b/m3docvqa/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools>=69.5"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[project] +name = "m3docvqa" +version = "0.0.1" +description = "M3DocVQA - Dataset package for M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding." +readme = "README.md" +requires-python = ">=3.10" +classifiers = ["Programming Language :: Python :: 3"] +dependencies = [ + "loguru", + "jsonlines", + "fire", + "pytest-playwright", + "figure", + "pdf2image", + "pillow", + "numpy<2.0.0", + "pdfrw", + "tqdm", + "reportlab", # only used in the test cases +] + +[tool.ruff] +target-version = "py310" + +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "I"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] diff --git a/m3docvqa/src/m3docvqa/downloader.py b/m3docvqa/src/m3docvqa/downloader.py new file mode 100644 index 0000000..ef3ec8a --- /dev/null +++ b/m3docvqa/src/m3docvqa/downloader.py @@ -0,0 +1,126 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Downloader Module for M3DocVQA + +This module provides functions to download Wikipedia pages in either PDF or PNG format +for the M3DocVQA dataset. It uses Playwright to load and capture the pages in a headless +browser environment and saves each page in the specified format. + +Functions: + - _download_wiki_page: Downloads a single Wikipedia page as a PDF or PNG. + - download_wiki_page: Manages the downloading of multiple Wikipedia pages. +""" + +from playwright.sync_api import sync_playwright +from loguru import logger +from pathlib import Path +import jsonlines +from tqdm.auto import tqdm +from m3docvqa.pdf_utils import is_pdf_downloaded + + +def _download_wiki_page(args: tuple[int, int, str, str, str, int]) -> tuple[bool, Exception | None]: + """Download a single Wikipedia page as a PDF or PNG using Playwright. + + Args: + args (Tuple[int, int, str, str, str, int]): Contains order in batch, total count, URL, save path, + save type ('pdf' or 'png'), and process ID. + + Returns: + Tuple[bool, Optional[Exception]]: A tuple where the first element is a boolean indicating success, + and the second element is an exception if an error occurred, or None otherwise. + """ + order_i, total, url, save_path, save_type, proc_id = args + + if is_pdf_downloaded(save_path): + if proc_id == 0: + logger.info(f"{order_i} / {total} - {save_path} already downloaded") + return True, None + + try: + with sync_playwright() as p: + browser = p.chromium.launch(headless=True) + context = browser.new_context(ignore_https_errors=True) + page = context.new_page() + page.set_default_timeout(30000) # 30 seconds timeout + + page.goto(url) + if save_type == 'png': + page.screenshot(path=save_path, full_page=True) + elif save_type == 'pdf': + page.emulate_media(media="screen") + page.pdf(path=save_path) + + browser.close() + + return True, None + except Exception as error: + logger.warning(f"Failed to download {url} as {save_type}. Error: {error}") + return False, error + + +def download_wiki_page( + urls: list[str], + save_paths: list[str], + save_type: str, + result_jsonl_path: str, + proc_id: int = 0, + n_proc: int = 1 +) -> list[bool]: + """Download multiple Wikipedia pages and log progress. + + Args: + urls (List[str]): List of Wikipedia URLs to download. + save_paths (List[str]): List of paths where each downloaded file will be saved. + save_type (str): File type to save each page as ('pdf' or 'png'). + result_jsonl_path (str): Path to the JSONL file where download results will be logged. + proc_id (int, optional): Process ID for parallel processing. Defaults to 0. + n_proc (int, optional): Total number of processes running in parallel. Defaults to 1. + + Returns: + List[bool]: A list of booleans indicating whether each download was successful. + """ + total = len(urls) + all_args = [(i, total, url, str(save_path), save_type, proc_id) + for i, (url, save_path) in enumerate(zip(urls, save_paths))] + + pbar = tqdm(total=len(all_args), ncols=100, disable=not (proc_id == 0)) + + results = [] + n_downloaded = 0 + + # Log results to a JSONL file + with jsonlines.open(result_jsonl_path, 'w') as writer: + for args in all_args: + downloaded, error = _download_wiki_page(args) + + if downloaded: + n_downloaded += 1 + + pbar.set_description(f"Process: {proc_id}/{n_proc} - Downloaded: {n_downloaded}/{total}") + pbar.update(1) + + results.append(downloaded) + writer.write({ + 'downloaded': downloaded, + 'args': [arg if not isinstance(arg, Path) else str(arg) for arg in args], + 'error': str(error) if error else None + }) + + pbar.close() + return results diff --git a/m3docvqa/src/m3docvqa/mmqa_downloader.py b/m3docvqa/src/m3docvqa/mmqa_downloader.py new file mode 100644 index 0000000..fcb2d53 --- /dev/null +++ b/m3docvqa/src/m3docvqa/mmqa_downloader.py @@ -0,0 +1,119 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Downloads the portion of the multimodalqa dataset from https://github.com/allenai/multimodalqa/tree/master/dataset +that is useful for creating the m3docvqa dataset. +""" +import gzip +import requests +from loguru import logger +from pathlib import Path + + +def download_file(url: str, output_path: str) -> None: + """Downloads a file from a given URL and saves it to the specified output path. + + Args: + url (str): The URL of the file to download. + output_path (str): The path where the downloaded file will be saved. + + Raises: + requests.exceptions.RequestException: If the file could not be downloaded. + """ + try: + response = requests.get(url, stream=True) + response.raise_for_status() + with open(output_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + logger.info(f"File downloaded successfully: {output_path}") + except requests.exceptions.RequestException as e: + logger.error(f"Failed to download file from {url}: {e}") + raise + +def decompress_gz_file(input_path: str | Path, output_path: str | Path) -> None: + """ + Decompresses a `.gz` file into its original format. + + Args: + input_path (str | Path): Path to the `.gz` file. + output_path (str | Path): Path where the decompressed file will be written. + + Raises: + ValueError: If the input path does not exist or is not a file. + """ + input_path = Path(input_path) + output_path = Path(output_path) + + if not input_path.is_file(): + raise ValueError(f"The input file {input_path} does not exist or is not a file.") + + with gzip.open(input_path, "rb") as f_in, open(output_path, "wb") as f_out: + f_out.write(f_in.read()) + logger.info(f"Decompressed {input_path} to {output_path}") + +def download_and_decompress_mmqa(output_directory: str | Path) -> None: + """ + Downloads and decompresses the MultiModalQA dataset files into the specified directory. + + Args: + output_directory (str | Path): The directory where the files will be stored. + + Steps: + 1. Creates the output directory if it doesn't exist. + 2. Downloads the `.jsonl.gz` files. + 3. Decompresses each `.gz` file into its `.jsonl` format. + 4. Removes the `.gz` files after decompression. + + Raises: + requests.exceptions.RequestException: If any of the files could not be downloaded. + """ + # Define base URL and file names + base_url = "https://github.com/allenai/multimodalqa/raw/refs/heads/master/dataset/" + files = [ + "MMQA_texts.jsonl.gz", + "MMQA_tables.jsonl.gz", + "MMQA_images.jsonl.gz", + "MMQA_dev.jsonl.gz", + "MMQA_train.jsonl.gz", + ] + + output_directory = Path(output_directory) + + # Ensure the output directory exists + if not output_directory.exists(): + output_directory.mkdir(parents=True, exist_ok=True) + logger.info(f"Created output directory: {output_directory}") + + for file_name in files: + compressed_path = output_directory / file_name + decompressed_path = output_directory / file_name.replace(".gz", "") + + try: + # Step 1: Download the file + logger.info(f"Downloading {file_name}...") + download_file(base_url + file_name, compressed_path) + + # Step 2: Decompress the file + logger.info(f"Decompressing {file_name}...") + decompress_gz_file(compressed_path, decompressed_path) + + # Step 3: Remove the compressed file + compressed_path.unlink() + logger.info(f"Removed compressed file: {compressed_path}") + except Exception as e: + logger.error(f"Error processing {file_name}: {e}") + raise diff --git a/m3docvqa/src/m3docvqa/pdf_utils.py b/m3docvqa/src/m3docvqa/pdf_utils.py new file mode 100644 index 0000000..1806930 --- /dev/null +++ b/m3docvqa/src/m3docvqa/pdf_utils.py @@ -0,0 +1,126 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""PDF Utilities Module for M3DocVQA. + +This module provides utility functions for managing and processing PDF files in the M3DocVQA dataset. +It includes functions for checking if a PDF has been downloaded, verifying if a PDF is clean (not corrupted), +and extracting images from PDF pages. + +Functions: + - is_pdf_downloaded: Checks if a given PDF file exists and can be opened without errors. + - is_pdf_clean: Checks if a PDF file is clean (not corrupted) and can be read without issues. + - get_images_from_pdf: Extracts images from each page of a PDF and optionally saves them in a specified directory. +""" + +from pdf2image import convert_from_path +from PIL import Image +from pdfrw import PdfReader +from pathlib import Path +from io import BytesIO +from loguru import logger + + +def is_pdf_downloaded(pdf_path: str) -> bool: + """Check if the PDF file exists and can be opened without errors. + + Args: + pdf_path (str): Path to the PDF file. + + Returns: + bool: True if the PDF file is downloaded and accessible; False otherwise. + """ + try: + with open(pdf_path, "rb") as f: + f.read(1) # Attempt to read a byte to verify file exists and is accessible + return True + except Exception as e: + logger.trace(f"Failed to open PDF at {pdf_path}: {e}") + return False + + +def is_pdf_clean(pdf_path: str) -> bool: + """Verify if a PDF file is clean (not corrupted) and can be read without errors. + + Args: + pdf_path (str): Path to the PDF file. + + Returns: + bool: True if the PDF file is clean and readable; False otherwise. + """ + try: + with open(pdf_path, "rb") as f: + idata = f.read() + ibuffer = BytesIO(idata) + PdfReader(ibuffer) # Attempt to read the PDF structure for validity + return True + except Exception as error: + logger.warning(f"PDF at {pdf_path} is corrupted or unreadable: {error}") + return False + + +def get_images_from_pdf( + pdf_path: str, + save_dir: str = None, + max_pages: int = None, + dpi_resolution: int = 144, + save_type: str = 'png' +) -> list[Image.Image]: + """Extract images from each page of a PDF and optionally save them to a directory. + + Args: + pdf_path (str): Path to the PDF file. + save_dir (str, optional): Directory where images will be saved. If None, images are not saved. Defaults to None. + max_pages (int, optional): Maximum number of pages to process. If None, all pages are processed. Defaults to None. + dpi_resolution (int, optional): Resolution for image extraction. Defaults to 144. + save_type (str, optional): Image file type to save as ('png', 'jpg', etc.). Defaults to 'png'. + + Returns: + list[Image.Image]: A list of images extracted from each page of the PDF. + """ + pdf_path_obj = Path(pdf_path) + assert pdf_path_obj.exists(), f"PDF file {pdf_path} does not exist." + + out_images = [] + + # Create save directory if saving images is enabled + if save_dir: + save_dir_path = Path(save_dir) + save_dir_path.mkdir(exist_ok=True, parents=True) + + try: + # Convert PDF to images using pdf2image + images = convert_from_path(pdf_path, dpi=dpi_resolution) + logger.info(f"PDF {pdf_path} has {len(images)} pages.") + + # Limit the number of pages processed if max_pages is set + if max_pages: + images = images[:max_pages] + + for page_index, image in enumerate(images): + out_images.append(image) + + # Save image if save directory is specified + if save_dir: + save_page_path = save_dir_path / f"{pdf_path_obj.stem}_{page_index + 1}.{save_type}" + if not save_page_path.exists(): + image.save(save_page_path) + logger.info(f"Saved page {page_index + 1} as image at {save_page_path}") + + except Exception as e: + logger.error(f"Error extracting images from PDF {pdf_path}: {e}") + + return out_images diff --git a/m3docvqa/src/m3docvqa/split_utils.py b/m3docvqa/src/m3docvqa/split_utils.py new file mode 100644 index 0000000..c7dbcc5 --- /dev/null +++ b/m3docvqa/src/m3docvqa/split_utils.py @@ -0,0 +1,93 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Split Utilities Module for M3DocVQA. + +This module provides utilities for organizing PDF files into split directories (e.g., train, dev) +and compressing these directories using functions from compression_utils. + +Functions: + - create_split_dirs: Copies specified PDF files into separate split directories (e.g., train, dev). + - compress_split_directory: Compresses the split directory into a `.tar.gz` archive. +""" + +from pathlib import Path +import shutil +import json +import jsonlines +from loguru import logger + +def create_split_dirs( + all_pdf_dir: str | Path, + target_dir_base: str | Path, + split_metadata_file: str | Path, + split: str +) -> None: + """Copies specified PDF files into a target directory based on a given split. + + Args: + all_pdf_dir (Union[str, Path]): Path to the directory containing all downloaded PDF files. + target_dir_base (Union[str, Path]): Base directory where the split-specific directory will be created. + split_metadata_file (Union[str, Path]): Path to the metadata JSONL file for the split. + split (str): Split type ('train' or 'dev'). + + Raises: + FileNotFoundError: If the JSONL metadata file does not exist. + ValueError: If the split is not 'train' or 'dev'. + """ + # Validate split type + if split not in {"train", "dev"}: + raise ValueError(f"Invalid split: {split}. Expected 'train' or 'dev'.") + + all_pdf_dir = Path(all_pdf_dir) + target_dir = Path(target_dir_base) / f'pdfs_{split}' + target_dir.mkdir(parents=True, exist_ok=True) + + # Validate metadata file + split_metadata_file = Path(split_metadata_file) + if not split_metadata_file.exists(): + raise FileNotFoundError(f"Metadata file for split '{split}' not found: {split_metadata_file}") + + # Load all doc IDs for the split + split_doc_ids = [] + with jsonlines.open(split_metadata_file) as reader: + for obj in reader: + split_doc_ids.extend(doc['doc_id'] for doc in obj['supporting_context']) + + # Remove duplicates and log the count + split_doc_ids = sorted(set(split_doc_ids)) + logger.info(f"Split {split} -> # supporting context: {len(split_doc_ids)}") + + # Save the split-specific IDs to a JSON file + split_doc_ids_output_path = Path(f'./{split}_doc_ids.json') + with open(split_doc_ids_output_path, 'w') as f: + json.dump(split_doc_ids, f, indent=4) + logger.info(f"Split {split} -> saved doc IDs at {split_doc_ids_output_path}") + + # Copy PDF files to the target directory + missing_files = [] + for doc_id in split_doc_ids: + pdf_file = all_pdf_dir / f"{doc_id}.pdf" + if pdf_file.exists(): + shutil.copy(pdf_file, target_dir / pdf_file.name) + else: + missing_files.append(pdf_file) + + if missing_files: + logger.warning(f"Warning: {len(missing_files)} files are missing and will be skipped.") + for missing_file in missing_files: + logger.warning(f" Missing: {missing_file}") + diff --git a/m3docvqa/src/m3docvqa/wiki_mapper.py b/m3docvqa/src/m3docvqa/wiki_mapper.py new file mode 100644 index 0000000..c826062 --- /dev/null +++ b/m3docvqa/src/m3docvqa/wiki_mapper.py @@ -0,0 +1,140 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Wiki Mapper. + +This module provides functionality to parse multimodalqa JSONL files that has been already downloaded that contains 'id' and 'url' mappings, +merge them into a single mapping, and save the result to a JSONL file. + +Each JSONL file should contain one JSON object per line with the following structure: +{ + "title": "Article Title", + "url": "https://en.wikipedia.org/wiki/Article_Title", + "id": "unique_id", + "text": "Text description of the article." +} +""" + +import json +from pathlib import Path +from loguru import logger + + +def parse_jsonl(file_path: str | Path) -> dict[str, str]: + """Parses a JSONL file from the multimodalqa dataset to extract a mapping of 'id' to 'url'. + + Args: + file_path (str | Path): Path to the JSONL file. + + Returns: + dict[str, str]: A dictionary mapping each 'id' to its corresponding 'url'. + + Raises: + FileNotFoundError: If the JSONL file does not exist. + ValueError: If the file contains invalid JSON lines. + """ + file_path = Path(file_path) + if not file_path.is_file(): + logger.error(f"File not found: {file_path}") + raise FileNotFoundError(f"File not found: {file_path}") + + id_url_mapping = {} + try: + with file_path.open("r") as file: + for line in file: + data = json.loads(line.strip()) + entry_id = data.get("id") + url = data.get("url") + if entry_id and url: + id_url_mapping[entry_id] = url + except json.JSONDecodeError as e: + logger.error(f"Error decoding JSON in file {file_path}: {e}") + raise ValueError(f"Invalid JSON in file {file_path}: {e}") + + logger.info(f"Parsed {len(id_url_mapping)} entries from {file_path}") + return id_url_mapping + + +def merge_mappings(mappings: list[dict[str, str]]) -> dict[str, str]: + """Merges multiple mappings into a single dictionary. + + Args: + mappings (list[dict[str, str]]): A list of dictionaries containing 'id' to 'url' mappings. + + Returns: + dict[str, str]: A merged dictionary containing all 'id' to 'url' mappings. + """ + merged_mapping = {} + for mapping in mappings: + merged_mapping.update(mapping) + logger.info(f"Merged {len(mappings)} mappings with a total of {len(merged_mapping)} entries.") + return merged_mapping + + +def save_mapping_to_jsonl(mapping: dict[str, str], output_file: str | Path) -> None: + """Saves the 'id'-to-'url' mapping to a JSONL file. + + Args: + mapping (dict[str, str]): The dictionary containing 'id' to 'url' mappings. + output_file (str | Path): The path to the output JSONL file. + + Raises: + IOError: If the file cannot be written. + """ + output_file = Path(output_file) + try: + with output_file.open("w") as file: + for entry_id, url in mapping.items(): + json.dump({"id": entry_id, "url": url}, file) + file.write("\n") + logger.info(f"Mapping saved to {output_file}") + except IOError as e: + logger.error(f"Error writing to file {output_file}: {e}") + raise + + +def generate_wiki_links_mapping( + text_file: str | Path, image_file: str | Path, table_file: str | Path, output_file: str | Path = "id_url_mapping.jsonl" +) -> None: + """Orchestrates the process of parsing input files, merging mappings, and saving the result to JSONL. + + Args: + text_file (str | Path): Path to the JSONL file containing text data with 'id' and 'url' fields. + image_file (str | Path): Path to the JSONL file containing image data with 'id' and 'url' fields. + table_file (str | Path): Path to the JSONL file containing table data with 'id' and 'url' fields. + output_file (str | Path): Path to save the output JSONL file. Defaults to 'id_url_mapping.jsonl'. + + Raises: + Exception: If any part of the pipeline fails. + """ + try: + # Parse input files + logger.info("Parsing JSONL files...") + text_mapping = parse_jsonl(text_file) + image_mapping = parse_jsonl(image_file) + table_mapping = parse_jsonl(table_file) + + # Merge mappings + logger.info("Merging mappings...") + merged_mapping = merge_mappings([text_mapping, image_mapping, table_mapping]) + + # Save the merged mapping + logger.info("Saving merged mapping to output file...") + save_mapping_to_jsonl(merged_mapping, output_file) + logger.info(f"Mapping successfully generated and saved to {output_file}") + except Exception as e: + logger.error(f"Error generating wiki links mapping: {e}") + raise diff --git a/m3docvqa/tests/test_downloader.py b/m3docvqa/tests/test_downloader.py new file mode 100644 index 0000000..454ea39 --- /dev/null +++ b/m3docvqa/tests/test_downloader.py @@ -0,0 +1,105 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from unittest.mock import patch, MagicMock +from pathlib import Path +import jsonlines +from m3docvqa.downloader import _download_wiki_page, download_wiki_page + + +@pytest.fixture +def test_urls_and_paths(tmp_path): + """Fixture to provide sample URLs and save paths for testing.""" + urls = ["https://en.wikipedia.org/wiki/SamplePage1", "https://en.wikipedia.org/wiki/SamplePage2"] + save_paths = [str(tmp_path / "sample1.pdf"), str(tmp_path / "sample2.pdf")] + return urls, save_paths + + +@patch("m3docvqa.downloader.sync_playwright") +def test__download_wiki_page_pdf(mock_playwright, tmp_path): + """Test downloading a single page as a PDF.""" + url = "https://en.wikipedia.org/wiki/SamplePage" + save_path = tmp_path / "sample.pdf" + args = (0, 1, url, str(save_path), 'pdf', 0) + + # Mock Playwright behavior + mock_browser = MagicMock() + mock_context = MagicMock() + mock_page = MagicMock() + mock_playwright.return_value.__enter__.return_value.chromium.launch.return_value = mock_browser + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + + # Call the function + downloaded, error = _download_wiki_page(args) + + # Assertions + assert downloaded is True + assert error is None + mock_page.goto.assert_called_once_with(url) + mock_page.pdf.assert_called_once_with(path=str(save_path)) + + +@patch("m3docvqa.downloader.sync_playwright") +def test__download_wiki_page_png(mock_playwright, tmp_path): + """Test downloading a single page as a PNG.""" + url = "https://en.wikipedia.org/wiki/SamplePage" + save_path = tmp_path / "sample.png" + args = (0, 1, url, str(save_path), 'png', 0) + + # Mock Playwright behavior + mock_browser = MagicMock() + mock_context = MagicMock() + mock_page = MagicMock() + mock_playwright.return_value.__enter__.return_value.chromium.launch.return_value = mock_browser + mock_browser.new_context.return_value = mock_context + mock_context.new_page.return_value = mock_page + + # Call the function + downloaded, error = _download_wiki_page(args) + + # Assertions + assert downloaded is True + assert error is None + mock_page.goto.assert_called_once_with(url) + mock_page.screenshot.assert_called_once_with(path=str(save_path), full_page=True) + + +@patch("m3docvqa.downloader._download_wiki_page") +def test_download_wiki_page_batch(mock_download_wiki_page, tmp_path, test_urls_and_paths): + """Test batch downloading multiple Wikipedia pages.""" + urls, save_paths = test_urls_and_paths + result_jsonl_path = tmp_path / "download_results.jsonl" + + # Mock individual downloads to always succeed + mock_download_wiki_page.side_effect = [(True, None), (True, None)] + + # Call the function + results = download_wiki_page(urls, save_paths, 'pdf', str(result_jsonl_path), proc_id=0, n_proc=1) + + # Assertions + assert results == [True, True] + assert result_jsonl_path.exists() + + # Check JSONL log entries + with jsonlines.open(result_jsonl_path, 'r') as reader: + log_entries = list(reader) + assert len(log_entries) == 2 + assert log_entries[0]['downloaded'] is True + assert log_entries[0]['error'] is None + assert log_entries[1]['downloaded'] is True + assert log_entries[1]['error'] is None diff --git a/m3docvqa/tests/test_pdf_utils.py b/m3docvqa/tests/test_pdf_utils.py new file mode 100644 index 0000000..8a12f40 --- /dev/null +++ b/m3docvqa/tests/test_pdf_utils.py @@ -0,0 +1,86 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from m3docvqa.pdf_utils import is_pdf_downloaded, is_pdf_clean, get_images_from_pdf +from pathlib import Path +from PIL import Image +from reportlab.pdfgen import canvas # For creating sample PDFs + + +@pytest.fixture +def sample_pdf(tmp_path) -> Path: + """Create a temporary sample PDF file for testing.""" + pdf_path = tmp_path / "sample.pdf" + c = canvas.Canvas(str(pdf_path)) + c.drawString(100, 100, "Sample PDF text for testing.") # Add sample text to the PDF + c.save() + return pdf_path + + +@pytest.fixture +def corrupted_pdf(tmp_path) -> Path: + """Create a temporary, corrupted PDF file for testing.""" + pdf_path = tmp_path / "corrupted.pdf" + pdf_path.write_bytes(b"%PDF-1.4 corrupted content") # Write incomplete/corrupted PDF content + return pdf_path + + +def test_is_pdf_downloaded_existing_pdf(sample_pdf): + """Test is_pdf_downloaded on a valid, existing PDF.""" + assert is_pdf_downloaded(str(sample_pdf)) is True, "Expected PDF to be recognized as downloaded." + + +def test_is_pdf_downloaded_nonexistent_pdf(tmp_path): + """Test is_pdf_downloaded on a non-existent PDF file.""" + non_existent_pdf = tmp_path / "non_existent.pdf" + assert is_pdf_downloaded(str(non_existent_pdf)) is False, "Expected non-existent PDF to be marked as not downloaded." + + +def test_is_pdf_clean_valid_pdf(sample_pdf): + """Test is_pdf_clean on a valid, clean PDF.""" + assert is_pdf_clean(str(sample_pdf)) is True, "Expected PDF to be recognized as clean." + + +def test_is_pdf_clean_corrupted_pdf(corrupted_pdf): + """Test is_pdf_clean on a corrupted PDF.""" + assert is_pdf_clean(str(corrupted_pdf)) is False, "Expected corrupted PDF to be marked as not clean." + + +def test_get_images_from_pdf_extract_images(sample_pdf, tmp_path): + """Test get_images_from_pdf to ensure it extracts images correctly.""" + image_dir = tmp_path / "images" + images = get_images_from_pdf(str(sample_pdf), save_dir=str(image_dir), dpi_resolution=72, save_type='png') + + # Verify that at least one image was extracted + assert len(images) > 0, "Expected at least one image to be extracted from the PDF." + + # Verify that images were saved to the directory + saved_images = list(image_dir.glob("*.png")) + assert len(saved_images) == len(images), "Expected number of saved images to match the number of extracted images." + + # Verify that the saved image files exist and are valid + for image_path in saved_images: + with Image.open(image_path) as img: + assert img.format == "PNG", "Expected saved image to be in PNG format." + + +def test_get_images_from_pdf_no_save_dir(sample_pdf): + """Test get_images_from_pdf without saving images, only returning them as a list.""" + images = get_images_from_pdf(str(sample_pdf), save_dir=None, dpi_resolution=72) + assert len(images) > 0, "Expected at least one image to be returned without saving." + assert all(isinstance(image, Image.Image) for image in images), "Expected all returned items to be PIL Image objects." + diff --git a/m3docvqa/tests/test_split_utils.py b/m3docvqa/tests/test_split_utils.py new file mode 100644 index 0000000..ef2bbd3 --- /dev/null +++ b/m3docvqa/tests/test_split_utils.py @@ -0,0 +1,107 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pathlib import Path +import shutil +import json +import jsonlines +from unittest.mock import MagicMock, patch +from m3docvqa.split_utils import create_split_dirs + + +@pytest.fixture +def mock_pdf_directory(tmp_path): + # Create a temporary directory for PDFs + pdf_dir = tmp_path / "pdfs" + pdf_dir.mkdir() + # Add some mock PDF files + (pdf_dir / "doc1.pdf").write_text("PDF content for doc1") + (pdf_dir / "doc2.pdf").write_text("PDF content for doc2") + return pdf_dir + + +@pytest.fixture +def mock_metadata_file(tmp_path): + # Create a temporary metadata file in JSONL format + metadata_file = tmp_path / "MMQA_train.jsonl" + data = [ + {"supporting_context": [{"doc_id": "doc1"}]}, + {"supporting_context": [{"doc_id": "doc2"}]} + ] + with jsonlines.open(metadata_file, mode='w') as writer: + writer.write_all(data) + return metadata_file + + +@pytest.fixture +def mock_target_directory(tmp_path): + return tmp_path / "target" + + +def test_create_split_dirs(mock_pdf_directory, mock_metadata_file, mock_target_directory): + """Test the create_split_dirs function.""" + # Prepare the split directory + split = "train" + + # Call the function to create split directories + create_split_dirs( + all_pdf_dir=mock_pdf_directory, + target_dir_base=mock_target_directory, + split_metadata_file=mock_metadata_file, + split=split + ) + + # Assert that the target directory exists and contains the expected PDF files + target_dir = mock_target_directory / f"pdfs_{split}" + assert target_dir.exists(), f"Directory {target_dir} was not created" + assert (target_dir / "doc1.pdf").exists(), "doc1.pdf was not copied" + assert (target_dir / "doc2.pdf").exists(), "doc2.pdf was not copied" + + +def test_create_split_dirs_missing_pdf(mock_metadata_file, mock_target_directory): + """Test create_split_dirs when PDF files are missing.""" + # Prepare the split directory + split = "train" + all_pdf_dir = Path("non_existing_pdf_dir") + + # Call the function and verify that the missing PDFs are handled correctly + create_split_dirs( + all_pdf_dir=all_pdf_dir, + target_dir_base=mock_target_directory, + split_metadata_file=mock_metadata_file, + split=split + ) + + target_dir = mock_target_directory / f"pdfs_{split}" + assert target_dir.exists(), f"Directory {target_dir} was not created" + assert not (target_dir / "doc1.pdf").exists(), "doc1.pdf should not exist" + assert not (target_dir / "doc2.pdf").exists(), "doc2.pdf should not exist" + + +@pytest.mark.parametrize("split, expected_error", [ + ("test_split", ValueError), # Invalid split type + (None, ValueError), # Missing split +]) +def test_create_split_dirs_invalid_split_type(mock_pdf_directory, mock_metadata_file, mock_target_directory, split, expected_error): + """Test invalid split types in create_split_dirs.""" + with pytest.raises(expected_error): + create_split_dirs( + all_pdf_dir=mock_pdf_directory, + target_dir_base=mock_target_directory, + split_metadata_file=mock_metadata_file, + split=split + ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a094ed9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,53 @@ +[build-system] +requires = ["setuptools>=69.5"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[project] +name = "m3docrag" +version = "0.0.1" +description = "Multimodal Document Understanding with RAG" +readme = "README.md" +requires-python = ">=3.10" +classifiers = ["Programming Language :: Python :: 3"] +dependencies = [ + "accelerate==1.1.0", + "loguru", + "requests", + "setuptools==69.5", + "transformers", + "tokenizers", + "flash-attn==2.5.8", + "bitsandbytes==0.43.1", + "safetensors", + "gpustat", + "icecream", + "pdf2image", + "numpy==1.26.4", + "torchvision", + "jsonlines", + "editdistance", + "einops", + "fire", + "peft", + "timm", + "sentencepiece", + "colpali-engine==0.3.1", + "easyocr", + "qwen-vl-utils", + "faiss-cpu", + "word2number", + "datasets>=3.0.0", + "python-dotenv", +] + +[tool.ruff] +target-version = "py310" + +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "I"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] \ No newline at end of file diff --git a/src/m3docrag.egg-info/PKG-INFO b/src/m3docrag.egg-info/PKG-INFO new file mode 100644 index 0000000..b005234 --- /dev/null +++ b/src/m3docrag.egg-info/PKG-INFO @@ -0,0 +1,215 @@ +Metadata-Version: 2.2 +Name: m3docrag +Version: 0.0.1 +Summary: Multimodal Document Understanding with RAG +Classifier: Programming Language :: Python :: 3 +Requires-Python: >=3.10 +Description-Content-Type: text/markdown +Requires-Dist: accelerate==1.1.0 +Requires-Dist: loguru +Requires-Dist: requests +Requires-Dist: setuptools==69.5 +Requires-Dist: transformers +Requires-Dist: tokenizers +Requires-Dist: flash-attn==2.5.8 +Requires-Dist: bitsandbytes==0.43.1 +Requires-Dist: safetensors +Requires-Dist: gpustat +Requires-Dist: icecream +Requires-Dist: pdf2image +Requires-Dist: numpy==1.26.4 +Requires-Dist: torchvision +Requires-Dist: jsonlines +Requires-Dist: editdistance +Requires-Dist: einops +Requires-Dist: fire +Requires-Dist: peft +Requires-Dist: timm +Requires-Dist: sentencepiece +Requires-Dist: colpali-engine==0.3.1 +Requires-Dist: easyocr +Requires-Dist: qwen-vl-utils +Requires-Dist: faiss-cpu +Requires-Dist: word2number +Requires-Dist: datasets>=3.0.0 +Requires-Dist: python-dotenv + +# M3DocRAG + +Code for [M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding](https://m3docrag.github.io/) + +by [Jaemin Cho](https://j-min.io/), [Debanjan Mahata](https://sites.google.com/a/ualr.edu/debanjan-mahata/), [Ozan İrsoy](https://wtimesx.com/), [Yujie He](https://scholar.google.com/citations?user=FbeAZGgAAAAJ&hl=en), [Mohit Bansal](https://www.cs.unc.edu/~mbansal/) + +# Summary + +## Comparison with previous approches + + + +Comparison of multi-modal document understanding pipelines. Previous works focus on (a) **Single-page DocVQA** that cannot handle many long documents or (b) **Text-based RAG** that ignores visual information. Our (c) **M3DocRAG** framework retrieves relevant documents and answers questions using multi-modal retrieval and MLM components, so that it can efficiently handle many long documents while preserving visual information. + +## M3DocRAG framework + + + +Our **M3DocRAG** framework consists of three stages: (1) document embedding, (2) page retrieval, and (3) question answering. +- In (1) document embedding, we extract visual embedding (with ColPali) to represent each page from all PDF documents. +- In (2) page retrieval, we retrieve the top-K pages of high relevance (MaxSim scores) with text queries. In an open-domain setting, we create approximate page indices for faster search. +- In (3) question answering, we conduct visual question answering with multi-modal LM (e.g. Qwen2-VL) to obtain the final answer. + + +# Setup + +## Package + +We assume conda has been installed + +```bash +git clone +cd m3docrag-release +pip install -e . + +# Install Poppler (for pdf2image; check https://pdf2image.readthedocs.io/en/latest/installation.html for details) +# conda install -y poppler +# or +# apt-get install poppler-utils +``` + +## Code structure + +```bash +examples/ # scripts to run PDF embedding / RAG +src/m3docrag/ + datasets/ # data loader for existing datasets + retrieval/ # retrieval model (e.g., ColPaLi) + vqa/ # vqa model (e.g., Qwen2-VL) + rag/ # RAG model that combines retrieval and vqa models + utils/ # misc utility methods +m3docvqa/ # how to setup m3docvqa dataset +``` +## Paths: Data, Embeddings, Model checkpoints, Outputs + +```bash +# in .env +LOCAL_DATA_DIR="/job/datasets" # where to store data +LOCAL_EMBEDDINGS_DIR="/job/embeddings" # where to store embeddings +LOCAL_MODEL_DIR="/job/model" # where to store model checkpoints +LOCAL_OUTPUT_DIR="/job/output" # where to store model outputs +``` + +You can adjust variables in [`.env`](.env) to change where to store data/embedding/model checkpoint/outputs by default. They are loaded in [`src/m3docrag/utils/paths.py`](./src/m3docrag/utils/paths.py) via [python-dotenv](https://github.com/theskumar/python-dotenv). + + +## Download M3DocVQA dataset + +Please see [m3docvqa/README.md](m3docvqa/README.md) for the download instruction. + +## Donwload model checkpoints + +By default, we use colpali-v1.2 for retrival and Qwen2-VL-7B-Instruct for question answering. + +At `$LOCAL_MODEL_DIR`, download [colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2), [colpaligemma-3b-mix-448-base](https://huggingface.co/vidore/colpaligemma-3b-mix-448-base) and [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) checkpoints. + +```bash +cd $LOCAL_MODEL_DIR + +git clone https://huggingface.co/vidore/colpaligemma-3b-pt-448-base # ColPali backbone +git clone https://huggingface.co/vidore/colpali-v1.2 # ColPali adapter +git clone https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct # VQA +``` + + + + + +# Example usage + +Below we describe example usage of M3DocRAG on M3DocVQA dataset. + + +## 1. Extract PDF embeddings + +```bash +DATASET_NAME="m3-docvqa" +RETRIEVAL_MODEL_TYPE="colpali" +RETRIEVAL_MODEL_NAME="colpaligemma-3b-pt-448-base" +RETRIEVAL_ADAPTER_MODEL_NAME="colpali-v1.2" +SPLIT="dev" +EMBEDDING_NAME=$RETRIEVAL_ADAPTER_MODEL_NAME"_"$DATASET_NAME"_"$SPLIT # where to save embeddings +accelerate launch --num_processes=1 --mixed_precision=bf16 examples/run_page_embedding.py \ + --use_retrieval \ + --retrieval_model_type=$RETRIEVAL_MODEL_TYPE \ + --data_name=$DATASET_NAME \ + --split=$SPLIT \ + --loop_unique_doc_ids=True \ + --output_dir=/job/embeddings/$EMBEDDING_NAME \ + --retrieval_model_name_or_path=$RETRIEVAL_MODEL_NAME \ + --retrieval_adapter_model_name_or_path=$RETRIEVAL_ADAPTER_MODEL_NAME +``` + +## 2. Indexing + +```bash +DATASET_NAME="m3-docvqa" +RETRIEVAL_MODEL_TYPE="colpali" +RETRIEVAL_ADAPTER_MODEL_NAME="colpali-v1.2" +SPLIT="dev" +FAISS_INDEX_TYPE='ivfflat' +EMBEDDING_NAME=$RETRIEVAL_ADAPTER_MODEL_NAME"_"$DATASET_NAME"_"$SPLIT +INDEX_NAME=$EMBEDDING_NAME"_pageindex_"$FAISS_INDEX_TYPE # where to save resulting index +echo $EMBEDDING_NAME +echo $FAISS_INDEX_TYPE +python examples/run_indexing_m3docvqa.py \ + --use_retrieval \ + --retrieval_model_type=$RETRIEVAL_MODEL_TYPE \ + --data_name=$DATASET_NAME \ + --split=$SPLIT \ + --loop_unique_doc_ids=False \ + --embedding_name=$EMBEDDING_NAME \ + --faiss_index_type=$FAISS_INDEX_TYPE \ + --output_dir=/job/embeddings/$INDEX_NAME +``` + +## 3. RAG + +```bash +BACKBONE_MODEL_NAME="Qwen2-VL-7B-Instruct" +RETRIEVAL_MODEL_TYPE="colpali" +RETRIEVAL_MODEL_NAME="colpaligemma-3b-pt-448-base" +RETRIEVAL_ADAPTER_MODEL_NAME="colpali-v1.2" +EMBEDDING_NAME="colpali-v1.2_m3-docvqa_dev" # from Step 1 Embedding +SPLIT="dev" +DATASET_NAME="m3-docvqa" +FAISS_INDEX_TYPE='ivfflat' +N_RETRIEVAL_PAGES=1 +INDEX_NAME="${EMBEDDING_NAME}_pageindex_$FAISS_INDEX_TYPE" # from Step 2 Indexing +OUTPUT_SAVE_NAME="${RETRIEVAL_ADAPTER_MODEL_NAME}_${BACKBONE_MODEL_NAME}_${DATASET_NAME}" # where to save RAG results +BITS=16 # BITS=4 for 4-bit qunaitzation in low memory GPUs +python examples/run_rag_m3docvqa.py \ + --use_retrieval \ + --retrieval_model_type=$RETRIEVAL_MODEL_TYPE \ + --load_embedding=True \ + --split=$SPLIT \ + --bits=$BITS \ + --n_retrieval_pages=$N_RETRIEVAL_PAGES \ + --data_name=$DATASET_NAME \ + --model_name_or_path=$BACKBONE_MODEL_NAME \ + --embedding_name=$EMBEDDING_NAME \ + --retrieval_model_name_or_path=$RETRIEVAL_MODEL_NAME \ + --retrieval_adapter_model_name_or_path=$RETRIEVAL_ADAPTER_MODEL_NAME \ + --output_dir=/job/eval_outputs/$OUTPUT_SAVE_NAME +``` + + +# Citation + +Please cite our paper if you use our dataset and/or method in your projects. + + +```bibtex +@article{Cho2024M3DocRAG, + author = {Jaemin Cho and Ozan İrsoy and Debanjan Mahata and Yujie He and Mohit Bansal}, + title = {M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding}, + year = {2024}, +} +``` diff --git a/src/m3docrag.egg-info/SOURCES.txt b/src/m3docrag.egg-info/SOURCES.txt new file mode 100644 index 0000000..83cd2eb --- /dev/null +++ b/src/m3docrag.egg-info/SOURCES.txt @@ -0,0 +1,31 @@ +README.md +pyproject.toml +src/m3docrag/__init__.py +src/m3docrag.egg-info/PKG-INFO +src/m3docrag.egg-info/SOURCES.txt +src/m3docrag.egg-info/dependency_links.txt +src/m3docrag.egg-info/requires.txt +src/m3docrag.egg-info/top_level.txt +src/m3docrag/datasets/__init__.py +src/m3docrag/datasets/m3_docvqa/__init__.py +src/m3docrag/datasets/m3_docvqa/common_utils.py +src/m3docrag/datasets/m3_docvqa/dataset.py +src/m3docrag/datasets/m3_docvqa/evaluate.py +src/m3docrag/rag/__init__.py +src/m3docrag/rag/base.py +src/m3docrag/rag/multimodal.py +src/m3docrag/rag/utils.py +src/m3docrag/retrieval/__init__.py +src/m3docrag/retrieval/colpali.py +src/m3docrag/utils/args.py +src/m3docrag/utils/distributed.py +src/m3docrag/utils/paths.py +src/m3docrag/utils/pdfs.py +src/m3docrag/utils/prompts.py +src/m3docrag/utils/tar.py +src/m3docrag/vqa/__init__.py +src/m3docrag/vqa/florence2.py +src/m3docrag/vqa/idefics2.py +src/m3docrag/vqa/idefics3.py +src/m3docrag/vqa/internvl2.py +src/m3docrag/vqa/qwen2.py \ No newline at end of file diff --git a/src/m3docrag.egg-info/dependency_links.txt b/src/m3docrag.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/m3docrag.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/m3docrag.egg-info/requires.txt b/src/m3docrag.egg-info/requires.txt new file mode 100644 index 0000000..2929b2b --- /dev/null +++ b/src/m3docrag.egg-info/requires.txt @@ -0,0 +1,28 @@ +accelerate==1.1.0 +loguru +requests +setuptools==69.5 +transformers +tokenizers +flash-attn==2.5.8 +bitsandbytes==0.43.1 +safetensors +gpustat +icecream +pdf2image +numpy==1.26.4 +torchvision +jsonlines +editdistance +einops +fire +peft +timm +sentencepiece +colpali-engine==0.3.1 +easyocr +qwen-vl-utils +faiss-cpu +word2number +datasets>=3.0.0 +python-dotenv diff --git a/src/m3docrag.egg-info/top_level.txt b/src/m3docrag.egg-info/top_level.txt new file mode 100644 index 0000000..33985ce --- /dev/null +++ b/src/m3docrag.egg-info/top_level.txt @@ -0,0 +1 @@ +m3docrag diff --git a/src/m3docrag/__init__.py b/src/m3docrag/__init__.py new file mode 100644 index 0000000..a4610ea --- /dev/null +++ b/src/m3docrag/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + diff --git a/src/m3docrag/datasets/__init__.py b/src/m3docrag/datasets/__init__.py new file mode 100644 index 0000000..a4610ea --- /dev/null +++ b/src/m3docrag/datasets/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + diff --git a/src/m3docrag/datasets/m3_docvqa/__init__.py b/src/m3docrag/datasets/m3_docvqa/__init__.py new file mode 100644 index 0000000..b572372 --- /dev/null +++ b/src/m3docrag/datasets/m3_docvqa/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from .dataset import M3DocVQADataset +from .evaluate import evaluate_predictions, evaluate_prediction_file \ No newline at end of file diff --git a/src/m3docrag/datasets/m3_docvqa/common_utils.py b/src/m3docrag/datasets/m3_docvqa/common_utils.py new file mode 100644 index 0000000..42da5d5 --- /dev/null +++ b/src/m3docrag/datasets/m3_docvqa/common_utils.py @@ -0,0 +1,142 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +# https://github.com/allenai/multimodalqa/blob/master/baselines/common_utils.py + +import json + + +ALL_QUESTION_TYPES = [ + 'TextQ', + 'TableQ', + 'ImageQ', + 'ImageListQ', + 'Compose(TableQ,ImageListQ)', + 'Compose(TextQ,ImageListQ)', + 'Compose(ImageQ,TableQ)', + 'Compose(ImageQ,TextQ)', + 'Compose(TextQ,TableQ)', + 'Compose(TableQ,TextQ)', + 'Intersect(TableQ,TextQ)', + 'Intersect(ImageListQ,TableQ)', + 'Intersect(ImageListQ,TextQ)', + 'Compare(Compose(TableQ,ImageQ),TableQ)', + 'Compare(Compose(TableQ,ImageQ),Compose(TableQ,TextQ))', + 'Compare(TableQ,Compose(TableQ,TextQ))', +] + +TEXT_SINGLE_HOP_QUESTION_TYPES = [ + 'TextQ', +] +TEXT_AS_FIRST_HOP_QUESTION_TYPES = [ + 'Compare(TableQ,Compose(TableQ,TextQ))', + 'Compose(ImageQ,TextQ)', + 'Compose(TableQ,TextQ)', + 'Intersect(TableQ,TextQ)', + 'Intersect(ImageListQ,TextQ)', +] +TEXT_AS_SECOND_HOP_QUESTION_TYPES = [ + 'Compare(Compose(TableQ,ImageQ),Compose(TableQ,TextQ))', + 'Compose(TextQ,ImageListQ)', + 'Compose(TextQ,TableQ)', +] + +TABLE_SINGLE_HOP_QUESTION_TYPES = [ + "TableQ" +] +TABLE_AS_FIRST_HOP_QUESTION_TYPES = [ + 'Compose(ImageQ,TableQ)', + 'Compose(TextQ,TableQ)', +] +TABLE_AS_SECOND_HOP_QUESTION_TYPES = [ + 'Compare(Compose(TableQ,ImageQ),TableQ)', + 'Compare(TableQ,Compose(TableQ,TextQ))', + 'Compose(TableQ,ImageListQ)', + 'Compose(TableQ,TextQ)', + 'Intersect(ImageListQ,TableQ)', + 'Intersect(TableQ,TextQ)', +] + +IMAGE_SINGLE_HOP_QUESTION_TYPES = [ + 'ImageQ', + 'ImageListQ' +] +IMAGE_AS_FIRST_HOP_QUESTION_TYPES = [ + 'Compare(Compose(TableQ,ImageQ),Compose(TableQ,TextQ))', + 'Compare(Compose(TableQ,ImageQ),TableQ)', + 'Compose(TableQ,ImageListQ)', + 'Compose(TextQ,ImageListQ)', + 'Intersect(ImageListQ,TableQ)', +] +IMAGE_AS_SECOND_HOP_QUESTION_TYPES = [ + 'Compose(ImageQ,TableQ)', + 'Compose(ImageQ,TextQ)', + 'Intersect(ImageListQ,TextQ)', +] + + +# every question should be answered either as a single hop question, or two-hop question +assert set(TEXT_SINGLE_HOP_QUESTION_TYPES + TEXT_AS_SECOND_HOP_QUESTION_TYPES + + TABLE_SINGLE_HOP_QUESTION_TYPES + TABLE_AS_SECOND_HOP_QUESTION_TYPES + + IMAGE_SINGLE_HOP_QUESTION_TYPES + IMAGE_AS_SECOND_HOP_QUESTION_TYPES) == set(ALL_QUESTION_TYPES) +assert len(set(TEXT_SINGLE_HOP_QUESTION_TYPES) & set(TEXT_AS_SECOND_HOP_QUESTION_TYPES)) == 0 +assert len(set(TABLE_SINGLE_HOP_QUESTION_TYPES) & set(TABLE_AS_SECOND_HOP_QUESTION_TYPES)) == 0 +assert len(set(IMAGE_SINGLE_HOP_QUESTION_TYPES) & set(IMAGE_AS_SECOND_HOP_QUESTION_TYPES)) == 0 + +SINGLE_HOP_QUESTION_TYPES = TEXT_SINGLE_HOP_QUESTION_TYPES \ + + TABLE_SINGLE_HOP_QUESTION_TYPES \ + + IMAGE_SINGLE_HOP_QUESTION_TYPES +MULTI_HOP_QUESTION_TYPES = TEXT_AS_SECOND_HOP_QUESTION_TYPES \ + + TABLE_AS_SECOND_HOP_QUESTION_TYPES + \ + IMAGE_AS_SECOND_HOP_QUESTION_TYPES +# no duplicated multi-hop question types +assert len(MULTI_HOP_QUESTION_TYPES) == len(set(MULTI_HOP_QUESTION_TYPES)) +# no duplication for the first hop +assert set(TEXT_AS_FIRST_HOP_QUESTION_TYPES + TABLE_AS_FIRST_HOP_QUESTION_TYPES + IMAGE_AS_FIRST_HOP_QUESTION_TYPES) \ + == set(MULTI_HOP_QUESTION_TYPES) +# single + multi = all +assert set(SINGLE_HOP_QUESTION_TYPES + MULTI_HOP_QUESTION_TYPES) == set(ALL_QUESTION_TYPES) + + +def process_question_for_implicit_decomp(question, question_type, hop=0, bridge_entity='', sep_token='[SEP]'): + if isinstance(bridge_entity, list) or isinstance(bridge_entity, set): + bridge_entity = "; ".join(bridge_entity) + return ( + f'{question_type} {sep_token} ' + f'HOP={hop} {sep_token} ' + f'{bridge_entity} {sep_token} ' + f'{question}') + + +def extract_numbers_from_str(s): + numbers = [] + for token in s.split(): + try: + num = int(token.replace(",", "")) + except: + try: + num = float(token) + except: + num = None + if num: + numbers.append(num) + return numbers + + +def read_jsonl(filename): + with open(filename, 'r') as f: + data = [json.loads(l.strip()) for l in f.readlines()] + return data \ No newline at end of file diff --git a/src/m3docrag/datasets/m3_docvqa/dataset.py b/src/m3docrag/datasets/m3_docvqa/dataset.py new file mode 100644 index 0000000..c63ddd8 --- /dev/null +++ b/src/m3docrag/datasets/m3_docvqa/dataset.py @@ -0,0 +1,144 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from datasets import load_dataset +from pathlib import Path +import torch +import safetensors +import json +import jsonlines +from copy import deepcopy +from tqdm.auto import tqdm +import PIL +from typing import List +from loguru import logger + +from m3docrag.utils.paths import LOCAL_DATA_DIR, LOCAL_EMBEDDINGS_DIR +from m3docrag.utils.pdfs import get_images_from_pdf + +class M3DocVQADataset(torch.utils.data.Dataset): + def __init__(self, args): + + self.args = args + + # e.g., /job/datasets/m3-docvqa + local_data_dir = Path(LOCAL_DATA_DIR) / args.data_name + + pdf_dir = local_data_dir / "splits" / f'pdfs_{args.split}' + assert pdf_dir.exists(), pdf_dir + self.pdf_dir = pdf_dir + + multimodalqa_data_dir = local_data_dir / "multimodalqa" + + mmqa_data_path = multimodalqa_data_dir / f"MMQA_{args.split}.jsonl" + self.mmqa_data_path = mmqa_data_path + + data = [] + with jsonlines.open(mmqa_data_path) as reader: + for i, obj in enumerate(reader): + data.append(obj) + logger.info(f"# Data {len(data)}") + self.data = data + + split_supporting_doc_ids_path = local_data_dir / f"{args.split}_doc_ids.json" + with open(split_supporting_doc_ids_path, 'r') as f: + all_supporting_doc_ids = json.load(open(split_supporting_doc_ids_path)) + # dev: 3366 + # train: 24162 + logger.info(f"# supporting doc ids in split {args.split}: {len(all_supporting_doc_ids)}") + self.all_supporting_doc_ids = all_supporting_doc_ids + + def __len__(self): + if self.args.loop_unique_doc_ids: + return len(self.all_supporting_doc_ids) + + if self.args.data_len is not None: + return self.args.data_len + + return len(self.data) + + def load_all_embeddings(self): + """Load all doc embeddings in memory""" + + emb_dir = Path(LOCAL_EMBEDDINGS_DIR) / self.args.embedding_name + + logger.info(f"Loading all doc embeddings from {emb_dir}") + + docid2embs = {} + docid2lens = {} + + for idx in tqdm(range(len(self.all_supporting_doc_ids))): + + doc_id = self.all_supporting_doc_ids[idx] + emb_path = Path(LOCAL_EMBEDDINGS_DIR) / self.args.embedding_name / f"{doc_id}.safetensors" + assert emb_path.exists(), emb_path + + if self.args.retrieval_model_type == 'colpali': + + with safetensors.safe_open(emb_path, framework="pt", device='cpu') as f: + + # [n_pages, n_tokens, dim] + doc_embs = f.get_tensor('embeddings') + + docid2embs[doc_id] = doc_embs.bfloat16() + + if self.args.retrieval_model_type == 'colpali': + return docid2embs + elif self.args.retrieval_model_type == 'colbert': + return docid2embs, docid2lens + + + def get_images_from_doc_id(self, doc_id: str) -> List[PIL.Image.Image]: + pdf_path = self.pdf_dir / f"{doc_id}.pdf" + page_images = get_images_from_pdf(pdf_path) + return page_images + + + def __getitem__(self, idx): + if self.args.loop_unique_doc_ids: + doc_id = self.all_supporting_doc_ids[idx] + + datum = { + 'doc_id': doc_id, + } + + if self.args.retrieval_model_type == 'colpali': + page_images = self.get_images_from_doc_id(doc_id) + datum['images'] = page_images + + return datum + + # keys(['qid', 'question', 'answers', 'metadata', 'supporting_context']) + datum = deepcopy(self.data[idx]) + + supporting_doc_ids = [] + for obj in datum['supporting_context']: + + supporting_doc_ids.append(obj['doc_id']) + datum['supporting_doc_ids'] = supporting_doc_ids + + return datum + + + +if __name__ == '__main__': + from m3docrag.utils.args import _example_args, parse_args + + args = parse_args(_example_args) + + dataset = M3DocVQADataset( + args=args + ) \ No newline at end of file diff --git a/src/m3docrag/datasets/m3_docvqa/evaluate.py b/src/m3docrag/datasets/m3_docvqa/evaluate.py new file mode 100644 index 0000000..7f3a6a2 --- /dev/null +++ b/src/m3docrag/datasets/m3_docvqa/evaluate.py @@ -0,0 +1,414 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +# https://github.com/allenai/multimodalqa/blob/master/baselines/evaluate.py + +import json +import argparse +import re +import string +import numpy as np +from collections import Counter +from typing import List, Set, Tuple, Union +from scipy.optimize import linear_sum_assignment +from word2number.w2n import word_to_num +from .common_utils import * + +# From here through _match_numbers_if_present was originally copied from the evaluation code of DROP dataset: +# https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py + +def _remove_articles(text: str) -> str: + regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) + return re.sub(regex, " ", text) + + +def _white_space_fix(text: str) -> str: + return " ".join(text.split()) + + +EXCLUDE = set(string.punctuation) + + +def _remove_punc(text: str) -> str: + if not _is_number(text): + return "".join(ch for ch in text if ch not in EXCLUDE) + else: + return text + + +def _lower(text: str) -> str: + return text.lower() + + +def _tokenize(text: str) -> List[str]: + return re.split(" |-", text) + + +def _normalize_answer(text: str) -> str: + """Lower text and remove punctuation, articles and extra whitespace.""" + + parts = [ + _white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token))))) + for token in _tokenize(text) + ] + parts = [part for part in parts if part.strip()] + normalized = " ".join(parts).strip() + return normalized + + +def _is_number(text: str) -> bool: + try: + float(text) + return True + except ValueError: + return False + + +def _is_word_number(text: str) -> bool: + try: + word_to_num(text) + return True + except ValueError: + return False + + +def _normalize_number(text: str) -> str: + if _is_number(text): + return str(float(text)) + #TODO: this is not included in the original drop evaluation script, we need to have our own in the end anyways. + elif _is_word_number(text): + return str(float(word_to_num(text))) + else: + return text + + +def _answer_to_bags( + answer: Union[str, List[str], Tuple[str, ...]] +) -> Tuple[List[str], List[Set[str]]]: + if isinstance(answer, (list, tuple)): + raw_spans = answer + else: + raw_spans = [answer] + normalized_spans: List[str] = [] + token_bags = [] + for raw_span in raw_spans: + normalized_span = _normalize_answer(raw_span) + normalized_spans.append(normalized_span) + token_bags.append(set(normalized_span.split())) + return normalized_spans, token_bags + + +def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]: + """ + Takes gold and predicted answer sets and first finds the optimal 1-1 alignment + between them and gets maximum metric values over all the answers. + """ + scores = np.zeros([len(gold), len(predicted)]) + for gold_index, gold_item in enumerate(gold): + for pred_index, pred_item in enumerate(predicted): + if _match_numbers_if_present(gold_item, pred_item): + scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item) + row_ind, col_ind = linear_sum_assignment(-scores) + + max_scores = np.zeros([max(len(gold), len(predicted))]) + for row, column in zip(row_ind, col_ind): + max_scores[row] = max(max_scores[row], scores[row, column]) + return max_scores + + +def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float: + intersection = len(gold_bag.intersection(predicted_bag)) + if not predicted_bag: + precision = 1.0 + else: + precision = intersection / float(len(predicted_bag)) + if not gold_bag: + recall = 1.0 + else: + recall = intersection / float(len(gold_bag)) + f1 = ( + (2 * precision * recall) / (precision + recall) + if not (precision == 0.0 and recall == 0.0) + else 0.0 + ) + return f1 + + +def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool: + gold_numbers = set() + predicted_numbers = set() + for word in gold_bag: + if _is_number(word): + gold_numbers.add(word) + for word in predicted_bag: + if _is_number(word): + predicted_numbers.add(word) + if (not gold_numbers) or gold_numbers.intersection(predicted_numbers): + return True + return False + + + +def list_em(predicted, gold): + predicted_bags = _answer_to_bags(predicted) + gold_bags = _answer_to_bags(gold) + if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]): + return 1.0 + else: + return 0.0 + + +def list_f1(predicted, gold): + predicted_bags = _answer_to_bags(predicted) + gold_bags = _answer_to_bags(gold) + f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1]) + f1 = np.mean(f1_per_bag) + f1 = round(f1, 2) + return f1 + + +def metric_max_over_ground_truths(metric_fn, prediction, gold_answers): + scores_for_ground_truths = [] + for gold_answer in gold_answers: + score = metric_fn(prediction, gold_answer) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def evaluate_predictions(predictions, gold_answers, example_types=None): + """To support multiple gold annotations, `gold_answers` should be a list, + with each item (either a string or a list) corresponding to one valid reference answer.""" + instance_eval_results = {} + instance_eval_results_by_types = {} + eval_funcs = { + "list_em": list_em, + "list_f1": list_f1 + } + for qas_id in gold_answers: + ref_answers = gold_answers[qas_id] + if qas_id not in predictions: + print(f"Missing prediction for question {qas_id}, and all scores for this question are set to zero") + instance_eval_results[qas_id] = { + metric: 0.0 for metric in eval_funcs.keys() + } + else: + pred_answer = predictions[qas_id] + instance_eval_results[qas_id] = { + metric: metric_max_over_ground_truths( + func, pred_answer, ref_answers + ) for metric, func in eval_funcs.items() + } + if example_types is not None: + example_type = example_types[qas_id] + if example_type not in instance_eval_results_by_types: + instance_eval_results_by_types[example_type] = {} + instance_eval_results_by_types[example_type][qas_id] = instance_eval_results[qas_id] + + eval_scores = {metric: np.mean([result[metric] for result in instance_eval_results.values()]) * 100 + for metric in eval_funcs.keys()} + + if example_types is not None: + eval_scores_by_types = {} + for example_type, type_instance_eval_results in instance_eval_results_by_types.items(): + eval_scores_by_types[example_type] = { + metric: np.mean([result[metric] for result in type_instance_eval_results.values()]) * 100 for metric in eval_funcs.keys() + } + return eval_scores, instance_eval_results, eval_scores_by_types + else: + return eval_scores, instance_eval_results + + +def eval_retrieval(qid2retrieval_results, gold_examples, recall_levels=[1, 2, 4, 5, 10]): + """ + Calculate document recall metrics per query, including recall@k for various k values, + and compute the average recall. + + Args: + qid2retrieval_results (dict): Dictionary mapping query IDs (qid) to retrieval results. + The retrieval results are a list of lists, where each sublist contains + [doc_id, page, score], e.g., [['53e1d168375850acbaf134da744ffae0', 3, 2.3689956665039062]]. + gold_examples (list): List of ground truth examples with 'qid' and 'supporting_context'. + Each 'supporting_context' is a list of dictionaries, + e.g., [{'doc_id': '8513db80c11ea439ab11eba406ec00d9', 'doc_part': 'table'}]. + recall_levels (list): List of k values for recall@k, e.g., [1, 5, 10]. + + Returns: + dict: Dictionary containing recall per query at different recall levels, + and the average recall for each level. + """ + + # Create a mapping from query ID to the supporting context (gold standard) + qid2supporting_context = { + example["qid"]: example['supporting_context'] for example in gold_examples + } + + # Initialize dictionaries to store recall at different levels + qid2recall_at_k = {k: {} for k in recall_levels} + avg_recall_at_k = {k: 0.0 for k in recall_levels} + + # Loop through each query and its supporting context + for qid, supporting_context in qid2supporting_context.items(): + # Get the retrieved documents for the current query + retrieved_results = qid2retrieval_results.get(qid, []) + + # Extract the relevant doc_ids from supporting_context + relevant_doc_ids = {ctx['doc_id'] for ctx in supporting_context} + + # Count total relevant documents for the query + n_relevant = len(relevant_doc_ids) + + # For each recall level, calculate how many relevant documents were retrieved + for k in recall_levels: + # Get the top-k results (or fewer if there are fewer results) + top_k_results = retrieved_results[:k] + + # Ensure doc_id uniqueness in top_k_results + top_k_doc_ids = set(result[0] for result in top_k_results) + + # Count how many relevant documents are in the top-k results + n_relevant_retrieved_at_k = len(top_k_doc_ids.intersection(relevant_doc_ids)) + + # Calculate recall@k for the current query + recall_at_k = n_relevant_retrieved_at_k / n_relevant if n_relevant > 0 else 0.0 + + # Ensure recall is between 0 and 1 + recall_at_k = min(recall_at_k, 1.0) + + # Store recall@k for the current query + qid2recall_at_k[k][qid] = recall_at_k + + # Calculate average recall@k across all queries + for k in recall_levels: + avg_recall_at_k[k] = sum(qid2recall_at_k[k].values()) / len(qid2recall_at_k[k]) if qid2recall_at_k[k] else 0.0 + + # Return recall@k for each query and average recall@k + return { + "recall_per_qid_at_k": qid2recall_at_k, + "average_recall_at_k": avg_recall_at_k + } + + +def evaluate_prediction_file(prediction_path, gold_path='/job/datasets/m3-docvqa/MMQA_dev.jsonl'): + # predicted_answers = json.load(open(prediction_path, encoding="utf-8")) + + if isinstance(prediction_path, dict): + raw_prediction_json = prediction_path + else: + raw_prediction_json = json.load(open(prediction_path, encoding="utf-8")) + examples = read_jsonl(gold_path) + gold_answers, answer_modalities, hop_types, question_types = {}, {}, {}, {} + + qid2predicted_answers = {} + qid2retrieval_results = {} + for qid, data in raw_prediction_json.items(): + qid2predicted_answers[qid] = data['pred_answer'].strip() + qid2retrieval_results[qid] = data['page_retrieval_results'] + + eval_retrieval_results = eval_retrieval(qid2retrieval_results, examples) + print('Average recall at K') + print(eval_retrieval_results['average_recall_at_k']) + + predicted_answers = qid2predicted_answers + + + all_scores = {} + all_scores['overall'] = {} + all_scores['modalities'] = {} + all_scores['hop_types'] = {} + all_scores['q_types'] = {} + + all_scores['average_recall_at_k'] = eval_retrieval_results['average_recall_at_k'] + + for example in examples: + qid = example["qid"] + # Currently we only have one ground truth answer. + # Even if there are multiple entries in example["answers"], the whole list should be regarded as one ref answer. + # However, our script supports evaluation with multiple ref answers. + # So, we will use an outer bracket here to pretend we have a list of ref answers. + gold_answer = [str(item["answer"]) for item in example["answers"]] + gold_answers[qid] = [gold_answer] + answer_modality = set([item["modality"] for item in example["answers"]]) + assert len(answer_modality) == 1 + answer_modalities[qid] = answer_modality.pop() + question_types[qid] = example["metadata"]["type"] + hop_types[qid] = "Multi-hop" if example["metadata"]["type"] in MULTI_HOP_QUESTION_TYPES else "Single-hop" + + eval_scores, instance_eval_results = evaluate_predictions(predicted_answers, gold_answers) + print("\n\nOverall result with different metrics: ") + for metric, value in eval_scores.items(): + print(f"{metric}: {value}") + + all_scores['overall'][metric] = value + + modality_counts = Counter(answer_modalities.values()) + _, _, eval_scores_by_modalities = \ + evaluate_predictions(predicted_answers, gold_answers, answer_modalities) + print("\n\nEval results for different modalities:") + for answer_modality in sorted(eval_scores_by_modalities.keys()): + result = eval_scores_by_modalities[answer_modality] + print(f"{answer_modality}") + print(f"# of examples: {modality_counts[answer_modality]}") + + all_scores['modalities'][answer_modality] = {} + + for metric, value in result.items(): + print(f"{metric}: {value}") + + all_scores['modalities'][answer_modality][metric] = value + + hop_type_counts = Counter(hop_types.values()) + _, _, eval_scores_by_hop_types = evaluate_predictions(predicted_answers, gold_answers, hop_types) + print("\n\nType\tCount\tEM\tF1") + for hop_type in sorted(eval_scores_by_hop_types.keys()): + result = eval_scores_by_hop_types[hop_type] + print(f"{hop_type}\t{hop_type_counts[hop_type]}\t{result['list_em']}\t{result['list_f1']}") + + all_scores['hop_types'][hop_type] = {} + + all_scores['hop_types'][hop_type][metric] = value + + question_type_counts = Counter(question_types.values()) + _, _, eval_scores_by_qtypes = evaluate_predictions(predicted_answers, gold_answers, question_types) + print("\n\nType\tCount\tEM\tF1") + for question_type in sorted(eval_scores_by_qtypes.keys()): + result = eval_scores_by_qtypes[question_type] + print(f"{question_type}\t{question_type_counts[question_type]}\t{result['list_em']}\t{result['list_f1']}") + + all_scores['q_types'][question_type] = {} + + all_scores['q_types'][question_type][metric] = value + + + # return eval_scores + return all_scores + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="evaluate on drop dataset") + parser.add_argument( + "--prediction_path", + type=str, + default="predictions.json", + help="location of the prediction file", + ) + parser.add_argument( + "--gold_path", + type=str, + default="dev.json", + help="location of the gold file", + ) + args = parser.parse_args() + evaluate_prediction_file(args.prediction_path, args.gold_path) \ No newline at end of file diff --git a/src/m3docrag/rag/__init__.py b/src/m3docrag/rag/__init__.py new file mode 100644 index 0000000..5cca15a --- /dev/null +++ b/src/m3docrag/rag/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + + +from .multimodal import MultimodalRAGModel +from .utils import reduce_embeddings \ No newline at end of file diff --git a/src/m3docrag/rag/base.py b/src/m3docrag/rag/base.py new file mode 100644 index 0000000..46fa7c2 --- /dev/null +++ b/src/m3docrag/rag/base.py @@ -0,0 +1,177 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + +from typing import List, Dict, Tuple +from tqdm.auto import tqdm +import torch +import numpy as np + +from .utils import get_top_k_pages, get_top_k_pages_single_page_from_each_doc + +class RAGModelBase: + def __init__( + self, + retrieval_model=None, + qa_model=None, + vqa_model=None, + ): + """Base class for RAG pipeline + + - retrieval_model: arbitrary retrieval model (e.g., ColPali / ColBERT) + - qa_model: arbitrary text-only QA model (e.g., LLama3) + - vqa_model: arbitrary VQA model (e.g., InternVL2, GPT4-o) + + """ + self.retrieval_model = retrieval_model + self.qa_model = qa_model + self.vqa_model = vqa_model + + if self.retrieval_model is not None: + self.retrieval_model.model.eval() + if self.vqa_model is not None: + self.vqa_model.model.eval() + if self.qa_model is not None: + self.qa_model.model.eval() + + + def retrieve_pages_from_docs( + self, + query: str, + docid2embs: Dict[str, torch.Tensor], + docid2lens: Dict[str, torch.Tensor] = None, + + index = None, + token2pageuid = None, + all_token_embeddings = None, + + n_return_pages: int = 1, + single_page_from_each_doc: bool = False, + show_progress=False, + ) -> List[Tuple]: + """ + Given text query and pre-extracted document embedding, + calculate similarity scores and return top-n pages + + Args: + - query (str): a text query to call retrieval model + - docid2embs (Dict[str, tensor]): collection of document embeddings + key: document_id + value: torch.tensor of size (n_tokens, emb_dim) + - index: faiss index + - n_return_pages (int): number of pages to return + - single_page_from_each_doc (bool): if true, only single page is retrieved from each PDF document. + + Return: + retrieval_results + [(doc_id, page_idx, scores)...] + """ + + + if index is not None: + + # [n_query_tokens, dim] + query_emb = self.retrieval_model.encode_queries([query])[0] + query_emb = query_emb.cpu().float().numpy().astype(np.float32) + + # NN search + k = n_return_pages + D, I = index.search(query_emb, k) + + # Sum the MaxSim scores across all query tokens for each document + final_page2scores = {} + + # Iterate over query tokens + for q_idx, query_emb in enumerate(query_emb): + + # Initialize a dictionary to hold document relevance scores + curent_q_page2scores = {} + + for nn_idx in range(k): + found_nearest_doc_token_idx = I[q_idx, nn_idx] + + page_uid = token2pageuid[found_nearest_doc_token_idx] # Get the document ID for this token + + # reconstruct the original score + doc_token_emb = all_token_embeddings[found_nearest_doc_token_idx] + score = (query_emb * doc_token_emb).sum() + + # MaxSim: aggregate the highest similarity score for each query token per document + if page_uid not in curent_q_page2scores: + curent_q_page2scores[page_uid] = score + else: + curent_q_page2scores[page_uid] = max(curent_q_page2scores[page_uid], score) + + + for page_uid, score in curent_q_page2scores.items(): + if page_uid in final_page2scores: + final_page2scores[page_uid] += score + else: + final_page2scores[page_uid] = score + + # Sort documents by their final relevance score + sorted_pages = sorted(final_page2scores.items(), key=lambda x: x[1], reverse=True) + + # Get the top-k document candidates + top_k_pages = sorted_pages[:k] + + + # [(doc_id, page_idx, scores)...] + + sorted_results = [] + for page_uid, score in top_k_pages: + # logger.info(f"{page_uid} with score {score}") + + # page_uid = f"{doc_id}_page{page_id}" + doc_id = page_uid.split('_page')[0] + page_idx = int(page_uid.split('_page')[-1]) + sorted_results.append((doc_id, page_idx, score.item())) + + return sorted_results + + docid2scores = {} + for doc_id, doc_embs in tqdm( + docid2embs.items(), + total=len(docid2embs), + disable = not show_progress, + desc=f"Calculating similarity score over documents" + ): + doc_lens = None + if docid2lens is not None: + doc_lens = docid2lens[doc_id] + + scores = self.retrieval_model.retrieve( + query=query, + doc_embeds=doc_embs, + doc_lens=doc_lens, + to_cpu=True, + return_top_1=False + ) + scores = scores.flatten().tolist() + docid2scores[doc_id] = scores + + # find the pages with top scores + if single_page_from_each_doc: + return get_top_k_pages_single_page_from_each_doc(docid2scores=docid2scores, k=n_return_pages) + else: + return get_top_k_pages(docid2scores=docid2scores, k=n_return_pages) + + + def run_qa(self): + raise NotImplementedError + + def run_vqa(self): + raise NotImplementedError \ No newline at end of file diff --git a/src/m3docrag/rag/multimodal.py b/src/m3docrag/rag/multimodal.py new file mode 100644 index 0000000..2f31377 --- /dev/null +++ b/src/m3docrag/rag/multimodal.py @@ -0,0 +1,51 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + + +from .base import RAGModelBase + +import torch + +from m3docrag.vqa import VQAModel +from m3docrag.retrieval import ColPaliRetrievalModel + + +class MultimodalRAGModel(RAGModelBase): + def __init__( + self, + retrieval_model: ColPaliRetrievalModel, + vqa_model: VQAModel = None + ): + self.retrieval_model = retrieval_model + self.vqa_model = vqa_model + + self.retrieval_model.model.eval() + + if self.vqa_model is not None and isinstance(self.vqa_model.model, torch.nn.Module): + self.vqa_model.model.eval() + + + def run_vqa( + self, + images, + question, + ) -> str: + + response = self.vqa_model.generate(images=images, question=question) + assert isinstance(response, str), type(response) + + return response \ No newline at end of file diff --git a/src/m3docrag/rag/utils.py b/src/m3docrag/rag/utils.py new file mode 100644 index 0000000..6bad6ee --- /dev/null +++ b/src/m3docrag/rag/utils.py @@ -0,0 +1,128 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + +from tqdm.auto import tqdm + + +def reduce_embeddings(docid2embs, dim='page', show_progress=True): + """Summarize document embedding by reducing (averaging) specific dimensions + + Input embedding: + [n_pages, n_tokens, emb_dim] + + Output embedding: + + - reduction_dim == 'page' + [1, n_tokens, emb_dim] + + - reduction_dim == 'token' + [1, n_pages, emb_dim] + + - reduction_dim == 'page_token' + [1, 1, emb_dim] + """ + + assert dim in ['page', 'token', 'page_token'], f"{dim}" + + new_docid2embs = {} + + for doc_id in tqdm( + list(docid2embs.keys()), + disable=not show_progress + ): + # [n_pages, n_tokens, dim] + embs = docid2embs[doc_id] + + emb_dim = embs.size(-1) + + if dim == 'page': + # [n_tokens, dim] + new_embs = embs.mean(dim=0) + elif dim == 'token': + # [n_pages, dim] + new_embs = embs.mean(dim=1) + elif dim == 'page_token': + # [1, dim] + new_embs = embs.mean(dim=0).mean(dim=0) + + new_docid2embs[doc_id] = new_embs.view(1, -1, emb_dim) + + return new_docid2embs + + +def get_top_k_pages(docid2scores: dict, k: int): + """ + # Example usage: + docid2scores = { + "doc1": [10, 50, 30], + "doc2": [40, 20, 60], + "doc3": [70, 90] + } + + k = 3 + top_k_pages = get_top_k_pages(docid2scores, k) + print(top_k_pages) + + -> [('doc3', 1, 90), ('doc3', 0, 70), ('doc2', 2, 60)] + """ + # Flatten the dictionary into a list of tuples (doc_id, page_index, score) + flattened_scores = [ + (doc_id, page_index, score) + for doc_id, scores in docid2scores.items() + for page_index, score in enumerate(scores) + ] + + # Sort by score in descending order + flattened_scores.sort(key=lambda x: x[2], reverse=True) + + # Get the top-k entries + top_k_pages = flattened_scores[:k] + + return top_k_pages + + +def get_top_k_pages_single_page_from_each_doc(docid2scores: dict, k: int): + """ + # Example usage: + docid2scores = { + "doc1": [10, 50, 30], + "doc2": [40, 20, 60], + "doc3": [70, 90] + } + + k = 2 + top_k_pages = get_top_k_pages_single_page_from_each_doc(docid2scores, k) + print(top_k_pages) + + -> [('doc3', 1, 90), ('doc2', 2, 60)] + """ + # First, get the highest scoring page for each document + highest_per_doc = [ + (doc_id, max(enumerate(scores), key=lambda x: x[1])) # (doc_id, (page_index, score)) + for doc_id, scores in docid2scores.items() + ] + + # Flatten the structure to (doc_id, page_index, score) + highest_per_doc_flat = [(doc_id, page_index, score) for doc_id, (page_index, score) in highest_per_doc] + + # Sort by score in descending order + highest_per_doc_flat.sort(key=lambda x: x[2], reverse=True) + + # Get the top-k entries + top_k_pages = highest_per_doc_flat[:k] + + return top_k_pages diff --git a/src/m3docrag/retrieval/__init__.py b/src/m3docrag/retrieval/__init__.py new file mode 100644 index 0000000..58d7b68 --- /dev/null +++ b/src/m3docrag/retrieval/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + +from .colpali import ColPaliRetrievalModel \ No newline at end of file diff --git a/src/m3docrag/retrieval/colpali.py b/src/m3docrag/retrieval/colpali.py new file mode 100644 index 0000000..c60d131 --- /dev/null +++ b/src/m3docrag/retrieval/colpali.py @@ -0,0 +1,303 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +# from transformers import AutoProcessor +from PIL import Image +from typing import List + +from colpali_engine.models import ColPali, ColPaliProcessor +from colpali_engine.models import ColQwen2, ColQwen2Processor + +def init( + backbone_name_or_path="/job/model/colpaligemma-3b-pt-448-base", + adapter_name_or_path= "/job/model/colpali-v1.2", + dtype=torch.bfloat16, +): + """ + Load ColPali Model and Processor from (locally downloaded) HF checkpoint. + + Args: + - backbone_model_name_or_path: downloaded from https://huggingface.co/vidore/colpaligemma-3b-pt-448-base + - adapter_name_or_path: downloaded from https://huggingface.co/vidore/colpali-v1.2 + Return: + - model + - processor + """ + + kwargs = {} + model_class = ColPali + processor_class = ColPaliProcessor + if 'colqwen' in str(adapter_name_or_path): + model_class = ColQwen2 + processor_class = ColQwen2Processor + kwargs['attn_implementation'] = "flash_attention_2" + + model = model_class.from_pretrained( + backbone_name_or_path, + torch_dtype=dtype, + low_cpu_mem_usage=True, + **kwargs + ).eval() + + model.load_adapter(adapter_name_or_path) + processor = processor_class.from_pretrained(adapter_name_or_path) + + return model, processor + + +def encode_images( + model, + processor, + images: List[Image.Image], + batch_size: int = 4, + to_cpu: bool = False, + use_tqdm: bool = False, + collate_fn=None, + return_doclens: bool = False + ): + """Create document embeddings with ColPali + + Args: + model + processor + images (List[Image.Image]) + (n_pages) + batch_size (int, optional): + batch size. Defaults to 4. + to_cpu (bool, optional): + whether to save embeddings in cpu tensors. Defaults to False. + use_tqdm (bool, optional): + whether to show tqdm progress bar. Defaults to False. + collate_fn (_type_, optional): + custom collate_fn for document dataloader. Defaults to None. + return_doclens (bool, optional): + whether to output the number of pages. Defaults to False. + + Returns: + doc_embs: List[torch.tensor] + visual embedding of documents (n_pages, n_tokens, n_dimension) + (optional) doclens + number of pages + """ + + if collate_fn is None: + collate_fn = processor.process_images + + # run inference - docs + dataloader = DataLoader( + images, + batch_size=batch_size, + shuffle=False, + collate_fn=collate_fn, + ) + + doc_embs = [] + if return_doclens: + doclens = [] + if use_tqdm: + dataloader = tqdm(dataloader) + for batch_doc in dataloader: + with torch.no_grad(): + batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()} + embeddings_doc = model(**batch_doc) + if to_cpu: + embeddings_doc = embeddings_doc.to("cpu") + if return_doclens: + _doclens = batch_doc.attention_mask.squeeze(-1).sum(-1).tolist() + doclens.extend(_doclens) + doc_embs.extend(list(torch.unbind(embeddings_doc))) + + if return_doclens: + doc_embs, doclens + else: + return doc_embs + + +def encode_queries( + model, + processor, + queries: List[str], + batch_size: int = 4, + to_cpu: bool = False, + use_tqdm: bool = False, + collate_fn=None, + ): + """Create query embeddings with ColPali + + Args: + model + processor + queries (List[str]): + text queries (n_queries,) + batch_size (int, optional): + batch size. Defaults to 4. + to_cpu (bool, optional): + whether to save embeddings in cpu tensors. Defaults to False. + use_tqdm (bool, optional): + whether to show tqdm progress bar. Defaults to False. + collate_fn (_type_, optional): + custom collate_fn for document dataloader. Defaults to None. + Returns: + query_embs: List[torch.tensor] + embedding of queries (n_queries, n_tokens, n_dimension) + """ + + if collate_fn is None: + collate_fn = processor.process_queries + + # run inference - queries + dataloader = DataLoader( + queries, + batch_size=batch_size, + shuffle=False, + collate_fn=collate_fn + ) + + query_embs = [] + if use_tqdm: + dataloader = tqdm(dataloader) + for batch_query in dataloader: + with torch.no_grad(): + batch_query = {k: v.to(model.device) for k, v in batch_query.items()} + embeddings_query = model(**batch_query) + if to_cpu: + embeddings_query = embeddings_query.to("cpu") + query_embs.extend(list(torch.unbind(embeddings_query))) + return query_embs + + +def retrieve( + model, + processor, + docs=None, + query=None, + doc_embeds=None, + query_embeds=None, + to_cpu=False, + batch_size=1, + use_tqdm=False, + return_top_1=True +): + """Find the right document image with colpali + """ + if doc_embeds is None: + doc_embeds = encode_images( + model, processor, + images=docs, + batch_size=batch_size, + use_tqdm=use_tqdm, + to_cpu=to_cpu, + ) + + if query_embeds is None: + query_embeds = encode_queries( + model, processor, + queries=[query], + batch_size=1, + use_tqdm=use_tqdm, + to_cpu=to_cpu, + ) + + qs = query_embeds + ds = doc_embeds + + qs = [q.to(ds[0].dtype) for q in qs] + + scores = processor.score_multi_vector(qs, ds) + + if return_top_1: + return scores.argmax(axis=1) + else: + return scores + + +class ColPaliRetrievalModel: + + def __init__(self, + backbone_name_or_path="/job/model/colpaligemma-3b-pt-448-base", + adapter_name_or_path= "/job/model/colpali-v1.2", + dtype=torch.bfloat16, + ): + model, processor = init(backbone_name_or_path=backbone_name_or_path, + adapter_name_or_path=adapter_name_or_path, + dtype=dtype, + ) + self.model = model.eval() + self.processor = processor + + def encode_queries(self, + queries: List[str], + batch_size: int = 4, + to_cpu: bool = False, + use_tqdm: bool = False, + collate_fn=None + ): + return encode_queries( + model=self.model, + processor=self.processor, + queries=queries, + batch_size=batch_size, + to_cpu=to_cpu, + use_tqdm=use_tqdm, + collate_fn=collate_fn) + + + def encode_images(self, + images: List[Image.Image], + batch_size: int = 4, + to_cpu: bool = False, + use_tqdm: bool = False, + collate_fn=None, + return_doclens: bool = False + ): + return encode_images( + model=self.model, + processor=self.processor, + images=images, + batch_size=batch_size, + to_cpu=to_cpu, + use_tqdm=use_tqdm, + collate_fn=collate_fn, + return_doclens=return_doclens, + ) + + def retrieve(self, + docs=None, + query=None, + doc_embeds=None, + doc_lens=None, + query_embeds=None, + to_cpu=False, + batch_size=1, + use_tqdm=False, + return_top_1=True + ): + + return retrieve( + model=self.model, + processor=self.processor, + docs=docs, + query=query, + doc_embeds=doc_embeds, + query_embeds=query_embeds, + to_cpu=to_cpu, + batch_size=batch_size, + use_tqdm=use_tqdm, + return_top_1=return_top_1 + ) \ No newline at end of file diff --git a/src/m3docrag/utils/args.py b/src/m3docrag/utils/args.py new file mode 100644 index 0000000..468a011 --- /dev/null +++ b/src/m3docrag/utils/args.py @@ -0,0 +1,75 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field +from typing import Dict, Optional, Sequence, List +import transformers +from icecream import ic + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + + # Data settings + split: str = 'train' + data_name: str = field(default='m3-docvqa', metadata={"help": "Local name to be stored at LOCAL_DATA_DIR"}) + data_len: int = field(default=None, metadata={"help": "number of examples to subsample from dataset"}) + use_dummy_images: bool = field(default=False, metadata={"help": "if true, skip downloading images"}) + load_embedding: bool = False + embedding_name: str = "colpali-v1.2_m3-docvqa_dev" + + max_pages: int = 20 + do_page_padding: bool = False + + # Retrieval settings + retrieval_model_type: str = field(default='colpali', metadata={"choices": ['colpali', 'colbert']}) + use_retrieval: bool = True + retrieval_only: bool = field(default=False, metadata={"help": "not running stage 2 (VQA)"}) + page_retrieval_type: str = 'logits' + loop_unique_doc_ids: bool = field(default=False, metadata={"help": "if true, apply retrieval only on unique doc ids"}) + + n_retrieval_pages: int = 1 + + + # Embedding indexing settings + faiss_index_type: str = field(default='ivfflat', metadata={"choices": ['flatip', 'ivfflat', 'ivfpq']}) + + + # Local paths + model_name_or_path: Optional[str] = field(default="Qwen2-VL-7B-Instruct") + retrieval_model_name_or_path: Optional[str] = field(default="colpaligemma-3b-pt-448-base") + retrieval_adapter_model_name_or_path: Optional[str] = field(default="colpali-v1.2") + + # Model settings + bits: int = field(default=16, metadata={"help": "Floating point precision. Use '4' for 4-bit quantization to save memory"}) + + # idefics2 settings + do_image_splitting: bool = False + + +_example_arg_str = """ +--output_dir=/job/outputs +--data_name=m3-docvqa +--use_retrieval=True +""" +_example_args = _example_arg_str.strip().split('\n') + + +def parse_args(args=None): + parser = transformers.HfArgumentParser(TrainingArguments) + parsed_args, remaining_args = parser.parse_args_into_dataclasses(args, return_remaining_strings=True) + ic(remaining_args) + + return parsed_args \ No newline at end of file diff --git a/src/m3docrag/utils/distributed.py b/src/m3docrag/utils/distributed.py new file mode 100644 index 0000000..16ea50d --- /dev/null +++ b/src/m3docrag/utils/distributed.py @@ -0,0 +1,209 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Utility functions to manage multi-device training +""" + +import inspect +import logging +import os +import platform +import subprocess +import sys + +# import pkg_resources +import importlib.metadata +import torch.cuda +import torch.distributed +from loguru import logger + + +def world_size() -> int: + """Returns the total number of processes in a distributed job (num_nodes x gpus_per_node). + Returns 1 in a non-distributed job. + """ + return int(os.environ.get("WORLD_SIZE", "1")) + + +def is_distributed() -> bool: + """Returns True iff this is a distributed job (more than one process).""" + return world_size() > 1 + + +def local_rank() -> int: + """Returns the local rank of the current process in a distributed job. + Returns 0 (local primary) for non-distributed jobs. + """ + return int(os.environ.get("LOCAL_RANK", "0")) + + +def global_rank() -> int: + """Returns the global rank of the current process in a distributed job. + Returns 0 (global primary) for non-distributed jobs. + """ + return int(os.environ.get("RANK", local_rank())) + + +def barrier(): + """Synchronizes all processes. Set GPU with local_rank to perform barrier used by this process.""" + torch.distributed.barrier(device_ids=[local_rank()]) + + +def patch_module_loggers(module_namespace): + modules = inspect.getmembers(module_namespace, predicate=inspect.ismodule) + + # check toplevel module + if hasattr(module_namespace, "logger"): + module_namespace.logger = logger + logger.info(f"Patching logger: {module_namespace.__name__}") + + for _, mod in modules: + if hasattr(mod, "logger"): + mod.logger = logger + logger.info(f"Patching logger: {mod.__name__}") + + +class InterceptLogHandler(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + # Get corresponding Loguru level if it exists. + level: str | int + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + # Find caller from where originated the logged message. + frame, depth = inspect.currentframe(), 0 + while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__): + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + + +def configure_distributed_logging( + rank_zero_level="INFO", + rank_non_zero_level="WARNING", +): + """This method configures logger to reduce noise in multi-node, multi-process evaluations (e.g. DeepSpeed)_summary_ + Args: + rank_zero_level (str, optional): Log level on zero rank process. Defaults to "INFO". + rank_non_zero_level (str, optional): Log level on non-zero rank processes. Defaults to "WARNING". + """ + + logger.remove() + rank = local_rank() + format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level} | rank={extra[rank]} | {name}:{function}:{line} - {message}\n{exception}" + if rank != 0: + logger.configure( + extra={"rank": f"{global_rank()}:{rank}"}, + handlers=[ + {"sink": sys.stderr, "format": format, "level": rank_non_zero_level} + ], + ) + else: + logger.configure( + extra={"rank": f"{global_rank()}:{rank}"}, + handlers=[{"sink": sys.stdout, "format": format, "level": rank_zero_level}], + ) + + # Attempt to intercept normal logging in libs + logging.basicConfig(handlers=[InterceptLogHandler()], level=0, force=True) + + +def get_cuda_version(): + """Get the installed CUDA version.""" + try: + output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8") + version_line = output.strip().split("\n")[-1] + version = version_line.split(" ")[-1] + return version + except Exception as e: + logger.info(f"Cannot detect CUDA version. Exception occured: {e}") + return "N/A" + + +def log_runtime_info(): + # Get Python runtime information + python_version = sys.version + python_implementation = platform.python_implementation() + python_build = platform.python_build() + + # Get environment variables + env_vars = os.environ + + # Get installed package versions + installed_packages = [ + # (d.project_name, d.version) for d in pkg_resources.working_set + (d.metadata["Name"], d.version) for d in importlib.metadata.distributions() + ] + + + # logger.info diagnostics + logger.info("Python Version: {}".format(python_version)) + logger.info("Python Implementation: {}".format(python_implementation)) + logger.info("Python Build: {}".format(python_build)) + + logger.info(f"Environment Variables: {env_vars}") + logger.info(f"Installed Packages: {installed_packages}") + + logger.info(f"CUDA version: {get_cuda_version()}") + logger.info(f"Is CUDA available for Torch?: {torch.cuda.is_available()}") + + logger.info(f"World size: {world_size()}") + + +def local_rank_zero(func): + """ + Decorator to execute function only in local zero rank. Can be useful for logging statistics. + """ + + def wrapper(*args, **kwargs): + if local_rank() == 0: + func(*args, **kwargs) + + return wrapper + + +def global_rank_zero(func): + """ + Decorator to execute function only in global zero rank. Can be useful for logging statistics. + """ + + def wrapper(*args, **kwargs): + if global_rank() == 0: + func(*args, **kwargs) + + return wrapper + + + +import gpustat +def print_gpu_stats(): + gpustat.cli.main([]) + +def supports_flash_attention(device_id=0): + """Check if a GPU supports FlashAttention.""" + major, minor = torch.cuda.get_device_capability(device_id) + + # Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0) + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + + return is_sm8x or is_sm90 \ No newline at end of file diff --git a/src/m3docrag/utils/paths.py b/src/m3docrag/utils/paths.py new file mode 100644 index 0000000..2b3a6b0 --- /dev/null +++ b/src/m3docrag/utils/paths.py @@ -0,0 +1,24 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from dotenv import load_dotenv +load_dotenv() + +LOCAL_DATA_DIR = os.getenv("LOCAL_DATA_DIR", "/job/datasets") +LOCAL_EMBEDDINGS_DIR = os.getenv("LOCAL_EMBEDDINGS_DIR", "/job/embeddings") +LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/job/model") +LOCAL_OUTPUT_DIR = os.getenv("LOCAL_OUTPUT_DIR", "/job/output") \ No newline at end of file diff --git a/src/m3docrag/utils/pdfs.py b/src/m3docrag/utils/pdfs.py new file mode 100644 index 0000000..def4856 --- /dev/null +++ b/src/m3docrag/utils/pdfs.py @@ -0,0 +1,75 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + +from pathlib import Path +from loguru import logger +from pdf2image import convert_from_path +from collections import Counter + + +def get_images_from_pdf(pdf_path, max_pages=None, dpi_resolution=144, save_dir='/tmp/', save_image=False, save_type='png', verbose=False): + pdf_path = Path(pdf_path) + assert pdf_path.exists(), f"{pdf_path} does not exist" + + pdf_fname = pdf_path.name + + images = convert_from_path(pdf_path, dpi=dpi_resolution) + + # PIL.PpmImagePlugin.PpmImageFile -> PIL.Image.Image + images = [img.convert('RGB') for img in images] + + # resizing to the most common image size so that we can stack in pytorch tensor + # PDFs (e.g., MMLongBench-Doc) have different image sizes + # width=1,224 and height=1,584 + # 1440, 810 + # 1191, 1684 + # 1440, 1080 + # 1536, 1152 + + # 1) find the most common image size + img_size_counter = Counter() + for img in images: + img_size_counter[img.size] += 1 + common_img_size, common_img_size_count = img_size_counter.most_common(1)[0] + + # 2) if pages have different sizes -> resize all pages to that image size + if len(images) != common_img_size_count: + logger.info(f"total: {len(images)} pages") + logger.info(f"resizing to the most common image size: {common_img_size} with count: {common_img_size_count}") + images = [img.resize(common_img_size) for img in images] + + if save_image: + save_dir = Path(save_dir) + save_dir.mkdir(exist_ok=True, parents=True) + + for page_index, page_image in enumerate(images): + save_page_path = save_dir / f"{pdf_fname}_{page_index+1}.{save_type}" + if not save_page_path.exists(): + page_image.save(save_page_path) + if verbose: + logger.info(f"Page {page_index} saved at {save_page_path}") + + return images + + + +if __name__ == '__main__': + get_images_from_pdf( + pdf_path="./multimodalqa_screenshots_pdfs/0df5cc80bcd2a27b91224d658ad3a7b5.pdf", + save_dir='./tmp/', + save_image=True + ) diff --git a/src/m3docrag/utils/prompts.py b/src/m3docrag/utils/prompts.py new file mode 100644 index 0000000..b16a903 --- /dev/null +++ b/src/m3docrag/utils/prompts.py @@ -0,0 +1,70 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import string + +binary_page_retrieval_template = """ +question: $question +Does this page have the answer to the question? +Answer only with yes or no. +""".strip() + +concat_page_retrieval_template = """ +question: $question +Which page is most relevant to the question? +""".strip() + +concat_page_retrieval_with_answer_template = """ +Find the page most relevant to the question and answer: "$question" +""".strip() + +concate_page_answer_template = """ +Find the answer to the question: "$question" +""".strip() + +short_answer_template = """ +question: $question +output only answer. +""".strip() + +long_answer_template = """ +question: $question +Answer the question with detailed explanation. +""".strip() + + +text_rag_template = """ +DOCUMENTS: +$documents + +QUESTION: +$question + +INSTRUCTIONS: +Answer the QUESTION using the DOCUMENTS text above. Simply output the answer only. + +Answer: +""" + + + +binary_page_retrieval_template = string.Template(binary_page_retrieval_template) +concat_page_retrieval_template = string.Template(concat_page_retrieval_template) +concat_page_retrieval_with_answer_template = string.Template(concat_page_retrieval_with_answer_template) +concate_page_answer_template = string.Template(concate_page_answer_template) +short_answer_template = string.Template(short_answer_template) +long_answer_template = string.Template(long_answer_template) +text_rag_template = string.Template(text_rag_template) \ No newline at end of file diff --git a/src/m3docrag/utils/tar.py b/src/m3docrag/utils/tar.py new file mode 100644 index 0000000..6e7021b --- /dev/null +++ b/src/m3docrag/utils/tar.py @@ -0,0 +1,32 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import tarfile +from loguru import logger + +def make_tarfile(source_dir, output_filename): + logger.info(f"Compressing {source_dir} to {output_filename} ...") + with tarfile.open(output_filename, "w:gz") as tar: + tar.add(source_dir, arcname='.') + logger.info(f"Compression done!") + + +def extract_tarfile(input_filename, target_dir): + logger.info(f"Extracting {input_filename} to {target_dir} ...") + with tarfile.open(input_filename) as f: + f.extractall(target_dir) + logger.info(f"Extraction done!") \ No newline at end of file diff --git a/src/m3docrag/vqa/__init__.py b/src/m3docrag/vqa/__init__.py new file mode 100644 index 0000000..83cb44a --- /dev/null +++ b/src/m3docrag/vqa/__init__.py @@ -0,0 +1,144 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + + +from typing import Union, List +from pathlib import Path +import torch + +from m3docrag.vqa import internvl2 +from m3docrag.vqa import idefics2 +from m3docrag.vqa import idefics3 +from m3docrag.vqa import florence2 +from m3docrag.vqa import qwen2 + +ALL_VQA_MODEL_TYPES = ['florence2', 'idefics2', 'internvl2', 'idefics3', 'qwen2'] + +def init( + model_name_or_path: Union[str, Path], + model_type: str, + **kwargs +): + + if 'internvl2' == model_type.lower(): + return internvl2.init( + model_name_or_path=model_name_or_path, + **kwargs + ) + elif 'idefics2' == model_type.lower(): + return idefics2.init( + model_name_or_path=model_name_or_path, + **kwargs + ) + elif 'idefics3' == model_type.lower(): + return idefics3.init( + model_name_or_path=model_name_or_path, + **kwargs + ) + elif 'florence2' == model_type.lower(): + return florence2.init( + model_name_or_path=model_name_or_path, + **kwargs + ) + elif 'qwen2' == model_type.lower(): + return qwen2.init( + model_name_or_path=model_name_or_path, + **kwargs + ) + else: + raise NotImplementedError(f"{model_type} is unsupported. Supported: {ALL_VQA_MODEL_TYPES}") + + +def generate( + model_type: str, + model, + processor, + **kwargs +) -> List[str]: + + if 'internvl2' == model_type.lower(): + return internvl2.generate( + model=model, + processor=processor, + **kwargs + ) + elif 'idefics2' == model_type.lower(): + return idefics2.generate( + model=model, + processor=processor, + **kwargs + ) + elif 'idefics3' == model_type.lower(): + return idefics3.generate( + model=model, + processor=processor, + **kwargs + ) + elif 'florence2' == model_type.lower(): + return florence2.generate( + model=model, + processor=processor, + **kwargs + ) + elif 'qwen2' == model_type.lower(): + return qwen2.generate( + model=model, + processor=processor, + **kwargs + ) + else: + raise NotImplementedError(f"{model_type} is unsupported. Supported: {ALL_VQA_MODEL_TYPES}") + + +class VQAModel: + + def __init__(self, model_name_or_path: Union[str, Path], model_type: str, **kwargs): + + model_loaded = init(model_name_or_path=model_name_or_path, model_type=model_type, **kwargs) + model = model_loaded['model'] + if 'tokenizer' in model_loaded: + processor = model_loaded['tokenizer'] + else: + processor = model_loaded['processor'] + + if isinstance(model, torch.nn.Module): + model = model.eval() + + # greedy decoding + if hasattr(model, 'generation_config'): + model.generation_config.temperature=None + model.generation_config.top_p=None + model.generation_config.top_k=None + + self.model = model + self.processor = processor + self.model_type = model_type + + def generate(self, images, question) -> str: + responses = generate( + model_type=self.model_type, + model=self.model, + processor=self.processor, + images=images, + question=question, + ) + assert isinstance(responses, list), responses + + out_text = responses[0] + out_text = out_text.strip() + + return out_text diff --git a/src/m3docrag/vqa/florence2.py b/src/m3docrag/vqa/florence2.py new file mode 100644 index 0000000..38eceea --- /dev/null +++ b/src/m3docrag/vqa/florence2.py @@ -0,0 +1,169 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch +from transformers import AutoProcessor, AutoModelForCausalLM +from PIL import Image + +from typing import Union, List + + +def init( + model_name_or_path, + dtype=torch.bfloat16, + # model_max_length=None, + **kwargs, +): + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + torch_dtype=dtype, + low_cpu_mem_usage=True, + trust_remote_code=True + ) + model.eval() + + processor = AutoProcessor.from_pretrained( + model_name_or_path, + trust_remote_code=True, + # model_max_length=model_max_length + ) + + return { + 'model': model, + 'processor': processor + } + + +def generate( + model, + processor, + question, + images +) -> List[str]: + + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + module = model.module + else: + module = model + answer = generate_caption(module, processor=processor, images=images, text_input=question) + return answer + + + +@torch.no_grad() +def generate_caption( + model, + processor, + images=None, + # task_prompt="", + text_input=None, + input_ids=None, + pixel_values=None, + max_new_tokens=77, + num_beams=1, + do_sample=False, + decode_text=True, + **generate_kwargs +): + + if input_ids is None and pixel_values is None: + if isinstance(images, Image.Image): + images = [images] + + B = len(images) + + if isinstance(text_input, str): + text_input = [text_input] * B + + inputs = processor( + text=text_input, + images=images, + return_tensors="pt") + + input_ids = inputs["input_ids"] + pixel_values = inputs["pixel_values"] + + p = next(iter(model.parameters())) + device = p.device + dtype = p.dtype + + generated_ids = model.generate( + input_ids=input_ids.to(device), + pixel_values=pixel_values.to(device, dtype), + max_new_tokens=max_new_tokens, + num_beams=num_beams, + do_sample=do_sample, + **generate_kwargs + ) + if decode_text: + out_text = decode_predictions(processor, generated_ids) + + return out_text + else: + return generated_ids + +def decode_predictions(processor, generated_ids): + B = len(generated_ids) + + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) + out_text = [] + for i in range(B): + out_text.append(generated_text[i].replace('', '').replace('', '').replace('', '').strip()) + return out_text + +# def load_model_from_ckpt( +# ckpt_dir, +# load_abstractor=False, +# num_input_tokens=576, +# num_query_tokens=64, +# proj_type='c-abs', +# projection_dim=1024, +# strict=False +# ): +# florence_config = Florence2Config.from_pretrained(ckpt_dir) + +# if load_abstractor: +# abstractor_config = AbstractorConfig( +# num_input_tokens=num_input_tokens, +# num_query_tokens=num_query_tokens, +# proj_type=proj_type, +# projection_dim=projection_dim, +# ) +# florence_config.abstractor_config = abstractor_config + +# florence_config.vision_config.model_type = 'davit' + +# model = Florence2ForConditionalGeneration(config=florence_config) + + +# ckpt_path = Path(ckpt_dir) / 'model.safetensors' +# if ckpt_path.exists(): +# logger.info(f"loading checkpoints from {ckpt_path}") +# state_dict = safetensors.torch.load_file(ckpt_path, device="cpu") + +# else: +# ckpt_path = Path(ckpt_dir) / 'pytorch_model.bin' +# logger.info(f"loading checkpoints from {ckpt_path}") +# state_dict = torch.load( +# ckpt_path, +# map_location="cpu", +# ) + +# load_result = model.load_state_dict(state_dict, strict=strict) +# logger.info(load_result) + +# return model + diff --git a/src/m3docrag/vqa/idefics2.py b/src/m3docrag/vqa/idefics2.py new file mode 100644 index 0000000..6aff2b9 --- /dev/null +++ b/src/m3docrag/vqa/idefics2.py @@ -0,0 +1,137 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + +import torch + +from transformers import Idefics2ForConditionalGeneration +from transformers import Idefics2Processor +from transformers import BitsAndBytesConfig + +from typing import Union, List + +def init( + model_name_or_path, + do_image_splitting=True, + bits=4, + dtype=torch.bfloat16, + **kwargs, +): + if bits == 4: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=dtype + ) + else: + bnb_config = None + model = Idefics2ForConditionalGeneration.from_pretrained( + model_name_or_path, + torch_dtype=dtype, + quantization_config=bnb_config, + low_cpu_mem_usage=True + ) + model.eval() + + processor = Idefics2Processor.from_pretrained( + model_name_or_path, + do_image_splitting=do_image_splitting, + # size= {"longest_edge": image_size, "shortest_edge": 378} + ) + + return { + 'model': model, + 'processor': processor + } + +def generate( + model, + processor, + question, + images +) -> List[str]: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + module = model.module + else: + module = model + + messages = idefics2_create_message(images=images, question=question) + + examples = [{'images': images, 'messages': messages}] + batch = idefics2_collate_fn(examples, processor) + + for k in batch: + batch[k] = batch[k].to(module.device) + + generated_ids = module.generate(**batch, max_new_tokens=50, do_sample=False) + answer = processor.batch_decode( + generated_ids[:, batch.input_ids.size(1):], + skip_special_tokens=True) + return answer + + + + +def idefics2_collate_fn(examples, + processor, + model_max_length=3600, + is_train=False, + image_token_id=None, + ): + + if image_token_id is None: + image_token_id = processor.tokenizer.additional_special_tokens_ids[processor.tokenizer.additional_special_tokens.index("")] + + texts = [] + images = [] + for example in examples: + prompt = processor.apply_chat_template(example['messages'], add_generation_prompt=not is_train) + texts.append(prompt) + images.append(example['images']) + + if is_train: + batch = processor(text=texts, images=images, return_tensors="pt", + padding=True, truncation=True, + max_length=model_max_length) + + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 + labels[labels == image_token_id] = -100 + batch["labels"] = labels + else: + batch = processor(text=texts, images=images, return_tensors="pt", + padding=True, + truncation=True, + max_length=model_max_length + ) + + return batch + + + +def idefics2_create_message(images, question, is_train=False, target_text=None): + content = [] + for page_i in range(len(images)): + # content += [{"type": "text", "text": f"page {page_i}: "}] + content += [{"type": "image"}] + content += [{"type": "text", "text": question}] + messages = [{"role": "user", "content": content}] + + if is_train: + messages += [{"role": "assistant", "content": [{"type": "text", "text": target_text}]}] + + return messages diff --git a/src/m3docrag/vqa/idefics3.py b/src/m3docrag/vqa/idefics3.py new file mode 100644 index 0000000..174024a --- /dev/null +++ b/src/m3docrag/vqa/idefics3.py @@ -0,0 +1,129 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + +import torch + +from transformers import AutoProcessor, AutoModelForVision2Seq +from transformers import BitsAndBytesConfig + +from typing import Union, List + +def init( + model_name_or_path, + bits=4, + dtype=torch.bfloat16, + **kwargs, +): + if bits == 4: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=dtype + ) + else: + bnb_config = None + + model = AutoModelForVision2Seq.from_pretrained( + model_name_or_path, + torch_dtype=dtype, + quantization_config=bnb_config, + low_cpu_mem_usage=True + ) + model.eval() + + processor = AutoProcessor.from_pretrained( + model_name_or_path, + ) + + return { + 'model': model, + 'processor': processor + } + +def generate( + model, + processor, + question, + images +) -> List[str]: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + module = model.module + else: + module = model + + messages = idefics3_create_message(images=images, question=question) + + examples = [{'images': images, 'messages': messages}] + batch = idefics3_collate_fn(examples, processor) + + for k in batch: + batch[k] = batch[k].to(module.device) + + generated_ids = module.generate(**batch, max_new_tokens=50, do_sample=False) + answer = processor.batch_decode( + generated_ids[:, batch.input_ids.size(1):], + skip_special_tokens=True) + return answer + + +def idefics3_collate_fn(examples, + processor, + is_train=False, + image_token_id=None, + ): + + texts = [] + images = [] + for example in examples: + prompt = processor.apply_chat_template(example['messages'], add_generation_prompt=not is_train) + texts.append(prompt) + images.append(example['images']) + + if is_train: + batch = processor(text=texts, images=images, return_tensors="pt", + padding=True, truncation=True, + # max_length=model_max_length + ) + + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 + labels[labels == image_token_id] = -100 + batch["labels"] = labels + else: + batch = processor(text=texts, images=images, return_tensors="pt", + padding=True, + truncation=True, + # max_length=model_max_length + ) + + return batch + + + +def idefics3_create_message(images, question, is_train=False, target_text=None): + content = [] + for page_i in range(len(images)): + # content += [{"type": "text", "text": f"page {page_i}: "}] + content += [{"type": "image"}] + content += [{"type": "text", "text": question}] + messages = [{"role": "user", "content": content}] + + if is_train: + messages += [{"role": "assistant", "content": [{"type": "text", "text": target_text}]}] + + return messages diff --git a/src/m3docrag/vqa/internvl2.py b/src/m3docrag/vqa/internvl2.py new file mode 100644 index 0000000..88590f4 --- /dev/null +++ b/src/m3docrag/vqa/internvl2.py @@ -0,0 +1,288 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch +import torchvision.transforms as T +from pathlib import Path +# from decord import VideoReader, cpu +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoModel, AutoTokenizer +from typing import Union, List + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + +def build_transform(input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + return transform + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + +def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + +def load_image(image_file, input_size=448, max_num=12): + + if isinstance(image_file, (str, Path)): + image = Image.open(image_file).convert('RGB') + elif isinstance(image_file, Image.Image): + image = image_file + + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + +# def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): +# if bound: +# start, end = bound[0], bound[1] +# else: +# start, end = -100000, 100000 +# start_idx = max(first_idx, round(start * fps)) +# end_idx = min(round(end * fps), max_frame) +# seg_size = float(end_idx - start_idx) / num_segments +# frame_indices = np.array([ +# int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) +# for idx in range(num_segments) +# ]) +# return frame_indices + +# def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32): +# vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) +# max_frame = len(vr) - 1 +# fps = float(vr.get_avg_fps()) + +# pixel_values_list, num_patches_list = [], [] +# transform = build_transform(input_size=input_size) +# frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) +# for frame_index in frame_indices: +# img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') +# img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num) +# pixel_values = [transform(tile) for tile in img] +# pixel_values = torch.stack(pixel_values) +# num_patches_list.append(pixel_values.shape[0]) +# pixel_values_list.append(pixel_values) +# pixel_values = torch.cat(pixel_values_list) +# return pixel_values, num_patches_list + + +def init( + model_name_or_path, + dtype=torch.bfloat16, + use_flash_attn=True, + **kwargs, +): + model = AutoModel.from_pretrained( + model_name_or_path, + torch_dtype=dtype, + low_cpu_mem_usage=True, + trust_remote_code=True, + use_flash_attn=use_flash_attn, + ) + model.eval() + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=False) + + return { + 'model': model, + # 'tokenizer': tokenizer + 'processor': tokenizer + } + +def generate( + model, + # tokenizer, + processor, + question, + images +) -> List[str]: + tokenizer = processor + + if len(images) == 1: + pixel_values = load_image( + images[0], + max_num=12).to(torch.bfloat16).to(model.device) + else: + raise NotImplementedError + + generation_config = dict(max_new_tokens=1024, do_sample=False) + + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + module = model.module + else: + module = model + + answer = module.chat(tokenizer, pixel_values, question, generation_config) + return [answer] + + +if __name__ == '__main__': + # If you want to load a model using multiple GPUs, please refer to the `Multiple GPUs` section. + path = 'OpenGVLab/InternVL2-8B' + model = AutoModel.from_pretrained( + path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True).eval().cuda() + tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) + + # set the max number of tiles in `max_num` + pixel_values = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda() + generation_config = dict(max_new_tokens=1024, do_sample=False) + + # pure-text conversation (纯文本对话) + question = 'Hello, who are you?' + response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True) + print(f'User: {question}\nAssistant: {response}') + + question = 'Can you tell me a story?' + response, history = model.chat(tokenizer, None, question, generation_config, history=history, return_history=True) + print(f'User: {question}\nAssistant: {response}') + + # single-image single-round conversation (单图单轮对话) + question = '\nPlease describe the image shortly.' + response = model.chat(tokenizer, pixel_values, question, generation_config) + print(f'User: {question}\nAssistant: {response}') + + # single-image multi-round conversation (单图多轮对话) + question = '\nPlease describe the image in detail.' + response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True) + print(f'User: {question}\nAssistant: {response}') + + question = 'Please write a poem according to the image.' + response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=history, return_history=True) + print(f'User: {question}\nAssistant: {response}') + + # multi-image multi-round conversation, combined images (多图多轮对话,拼接图像) + pixel_values1 = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda() + pixel_values2 = load_image('./examples/image2.jpg', max_num=12).to(torch.bfloat16).cuda() + pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0) + + question = '\nDescribe the two images in detail.' + response, history = model.chat(tokenizer, pixel_values, question, generation_config, + history=None, return_history=True) + print(f'User: {question}\nAssistant: {response}') + + question = 'What are the similarities and differences between these two images.' + response, history = model.chat(tokenizer, pixel_values, question, generation_config, + history=history, return_history=True) + print(f'User: {question}\nAssistant: {response}') + + # multi-image multi-round conversation, separate images (多图多轮对话,独立图像) + pixel_values1 = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda() + pixel_values2 = load_image('./examples/image2.jpg', max_num=12).to(torch.bfloat16).cuda() + pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0) + num_patches_list = [pixel_values1.size(0), pixel_values2.size(0)] + + question = 'Image-1: \nImage-2: \nDescribe the two images in detail.' + response, history = model.chat(tokenizer, pixel_values, question, generation_config, + num_patches_list=num_patches_list, + history=None, return_history=True) + print(f'User: {question}\nAssistant: {response}') + + question = 'What are the similarities and differences between these two images.' + response, history = model.chat(tokenizer, pixel_values, question, generation_config, + num_patches_list=num_patches_list, + history=history, return_history=True) + print(f'User: {question}\nAssistant: {response}') + + # batch inference, single image per sample (单图批处理) + pixel_values1 = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda() + pixel_values2 = load_image('./examples/image2.jpg', max_num=12).to(torch.bfloat16).cuda() + num_patches_list = [pixel_values1.size(0), pixel_values2.size(0)] + pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0) + + questions = ['\nDescribe the image in detail.'] * len(num_patches_list) + responses = model.batch_chat(tokenizer, pixel_values, + num_patches_list=num_patches_list, + questions=questions, + generation_config=generation_config) + for question, response in zip(questions, responses): + print(f'User: {question}\nAssistant: {response}') + + # video multi-round conversation (视频多轮对话) + + + video_path = './examples/red-panda.mp4' + pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1) + pixel_values = pixel_values.to(torch.bfloat16).cuda() + video_prefix = ''.join([f'Frame{i+1}: \n' for i in range(len(num_patches_list))]) + question = video_prefix + 'What is the red panda doing?' + # Frame1: \nFrame2: \n...\nFrame8: \n{question} + response, history = model.chat(tokenizer, pixel_values, question, generation_config, + num_patches_list=num_patches_list, history=None, return_history=True) + print(f'User: {question}\nAssistant: {response}') + + question = 'Describe this video in detail. Don\'t repeat.' + response, history = model.chat(tokenizer, pixel_values, question, generation_config, + num_patches_list=num_patches_list, history=history, return_history=True) + print(f'User: {question}\nAssistant: {response}') diff --git a/src/m3docrag/vqa/qwen2.py b/src/m3docrag/vqa/qwen2.py new file mode 100644 index 0000000..a37dd21 --- /dev/null +++ b/src/m3docrag/vqa/qwen2.py @@ -0,0 +1,100 @@ +# Copyright 2024 Bloomberg Finance L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from PIL import Image +import requests +import torch +from torchvision import io +from typing import Dict, List +from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor, BitsAndBytesConfig +from qwen_vl_utils import process_vision_info + + +def init( + model_name_or_path, + dtype=torch.bfloat16, + bits=16, + attn_implementation="flash_attention_2", + **kwargs, +): + if bits == 4: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=dtype + ) + else: + bnb_config = None + # Load the model in half-precision on the available device(s) + model = Qwen2VLForConditionalGeneration.from_pretrained( + model_name_or_path, + torch_dtype=dtype, + low_cpu_mem_usage=True, + attn_implementation=attn_implementation, + quantization_config=bnb_config, + vision_config={"torch_dtype": dtype} + ) + model.eval() + processor = AutoProcessor.from_pretrained(model_name_or_path) + + return { + 'model': model, + 'processor': processor + } + +def generate( + model, + processor, + question, + images +) -> List[str]: + + image_content = [{"type": "image", "image": "dummy_content"}] * len(images) + + messages = [ + { + "role": "user", + "content": image_content + [{"type": "text", "text": question}] + } + ] + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + # image_inputs, video_inputs = process_vision_info(messages) + inputs = processor( + text=[text], + images=images, + # videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + p = next(iter(model.parameters())) + + inputs = inputs.to(p.device) + + # Inference + generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_text = processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + assert isinstance(output_text, list), output_text + + return output_text \ No newline at end of file