Release commit
Signed-off-by: Stephen Augustus <saugustus2@bloomberg.net>
This commit is contained in:
215
src/m3docrag.egg-info/PKG-INFO
Normal file
215
src/m3docrag.egg-info/PKG-INFO
Normal file
@ -0,0 +1,215 @@
|
||||
Metadata-Version: 2.2
|
||||
Name: m3docrag
|
||||
Version: 0.0.1
|
||||
Summary: Multimodal Document Understanding with RAG
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Requires-Python: >=3.10
|
||||
Description-Content-Type: text/markdown
|
||||
Requires-Dist: accelerate==1.1.0
|
||||
Requires-Dist: loguru
|
||||
Requires-Dist: requests
|
||||
Requires-Dist: setuptools==69.5
|
||||
Requires-Dist: transformers
|
||||
Requires-Dist: tokenizers
|
||||
Requires-Dist: flash-attn==2.5.8
|
||||
Requires-Dist: bitsandbytes==0.43.1
|
||||
Requires-Dist: safetensors
|
||||
Requires-Dist: gpustat
|
||||
Requires-Dist: icecream
|
||||
Requires-Dist: pdf2image
|
||||
Requires-Dist: numpy==1.26.4
|
||||
Requires-Dist: torchvision
|
||||
Requires-Dist: jsonlines
|
||||
Requires-Dist: editdistance
|
||||
Requires-Dist: einops
|
||||
Requires-Dist: fire
|
||||
Requires-Dist: peft
|
||||
Requires-Dist: timm
|
||||
Requires-Dist: sentencepiece
|
||||
Requires-Dist: colpali-engine==0.3.1
|
||||
Requires-Dist: easyocr
|
||||
Requires-Dist: qwen-vl-utils
|
||||
Requires-Dist: faiss-cpu
|
||||
Requires-Dist: word2number
|
||||
Requires-Dist: datasets>=3.0.0
|
||||
Requires-Dist: python-dotenv
|
||||
|
||||
# M3DocRAG
|
||||
|
||||
Code for [M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding](https://m3docrag.github.io/)
|
||||
|
||||
by [Jaemin Cho](https://j-min.io/), [Debanjan Mahata](https://sites.google.com/a/ualr.edu/debanjan-mahata/), [Ozan İrsoy](https://wtimesx.com/), [Yujie He](https://scholar.google.com/citations?user=FbeAZGgAAAAJ&hl=en), [Mohit Bansal](https://www.cs.unc.edu/~mbansal/)
|
||||
|
||||
# Summary
|
||||
|
||||
## Comparison with previous approches
|
||||
|
||||
<img src='./assets/m3docrag_teaser.png' >
|
||||
|
||||
Comparison of multi-modal document understanding pipelines. Previous works focus on (a) **Single-page DocVQA** that cannot handle many long documents or (b) **Text-based RAG** that ignores visual information. Our (c) **M3DocRAG** framework retrieves relevant documents and answers questions using multi-modal retrieval and MLM components, so that it can efficiently handle many long documents while preserving visual information.
|
||||
|
||||
## M3DocRAG framework
|
||||
|
||||
<img src='./assets/method.png' >
|
||||
|
||||
Our **M3DocRAG** framework consists of three stages: (1) document embedding, (2) page retrieval, and (3) question answering.
|
||||
- In (1) document embedding, we extract visual embedding (with ColPali) to represent each page from all PDF documents.
|
||||
- In (2) page retrieval, we retrieve the top-K pages of high relevance (MaxSim scores) with text queries. In an open-domain setting, we create approximate page indices for faster search.
|
||||
- In (3) question answering, we conduct visual question answering with multi-modal LM (e.g. Qwen2-VL) to obtain the final answer.
|
||||
|
||||
|
||||
# Setup
|
||||
|
||||
## Package
|
||||
|
||||
We assume conda has been installed
|
||||
|
||||
```bash
|
||||
git clone <REPO_URL>
|
||||
cd m3docrag-release
|
||||
pip install -e .
|
||||
|
||||
# Install Poppler (for pdf2image; check https://pdf2image.readthedocs.io/en/latest/installation.html for details)
|
||||
# conda install -y poppler
|
||||
# or
|
||||
# apt-get install poppler-utils
|
||||
```
|
||||
|
||||
## Code structure
|
||||
|
||||
```bash
|
||||
examples/ # scripts to run PDF embedding / RAG
|
||||
src/m3docrag/
|
||||
datasets/ # data loader for existing datasets
|
||||
retrieval/ # retrieval model (e.g., ColPaLi)
|
||||
vqa/ # vqa model (e.g., Qwen2-VL)
|
||||
rag/ # RAG model that combines retrieval and vqa models
|
||||
utils/ # misc utility methods
|
||||
m3docvqa/ # how to setup m3docvqa dataset
|
||||
```
|
||||
## Paths: Data, Embeddings, Model checkpoints, Outputs
|
||||
|
||||
```bash
|
||||
# in .env
|
||||
LOCAL_DATA_DIR="/job/datasets" # where to store data
|
||||
LOCAL_EMBEDDINGS_DIR="/job/embeddings" # where to store embeddings
|
||||
LOCAL_MODEL_DIR="/job/model" # where to store model checkpoints
|
||||
LOCAL_OUTPUT_DIR="/job/output" # where to store model outputs
|
||||
```
|
||||
|
||||
You can adjust variables in [`.env`](.env) to change where to store data/embedding/model checkpoint/outputs by default. They are loaded in [`src/m3docrag/utils/paths.py`](./src/m3docrag/utils/paths.py) via [python-dotenv](https://github.com/theskumar/python-dotenv).
|
||||
|
||||
|
||||
## Download M3DocVQA dataset
|
||||
|
||||
Please see [m3docvqa/README.md](m3docvqa/README.md) for the download instruction.
|
||||
|
||||
## Donwload model checkpoints
|
||||
|
||||
By default, we use colpali-v1.2 for retrival and Qwen2-VL-7B-Instruct for question answering.
|
||||
|
||||
At `$LOCAL_MODEL_DIR`, download [colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2), [colpaligemma-3b-mix-448-base](https://huggingface.co/vidore/colpaligemma-3b-mix-448-base) and [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) checkpoints.
|
||||
|
||||
```bash
|
||||
cd $LOCAL_MODEL_DIR
|
||||
|
||||
git clone https://huggingface.co/vidore/colpaligemma-3b-pt-448-base # ColPali backbone
|
||||
git clone https://huggingface.co/vidore/colpali-v1.2 # ColPali adapter
|
||||
git clone https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct # VQA
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Example usage
|
||||
|
||||
Below we describe example usage of M3DocRAG on M3DocVQA dataset.
|
||||
|
||||
|
||||
## 1. Extract PDF embeddings
|
||||
|
||||
```bash
|
||||
DATASET_NAME="m3-docvqa"
|
||||
RETRIEVAL_MODEL_TYPE="colpali"
|
||||
RETRIEVAL_MODEL_NAME="colpaligemma-3b-pt-448-base"
|
||||
RETRIEVAL_ADAPTER_MODEL_NAME="colpali-v1.2"
|
||||
SPLIT="dev"
|
||||
EMBEDDING_NAME=$RETRIEVAL_ADAPTER_MODEL_NAME"_"$DATASET_NAME"_"$SPLIT # where to save embeddings
|
||||
accelerate launch --num_processes=1 --mixed_precision=bf16 examples/run_page_embedding.py \
|
||||
--use_retrieval \
|
||||
--retrieval_model_type=$RETRIEVAL_MODEL_TYPE \
|
||||
--data_name=$DATASET_NAME \
|
||||
--split=$SPLIT \
|
||||
--loop_unique_doc_ids=True \
|
||||
--output_dir=/job/embeddings/$EMBEDDING_NAME \
|
||||
--retrieval_model_name_or_path=$RETRIEVAL_MODEL_NAME \
|
||||
--retrieval_adapter_model_name_or_path=$RETRIEVAL_ADAPTER_MODEL_NAME
|
||||
```
|
||||
|
||||
## 2. Indexing
|
||||
|
||||
```bash
|
||||
DATASET_NAME="m3-docvqa"
|
||||
RETRIEVAL_MODEL_TYPE="colpali"
|
||||
RETRIEVAL_ADAPTER_MODEL_NAME="colpali-v1.2"
|
||||
SPLIT="dev"
|
||||
FAISS_INDEX_TYPE='ivfflat'
|
||||
EMBEDDING_NAME=$RETRIEVAL_ADAPTER_MODEL_NAME"_"$DATASET_NAME"_"$SPLIT
|
||||
INDEX_NAME=$EMBEDDING_NAME"_pageindex_"$FAISS_INDEX_TYPE # where to save resulting index
|
||||
echo $EMBEDDING_NAME
|
||||
echo $FAISS_INDEX_TYPE
|
||||
python examples/run_indexing_m3docvqa.py \
|
||||
--use_retrieval \
|
||||
--retrieval_model_type=$RETRIEVAL_MODEL_TYPE \
|
||||
--data_name=$DATASET_NAME \
|
||||
--split=$SPLIT \
|
||||
--loop_unique_doc_ids=False \
|
||||
--embedding_name=$EMBEDDING_NAME \
|
||||
--faiss_index_type=$FAISS_INDEX_TYPE \
|
||||
--output_dir=/job/embeddings/$INDEX_NAME
|
||||
```
|
||||
|
||||
## 3. RAG
|
||||
|
||||
```bash
|
||||
BACKBONE_MODEL_NAME="Qwen2-VL-7B-Instruct"
|
||||
RETRIEVAL_MODEL_TYPE="colpali"
|
||||
RETRIEVAL_MODEL_NAME="colpaligemma-3b-pt-448-base"
|
||||
RETRIEVAL_ADAPTER_MODEL_NAME="colpali-v1.2"
|
||||
EMBEDDING_NAME="colpali-v1.2_m3-docvqa_dev" # from Step 1 Embedding
|
||||
SPLIT="dev"
|
||||
DATASET_NAME="m3-docvqa"
|
||||
FAISS_INDEX_TYPE='ivfflat'
|
||||
N_RETRIEVAL_PAGES=1
|
||||
INDEX_NAME="${EMBEDDING_NAME}_pageindex_$FAISS_INDEX_TYPE" # from Step 2 Indexing
|
||||
OUTPUT_SAVE_NAME="${RETRIEVAL_ADAPTER_MODEL_NAME}_${BACKBONE_MODEL_NAME}_${DATASET_NAME}" # where to save RAG results
|
||||
BITS=16 # BITS=4 for 4-bit qunaitzation in low memory GPUs
|
||||
python examples/run_rag_m3docvqa.py \
|
||||
--use_retrieval \
|
||||
--retrieval_model_type=$RETRIEVAL_MODEL_TYPE \
|
||||
--load_embedding=True \
|
||||
--split=$SPLIT \
|
||||
--bits=$BITS \
|
||||
--n_retrieval_pages=$N_RETRIEVAL_PAGES \
|
||||
--data_name=$DATASET_NAME \
|
||||
--model_name_or_path=$BACKBONE_MODEL_NAME \
|
||||
--embedding_name=$EMBEDDING_NAME \
|
||||
--retrieval_model_name_or_path=$RETRIEVAL_MODEL_NAME \
|
||||
--retrieval_adapter_model_name_or_path=$RETRIEVAL_ADAPTER_MODEL_NAME \
|
||||
--output_dir=/job/eval_outputs/$OUTPUT_SAVE_NAME
|
||||
```
|
||||
|
||||
|
||||
# Citation
|
||||
|
||||
Please cite our paper if you use our dataset and/or method in your projects.
|
||||
|
||||
|
||||
```bibtex
|
||||
@article{Cho2024M3DocRAG,
|
||||
author = {Jaemin Cho and Ozan İrsoy and Debanjan Mahata and Yujie He and Mohit Bansal},
|
||||
title = {M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding},
|
||||
year = {2024},
|
||||
}
|
||||
```
|
31
src/m3docrag.egg-info/SOURCES.txt
Normal file
31
src/m3docrag.egg-info/SOURCES.txt
Normal file
@ -0,0 +1,31 @@
|
||||
README.md
|
||||
pyproject.toml
|
||||
src/m3docrag/__init__.py
|
||||
src/m3docrag.egg-info/PKG-INFO
|
||||
src/m3docrag.egg-info/SOURCES.txt
|
||||
src/m3docrag.egg-info/dependency_links.txt
|
||||
src/m3docrag.egg-info/requires.txt
|
||||
src/m3docrag.egg-info/top_level.txt
|
||||
src/m3docrag/datasets/__init__.py
|
||||
src/m3docrag/datasets/m3_docvqa/__init__.py
|
||||
src/m3docrag/datasets/m3_docvqa/common_utils.py
|
||||
src/m3docrag/datasets/m3_docvqa/dataset.py
|
||||
src/m3docrag/datasets/m3_docvqa/evaluate.py
|
||||
src/m3docrag/rag/__init__.py
|
||||
src/m3docrag/rag/base.py
|
||||
src/m3docrag/rag/multimodal.py
|
||||
src/m3docrag/rag/utils.py
|
||||
src/m3docrag/retrieval/__init__.py
|
||||
src/m3docrag/retrieval/colpali.py
|
||||
src/m3docrag/utils/args.py
|
||||
src/m3docrag/utils/distributed.py
|
||||
src/m3docrag/utils/paths.py
|
||||
src/m3docrag/utils/pdfs.py
|
||||
src/m3docrag/utils/prompts.py
|
||||
src/m3docrag/utils/tar.py
|
||||
src/m3docrag/vqa/__init__.py
|
||||
src/m3docrag/vqa/florence2.py
|
||||
src/m3docrag/vqa/idefics2.py
|
||||
src/m3docrag/vqa/idefics3.py
|
||||
src/m3docrag/vqa/internvl2.py
|
||||
src/m3docrag/vqa/qwen2.py
|
1
src/m3docrag.egg-info/dependency_links.txt
Normal file
1
src/m3docrag.egg-info/dependency_links.txt
Normal file
@ -0,0 +1 @@
|
||||
|
28
src/m3docrag.egg-info/requires.txt
Normal file
28
src/m3docrag.egg-info/requires.txt
Normal file
@ -0,0 +1,28 @@
|
||||
accelerate==1.1.0
|
||||
loguru
|
||||
requests
|
||||
setuptools==69.5
|
||||
transformers
|
||||
tokenizers
|
||||
flash-attn==2.5.8
|
||||
bitsandbytes==0.43.1
|
||||
safetensors
|
||||
gpustat
|
||||
icecream
|
||||
pdf2image
|
||||
numpy==1.26.4
|
||||
torchvision
|
||||
jsonlines
|
||||
editdistance
|
||||
einops
|
||||
fire
|
||||
peft
|
||||
timm
|
||||
sentencepiece
|
||||
colpali-engine==0.3.1
|
||||
easyocr
|
||||
qwen-vl-utils
|
||||
faiss-cpu
|
||||
word2number
|
||||
datasets>=3.0.0
|
||||
python-dotenv
|
1
src/m3docrag.egg-info/top_level.txt
Normal file
1
src/m3docrag.egg-info/top_level.txt
Normal file
@ -0,0 +1 @@
|
||||
m3docrag
|
16
src/m3docrag/__init__.py
Normal file
16
src/m3docrag/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
16
src/m3docrag/datasets/__init__.py
Normal file
16
src/m3docrag/datasets/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
18
src/m3docrag/datasets/m3_docvqa/__init__.py
Normal file
18
src/m3docrag/datasets/m3_docvqa/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .dataset import M3DocVQADataset
|
||||
from .evaluate import evaluate_predictions, evaluate_prediction_file
|
142
src/m3docrag/datasets/m3_docvqa/common_utils.py
Normal file
142
src/m3docrag/datasets/m3_docvqa/common_utils.py
Normal file
@ -0,0 +1,142 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# https://github.com/allenai/multimodalqa/blob/master/baselines/common_utils.py
|
||||
|
||||
import json
|
||||
|
||||
|
||||
ALL_QUESTION_TYPES = [
|
||||
'TextQ',
|
||||
'TableQ',
|
||||
'ImageQ',
|
||||
'ImageListQ',
|
||||
'Compose(TableQ,ImageListQ)',
|
||||
'Compose(TextQ,ImageListQ)',
|
||||
'Compose(ImageQ,TableQ)',
|
||||
'Compose(ImageQ,TextQ)',
|
||||
'Compose(TextQ,TableQ)',
|
||||
'Compose(TableQ,TextQ)',
|
||||
'Intersect(TableQ,TextQ)',
|
||||
'Intersect(ImageListQ,TableQ)',
|
||||
'Intersect(ImageListQ,TextQ)',
|
||||
'Compare(Compose(TableQ,ImageQ),TableQ)',
|
||||
'Compare(Compose(TableQ,ImageQ),Compose(TableQ,TextQ))',
|
||||
'Compare(TableQ,Compose(TableQ,TextQ))',
|
||||
]
|
||||
|
||||
TEXT_SINGLE_HOP_QUESTION_TYPES = [
|
||||
'TextQ',
|
||||
]
|
||||
TEXT_AS_FIRST_HOP_QUESTION_TYPES = [
|
||||
'Compare(TableQ,Compose(TableQ,TextQ))',
|
||||
'Compose(ImageQ,TextQ)',
|
||||
'Compose(TableQ,TextQ)',
|
||||
'Intersect(TableQ,TextQ)',
|
||||
'Intersect(ImageListQ,TextQ)',
|
||||
]
|
||||
TEXT_AS_SECOND_HOP_QUESTION_TYPES = [
|
||||
'Compare(Compose(TableQ,ImageQ),Compose(TableQ,TextQ))',
|
||||
'Compose(TextQ,ImageListQ)',
|
||||
'Compose(TextQ,TableQ)',
|
||||
]
|
||||
|
||||
TABLE_SINGLE_HOP_QUESTION_TYPES = [
|
||||
"TableQ"
|
||||
]
|
||||
TABLE_AS_FIRST_HOP_QUESTION_TYPES = [
|
||||
'Compose(ImageQ,TableQ)',
|
||||
'Compose(TextQ,TableQ)',
|
||||
]
|
||||
TABLE_AS_SECOND_HOP_QUESTION_TYPES = [
|
||||
'Compare(Compose(TableQ,ImageQ),TableQ)',
|
||||
'Compare(TableQ,Compose(TableQ,TextQ))',
|
||||
'Compose(TableQ,ImageListQ)',
|
||||
'Compose(TableQ,TextQ)',
|
||||
'Intersect(ImageListQ,TableQ)',
|
||||
'Intersect(TableQ,TextQ)',
|
||||
]
|
||||
|
||||
IMAGE_SINGLE_HOP_QUESTION_TYPES = [
|
||||
'ImageQ',
|
||||
'ImageListQ'
|
||||
]
|
||||
IMAGE_AS_FIRST_HOP_QUESTION_TYPES = [
|
||||
'Compare(Compose(TableQ,ImageQ),Compose(TableQ,TextQ))',
|
||||
'Compare(Compose(TableQ,ImageQ),TableQ)',
|
||||
'Compose(TableQ,ImageListQ)',
|
||||
'Compose(TextQ,ImageListQ)',
|
||||
'Intersect(ImageListQ,TableQ)',
|
||||
]
|
||||
IMAGE_AS_SECOND_HOP_QUESTION_TYPES = [
|
||||
'Compose(ImageQ,TableQ)',
|
||||
'Compose(ImageQ,TextQ)',
|
||||
'Intersect(ImageListQ,TextQ)',
|
||||
]
|
||||
|
||||
|
||||
# every question should be answered either as a single hop question, or two-hop question
|
||||
assert set(TEXT_SINGLE_HOP_QUESTION_TYPES + TEXT_AS_SECOND_HOP_QUESTION_TYPES
|
||||
+ TABLE_SINGLE_HOP_QUESTION_TYPES + TABLE_AS_SECOND_HOP_QUESTION_TYPES
|
||||
+ IMAGE_SINGLE_HOP_QUESTION_TYPES + IMAGE_AS_SECOND_HOP_QUESTION_TYPES) == set(ALL_QUESTION_TYPES)
|
||||
assert len(set(TEXT_SINGLE_HOP_QUESTION_TYPES) & set(TEXT_AS_SECOND_HOP_QUESTION_TYPES)) == 0
|
||||
assert len(set(TABLE_SINGLE_HOP_QUESTION_TYPES) & set(TABLE_AS_SECOND_HOP_QUESTION_TYPES)) == 0
|
||||
assert len(set(IMAGE_SINGLE_HOP_QUESTION_TYPES) & set(IMAGE_AS_SECOND_HOP_QUESTION_TYPES)) == 0
|
||||
|
||||
SINGLE_HOP_QUESTION_TYPES = TEXT_SINGLE_HOP_QUESTION_TYPES \
|
||||
+ TABLE_SINGLE_HOP_QUESTION_TYPES \
|
||||
+ IMAGE_SINGLE_HOP_QUESTION_TYPES
|
||||
MULTI_HOP_QUESTION_TYPES = TEXT_AS_SECOND_HOP_QUESTION_TYPES \
|
||||
+ TABLE_AS_SECOND_HOP_QUESTION_TYPES + \
|
||||
IMAGE_AS_SECOND_HOP_QUESTION_TYPES
|
||||
# no duplicated multi-hop question types
|
||||
assert len(MULTI_HOP_QUESTION_TYPES) == len(set(MULTI_HOP_QUESTION_TYPES))
|
||||
# no duplication for the first hop
|
||||
assert set(TEXT_AS_FIRST_HOP_QUESTION_TYPES + TABLE_AS_FIRST_HOP_QUESTION_TYPES + IMAGE_AS_FIRST_HOP_QUESTION_TYPES) \
|
||||
== set(MULTI_HOP_QUESTION_TYPES)
|
||||
# single + multi = all
|
||||
assert set(SINGLE_HOP_QUESTION_TYPES + MULTI_HOP_QUESTION_TYPES) == set(ALL_QUESTION_TYPES)
|
||||
|
||||
|
||||
def process_question_for_implicit_decomp(question, question_type, hop=0, bridge_entity='', sep_token='[SEP]'):
|
||||
if isinstance(bridge_entity, list) or isinstance(bridge_entity, set):
|
||||
bridge_entity = "; ".join(bridge_entity)
|
||||
return (
|
||||
f'{question_type} {sep_token} '
|
||||
f'HOP={hop} {sep_token} '
|
||||
f'{bridge_entity} {sep_token} '
|
||||
f'{question}')
|
||||
|
||||
|
||||
def extract_numbers_from_str(s):
|
||||
numbers = []
|
||||
for token in s.split():
|
||||
try:
|
||||
num = int(token.replace(",", ""))
|
||||
except:
|
||||
try:
|
||||
num = float(token)
|
||||
except:
|
||||
num = None
|
||||
if num:
|
||||
numbers.append(num)
|
||||
return numbers
|
||||
|
||||
|
||||
def read_jsonl(filename):
|
||||
with open(filename, 'r') as f:
|
||||
data = [json.loads(l.strip()) for l in f.readlines()]
|
||||
return data
|
144
src/m3docrag/datasets/m3_docvqa/dataset.py
Normal file
144
src/m3docrag/datasets/m3_docvqa/dataset.py
Normal file
@ -0,0 +1,144 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from datasets import load_dataset
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import safetensors
|
||||
import json
|
||||
import jsonlines
|
||||
from copy import deepcopy
|
||||
from tqdm.auto import tqdm
|
||||
import PIL
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
|
||||
from m3docrag.utils.paths import LOCAL_DATA_DIR, LOCAL_EMBEDDINGS_DIR
|
||||
from m3docrag.utils.pdfs import get_images_from_pdf
|
||||
|
||||
class M3DocVQADataset(torch.utils.data.Dataset):
|
||||
def __init__(self, args):
|
||||
|
||||
self.args = args
|
||||
|
||||
# e.g., /job/datasets/m3-docvqa
|
||||
local_data_dir = Path(LOCAL_DATA_DIR) / args.data_name
|
||||
|
||||
pdf_dir = local_data_dir / "splits" / f'pdfs_{args.split}'
|
||||
assert pdf_dir.exists(), pdf_dir
|
||||
self.pdf_dir = pdf_dir
|
||||
|
||||
multimodalqa_data_dir = local_data_dir / "multimodalqa"
|
||||
|
||||
mmqa_data_path = multimodalqa_data_dir / f"MMQA_{args.split}.jsonl"
|
||||
self.mmqa_data_path = mmqa_data_path
|
||||
|
||||
data = []
|
||||
with jsonlines.open(mmqa_data_path) as reader:
|
||||
for i, obj in enumerate(reader):
|
||||
data.append(obj)
|
||||
logger.info(f"# Data {len(data)}")
|
||||
self.data = data
|
||||
|
||||
split_supporting_doc_ids_path = local_data_dir / f"{args.split}_doc_ids.json"
|
||||
with open(split_supporting_doc_ids_path, 'r') as f:
|
||||
all_supporting_doc_ids = json.load(open(split_supporting_doc_ids_path))
|
||||
# dev: 3366
|
||||
# train: 24162
|
||||
logger.info(f"# supporting doc ids in split {args.split}: {len(all_supporting_doc_ids)}")
|
||||
self.all_supporting_doc_ids = all_supporting_doc_ids
|
||||
|
||||
def __len__(self):
|
||||
if self.args.loop_unique_doc_ids:
|
||||
return len(self.all_supporting_doc_ids)
|
||||
|
||||
if self.args.data_len is not None:
|
||||
return self.args.data_len
|
||||
|
||||
return len(self.data)
|
||||
|
||||
def load_all_embeddings(self):
|
||||
"""Load all doc embeddings in memory"""
|
||||
|
||||
emb_dir = Path(LOCAL_EMBEDDINGS_DIR) / self.args.embedding_name
|
||||
|
||||
logger.info(f"Loading all doc embeddings from {emb_dir}")
|
||||
|
||||
docid2embs = {}
|
||||
docid2lens = {}
|
||||
|
||||
for idx in tqdm(range(len(self.all_supporting_doc_ids))):
|
||||
|
||||
doc_id = self.all_supporting_doc_ids[idx]
|
||||
emb_path = Path(LOCAL_EMBEDDINGS_DIR) / self.args.embedding_name / f"{doc_id}.safetensors"
|
||||
assert emb_path.exists(), emb_path
|
||||
|
||||
if self.args.retrieval_model_type == 'colpali':
|
||||
|
||||
with safetensors.safe_open(emb_path, framework="pt", device='cpu') as f:
|
||||
|
||||
# [n_pages, n_tokens, dim]
|
||||
doc_embs = f.get_tensor('embeddings')
|
||||
|
||||
docid2embs[doc_id] = doc_embs.bfloat16()
|
||||
|
||||
if self.args.retrieval_model_type == 'colpali':
|
||||
return docid2embs
|
||||
elif self.args.retrieval_model_type == 'colbert':
|
||||
return docid2embs, docid2lens
|
||||
|
||||
|
||||
def get_images_from_doc_id(self, doc_id: str) -> List[PIL.Image.Image]:
|
||||
pdf_path = self.pdf_dir / f"{doc_id}.pdf"
|
||||
page_images = get_images_from_pdf(pdf_path)
|
||||
return page_images
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.args.loop_unique_doc_ids:
|
||||
doc_id = self.all_supporting_doc_ids[idx]
|
||||
|
||||
datum = {
|
||||
'doc_id': doc_id,
|
||||
}
|
||||
|
||||
if self.args.retrieval_model_type == 'colpali':
|
||||
page_images = self.get_images_from_doc_id(doc_id)
|
||||
datum['images'] = page_images
|
||||
|
||||
return datum
|
||||
|
||||
# keys(['qid', 'question', 'answers', 'metadata', 'supporting_context'])
|
||||
datum = deepcopy(self.data[idx])
|
||||
|
||||
supporting_doc_ids = []
|
||||
for obj in datum['supporting_context']:
|
||||
|
||||
supporting_doc_ids.append(obj['doc_id'])
|
||||
datum['supporting_doc_ids'] = supporting_doc_ids
|
||||
|
||||
return datum
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from m3docrag.utils.args import _example_args, parse_args
|
||||
|
||||
args = parse_args(_example_args)
|
||||
|
||||
dataset = M3DocVQADataset(
|
||||
args=args
|
||||
)
|
414
src/m3docrag/datasets/m3_docvqa/evaluate.py
Normal file
414
src/m3docrag/datasets/m3_docvqa/evaluate.py
Normal file
@ -0,0 +1,414 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# https://github.com/allenai/multimodalqa/blob/master/baselines/evaluate.py
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import re
|
||||
import string
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
from typing import List, Set, Tuple, Union
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from word2number.w2n import word_to_num
|
||||
from .common_utils import *
|
||||
|
||||
# From here through _match_numbers_if_present was originally copied from the evaluation code of DROP dataset:
|
||||
# https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
|
||||
|
||||
def _remove_articles(text: str) -> str:
|
||||
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
||||
return re.sub(regex, " ", text)
|
||||
|
||||
|
||||
def _white_space_fix(text: str) -> str:
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
EXCLUDE = set(string.punctuation)
|
||||
|
||||
|
||||
def _remove_punc(text: str) -> str:
|
||||
if not _is_number(text):
|
||||
return "".join(ch for ch in text if ch not in EXCLUDE)
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def _lower(text: str) -> str:
|
||||
return text.lower()
|
||||
|
||||
|
||||
def _tokenize(text: str) -> List[str]:
|
||||
return re.split(" |-", text)
|
||||
|
||||
|
||||
def _normalize_answer(text: str) -> str:
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
parts = [
|
||||
_white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token)))))
|
||||
for token in _tokenize(text)
|
||||
]
|
||||
parts = [part for part in parts if part.strip()]
|
||||
normalized = " ".join(parts).strip()
|
||||
return normalized
|
||||
|
||||
|
||||
def _is_number(text: str) -> bool:
|
||||
try:
|
||||
float(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _is_word_number(text: str) -> bool:
|
||||
try:
|
||||
word_to_num(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_number(text: str) -> str:
|
||||
if _is_number(text):
|
||||
return str(float(text))
|
||||
#TODO: this is not included in the original drop evaluation script, we need to have our own in the end anyways.
|
||||
elif _is_word_number(text):
|
||||
return str(float(word_to_num(text)))
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def _answer_to_bags(
|
||||
answer: Union[str, List[str], Tuple[str, ...]]
|
||||
) -> Tuple[List[str], List[Set[str]]]:
|
||||
if isinstance(answer, (list, tuple)):
|
||||
raw_spans = answer
|
||||
else:
|
||||
raw_spans = [answer]
|
||||
normalized_spans: List[str] = []
|
||||
token_bags = []
|
||||
for raw_span in raw_spans:
|
||||
normalized_span = _normalize_answer(raw_span)
|
||||
normalized_spans.append(normalized_span)
|
||||
token_bags.append(set(normalized_span.split()))
|
||||
return normalized_spans, token_bags
|
||||
|
||||
|
||||
def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
|
||||
"""
|
||||
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
|
||||
between them and gets maximum metric values over all the answers.
|
||||
"""
|
||||
scores = np.zeros([len(gold), len(predicted)])
|
||||
for gold_index, gold_item in enumerate(gold):
|
||||
for pred_index, pred_item in enumerate(predicted):
|
||||
if _match_numbers_if_present(gold_item, pred_item):
|
||||
scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item)
|
||||
row_ind, col_ind = linear_sum_assignment(-scores)
|
||||
|
||||
max_scores = np.zeros([max(len(gold), len(predicted))])
|
||||
for row, column in zip(row_ind, col_ind):
|
||||
max_scores[row] = max(max_scores[row], scores[row, column])
|
||||
return max_scores
|
||||
|
||||
|
||||
def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
|
||||
intersection = len(gold_bag.intersection(predicted_bag))
|
||||
if not predicted_bag:
|
||||
precision = 1.0
|
||||
else:
|
||||
precision = intersection / float(len(predicted_bag))
|
||||
if not gold_bag:
|
||||
recall = 1.0
|
||||
else:
|
||||
recall = intersection / float(len(gold_bag))
|
||||
f1 = (
|
||||
(2 * precision * recall) / (precision + recall)
|
||||
if not (precision == 0.0 and recall == 0.0)
|
||||
else 0.0
|
||||
)
|
||||
return f1
|
||||
|
||||
|
||||
def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool:
|
||||
gold_numbers = set()
|
||||
predicted_numbers = set()
|
||||
for word in gold_bag:
|
||||
if _is_number(word):
|
||||
gold_numbers.add(word)
|
||||
for word in predicted_bag:
|
||||
if _is_number(word):
|
||||
predicted_numbers.add(word)
|
||||
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def list_em(predicted, gold):
|
||||
predicted_bags = _answer_to_bags(predicted)
|
||||
gold_bags = _answer_to_bags(gold)
|
||||
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
|
||||
return 1.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
|
||||
def list_f1(predicted, gold):
|
||||
predicted_bags = _answer_to_bags(predicted)
|
||||
gold_bags = _answer_to_bags(gold)
|
||||
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
|
||||
f1 = np.mean(f1_per_bag)
|
||||
f1 = round(f1, 2)
|
||||
return f1
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, gold_answers):
|
||||
scores_for_ground_truths = []
|
||||
for gold_answer in gold_answers:
|
||||
score = metric_fn(prediction, gold_answer)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
|
||||
def evaluate_predictions(predictions, gold_answers, example_types=None):
|
||||
"""To support multiple gold annotations, `gold_answers` should be a list,
|
||||
with each item (either a string or a list) corresponding to one valid reference answer."""
|
||||
instance_eval_results = {}
|
||||
instance_eval_results_by_types = {}
|
||||
eval_funcs = {
|
||||
"list_em": list_em,
|
||||
"list_f1": list_f1
|
||||
}
|
||||
for qas_id in gold_answers:
|
||||
ref_answers = gold_answers[qas_id]
|
||||
if qas_id not in predictions:
|
||||
print(f"Missing prediction for question {qas_id}, and all scores for this question are set to zero")
|
||||
instance_eval_results[qas_id] = {
|
||||
metric: 0.0 for metric in eval_funcs.keys()
|
||||
}
|
||||
else:
|
||||
pred_answer = predictions[qas_id]
|
||||
instance_eval_results[qas_id] = {
|
||||
metric: metric_max_over_ground_truths(
|
||||
func, pred_answer, ref_answers
|
||||
) for metric, func in eval_funcs.items()
|
||||
}
|
||||
if example_types is not None:
|
||||
example_type = example_types[qas_id]
|
||||
if example_type not in instance_eval_results_by_types:
|
||||
instance_eval_results_by_types[example_type] = {}
|
||||
instance_eval_results_by_types[example_type][qas_id] = instance_eval_results[qas_id]
|
||||
|
||||
eval_scores = {metric: np.mean([result[metric] for result in instance_eval_results.values()]) * 100
|
||||
for metric in eval_funcs.keys()}
|
||||
|
||||
if example_types is not None:
|
||||
eval_scores_by_types = {}
|
||||
for example_type, type_instance_eval_results in instance_eval_results_by_types.items():
|
||||
eval_scores_by_types[example_type] = {
|
||||
metric: np.mean([result[metric] for result in type_instance_eval_results.values()]) * 100 for metric in eval_funcs.keys()
|
||||
}
|
||||
return eval_scores, instance_eval_results, eval_scores_by_types
|
||||
else:
|
||||
return eval_scores, instance_eval_results
|
||||
|
||||
|
||||
def eval_retrieval(qid2retrieval_results, gold_examples, recall_levels=[1, 2, 4, 5, 10]):
|
||||
"""
|
||||
Calculate document recall metrics per query, including recall@k for various k values,
|
||||
and compute the average recall.
|
||||
|
||||
Args:
|
||||
qid2retrieval_results (dict): Dictionary mapping query IDs (qid) to retrieval results.
|
||||
The retrieval results are a list of lists, where each sublist contains
|
||||
[doc_id, page, score], e.g., [['53e1d168375850acbaf134da744ffae0', 3, 2.3689956665039062]].
|
||||
gold_examples (list): List of ground truth examples with 'qid' and 'supporting_context'.
|
||||
Each 'supporting_context' is a list of dictionaries,
|
||||
e.g., [{'doc_id': '8513db80c11ea439ab11eba406ec00d9', 'doc_part': 'table'}].
|
||||
recall_levels (list): List of k values for recall@k, e.g., [1, 5, 10].
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing recall per query at different recall levels,
|
||||
and the average recall for each level.
|
||||
"""
|
||||
|
||||
# Create a mapping from query ID to the supporting context (gold standard)
|
||||
qid2supporting_context = {
|
||||
example["qid"]: example['supporting_context'] for example in gold_examples
|
||||
}
|
||||
|
||||
# Initialize dictionaries to store recall at different levels
|
||||
qid2recall_at_k = {k: {} for k in recall_levels}
|
||||
avg_recall_at_k = {k: 0.0 for k in recall_levels}
|
||||
|
||||
# Loop through each query and its supporting context
|
||||
for qid, supporting_context in qid2supporting_context.items():
|
||||
# Get the retrieved documents for the current query
|
||||
retrieved_results = qid2retrieval_results.get(qid, [])
|
||||
|
||||
# Extract the relevant doc_ids from supporting_context
|
||||
relevant_doc_ids = {ctx['doc_id'] for ctx in supporting_context}
|
||||
|
||||
# Count total relevant documents for the query
|
||||
n_relevant = len(relevant_doc_ids)
|
||||
|
||||
# For each recall level, calculate how many relevant documents were retrieved
|
||||
for k in recall_levels:
|
||||
# Get the top-k results (or fewer if there are fewer results)
|
||||
top_k_results = retrieved_results[:k]
|
||||
|
||||
# Ensure doc_id uniqueness in top_k_results
|
||||
top_k_doc_ids = set(result[0] for result in top_k_results)
|
||||
|
||||
# Count how many relevant documents are in the top-k results
|
||||
n_relevant_retrieved_at_k = len(top_k_doc_ids.intersection(relevant_doc_ids))
|
||||
|
||||
# Calculate recall@k for the current query
|
||||
recall_at_k = n_relevant_retrieved_at_k / n_relevant if n_relevant > 0 else 0.0
|
||||
|
||||
# Ensure recall is between 0 and 1
|
||||
recall_at_k = min(recall_at_k, 1.0)
|
||||
|
||||
# Store recall@k for the current query
|
||||
qid2recall_at_k[k][qid] = recall_at_k
|
||||
|
||||
# Calculate average recall@k across all queries
|
||||
for k in recall_levels:
|
||||
avg_recall_at_k[k] = sum(qid2recall_at_k[k].values()) / len(qid2recall_at_k[k]) if qid2recall_at_k[k] else 0.0
|
||||
|
||||
# Return recall@k for each query and average recall@k
|
||||
return {
|
||||
"recall_per_qid_at_k": qid2recall_at_k,
|
||||
"average_recall_at_k": avg_recall_at_k
|
||||
}
|
||||
|
||||
|
||||
def evaluate_prediction_file(prediction_path, gold_path='/job/datasets/m3-docvqa/MMQA_dev.jsonl'):
|
||||
# predicted_answers = json.load(open(prediction_path, encoding="utf-8"))
|
||||
|
||||
if isinstance(prediction_path, dict):
|
||||
raw_prediction_json = prediction_path
|
||||
else:
|
||||
raw_prediction_json = json.load(open(prediction_path, encoding="utf-8"))
|
||||
examples = read_jsonl(gold_path)
|
||||
gold_answers, answer_modalities, hop_types, question_types = {}, {}, {}, {}
|
||||
|
||||
qid2predicted_answers = {}
|
||||
qid2retrieval_results = {}
|
||||
for qid, data in raw_prediction_json.items():
|
||||
qid2predicted_answers[qid] = data['pred_answer'].strip()
|
||||
qid2retrieval_results[qid] = data['page_retrieval_results']
|
||||
|
||||
eval_retrieval_results = eval_retrieval(qid2retrieval_results, examples)
|
||||
print('Average recall at K')
|
||||
print(eval_retrieval_results['average_recall_at_k'])
|
||||
|
||||
predicted_answers = qid2predicted_answers
|
||||
|
||||
|
||||
all_scores = {}
|
||||
all_scores['overall'] = {}
|
||||
all_scores['modalities'] = {}
|
||||
all_scores['hop_types'] = {}
|
||||
all_scores['q_types'] = {}
|
||||
|
||||
all_scores['average_recall_at_k'] = eval_retrieval_results['average_recall_at_k']
|
||||
|
||||
for example in examples:
|
||||
qid = example["qid"]
|
||||
# Currently we only have one ground truth answer.
|
||||
# Even if there are multiple entries in example["answers"], the whole list should be regarded as one ref answer.
|
||||
# However, our script supports evaluation with multiple ref answers.
|
||||
# So, we will use an outer bracket here to pretend we have a list of ref answers.
|
||||
gold_answer = [str(item["answer"]) for item in example["answers"]]
|
||||
gold_answers[qid] = [gold_answer]
|
||||
answer_modality = set([item["modality"] for item in example["answers"]])
|
||||
assert len(answer_modality) == 1
|
||||
answer_modalities[qid] = answer_modality.pop()
|
||||
question_types[qid] = example["metadata"]["type"]
|
||||
hop_types[qid] = "Multi-hop" if example["metadata"]["type"] in MULTI_HOP_QUESTION_TYPES else "Single-hop"
|
||||
|
||||
eval_scores, instance_eval_results = evaluate_predictions(predicted_answers, gold_answers)
|
||||
print("\n\nOverall result with different metrics: ")
|
||||
for metric, value in eval_scores.items():
|
||||
print(f"{metric}: {value}")
|
||||
|
||||
all_scores['overall'][metric] = value
|
||||
|
||||
modality_counts = Counter(answer_modalities.values())
|
||||
_, _, eval_scores_by_modalities = \
|
||||
evaluate_predictions(predicted_answers, gold_answers, answer_modalities)
|
||||
print("\n\nEval results for different modalities:")
|
||||
for answer_modality in sorted(eval_scores_by_modalities.keys()):
|
||||
result = eval_scores_by_modalities[answer_modality]
|
||||
print(f"{answer_modality}")
|
||||
print(f"# of examples: {modality_counts[answer_modality]}")
|
||||
|
||||
all_scores['modalities'][answer_modality] = {}
|
||||
|
||||
for metric, value in result.items():
|
||||
print(f"{metric}: {value}")
|
||||
|
||||
all_scores['modalities'][answer_modality][metric] = value
|
||||
|
||||
hop_type_counts = Counter(hop_types.values())
|
||||
_, _, eval_scores_by_hop_types = evaluate_predictions(predicted_answers, gold_answers, hop_types)
|
||||
print("\n\nType\tCount\tEM\tF1")
|
||||
for hop_type in sorted(eval_scores_by_hop_types.keys()):
|
||||
result = eval_scores_by_hop_types[hop_type]
|
||||
print(f"{hop_type}\t{hop_type_counts[hop_type]}\t{result['list_em']}\t{result['list_f1']}")
|
||||
|
||||
all_scores['hop_types'][hop_type] = {}
|
||||
|
||||
all_scores['hop_types'][hop_type][metric] = value
|
||||
|
||||
question_type_counts = Counter(question_types.values())
|
||||
_, _, eval_scores_by_qtypes = evaluate_predictions(predicted_answers, gold_answers, question_types)
|
||||
print("\n\nType\tCount\tEM\tF1")
|
||||
for question_type in sorted(eval_scores_by_qtypes.keys()):
|
||||
result = eval_scores_by_qtypes[question_type]
|
||||
print(f"{question_type}\t{question_type_counts[question_type]}\t{result['list_em']}\t{result['list_f1']}")
|
||||
|
||||
all_scores['q_types'][question_type] = {}
|
||||
|
||||
all_scores['q_types'][question_type][metric] = value
|
||||
|
||||
|
||||
# return eval_scores
|
||||
return all_scores
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="evaluate on drop dataset")
|
||||
parser.add_argument(
|
||||
"--prediction_path",
|
||||
type=str,
|
||||
default="predictions.json",
|
||||
help="location of the prediction file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_path",
|
||||
type=str,
|
||||
default="dev.json",
|
||||
help="location of the gold file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
evaluate_prediction_file(args.prediction_path, args.gold_path)
|
20
src/m3docrag/rag/__init__.py
Normal file
20
src/m3docrag/rag/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
|
||||
from .multimodal import MultimodalRAGModel
|
||||
from .utils import reduce_embeddings
|
177
src/m3docrag/rag/base.py
Normal file
177
src/m3docrag/rag/base.py
Normal file
@ -0,0 +1,177 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
from typing import List, Dict, Tuple
|
||||
from tqdm.auto import tqdm
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from .utils import get_top_k_pages, get_top_k_pages_single_page_from_each_doc
|
||||
|
||||
class RAGModelBase:
|
||||
def __init__(
|
||||
self,
|
||||
retrieval_model=None,
|
||||
qa_model=None,
|
||||
vqa_model=None,
|
||||
):
|
||||
"""Base class for RAG pipeline
|
||||
|
||||
- retrieval_model: arbitrary retrieval model (e.g., ColPali / ColBERT)
|
||||
- qa_model: arbitrary text-only QA model (e.g., LLama3)
|
||||
- vqa_model: arbitrary VQA model (e.g., InternVL2, GPT4-o)
|
||||
|
||||
"""
|
||||
self.retrieval_model = retrieval_model
|
||||
self.qa_model = qa_model
|
||||
self.vqa_model = vqa_model
|
||||
|
||||
if self.retrieval_model is not None:
|
||||
self.retrieval_model.model.eval()
|
||||
if self.vqa_model is not None:
|
||||
self.vqa_model.model.eval()
|
||||
if self.qa_model is not None:
|
||||
self.qa_model.model.eval()
|
||||
|
||||
|
||||
def retrieve_pages_from_docs(
|
||||
self,
|
||||
query: str,
|
||||
docid2embs: Dict[str, torch.Tensor],
|
||||
docid2lens: Dict[str, torch.Tensor] = None,
|
||||
|
||||
index = None,
|
||||
token2pageuid = None,
|
||||
all_token_embeddings = None,
|
||||
|
||||
n_return_pages: int = 1,
|
||||
single_page_from_each_doc: bool = False,
|
||||
show_progress=False,
|
||||
) -> List[Tuple]:
|
||||
"""
|
||||
Given text query and pre-extracted document embedding,
|
||||
calculate similarity scores and return top-n pages
|
||||
|
||||
Args:
|
||||
- query (str): a text query to call retrieval model
|
||||
- docid2embs (Dict[str, tensor]): collection of document embeddings
|
||||
key: document_id
|
||||
value: torch.tensor of size (n_tokens, emb_dim)
|
||||
- index: faiss index
|
||||
- n_return_pages (int): number of pages to return
|
||||
- single_page_from_each_doc (bool): if true, only single page is retrieved from each PDF document.
|
||||
|
||||
Return:
|
||||
retrieval_results
|
||||
[(doc_id, page_idx, scores)...]
|
||||
"""
|
||||
|
||||
|
||||
if index is not None:
|
||||
|
||||
# [n_query_tokens, dim]
|
||||
query_emb = self.retrieval_model.encode_queries([query])[0]
|
||||
query_emb = query_emb.cpu().float().numpy().astype(np.float32)
|
||||
|
||||
# NN search
|
||||
k = n_return_pages
|
||||
D, I = index.search(query_emb, k)
|
||||
|
||||
# Sum the MaxSim scores across all query tokens for each document
|
||||
final_page2scores = {}
|
||||
|
||||
# Iterate over query tokens
|
||||
for q_idx, query_emb in enumerate(query_emb):
|
||||
|
||||
# Initialize a dictionary to hold document relevance scores
|
||||
curent_q_page2scores = {}
|
||||
|
||||
for nn_idx in range(k):
|
||||
found_nearest_doc_token_idx = I[q_idx, nn_idx]
|
||||
|
||||
page_uid = token2pageuid[found_nearest_doc_token_idx] # Get the document ID for this token
|
||||
|
||||
# reconstruct the original score
|
||||
doc_token_emb = all_token_embeddings[found_nearest_doc_token_idx]
|
||||
score = (query_emb * doc_token_emb).sum()
|
||||
|
||||
# MaxSim: aggregate the highest similarity score for each query token per document
|
||||
if page_uid not in curent_q_page2scores:
|
||||
curent_q_page2scores[page_uid] = score
|
||||
else:
|
||||
curent_q_page2scores[page_uid] = max(curent_q_page2scores[page_uid], score)
|
||||
|
||||
|
||||
for page_uid, score in curent_q_page2scores.items():
|
||||
if page_uid in final_page2scores:
|
||||
final_page2scores[page_uid] += score
|
||||
else:
|
||||
final_page2scores[page_uid] = score
|
||||
|
||||
# Sort documents by their final relevance score
|
||||
sorted_pages = sorted(final_page2scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Get the top-k document candidates
|
||||
top_k_pages = sorted_pages[:k]
|
||||
|
||||
|
||||
# [(doc_id, page_idx, scores)...]
|
||||
|
||||
sorted_results = []
|
||||
for page_uid, score in top_k_pages:
|
||||
# logger.info(f"{page_uid} with score {score}")
|
||||
|
||||
# page_uid = f"{doc_id}_page{page_id}"
|
||||
doc_id = page_uid.split('_page')[0]
|
||||
page_idx = int(page_uid.split('_page')[-1])
|
||||
sorted_results.append((doc_id, page_idx, score.item()))
|
||||
|
||||
return sorted_results
|
||||
|
||||
docid2scores = {}
|
||||
for doc_id, doc_embs in tqdm(
|
||||
docid2embs.items(),
|
||||
total=len(docid2embs),
|
||||
disable = not show_progress,
|
||||
desc=f"Calculating similarity score over documents"
|
||||
):
|
||||
doc_lens = None
|
||||
if docid2lens is not None:
|
||||
doc_lens = docid2lens[doc_id]
|
||||
|
||||
scores = self.retrieval_model.retrieve(
|
||||
query=query,
|
||||
doc_embeds=doc_embs,
|
||||
doc_lens=doc_lens,
|
||||
to_cpu=True,
|
||||
return_top_1=False
|
||||
)
|
||||
scores = scores.flatten().tolist()
|
||||
docid2scores[doc_id] = scores
|
||||
|
||||
# find the pages with top scores
|
||||
if single_page_from_each_doc:
|
||||
return get_top_k_pages_single_page_from_each_doc(docid2scores=docid2scores, k=n_return_pages)
|
||||
else:
|
||||
return get_top_k_pages(docid2scores=docid2scores, k=n_return_pages)
|
||||
|
||||
|
||||
def run_qa(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def run_vqa(self):
|
||||
raise NotImplementedError
|
51
src/m3docrag/rag/multimodal.py
Normal file
51
src/m3docrag/rag/multimodal.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
|
||||
from .base import RAGModelBase
|
||||
|
||||
import torch
|
||||
|
||||
from m3docrag.vqa import VQAModel
|
||||
from m3docrag.retrieval import ColPaliRetrievalModel
|
||||
|
||||
|
||||
class MultimodalRAGModel(RAGModelBase):
|
||||
def __init__(
|
||||
self,
|
||||
retrieval_model: ColPaliRetrievalModel,
|
||||
vqa_model: VQAModel = None
|
||||
):
|
||||
self.retrieval_model = retrieval_model
|
||||
self.vqa_model = vqa_model
|
||||
|
||||
self.retrieval_model.model.eval()
|
||||
|
||||
if self.vqa_model is not None and isinstance(self.vqa_model.model, torch.nn.Module):
|
||||
self.vqa_model.model.eval()
|
||||
|
||||
|
||||
def run_vqa(
|
||||
self,
|
||||
images,
|
||||
question,
|
||||
) -> str:
|
||||
|
||||
response = self.vqa_model.generate(images=images, question=question)
|
||||
assert isinstance(response, str), type(response)
|
||||
|
||||
return response
|
128
src/m3docrag/rag/utils.py
Normal file
128
src/m3docrag/rag/utils.py
Normal file
@ -0,0 +1,128 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def reduce_embeddings(docid2embs, dim='page', show_progress=True):
|
||||
"""Summarize document embedding by reducing (averaging) specific dimensions
|
||||
|
||||
Input embedding:
|
||||
[n_pages, n_tokens, emb_dim]
|
||||
|
||||
Output embedding:
|
||||
|
||||
- reduction_dim == 'page'
|
||||
[1, n_tokens, emb_dim]
|
||||
|
||||
- reduction_dim == 'token'
|
||||
[1, n_pages, emb_dim]
|
||||
|
||||
- reduction_dim == 'page_token'
|
||||
[1, 1, emb_dim]
|
||||
"""
|
||||
|
||||
assert dim in ['page', 'token', 'page_token'], f"{dim}"
|
||||
|
||||
new_docid2embs = {}
|
||||
|
||||
for doc_id in tqdm(
|
||||
list(docid2embs.keys()),
|
||||
disable=not show_progress
|
||||
):
|
||||
# [n_pages, n_tokens, dim]
|
||||
embs = docid2embs[doc_id]
|
||||
|
||||
emb_dim = embs.size(-1)
|
||||
|
||||
if dim == 'page':
|
||||
# [n_tokens, dim]
|
||||
new_embs = embs.mean(dim=0)
|
||||
elif dim == 'token':
|
||||
# [n_pages, dim]
|
||||
new_embs = embs.mean(dim=1)
|
||||
elif dim == 'page_token':
|
||||
# [1, dim]
|
||||
new_embs = embs.mean(dim=0).mean(dim=0)
|
||||
|
||||
new_docid2embs[doc_id] = new_embs.view(1, -1, emb_dim)
|
||||
|
||||
return new_docid2embs
|
||||
|
||||
|
||||
def get_top_k_pages(docid2scores: dict, k: int):
|
||||
"""
|
||||
# Example usage:
|
||||
docid2scores = {
|
||||
"doc1": [10, 50, 30],
|
||||
"doc2": [40, 20, 60],
|
||||
"doc3": [70, 90]
|
||||
}
|
||||
|
||||
k = 3
|
||||
top_k_pages = get_top_k_pages(docid2scores, k)
|
||||
print(top_k_pages)
|
||||
|
||||
-> [('doc3', 1, 90), ('doc3', 0, 70), ('doc2', 2, 60)]
|
||||
"""
|
||||
# Flatten the dictionary into a list of tuples (doc_id, page_index, score)
|
||||
flattened_scores = [
|
||||
(doc_id, page_index, score)
|
||||
for doc_id, scores in docid2scores.items()
|
||||
for page_index, score in enumerate(scores)
|
||||
]
|
||||
|
||||
# Sort by score in descending order
|
||||
flattened_scores.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Get the top-k entries
|
||||
top_k_pages = flattened_scores[:k]
|
||||
|
||||
return top_k_pages
|
||||
|
||||
|
||||
def get_top_k_pages_single_page_from_each_doc(docid2scores: dict, k: int):
|
||||
"""
|
||||
# Example usage:
|
||||
docid2scores = {
|
||||
"doc1": [10, 50, 30],
|
||||
"doc2": [40, 20, 60],
|
||||
"doc3": [70, 90]
|
||||
}
|
||||
|
||||
k = 2
|
||||
top_k_pages = get_top_k_pages_single_page_from_each_doc(docid2scores, k)
|
||||
print(top_k_pages)
|
||||
|
||||
-> [('doc3', 1, 90), ('doc2', 2, 60)]
|
||||
"""
|
||||
# First, get the highest scoring page for each document
|
||||
highest_per_doc = [
|
||||
(doc_id, max(enumerate(scores), key=lambda x: x[1])) # (doc_id, (page_index, score))
|
||||
for doc_id, scores in docid2scores.items()
|
||||
]
|
||||
|
||||
# Flatten the structure to (doc_id, page_index, score)
|
||||
highest_per_doc_flat = [(doc_id, page_index, score) for doc_id, (page_index, score) in highest_per_doc]
|
||||
|
||||
# Sort by score in descending order
|
||||
highest_per_doc_flat.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Get the top-k entries
|
||||
top_k_pages = highest_per_doc_flat[:k]
|
||||
|
||||
return top_k_pages
|
18
src/m3docrag/retrieval/__init__.py
Normal file
18
src/m3docrag/retrieval/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
from .colpali import ColPaliRetrievalModel
|
303
src/m3docrag/retrieval/colpali.py
Normal file
303
src/m3docrag/retrieval/colpali.py
Normal file
@ -0,0 +1,303 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
# from transformers import AutoProcessor
|
||||
from PIL import Image
|
||||
from typing import List
|
||||
|
||||
from colpali_engine.models import ColPali, ColPaliProcessor
|
||||
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
||||
|
||||
def init(
|
||||
backbone_name_or_path="/job/model/colpaligemma-3b-pt-448-base",
|
||||
adapter_name_or_path= "/job/model/colpali-v1.2",
|
||||
dtype=torch.bfloat16,
|
||||
):
|
||||
"""
|
||||
Load ColPali Model and Processor from (locally downloaded) HF checkpoint.
|
||||
|
||||
Args:
|
||||
- backbone_model_name_or_path: downloaded from https://huggingface.co/vidore/colpaligemma-3b-pt-448-base
|
||||
- adapter_name_or_path: downloaded from https://huggingface.co/vidore/colpali-v1.2
|
||||
Return:
|
||||
- model
|
||||
- processor
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
model_class = ColPali
|
||||
processor_class = ColPaliProcessor
|
||||
if 'colqwen' in str(adapter_name_or_path):
|
||||
model_class = ColQwen2
|
||||
processor_class = ColQwen2Processor
|
||||
kwargs['attn_implementation'] = "flash_attention_2"
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
backbone_name_or_path,
|
||||
torch_dtype=dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
**kwargs
|
||||
).eval()
|
||||
|
||||
model.load_adapter(adapter_name_or_path)
|
||||
processor = processor_class.from_pretrained(adapter_name_or_path)
|
||||
|
||||
return model, processor
|
||||
|
||||
|
||||
def encode_images(
|
||||
model,
|
||||
processor,
|
||||
images: List[Image.Image],
|
||||
batch_size: int = 4,
|
||||
to_cpu: bool = False,
|
||||
use_tqdm: bool = False,
|
||||
collate_fn=None,
|
||||
return_doclens: bool = False
|
||||
):
|
||||
"""Create document embeddings with ColPali
|
||||
|
||||
Args:
|
||||
model
|
||||
processor
|
||||
images (List[Image.Image])
|
||||
(n_pages)
|
||||
batch_size (int, optional):
|
||||
batch size. Defaults to 4.
|
||||
to_cpu (bool, optional):
|
||||
whether to save embeddings in cpu tensors. Defaults to False.
|
||||
use_tqdm (bool, optional):
|
||||
whether to show tqdm progress bar. Defaults to False.
|
||||
collate_fn (_type_, optional):
|
||||
custom collate_fn for document dataloader. Defaults to None.
|
||||
return_doclens (bool, optional):
|
||||
whether to output the number of pages. Defaults to False.
|
||||
|
||||
Returns:
|
||||
doc_embs: List[torch.tensor]
|
||||
visual embedding of documents (n_pages, n_tokens, n_dimension)
|
||||
(optional) doclens
|
||||
number of pages
|
||||
"""
|
||||
|
||||
if collate_fn is None:
|
||||
collate_fn = processor.process_images
|
||||
|
||||
# run inference - docs
|
||||
dataloader = DataLoader(
|
||||
images,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
doc_embs = []
|
||||
if return_doclens:
|
||||
doclens = []
|
||||
if use_tqdm:
|
||||
dataloader = tqdm(dataloader)
|
||||
for batch_doc in dataloader:
|
||||
with torch.no_grad():
|
||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||
embeddings_doc = model(**batch_doc)
|
||||
if to_cpu:
|
||||
embeddings_doc = embeddings_doc.to("cpu")
|
||||
if return_doclens:
|
||||
_doclens = batch_doc.attention_mask.squeeze(-1).sum(-1).tolist()
|
||||
doclens.extend(_doclens)
|
||||
doc_embs.extend(list(torch.unbind(embeddings_doc)))
|
||||
|
||||
if return_doclens:
|
||||
doc_embs, doclens
|
||||
else:
|
||||
return doc_embs
|
||||
|
||||
|
||||
def encode_queries(
|
||||
model,
|
||||
processor,
|
||||
queries: List[str],
|
||||
batch_size: int = 4,
|
||||
to_cpu: bool = False,
|
||||
use_tqdm: bool = False,
|
||||
collate_fn=None,
|
||||
):
|
||||
"""Create query embeddings with ColPali
|
||||
|
||||
Args:
|
||||
model
|
||||
processor
|
||||
queries (List[str]):
|
||||
text queries (n_queries,)
|
||||
batch_size (int, optional):
|
||||
batch size. Defaults to 4.
|
||||
to_cpu (bool, optional):
|
||||
whether to save embeddings in cpu tensors. Defaults to False.
|
||||
use_tqdm (bool, optional):
|
||||
whether to show tqdm progress bar. Defaults to False.
|
||||
collate_fn (_type_, optional):
|
||||
custom collate_fn for document dataloader. Defaults to None.
|
||||
Returns:
|
||||
query_embs: List[torch.tensor]
|
||||
embedding of queries (n_queries, n_tokens, n_dimension)
|
||||
"""
|
||||
|
||||
if collate_fn is None:
|
||||
collate_fn = processor.process_queries
|
||||
|
||||
# run inference - queries
|
||||
dataloader = DataLoader(
|
||||
queries,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
|
||||
query_embs = []
|
||||
if use_tqdm:
|
||||
dataloader = tqdm(dataloader)
|
||||
for batch_query in dataloader:
|
||||
with torch.no_grad():
|
||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||
embeddings_query = model(**batch_query)
|
||||
if to_cpu:
|
||||
embeddings_query = embeddings_query.to("cpu")
|
||||
query_embs.extend(list(torch.unbind(embeddings_query)))
|
||||
return query_embs
|
||||
|
||||
|
||||
def retrieve(
|
||||
model,
|
||||
processor,
|
||||
docs=None,
|
||||
query=None,
|
||||
doc_embeds=None,
|
||||
query_embeds=None,
|
||||
to_cpu=False,
|
||||
batch_size=1,
|
||||
use_tqdm=False,
|
||||
return_top_1=True
|
||||
):
|
||||
"""Find the right document image with colpali
|
||||
"""
|
||||
if doc_embeds is None:
|
||||
doc_embeds = encode_images(
|
||||
model, processor,
|
||||
images=docs,
|
||||
batch_size=batch_size,
|
||||
use_tqdm=use_tqdm,
|
||||
to_cpu=to_cpu,
|
||||
)
|
||||
|
||||
if query_embeds is None:
|
||||
query_embeds = encode_queries(
|
||||
model, processor,
|
||||
queries=[query],
|
||||
batch_size=1,
|
||||
use_tqdm=use_tqdm,
|
||||
to_cpu=to_cpu,
|
||||
)
|
||||
|
||||
qs = query_embeds
|
||||
ds = doc_embeds
|
||||
|
||||
qs = [q.to(ds[0].dtype) for q in qs]
|
||||
|
||||
scores = processor.score_multi_vector(qs, ds)
|
||||
|
||||
if return_top_1:
|
||||
return scores.argmax(axis=1)
|
||||
else:
|
||||
return scores
|
||||
|
||||
|
||||
class ColPaliRetrievalModel:
|
||||
|
||||
def __init__(self,
|
||||
backbone_name_or_path="/job/model/colpaligemma-3b-pt-448-base",
|
||||
adapter_name_or_path= "/job/model/colpali-v1.2",
|
||||
dtype=torch.bfloat16,
|
||||
):
|
||||
model, processor = init(backbone_name_or_path=backbone_name_or_path,
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.model = model.eval()
|
||||
self.processor = processor
|
||||
|
||||
def encode_queries(self,
|
||||
queries: List[str],
|
||||
batch_size: int = 4,
|
||||
to_cpu: bool = False,
|
||||
use_tqdm: bool = False,
|
||||
collate_fn=None
|
||||
):
|
||||
return encode_queries(
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
queries=queries,
|
||||
batch_size=batch_size,
|
||||
to_cpu=to_cpu,
|
||||
use_tqdm=use_tqdm,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
|
||||
def encode_images(self,
|
||||
images: List[Image.Image],
|
||||
batch_size: int = 4,
|
||||
to_cpu: bool = False,
|
||||
use_tqdm: bool = False,
|
||||
collate_fn=None,
|
||||
return_doclens: bool = False
|
||||
):
|
||||
return encode_images(
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
images=images,
|
||||
batch_size=batch_size,
|
||||
to_cpu=to_cpu,
|
||||
use_tqdm=use_tqdm,
|
||||
collate_fn=collate_fn,
|
||||
return_doclens=return_doclens,
|
||||
)
|
||||
|
||||
def retrieve(self,
|
||||
docs=None,
|
||||
query=None,
|
||||
doc_embeds=None,
|
||||
doc_lens=None,
|
||||
query_embeds=None,
|
||||
to_cpu=False,
|
||||
batch_size=1,
|
||||
use_tqdm=False,
|
||||
return_top_1=True
|
||||
):
|
||||
|
||||
return retrieve(
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
docs=docs,
|
||||
query=query,
|
||||
doc_embeds=doc_embeds,
|
||||
query_embeds=query_embeds,
|
||||
to_cpu=to_cpu,
|
||||
batch_size=batch_size,
|
||||
use_tqdm=use_tqdm,
|
||||
return_top_1=return_top_1
|
||||
)
|
75
src/m3docrag/utils/args.py
Normal file
75
src/m3docrag/utils/args.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Sequence, List
|
||||
import transformers
|
||||
from icecream import ic
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(transformers.TrainingArguments):
|
||||
|
||||
# Data settings
|
||||
split: str = 'train'
|
||||
data_name: str = field(default='m3-docvqa', metadata={"help": "Local name to be stored at LOCAL_DATA_DIR"})
|
||||
data_len: int = field(default=None, metadata={"help": "number of examples to subsample from dataset"})
|
||||
use_dummy_images: bool = field(default=False, metadata={"help": "if true, skip downloading images"})
|
||||
load_embedding: bool = False
|
||||
embedding_name: str = "colpali-v1.2_m3-docvqa_dev"
|
||||
|
||||
max_pages: int = 20
|
||||
do_page_padding: bool = False
|
||||
|
||||
# Retrieval settings
|
||||
retrieval_model_type: str = field(default='colpali', metadata={"choices": ['colpali', 'colbert']})
|
||||
use_retrieval: bool = True
|
||||
retrieval_only: bool = field(default=False, metadata={"help": "not running stage 2 (VQA)"})
|
||||
page_retrieval_type: str = 'logits'
|
||||
loop_unique_doc_ids: bool = field(default=False, metadata={"help": "if true, apply retrieval only on unique doc ids"})
|
||||
|
||||
n_retrieval_pages: int = 1
|
||||
|
||||
|
||||
# Embedding indexing settings
|
||||
faiss_index_type: str = field(default='ivfflat', metadata={"choices": ['flatip', 'ivfflat', 'ivfpq']})
|
||||
|
||||
|
||||
# Local paths
|
||||
model_name_or_path: Optional[str] = field(default="Qwen2-VL-7B-Instruct")
|
||||
retrieval_model_name_or_path: Optional[str] = field(default="colpaligemma-3b-pt-448-base")
|
||||
retrieval_adapter_model_name_or_path: Optional[str] = field(default="colpali-v1.2")
|
||||
|
||||
# Model settings
|
||||
bits: int = field(default=16, metadata={"help": "Floating point precision. Use '4' for 4-bit quantization to save memory"})
|
||||
|
||||
# idefics2 settings
|
||||
do_image_splitting: bool = False
|
||||
|
||||
|
||||
_example_arg_str = """
|
||||
--output_dir=/job/outputs
|
||||
--data_name=m3-docvqa
|
||||
--use_retrieval=True
|
||||
"""
|
||||
_example_args = _example_arg_str.strip().split('\n')
|
||||
|
||||
|
||||
def parse_args(args=None):
|
||||
parser = transformers.HfArgumentParser(TrainingArguments)
|
||||
parsed_args, remaining_args = parser.parse_args_into_dataclasses(args, return_remaining_strings=True)
|
||||
ic(remaining_args)
|
||||
|
||||
return parsed_args
|
209
src/m3docrag/utils/distributed.py
Normal file
209
src/m3docrag/utils/distributed.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
Utility functions to manage multi-device training
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# import pkg_resources
|
||||
import importlib.metadata
|
||||
import torch.cuda
|
||||
import torch.distributed
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def world_size() -> int:
|
||||
"""Returns the total number of processes in a distributed job (num_nodes x gpus_per_node).
|
||||
Returns 1 in a non-distributed job.
|
||||
"""
|
||||
return int(os.environ.get("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
def is_distributed() -> bool:
|
||||
"""Returns True iff this is a distributed job (more than one process)."""
|
||||
return world_size() > 1
|
||||
|
||||
|
||||
def local_rank() -> int:
|
||||
"""Returns the local rank of the current process in a distributed job.
|
||||
Returns 0 (local primary) for non-distributed jobs.
|
||||
"""
|
||||
return int(os.environ.get("LOCAL_RANK", "0"))
|
||||
|
||||
|
||||
def global_rank() -> int:
|
||||
"""Returns the global rank of the current process in a distributed job.
|
||||
Returns 0 (global primary) for non-distributed jobs.
|
||||
"""
|
||||
return int(os.environ.get("RANK", local_rank()))
|
||||
|
||||
|
||||
def barrier():
|
||||
"""Synchronizes all processes. Set GPU with local_rank to perform barrier used by this process."""
|
||||
torch.distributed.barrier(device_ids=[local_rank()])
|
||||
|
||||
|
||||
def patch_module_loggers(module_namespace):
|
||||
modules = inspect.getmembers(module_namespace, predicate=inspect.ismodule)
|
||||
|
||||
# check toplevel module
|
||||
if hasattr(module_namespace, "logger"):
|
||||
module_namespace.logger = logger
|
||||
logger.info(f"Patching logger: {module_namespace.__name__}")
|
||||
|
||||
for _, mod in modules:
|
||||
if hasattr(mod, "logger"):
|
||||
mod.logger = logger
|
||||
logger.info(f"Patching logger: {mod.__name__}")
|
||||
|
||||
|
||||
class InterceptLogHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# Get corresponding Loguru level if it exists.
|
||||
level: str | int
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
# Find caller from where originated the logged message.
|
||||
frame, depth = inspect.currentframe(), 0
|
||||
while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__):
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
|
||||
|
||||
def configure_distributed_logging(
|
||||
rank_zero_level="INFO",
|
||||
rank_non_zero_level="WARNING",
|
||||
):
|
||||
"""This method configures logger to reduce noise in multi-node, multi-process evaluations (e.g. DeepSpeed)_summary_
|
||||
Args:
|
||||
rank_zero_level (str, optional): Log level on zero rank process. Defaults to "INFO".
|
||||
rank_non_zero_level (str, optional): Log level on non-zero rank processes. Defaults to "WARNING".
|
||||
"""
|
||||
|
||||
logger.remove()
|
||||
rank = local_rank()
|
||||
format = "<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> | <lvl>{level}</lvl> | rank={extra[rank]} | <c>{name}</c>:<c>{function}</c>:<c>{line}</c> - <lvl>{message}</lvl>\n{exception}"
|
||||
if rank != 0:
|
||||
logger.configure(
|
||||
extra={"rank": f"{global_rank()}:{rank}"},
|
||||
handlers=[
|
||||
{"sink": sys.stderr, "format": format, "level": rank_non_zero_level}
|
||||
],
|
||||
)
|
||||
else:
|
||||
logger.configure(
|
||||
extra={"rank": f"{global_rank()}:{rank}"},
|
||||
handlers=[{"sink": sys.stdout, "format": format, "level": rank_zero_level}],
|
||||
)
|
||||
|
||||
# Attempt to intercept normal logging in libs
|
||||
logging.basicConfig(handlers=[InterceptLogHandler()], level=0, force=True)
|
||||
|
||||
|
||||
def get_cuda_version():
|
||||
"""Get the installed CUDA version."""
|
||||
try:
|
||||
output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
|
||||
version_line = output.strip().split("\n")[-1]
|
||||
version = version_line.split(" ")[-1]
|
||||
return version
|
||||
except Exception as e:
|
||||
logger.info(f"Cannot detect CUDA version. Exception occured: {e}")
|
||||
return "N/A"
|
||||
|
||||
|
||||
def log_runtime_info():
|
||||
# Get Python runtime information
|
||||
python_version = sys.version
|
||||
python_implementation = platform.python_implementation()
|
||||
python_build = platform.python_build()
|
||||
|
||||
# Get environment variables
|
||||
env_vars = os.environ
|
||||
|
||||
# Get installed package versions
|
||||
installed_packages = [
|
||||
# (d.project_name, d.version) for d in pkg_resources.working_set
|
||||
(d.metadata["Name"], d.version) for d in importlib.metadata.distributions()
|
||||
]
|
||||
|
||||
|
||||
# logger.info diagnostics
|
||||
logger.info("Python Version: {}".format(python_version))
|
||||
logger.info("Python Implementation: {}".format(python_implementation))
|
||||
logger.info("Python Build: {}".format(python_build))
|
||||
|
||||
logger.info(f"Environment Variables: {env_vars}")
|
||||
logger.info(f"Installed Packages: {installed_packages}")
|
||||
|
||||
logger.info(f"CUDA version: {get_cuda_version()}")
|
||||
logger.info(f"Is CUDA available for Torch?: {torch.cuda.is_available()}")
|
||||
|
||||
logger.info(f"World size: {world_size()}")
|
||||
|
||||
|
||||
def local_rank_zero(func):
|
||||
"""
|
||||
Decorator to execute function only in local zero rank. Can be useful for logging statistics.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if local_rank() == 0:
|
||||
func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def global_rank_zero(func):
|
||||
"""
|
||||
Decorator to execute function only in global zero rank. Can be useful for logging statistics.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if global_rank() == 0:
|
||||
func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
import gpustat
|
||||
def print_gpu_stats():
|
||||
gpustat.cli.main([])
|
||||
|
||||
def supports_flash_attention(device_id=0):
|
||||
"""Check if a GPU supports FlashAttention."""
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
|
||||
# Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
|
||||
is_sm8x = major == 8 and minor >= 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
|
||||
return is_sm8x or is_sm90
|
24
src/m3docrag/utils/paths.py
Normal file
24
src/m3docrag/utils/paths.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
LOCAL_DATA_DIR = os.getenv("LOCAL_DATA_DIR", "/job/datasets")
|
||||
LOCAL_EMBEDDINGS_DIR = os.getenv("LOCAL_EMBEDDINGS_DIR", "/job/embeddings")
|
||||
LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/job/model")
|
||||
LOCAL_OUTPUT_DIR = os.getenv("LOCAL_OUTPUT_DIR", "/job/output")
|
75
src/m3docrag/utils/pdfs.py
Normal file
75
src/m3docrag/utils/pdfs.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from pdf2image import convert_from_path
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def get_images_from_pdf(pdf_path, max_pages=None, dpi_resolution=144, save_dir='/tmp/', save_image=False, save_type='png', verbose=False):
|
||||
pdf_path = Path(pdf_path)
|
||||
assert pdf_path.exists(), f"{pdf_path} does not exist"
|
||||
|
||||
pdf_fname = pdf_path.name
|
||||
|
||||
images = convert_from_path(pdf_path, dpi=dpi_resolution)
|
||||
|
||||
# PIL.PpmImagePlugin.PpmImageFile -> PIL.Image.Image
|
||||
images = [img.convert('RGB') for img in images]
|
||||
|
||||
# resizing to the most common image size so that we can stack in pytorch tensor
|
||||
# PDFs (e.g., MMLongBench-Doc) have different image sizes
|
||||
# width=1,224 and height=1,584
|
||||
# 1440, 810
|
||||
# 1191, 1684
|
||||
# 1440, 1080
|
||||
# 1536, 1152
|
||||
|
||||
# 1) find the most common image size
|
||||
img_size_counter = Counter()
|
||||
for img in images:
|
||||
img_size_counter[img.size] += 1
|
||||
common_img_size, common_img_size_count = img_size_counter.most_common(1)[0]
|
||||
|
||||
# 2) if pages have different sizes -> resize all pages to that image size
|
||||
if len(images) != common_img_size_count:
|
||||
logger.info(f"total: {len(images)} pages")
|
||||
logger.info(f"resizing to the most common image size: {common_img_size} with count: {common_img_size_count}")
|
||||
images = [img.resize(common_img_size) for img in images]
|
||||
|
||||
if save_image:
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
for page_index, page_image in enumerate(images):
|
||||
save_page_path = save_dir / f"{pdf_fname}_{page_index+1}.{save_type}"
|
||||
if not save_page_path.exists():
|
||||
page_image.save(save_page_path)
|
||||
if verbose:
|
||||
logger.info(f"Page {page_index} saved at {save_page_path}")
|
||||
|
||||
return images
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
get_images_from_pdf(
|
||||
pdf_path="./multimodalqa_screenshots_pdfs/0df5cc80bcd2a27b91224d658ad3a7b5.pdf",
|
||||
save_dir='./tmp/',
|
||||
save_image=True
|
||||
)
|
70
src/m3docrag/utils/prompts.py
Normal file
70
src/m3docrag/utils/prompts.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import string
|
||||
|
||||
binary_page_retrieval_template = """
|
||||
question: $question
|
||||
Does this page have the answer to the question?
|
||||
Answer only with yes or no.
|
||||
""".strip()
|
||||
|
||||
concat_page_retrieval_template = """
|
||||
question: $question
|
||||
Which page is most relevant to the question?
|
||||
""".strip()
|
||||
|
||||
concat_page_retrieval_with_answer_template = """
|
||||
Find the page most relevant to the question and answer: "$question"
|
||||
""".strip()
|
||||
|
||||
concate_page_answer_template = """
|
||||
Find the answer to the question: "$question"
|
||||
""".strip()
|
||||
|
||||
short_answer_template = """
|
||||
question: $question
|
||||
output only answer.
|
||||
""".strip()
|
||||
|
||||
long_answer_template = """
|
||||
question: $question
|
||||
Answer the question with detailed explanation.
|
||||
""".strip()
|
||||
|
||||
|
||||
text_rag_template = """
|
||||
DOCUMENTS:
|
||||
$documents
|
||||
|
||||
QUESTION:
|
||||
$question
|
||||
|
||||
INSTRUCTIONS:
|
||||
Answer the QUESTION using the DOCUMENTS text above. Simply output the answer only.
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
|
||||
|
||||
binary_page_retrieval_template = string.Template(binary_page_retrieval_template)
|
||||
concat_page_retrieval_template = string.Template(concat_page_retrieval_template)
|
||||
concat_page_retrieval_with_answer_template = string.Template(concat_page_retrieval_with_answer_template)
|
||||
concate_page_answer_template = string.Template(concate_page_answer_template)
|
||||
short_answer_template = string.Template(short_answer_template)
|
||||
long_answer_template = string.Template(long_answer_template)
|
||||
text_rag_template = string.Template(text_rag_template)
|
32
src/m3docrag/utils/tar.py
Normal file
32
src/m3docrag/utils/tar.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import tarfile
|
||||
from loguru import logger
|
||||
|
||||
def make_tarfile(source_dir, output_filename):
|
||||
logger.info(f"Compressing {source_dir} to {output_filename} ...")
|
||||
with tarfile.open(output_filename, "w:gz") as tar:
|
||||
tar.add(source_dir, arcname='.')
|
||||
logger.info(f"Compression done!")
|
||||
|
||||
|
||||
def extract_tarfile(input_filename, target_dir):
|
||||
logger.info(f"Extracting {input_filename} to {target_dir} ...")
|
||||
with tarfile.open(input_filename) as f:
|
||||
f.extractall(target_dir)
|
||||
logger.info(f"Extraction done!")
|
144
src/m3docrag/vqa/__init__.py
Normal file
144
src/m3docrag/vqa/__init__.py
Normal file
@ -0,0 +1,144 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
|
||||
from typing import Union, List
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
from m3docrag.vqa import internvl2
|
||||
from m3docrag.vqa import idefics2
|
||||
from m3docrag.vqa import idefics3
|
||||
from m3docrag.vqa import florence2
|
||||
from m3docrag.vqa import qwen2
|
||||
|
||||
ALL_VQA_MODEL_TYPES = ['florence2', 'idefics2', 'internvl2', 'idefics3', 'qwen2']
|
||||
|
||||
def init(
|
||||
model_name_or_path: Union[str, Path],
|
||||
model_type: str,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if 'internvl2' == model_type.lower():
|
||||
return internvl2.init(
|
||||
model_name_or_path=model_name_or_path,
|
||||
**kwargs
|
||||
)
|
||||
elif 'idefics2' == model_type.lower():
|
||||
return idefics2.init(
|
||||
model_name_or_path=model_name_or_path,
|
||||
**kwargs
|
||||
)
|
||||
elif 'idefics3' == model_type.lower():
|
||||
return idefics3.init(
|
||||
model_name_or_path=model_name_or_path,
|
||||
**kwargs
|
||||
)
|
||||
elif 'florence2' == model_type.lower():
|
||||
return florence2.init(
|
||||
model_name_or_path=model_name_or_path,
|
||||
**kwargs
|
||||
)
|
||||
elif 'qwen2' == model_type.lower():
|
||||
return qwen2.init(
|
||||
model_name_or_path=model_name_or_path,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{model_type} is unsupported. Supported: {ALL_VQA_MODEL_TYPES}")
|
||||
|
||||
|
||||
def generate(
|
||||
model_type: str,
|
||||
model,
|
||||
processor,
|
||||
**kwargs
|
||||
) -> List[str]:
|
||||
|
||||
if 'internvl2' == model_type.lower():
|
||||
return internvl2.generate(
|
||||
model=model,
|
||||
processor=processor,
|
||||
**kwargs
|
||||
)
|
||||
elif 'idefics2' == model_type.lower():
|
||||
return idefics2.generate(
|
||||
model=model,
|
||||
processor=processor,
|
||||
**kwargs
|
||||
)
|
||||
elif 'idefics3' == model_type.lower():
|
||||
return idefics3.generate(
|
||||
model=model,
|
||||
processor=processor,
|
||||
**kwargs
|
||||
)
|
||||
elif 'florence2' == model_type.lower():
|
||||
return florence2.generate(
|
||||
model=model,
|
||||
processor=processor,
|
||||
**kwargs
|
||||
)
|
||||
elif 'qwen2' == model_type.lower():
|
||||
return qwen2.generate(
|
||||
model=model,
|
||||
processor=processor,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{model_type} is unsupported. Supported: {ALL_VQA_MODEL_TYPES}")
|
||||
|
||||
|
||||
class VQAModel:
|
||||
|
||||
def __init__(self, model_name_or_path: Union[str, Path], model_type: str, **kwargs):
|
||||
|
||||
model_loaded = init(model_name_or_path=model_name_or_path, model_type=model_type, **kwargs)
|
||||
model = model_loaded['model']
|
||||
if 'tokenizer' in model_loaded:
|
||||
processor = model_loaded['tokenizer']
|
||||
else:
|
||||
processor = model_loaded['processor']
|
||||
|
||||
if isinstance(model, torch.nn.Module):
|
||||
model = model.eval()
|
||||
|
||||
# greedy decoding
|
||||
if hasattr(model, 'generation_config'):
|
||||
model.generation_config.temperature=None
|
||||
model.generation_config.top_p=None
|
||||
model.generation_config.top_k=None
|
||||
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.model_type = model_type
|
||||
|
||||
def generate(self, images, question) -> str:
|
||||
responses = generate(
|
||||
model_type=self.model_type,
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
images=images,
|
||||
question=question,
|
||||
)
|
||||
assert isinstance(responses, list), responses
|
||||
|
||||
out_text = responses[0]
|
||||
out_text = out_text.strip()
|
||||
|
||||
return out_text
|
169
src/m3docrag/vqa/florence2.py
Normal file
169
src/m3docrag/vqa/florence2.py
Normal file
@ -0,0 +1,169 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
from PIL import Image
|
||||
|
||||
from typing import Union, List
|
||||
|
||||
|
||||
def init(
|
||||
model_name_or_path,
|
||||
dtype=torch.bfloat16,
|
||||
# model_max_length=None,
|
||||
**kwargs,
|
||||
):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
torch_dtype=dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True
|
||||
)
|
||||
model.eval()
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
# model_max_length=model_max_length
|
||||
)
|
||||
|
||||
return {
|
||||
'model': model,
|
||||
'processor': processor
|
||||
}
|
||||
|
||||
|
||||
def generate(
|
||||
model,
|
||||
processor,
|
||||
question,
|
||||
images
|
||||
) -> List[str]:
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
module = model.module
|
||||
else:
|
||||
module = model
|
||||
answer = generate_caption(module, processor=processor, images=images, text_input=question)
|
||||
return answer
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_caption(
|
||||
model,
|
||||
processor,
|
||||
images=None,
|
||||
# task_prompt="<MORE_DETAILED_CAPTION>",
|
||||
text_input=None,
|
||||
input_ids=None,
|
||||
pixel_values=None,
|
||||
max_new_tokens=77,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
decode_text=True,
|
||||
**generate_kwargs
|
||||
):
|
||||
|
||||
if input_ids is None and pixel_values is None:
|
||||
if isinstance(images, Image.Image):
|
||||
images = [images]
|
||||
|
||||
B = len(images)
|
||||
|
||||
if isinstance(text_input, str):
|
||||
text_input = [text_input] * B
|
||||
|
||||
inputs = processor(
|
||||
text=text_input,
|
||||
images=images,
|
||||
return_tensors="pt")
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
p = next(iter(model.parameters()))
|
||||
device = p.device
|
||||
dtype = p.dtype
|
||||
|
||||
generated_ids = model.generate(
|
||||
input_ids=input_ids.to(device),
|
||||
pixel_values=pixel_values.to(device, dtype),
|
||||
max_new_tokens=max_new_tokens,
|
||||
num_beams=num_beams,
|
||||
do_sample=do_sample,
|
||||
**generate_kwargs
|
||||
)
|
||||
if decode_text:
|
||||
out_text = decode_predictions(processor, generated_ids)
|
||||
|
||||
return out_text
|
||||
else:
|
||||
return generated_ids
|
||||
|
||||
def decode_predictions(processor, generated_ids):
|
||||
B = len(generated_ids)
|
||||
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
out_text = []
|
||||
for i in range(B):
|
||||
out_text.append(generated_text[i].replace('<s>', '').replace('</s>', '').replace('<pad>', '').strip())
|
||||
return out_text
|
||||
|
||||
# def load_model_from_ckpt(
|
||||
# ckpt_dir,
|
||||
# load_abstractor=False,
|
||||
# num_input_tokens=576,
|
||||
# num_query_tokens=64,
|
||||
# proj_type='c-abs',
|
||||
# projection_dim=1024,
|
||||
# strict=False
|
||||
# ):
|
||||
# florence_config = Florence2Config.from_pretrained(ckpt_dir)
|
||||
|
||||
# if load_abstractor:
|
||||
# abstractor_config = AbstractorConfig(
|
||||
# num_input_tokens=num_input_tokens,
|
||||
# num_query_tokens=num_query_tokens,
|
||||
# proj_type=proj_type,
|
||||
# projection_dim=projection_dim,
|
||||
# )
|
||||
# florence_config.abstractor_config = abstractor_config
|
||||
|
||||
# florence_config.vision_config.model_type = 'davit'
|
||||
|
||||
# model = Florence2ForConditionalGeneration(config=florence_config)
|
||||
|
||||
|
||||
# ckpt_path = Path(ckpt_dir) / 'model.safetensors'
|
||||
# if ckpt_path.exists():
|
||||
# logger.info(f"loading checkpoints from {ckpt_path}")
|
||||
# state_dict = safetensors.torch.load_file(ckpt_path, device="cpu")
|
||||
|
||||
# else:
|
||||
# ckpt_path = Path(ckpt_dir) / 'pytorch_model.bin'
|
||||
# logger.info(f"loading checkpoints from {ckpt_path}")
|
||||
# state_dict = torch.load(
|
||||
# ckpt_path,
|
||||
# map_location="cpu",
|
||||
# )
|
||||
|
||||
# load_result = model.load_state_dict(state_dict, strict=strict)
|
||||
# logger.info(load_result)
|
||||
|
||||
# return model
|
||||
|
137
src/m3docrag/vqa/idefics2.py
Normal file
137
src/m3docrag/vqa/idefics2.py
Normal file
@ -0,0 +1,137 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import Idefics2ForConditionalGeneration
|
||||
from transformers import Idefics2Processor
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
from typing import Union, List
|
||||
|
||||
def init(
|
||||
model_name_or_path,
|
||||
do_image_splitting=True,
|
||||
bits=4,
|
||||
dtype=torch.bfloat16,
|
||||
**kwargs,
|
||||
):
|
||||
if bits == 4:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_compute_dtype=dtype
|
||||
)
|
||||
else:
|
||||
bnb_config = None
|
||||
model = Idefics2ForConditionalGeneration.from_pretrained(
|
||||
model_name_or_path,
|
||||
torch_dtype=dtype,
|
||||
quantization_config=bnb_config,
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
model.eval()
|
||||
|
||||
processor = Idefics2Processor.from_pretrained(
|
||||
model_name_or_path,
|
||||
do_image_splitting=do_image_splitting,
|
||||
# size= {"longest_edge": image_size, "shortest_edge": 378}
|
||||
)
|
||||
|
||||
return {
|
||||
'model': model,
|
||||
'processor': processor
|
||||
}
|
||||
|
||||
def generate(
|
||||
model,
|
||||
processor,
|
||||
question,
|
||||
images
|
||||
) -> List[str]:
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
module = model.module
|
||||
else:
|
||||
module = model
|
||||
|
||||
messages = idefics2_create_message(images=images, question=question)
|
||||
|
||||
examples = [{'images': images, 'messages': messages}]
|
||||
batch = idefics2_collate_fn(examples, processor)
|
||||
|
||||
for k in batch:
|
||||
batch[k] = batch[k].to(module.device)
|
||||
|
||||
generated_ids = module.generate(**batch, max_new_tokens=50, do_sample=False)
|
||||
answer = processor.batch_decode(
|
||||
generated_ids[:, batch.input_ids.size(1):],
|
||||
skip_special_tokens=True)
|
||||
return answer
|
||||
|
||||
|
||||
|
||||
|
||||
def idefics2_collate_fn(examples,
|
||||
processor,
|
||||
model_max_length=3600,
|
||||
is_train=False,
|
||||
image_token_id=None,
|
||||
):
|
||||
|
||||
if image_token_id is None:
|
||||
image_token_id = processor.tokenizer.additional_special_tokens_ids[processor.tokenizer.additional_special_tokens.index("<image>")]
|
||||
|
||||
texts = []
|
||||
images = []
|
||||
for example in examples:
|
||||
prompt = processor.apply_chat_template(example['messages'], add_generation_prompt=not is_train)
|
||||
texts.append(prompt)
|
||||
images.append(example['images'])
|
||||
|
||||
if is_train:
|
||||
batch = processor(text=texts, images=images, return_tensors="pt",
|
||||
padding=True, truncation=True,
|
||||
max_length=model_max_length)
|
||||
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == image_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
else:
|
||||
batch = processor(text=texts, images=images, return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=model_max_length
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
|
||||
def idefics2_create_message(images, question, is_train=False, target_text=None):
|
||||
content = []
|
||||
for page_i in range(len(images)):
|
||||
# content += [{"type": "text", "text": f"page {page_i}: "}]
|
||||
content += [{"type": "image"}]
|
||||
content += [{"type": "text", "text": question}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
if is_train:
|
||||
messages += [{"role": "assistant", "content": [{"type": "text", "text": target_text}]}]
|
||||
|
||||
return messages
|
129
src/m3docrag/vqa/idefics3.py
Normal file
129
src/m3docrag/vqa/idefics3.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AutoProcessor, AutoModelForVision2Seq
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
from typing import Union, List
|
||||
|
||||
def init(
|
||||
model_name_or_path,
|
||||
bits=4,
|
||||
dtype=torch.bfloat16,
|
||||
**kwargs,
|
||||
):
|
||||
if bits == 4:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_compute_dtype=dtype
|
||||
)
|
||||
else:
|
||||
bnb_config = None
|
||||
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_name_or_path,
|
||||
torch_dtype=dtype,
|
||||
quantization_config=bnb_config,
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
model.eval()
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_name_or_path,
|
||||
)
|
||||
|
||||
return {
|
||||
'model': model,
|
||||
'processor': processor
|
||||
}
|
||||
|
||||
def generate(
|
||||
model,
|
||||
processor,
|
||||
question,
|
||||
images
|
||||
) -> List[str]:
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
module = model.module
|
||||
else:
|
||||
module = model
|
||||
|
||||
messages = idefics3_create_message(images=images, question=question)
|
||||
|
||||
examples = [{'images': images, 'messages': messages}]
|
||||
batch = idefics3_collate_fn(examples, processor)
|
||||
|
||||
for k in batch:
|
||||
batch[k] = batch[k].to(module.device)
|
||||
|
||||
generated_ids = module.generate(**batch, max_new_tokens=50, do_sample=False)
|
||||
answer = processor.batch_decode(
|
||||
generated_ids[:, batch.input_ids.size(1):],
|
||||
skip_special_tokens=True)
|
||||
return answer
|
||||
|
||||
|
||||
def idefics3_collate_fn(examples,
|
||||
processor,
|
||||
is_train=False,
|
||||
image_token_id=None,
|
||||
):
|
||||
|
||||
texts = []
|
||||
images = []
|
||||
for example in examples:
|
||||
prompt = processor.apply_chat_template(example['messages'], add_generation_prompt=not is_train)
|
||||
texts.append(prompt)
|
||||
images.append(example['images'])
|
||||
|
||||
if is_train:
|
||||
batch = processor(text=texts, images=images, return_tensors="pt",
|
||||
padding=True, truncation=True,
|
||||
# max_length=model_max_length
|
||||
)
|
||||
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == image_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
else:
|
||||
batch = processor(text=texts, images=images, return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
# max_length=model_max_length
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
|
||||
def idefics3_create_message(images, question, is_train=False, target_text=None):
|
||||
content = []
|
||||
for page_i in range(len(images)):
|
||||
# content += [{"type": "text", "text": f"page {page_i}: "}]
|
||||
content += [{"type": "image"}]
|
||||
content += [{"type": "text", "text": question}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
if is_train:
|
||||
messages += [{"role": "assistant", "content": [{"type": "text", "text": target_text}]}]
|
||||
|
||||
return messages
|
288
src/m3docrag/vqa/internvl2.py
Normal file
288
src/m3docrag/vqa/internvl2.py
Normal file
@ -0,0 +1,288 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from pathlib import Path
|
||||
# from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from typing import Union, List
|
||||
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
def build_transform(input_size):
|
||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||
transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD)
|
||||
])
|
||||
return transform
|
||||
|
||||
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
||||
best_ratio_diff = float('inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
||||
orig_width, orig_height = image.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set(
|
||||
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
||||
i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size
|
||||
)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images
|
||||
|
||||
def load_image(image_file, input_size=448, max_num=12):
|
||||
|
||||
if isinstance(image_file, (str, Path)):
|
||||
image = Image.open(image_file).convert('RGB')
|
||||
elif isinstance(image_file, Image.Image):
|
||||
image = image_file
|
||||
|
||||
transform = build_transform(input_size=input_size)
|
||||
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
return pixel_values
|
||||
|
||||
# def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
||||
# if bound:
|
||||
# start, end = bound[0], bound[1]
|
||||
# else:
|
||||
# start, end = -100000, 100000
|
||||
# start_idx = max(first_idx, round(start * fps))
|
||||
# end_idx = min(round(end * fps), max_frame)
|
||||
# seg_size = float(end_idx - start_idx) / num_segments
|
||||
# frame_indices = np.array([
|
||||
# int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
||||
# for idx in range(num_segments)
|
||||
# ])
|
||||
# return frame_indices
|
||||
|
||||
# def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
||||
# vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||
# max_frame = len(vr) - 1
|
||||
# fps = float(vr.get_avg_fps())
|
||||
|
||||
# pixel_values_list, num_patches_list = [], []
|
||||
# transform = build_transform(input_size=input_size)
|
||||
# frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
|
||||
# for frame_index in frame_indices:
|
||||
# img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
|
||||
# img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
||||
# pixel_values = [transform(tile) for tile in img]
|
||||
# pixel_values = torch.stack(pixel_values)
|
||||
# num_patches_list.append(pixel_values.shape[0])
|
||||
# pixel_values_list.append(pixel_values)
|
||||
# pixel_values = torch.cat(pixel_values_list)
|
||||
# return pixel_values, num_patches_list
|
||||
|
||||
|
||||
def init(
|
||||
model_name_or_path,
|
||||
dtype=torch.bfloat16,
|
||||
use_flash_attn=True,
|
||||
**kwargs,
|
||||
):
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
torch_dtype=dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
model.eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=False)
|
||||
|
||||
return {
|
||||
'model': model,
|
||||
# 'tokenizer': tokenizer
|
||||
'processor': tokenizer
|
||||
}
|
||||
|
||||
def generate(
|
||||
model,
|
||||
# tokenizer,
|
||||
processor,
|
||||
question,
|
||||
images
|
||||
) -> List[str]:
|
||||
tokenizer = processor
|
||||
|
||||
if len(images) == 1:
|
||||
pixel_values = load_image(
|
||||
images[0],
|
||||
max_num=12).to(torch.bfloat16).to(model.device)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
generation_config = dict(max_new_tokens=1024, do_sample=False)
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
module = model.module
|
||||
else:
|
||||
module = model
|
||||
|
||||
answer = module.chat(tokenizer, pixel_values, question, generation_config)
|
||||
return [answer]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# If you want to load a model using multiple GPUs, please refer to the `Multiple GPUs` section.
|
||||
path = 'OpenGVLab/InternVL2-8B'
|
||||
model = AutoModel.from_pretrained(
|
||||
path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True).eval().cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
|
||||
|
||||
# set the max number of tiles in `max_num`
|
||||
pixel_values = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
|
||||
generation_config = dict(max_new_tokens=1024, do_sample=False)
|
||||
|
||||
# pure-text conversation (纯文本对话)
|
||||
question = 'Hello, who are you?'
|
||||
response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
|
||||
print(f'User: {question}\nAssistant: {response}')
|
||||
|
||||
question = 'Can you tell me a story?'
|
||||
response, history = model.chat(tokenizer, None, question, generation_config, history=history, return_history=True)
|
||||
print(f'User: {question}\nAssistant: {response}')
|
||||
|
||||
# single-image single-round conversation (单图单轮对话)
|
||||
question = '<image>\nPlease describe the image shortly.'
|
||||
response = model.chat(tokenizer, pixel_values, question, generation_config)
|
||||
print(f'User: {question}\nAssistant: {response}')
|
||||
|
||||
# single-image multi-round conversation (单图多轮对话)
|
||||
question = '<image>\nPlease describe the image in detail.'
|
||||
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
|
||||
print(f'User: {question}\nAssistant: {response}')
|
||||
|
||||
question = 'Please write a poem according to the image.'
|
||||
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=history, return_history=True)
|
||||
print(f'User: {question}\nAssistant: {response}')
|
||||
|
||||
# multi-image multi-round conversation, combined images (多图多轮对话,拼接图像)
|
||||
pixel_values1 = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
|
||||
pixel_values2 = load_image('./examples/image2.jpg', max_num=12).to(torch.bfloat16).cuda()
|
||||
pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)
|
||||
|
||||
question = '<image>\nDescribe the two images in detail.'
|
||||
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
|
||||
history=None, return_history=True)
|
||||
print(f'User: {question}\nAssistant: {response}')
|
||||
|
||||
question = 'What are the similarities and differences between these two images.'
|
||||
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
|
||||
history=history, return_history=True)
|
||||
print(f'User: {question}\nAssistant: {response}')
|
||||
|
||||
# multi-image multi-round conversation, separate images (多图多轮对话,独立图像)
|
||||
pixel_values1 = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
|
||||
pixel_values2 = load_image('./examples/image2.jpg', max_num=12).to(torch.bfloat16).cuda()
|
||||
pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)
|
||||
num_patches_list = [pixel_values1.size(0), pixel_values2.size(0)]
|
||||
|
||||
question = 'Image-1: <image>\nImage-2: <image>\nDescribe the two images in detail.'
|
||||
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
|
||||
num_patches_list=num_patches_list,
|
||||
history=None, return_history=True)
|
||||
print(f'User: {question}\nAssistant: {response}')
|
||||
|
||||
question = 'What are the similarities and differences between these two images.'
|
||||
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
|
||||
num_patches_list=num_patches_list,
|
||||
history=history, return_history=True)
|
||||
print(f'User: {question}\nAssistant: {response}')
|
||||
|
||||
# batch inference, single image per sample (单图批处理)
|
||||
pixel_values1 = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
|
||||
pixel_values2 = load_image('./examples/image2.jpg', max_num=12).to(torch.bfloat16).cuda()
|
||||
num_patches_list = [pixel_values1.size(0), pixel_values2.size(0)]
|
||||
pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)
|
||||
|
||||
questions = ['<image>\nDescribe the image in detail.'] * len(num_patches_list)
|
||||
responses = model.batch_chat(tokenizer, pixel_values,
|
||||
num_patches_list=num_patches_list,
|
||||
questions=questions,
|
||||
generation_config=generation_config)
|
||||
for question, response in zip(questions, responses):
|
||||
print(f'User: {question}\nAssistant: {response}')
|
||||
|
||||
# video multi-round conversation (视频多轮对话)
|
||||
|
||||
|
||||
video_path = './examples/red-panda.mp4'
|
||||
pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1)
|
||||
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
||||
video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(num_patches_list))])
|
||||
question = video_prefix + 'What is the red panda doing?'
|
||||
# Frame1: <image>\nFrame2: <image>\n...\nFrame8: <image>\n{question}
|
||||
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
|
||||
num_patches_list=num_patches_list, history=None, return_history=True)
|
||||
print(f'User: {question}\nAssistant: {response}')
|
||||
|
||||
question = 'Describe this video in detail. Don\'t repeat.'
|
||||
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
|
||||
num_patches_list=num_patches_list, history=history, return_history=True)
|
||||
print(f'User: {question}\nAssistant: {response}')
|
100
src/m3docrag/vqa/qwen2.py
Normal file
100
src/m3docrag/vqa/qwen2.py
Normal file
@ -0,0 +1,100 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
from torchvision import io
|
||||
from typing import Dict, List
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor, BitsAndBytesConfig
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
|
||||
def init(
|
||||
model_name_or_path,
|
||||
dtype=torch.bfloat16,
|
||||
bits=16,
|
||||
attn_implementation="flash_attention_2",
|
||||
**kwargs,
|
||||
):
|
||||
if bits == 4:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_compute_dtype=dtype
|
||||
)
|
||||
else:
|
||||
bnb_config = None
|
||||
# Load the model in half-precision on the available device(s)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name_or_path,
|
||||
torch_dtype=dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
attn_implementation=attn_implementation,
|
||||
quantization_config=bnb_config,
|
||||
vision_config={"torch_dtype": dtype}
|
||||
)
|
||||
model.eval()
|
||||
processor = AutoProcessor.from_pretrained(model_name_or_path)
|
||||
|
||||
return {
|
||||
'model': model,
|
||||
'processor': processor
|
||||
}
|
||||
|
||||
def generate(
|
||||
model,
|
||||
processor,
|
||||
question,
|
||||
images
|
||||
) -> List[str]:
|
||||
|
||||
image_content = [{"type": "image", "image": "dummy_content"}] * len(images)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": image_content + [{"type": "text", "text": question}]
|
||||
}
|
||||
]
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
# image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=images,
|
||||
# videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
p = next(iter(model.parameters()))
|
||||
|
||||
inputs = inputs.to(p.device)
|
||||
|
||||
# Inference
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
assert isinstance(output_text, list), output_text
|
||||
|
||||
return output_text
|
Reference in New Issue
Block a user