Release commit

Signed-off-by: Stephen Augustus <saugustus2@bloomberg.net>
This commit is contained in:
j-min
2025-01-30 17:04:56 -05:00
committed by oir
parent e04aeadfb0
commit 27aac8d521
50 changed files with 5692 additions and 0 deletions

View File

@ -0,0 +1,215 @@
Metadata-Version: 2.2
Name: m3docrag
Version: 0.0.1
Summary: Multimodal Document Understanding with RAG
Classifier: Programming Language :: Python :: 3
Requires-Python: >=3.10
Description-Content-Type: text/markdown
Requires-Dist: accelerate==1.1.0
Requires-Dist: loguru
Requires-Dist: requests
Requires-Dist: setuptools==69.5
Requires-Dist: transformers
Requires-Dist: tokenizers
Requires-Dist: flash-attn==2.5.8
Requires-Dist: bitsandbytes==0.43.1
Requires-Dist: safetensors
Requires-Dist: gpustat
Requires-Dist: icecream
Requires-Dist: pdf2image
Requires-Dist: numpy==1.26.4
Requires-Dist: torchvision
Requires-Dist: jsonlines
Requires-Dist: editdistance
Requires-Dist: einops
Requires-Dist: fire
Requires-Dist: peft
Requires-Dist: timm
Requires-Dist: sentencepiece
Requires-Dist: colpali-engine==0.3.1
Requires-Dist: easyocr
Requires-Dist: qwen-vl-utils
Requires-Dist: faiss-cpu
Requires-Dist: word2number
Requires-Dist: datasets>=3.0.0
Requires-Dist: python-dotenv
# M3DocRAG
Code for [M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding](https://m3docrag.github.io/)
by [Jaemin Cho](https://j-min.io/), [Debanjan Mahata](https://sites.google.com/a/ualr.edu/debanjan-mahata/), [Ozan İrsoy](https://wtimesx.com/), [Yujie He](https://scholar.google.com/citations?user=FbeAZGgAAAAJ&hl=en), [Mohit Bansal](https://www.cs.unc.edu/~mbansal/)
# Summary
## Comparison with previous approches
<img src='./assets/m3docrag_teaser.png' >
Comparison of multi-modal document understanding pipelines. Previous works focus on (a) **Single-page DocVQA** that cannot handle many long documents or (b) **Text-based RAG** that ignores visual information. Our (c) **M3DocRAG** framework retrieves relevant documents and answers questions using multi-modal retrieval and MLM components, so that it can efficiently handle many long documents while preserving visual information.
## M3DocRAG framework
<img src='./assets/method.png' >
Our **M3DocRAG** framework consists of three stages: (1) document embedding, (2) page retrieval, and (3) question answering.
- In (1) document embedding, we extract visual embedding (with ColPali) to represent each page from all PDF documents.
- In (2) page retrieval, we retrieve the top-K pages of high relevance (MaxSim scores) with text queries. In an open-domain setting, we create approximate page indices for faster search.
- In (3) question answering, we conduct visual question answering with multi-modal LM (e.g. Qwen2-VL) to obtain the final answer.
# Setup
## Package
We assume conda has been installed
```bash
git clone <REPO_URL>
cd m3docrag-release
pip install -e .
# Install Poppler (for pdf2image; check https://pdf2image.readthedocs.io/en/latest/installation.html for details)
# conda install -y poppler
# or
# apt-get install poppler-utils
```
## Code structure
```bash
examples/ # scripts to run PDF embedding / RAG
src/m3docrag/
datasets/ # data loader for existing datasets
retrieval/ # retrieval model (e.g., ColPaLi)
vqa/ # vqa model (e.g., Qwen2-VL)
rag/ # RAG model that combines retrieval and vqa models
utils/ # misc utility methods
m3docvqa/ # how to setup m3docvqa dataset
```
## Paths: Data, Embeddings, Model checkpoints, Outputs
```bash
# in .env
LOCAL_DATA_DIR="/job/datasets" # where to store data
LOCAL_EMBEDDINGS_DIR="/job/embeddings" # where to store embeddings
LOCAL_MODEL_DIR="/job/model" # where to store model checkpoints
LOCAL_OUTPUT_DIR="/job/output" # where to store model outputs
```
You can adjust variables in [`.env`](.env) to change where to store data/embedding/model checkpoint/outputs by default. They are loaded in [`src/m3docrag/utils/paths.py`](./src/m3docrag/utils/paths.py) via [python-dotenv](https://github.com/theskumar/python-dotenv).
## Download M3DocVQA dataset
Please see [m3docvqa/README.md](m3docvqa/README.md) for the download instruction.
## Donwload model checkpoints
By default, we use colpali-v1.2 for retrival and Qwen2-VL-7B-Instruct for question answering.
At `$LOCAL_MODEL_DIR`, download [colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2), [colpaligemma-3b-mix-448-base](https://huggingface.co/vidore/colpaligemma-3b-mix-448-base) and [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) checkpoints.
```bash
cd $LOCAL_MODEL_DIR
git clone https://huggingface.co/vidore/colpaligemma-3b-pt-448-base # ColPali backbone
git clone https://huggingface.co/vidore/colpali-v1.2 # ColPali adapter
git clone https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct # VQA
```
# Example usage
Below we describe example usage of M3DocRAG on M3DocVQA dataset.
## 1. Extract PDF embeddings
```bash
DATASET_NAME="m3-docvqa"
RETRIEVAL_MODEL_TYPE="colpali"
RETRIEVAL_MODEL_NAME="colpaligemma-3b-pt-448-base"
RETRIEVAL_ADAPTER_MODEL_NAME="colpali-v1.2"
SPLIT="dev"
EMBEDDING_NAME=$RETRIEVAL_ADAPTER_MODEL_NAME"_"$DATASET_NAME"_"$SPLIT # where to save embeddings
accelerate launch --num_processes=1 --mixed_precision=bf16 examples/run_page_embedding.py \
--use_retrieval \
--retrieval_model_type=$RETRIEVAL_MODEL_TYPE \
--data_name=$DATASET_NAME \
--split=$SPLIT \
--loop_unique_doc_ids=True \
--output_dir=/job/embeddings/$EMBEDDING_NAME \
--retrieval_model_name_or_path=$RETRIEVAL_MODEL_NAME \
--retrieval_adapter_model_name_or_path=$RETRIEVAL_ADAPTER_MODEL_NAME
```
## 2. Indexing
```bash
DATASET_NAME="m3-docvqa"
RETRIEVAL_MODEL_TYPE="colpali"
RETRIEVAL_ADAPTER_MODEL_NAME="colpali-v1.2"
SPLIT="dev"
FAISS_INDEX_TYPE='ivfflat'
EMBEDDING_NAME=$RETRIEVAL_ADAPTER_MODEL_NAME"_"$DATASET_NAME"_"$SPLIT
INDEX_NAME=$EMBEDDING_NAME"_pageindex_"$FAISS_INDEX_TYPE # where to save resulting index
echo $EMBEDDING_NAME
echo $FAISS_INDEX_TYPE
python examples/run_indexing_m3docvqa.py \
--use_retrieval \
--retrieval_model_type=$RETRIEVAL_MODEL_TYPE \
--data_name=$DATASET_NAME \
--split=$SPLIT \
--loop_unique_doc_ids=False \
--embedding_name=$EMBEDDING_NAME \
--faiss_index_type=$FAISS_INDEX_TYPE \
--output_dir=/job/embeddings/$INDEX_NAME
```
## 3. RAG
```bash
BACKBONE_MODEL_NAME="Qwen2-VL-7B-Instruct"
RETRIEVAL_MODEL_TYPE="colpali"
RETRIEVAL_MODEL_NAME="colpaligemma-3b-pt-448-base"
RETRIEVAL_ADAPTER_MODEL_NAME="colpali-v1.2"
EMBEDDING_NAME="colpali-v1.2_m3-docvqa_dev" # from Step 1 Embedding
SPLIT="dev"
DATASET_NAME="m3-docvqa"
FAISS_INDEX_TYPE='ivfflat'
N_RETRIEVAL_PAGES=1
INDEX_NAME="${EMBEDDING_NAME}_pageindex_$FAISS_INDEX_TYPE" # from Step 2 Indexing
OUTPUT_SAVE_NAME="${RETRIEVAL_ADAPTER_MODEL_NAME}_${BACKBONE_MODEL_NAME}_${DATASET_NAME}" # where to save RAG results
BITS=16 # BITS=4 for 4-bit qunaitzation in low memory GPUs
python examples/run_rag_m3docvqa.py \
--use_retrieval \
--retrieval_model_type=$RETRIEVAL_MODEL_TYPE \
--load_embedding=True \
--split=$SPLIT \
--bits=$BITS \
--n_retrieval_pages=$N_RETRIEVAL_PAGES \
--data_name=$DATASET_NAME \
--model_name_or_path=$BACKBONE_MODEL_NAME \
--embedding_name=$EMBEDDING_NAME \
--retrieval_model_name_or_path=$RETRIEVAL_MODEL_NAME \
--retrieval_adapter_model_name_or_path=$RETRIEVAL_ADAPTER_MODEL_NAME \
--output_dir=/job/eval_outputs/$OUTPUT_SAVE_NAME
```
# Citation
Please cite our paper if you use our dataset and/or method in your projects.
```bibtex
@article{Cho2024M3DocRAG,
author = {Jaemin Cho and Ozan İrsoy and Debanjan Mahata and Yujie He and Mohit Bansal},
title = {M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding},
year = {2024},
}
```

View File

@ -0,0 +1,31 @@
README.md
pyproject.toml
src/m3docrag/__init__.py
src/m3docrag.egg-info/PKG-INFO
src/m3docrag.egg-info/SOURCES.txt
src/m3docrag.egg-info/dependency_links.txt
src/m3docrag.egg-info/requires.txt
src/m3docrag.egg-info/top_level.txt
src/m3docrag/datasets/__init__.py
src/m3docrag/datasets/m3_docvqa/__init__.py
src/m3docrag/datasets/m3_docvqa/common_utils.py
src/m3docrag/datasets/m3_docvqa/dataset.py
src/m3docrag/datasets/m3_docvqa/evaluate.py
src/m3docrag/rag/__init__.py
src/m3docrag/rag/base.py
src/m3docrag/rag/multimodal.py
src/m3docrag/rag/utils.py
src/m3docrag/retrieval/__init__.py
src/m3docrag/retrieval/colpali.py
src/m3docrag/utils/args.py
src/m3docrag/utils/distributed.py
src/m3docrag/utils/paths.py
src/m3docrag/utils/pdfs.py
src/m3docrag/utils/prompts.py
src/m3docrag/utils/tar.py
src/m3docrag/vqa/__init__.py
src/m3docrag/vqa/florence2.py
src/m3docrag/vqa/idefics2.py
src/m3docrag/vqa/idefics3.py
src/m3docrag/vqa/internvl2.py
src/m3docrag/vqa/qwen2.py

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,28 @@
accelerate==1.1.0
loguru
requests
setuptools==69.5
transformers
tokenizers
flash-attn==2.5.8
bitsandbytes==0.43.1
safetensors
gpustat
icecream
pdf2image
numpy==1.26.4
torchvision
jsonlines
editdistance
einops
fire
peft
timm
sentencepiece
colpali-engine==0.3.1
easyocr
qwen-vl-utils
faiss-cpu
word2number
datasets>=3.0.0
python-dotenv

View File

@ -0,0 +1 @@
m3docrag

16
src/m3docrag/__init__.py Normal file
View File

@ -0,0 +1,16 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

View File

@ -0,0 +1,16 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

View File

@ -0,0 +1,18 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from .dataset import M3DocVQADataset
from .evaluate import evaluate_predictions, evaluate_prediction_file

View File

@ -0,0 +1,142 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# https://github.com/allenai/multimodalqa/blob/master/baselines/common_utils.py
import json
ALL_QUESTION_TYPES = [
'TextQ',
'TableQ',
'ImageQ',
'ImageListQ',
'Compose(TableQ,ImageListQ)',
'Compose(TextQ,ImageListQ)',
'Compose(ImageQ,TableQ)',
'Compose(ImageQ,TextQ)',
'Compose(TextQ,TableQ)',
'Compose(TableQ,TextQ)',
'Intersect(TableQ,TextQ)',
'Intersect(ImageListQ,TableQ)',
'Intersect(ImageListQ,TextQ)',
'Compare(Compose(TableQ,ImageQ),TableQ)',
'Compare(Compose(TableQ,ImageQ),Compose(TableQ,TextQ))',
'Compare(TableQ,Compose(TableQ,TextQ))',
]
TEXT_SINGLE_HOP_QUESTION_TYPES = [
'TextQ',
]
TEXT_AS_FIRST_HOP_QUESTION_TYPES = [
'Compare(TableQ,Compose(TableQ,TextQ))',
'Compose(ImageQ,TextQ)',
'Compose(TableQ,TextQ)',
'Intersect(TableQ,TextQ)',
'Intersect(ImageListQ,TextQ)',
]
TEXT_AS_SECOND_HOP_QUESTION_TYPES = [
'Compare(Compose(TableQ,ImageQ),Compose(TableQ,TextQ))',
'Compose(TextQ,ImageListQ)',
'Compose(TextQ,TableQ)',
]
TABLE_SINGLE_HOP_QUESTION_TYPES = [
"TableQ"
]
TABLE_AS_FIRST_HOP_QUESTION_TYPES = [
'Compose(ImageQ,TableQ)',
'Compose(TextQ,TableQ)',
]
TABLE_AS_SECOND_HOP_QUESTION_TYPES = [
'Compare(Compose(TableQ,ImageQ),TableQ)',
'Compare(TableQ,Compose(TableQ,TextQ))',
'Compose(TableQ,ImageListQ)',
'Compose(TableQ,TextQ)',
'Intersect(ImageListQ,TableQ)',
'Intersect(TableQ,TextQ)',
]
IMAGE_SINGLE_HOP_QUESTION_TYPES = [
'ImageQ',
'ImageListQ'
]
IMAGE_AS_FIRST_HOP_QUESTION_TYPES = [
'Compare(Compose(TableQ,ImageQ),Compose(TableQ,TextQ))',
'Compare(Compose(TableQ,ImageQ),TableQ)',
'Compose(TableQ,ImageListQ)',
'Compose(TextQ,ImageListQ)',
'Intersect(ImageListQ,TableQ)',
]
IMAGE_AS_SECOND_HOP_QUESTION_TYPES = [
'Compose(ImageQ,TableQ)',
'Compose(ImageQ,TextQ)',
'Intersect(ImageListQ,TextQ)',
]
# every question should be answered either as a single hop question, or two-hop question
assert set(TEXT_SINGLE_HOP_QUESTION_TYPES + TEXT_AS_SECOND_HOP_QUESTION_TYPES
+ TABLE_SINGLE_HOP_QUESTION_TYPES + TABLE_AS_SECOND_HOP_QUESTION_TYPES
+ IMAGE_SINGLE_HOP_QUESTION_TYPES + IMAGE_AS_SECOND_HOP_QUESTION_TYPES) == set(ALL_QUESTION_TYPES)
assert len(set(TEXT_SINGLE_HOP_QUESTION_TYPES) & set(TEXT_AS_SECOND_HOP_QUESTION_TYPES)) == 0
assert len(set(TABLE_SINGLE_HOP_QUESTION_TYPES) & set(TABLE_AS_SECOND_HOP_QUESTION_TYPES)) == 0
assert len(set(IMAGE_SINGLE_HOP_QUESTION_TYPES) & set(IMAGE_AS_SECOND_HOP_QUESTION_TYPES)) == 0
SINGLE_HOP_QUESTION_TYPES = TEXT_SINGLE_HOP_QUESTION_TYPES \
+ TABLE_SINGLE_HOP_QUESTION_TYPES \
+ IMAGE_SINGLE_HOP_QUESTION_TYPES
MULTI_HOP_QUESTION_TYPES = TEXT_AS_SECOND_HOP_QUESTION_TYPES \
+ TABLE_AS_SECOND_HOP_QUESTION_TYPES + \
IMAGE_AS_SECOND_HOP_QUESTION_TYPES
# no duplicated multi-hop question types
assert len(MULTI_HOP_QUESTION_TYPES) == len(set(MULTI_HOP_QUESTION_TYPES))
# no duplication for the first hop
assert set(TEXT_AS_FIRST_HOP_QUESTION_TYPES + TABLE_AS_FIRST_HOP_QUESTION_TYPES + IMAGE_AS_FIRST_HOP_QUESTION_TYPES) \
== set(MULTI_HOP_QUESTION_TYPES)
# single + multi = all
assert set(SINGLE_HOP_QUESTION_TYPES + MULTI_HOP_QUESTION_TYPES) == set(ALL_QUESTION_TYPES)
def process_question_for_implicit_decomp(question, question_type, hop=0, bridge_entity='', sep_token='[SEP]'):
if isinstance(bridge_entity, list) or isinstance(bridge_entity, set):
bridge_entity = "; ".join(bridge_entity)
return (
f'{question_type} {sep_token} '
f'HOP={hop} {sep_token} '
f'{bridge_entity} {sep_token} '
f'{question}')
def extract_numbers_from_str(s):
numbers = []
for token in s.split():
try:
num = int(token.replace(",", ""))
except:
try:
num = float(token)
except:
num = None
if num:
numbers.append(num)
return numbers
def read_jsonl(filename):
with open(filename, 'r') as f:
data = [json.loads(l.strip()) for l in f.readlines()]
return data

View File

@ -0,0 +1,144 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from datasets import load_dataset
from pathlib import Path
import torch
import safetensors
import json
import jsonlines
from copy import deepcopy
from tqdm.auto import tqdm
import PIL
from typing import List
from loguru import logger
from m3docrag.utils.paths import LOCAL_DATA_DIR, LOCAL_EMBEDDINGS_DIR
from m3docrag.utils.pdfs import get_images_from_pdf
class M3DocVQADataset(torch.utils.data.Dataset):
def __init__(self, args):
self.args = args
# e.g., /job/datasets/m3-docvqa
local_data_dir = Path(LOCAL_DATA_DIR) / args.data_name
pdf_dir = local_data_dir / "splits" / f'pdfs_{args.split}'
assert pdf_dir.exists(), pdf_dir
self.pdf_dir = pdf_dir
multimodalqa_data_dir = local_data_dir / "multimodalqa"
mmqa_data_path = multimodalqa_data_dir / f"MMQA_{args.split}.jsonl"
self.mmqa_data_path = mmqa_data_path
data = []
with jsonlines.open(mmqa_data_path) as reader:
for i, obj in enumerate(reader):
data.append(obj)
logger.info(f"# Data {len(data)}")
self.data = data
split_supporting_doc_ids_path = local_data_dir / f"{args.split}_doc_ids.json"
with open(split_supporting_doc_ids_path, 'r') as f:
all_supporting_doc_ids = json.load(open(split_supporting_doc_ids_path))
# dev: 3366
# train: 24162
logger.info(f"# supporting doc ids in split {args.split}: {len(all_supporting_doc_ids)}")
self.all_supporting_doc_ids = all_supporting_doc_ids
def __len__(self):
if self.args.loop_unique_doc_ids:
return len(self.all_supporting_doc_ids)
if self.args.data_len is not None:
return self.args.data_len
return len(self.data)
def load_all_embeddings(self):
"""Load all doc embeddings in memory"""
emb_dir = Path(LOCAL_EMBEDDINGS_DIR) / self.args.embedding_name
logger.info(f"Loading all doc embeddings from {emb_dir}")
docid2embs = {}
docid2lens = {}
for idx in tqdm(range(len(self.all_supporting_doc_ids))):
doc_id = self.all_supporting_doc_ids[idx]
emb_path = Path(LOCAL_EMBEDDINGS_DIR) / self.args.embedding_name / f"{doc_id}.safetensors"
assert emb_path.exists(), emb_path
if self.args.retrieval_model_type == 'colpali':
with safetensors.safe_open(emb_path, framework="pt", device='cpu') as f:
# [n_pages, n_tokens, dim]
doc_embs = f.get_tensor('embeddings')
docid2embs[doc_id] = doc_embs.bfloat16()
if self.args.retrieval_model_type == 'colpali':
return docid2embs
elif self.args.retrieval_model_type == 'colbert':
return docid2embs, docid2lens
def get_images_from_doc_id(self, doc_id: str) -> List[PIL.Image.Image]:
pdf_path = self.pdf_dir / f"{doc_id}.pdf"
page_images = get_images_from_pdf(pdf_path)
return page_images
def __getitem__(self, idx):
if self.args.loop_unique_doc_ids:
doc_id = self.all_supporting_doc_ids[idx]
datum = {
'doc_id': doc_id,
}
if self.args.retrieval_model_type == 'colpali':
page_images = self.get_images_from_doc_id(doc_id)
datum['images'] = page_images
return datum
# keys(['qid', 'question', 'answers', 'metadata', 'supporting_context'])
datum = deepcopy(self.data[idx])
supporting_doc_ids = []
for obj in datum['supporting_context']:
supporting_doc_ids.append(obj['doc_id'])
datum['supporting_doc_ids'] = supporting_doc_ids
return datum
if __name__ == '__main__':
from m3docrag.utils.args import _example_args, parse_args
args = parse_args(_example_args)
dataset = M3DocVQADataset(
args=args
)

View File

@ -0,0 +1,414 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# https://github.com/allenai/multimodalqa/blob/master/baselines/evaluate.py
import json
import argparse
import re
import string
import numpy as np
from collections import Counter
from typing import List, Set, Tuple, Union
from scipy.optimize import linear_sum_assignment
from word2number.w2n import word_to_num
from .common_utils import *
# From here through _match_numbers_if_present was originally copied from the evaluation code of DROP dataset:
# https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
def _remove_articles(text: str) -> str:
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def _white_space_fix(text: str) -> str:
return " ".join(text.split())
EXCLUDE = set(string.punctuation)
def _remove_punc(text: str) -> str:
if not _is_number(text):
return "".join(ch for ch in text if ch not in EXCLUDE)
else:
return text
def _lower(text: str) -> str:
return text.lower()
def _tokenize(text: str) -> List[str]:
return re.split(" |-", text)
def _normalize_answer(text: str) -> str:
"""Lower text and remove punctuation, articles and extra whitespace."""
parts = [
_white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token)))))
for token in _tokenize(text)
]
parts = [part for part in parts if part.strip()]
normalized = " ".join(parts).strip()
return normalized
def _is_number(text: str) -> bool:
try:
float(text)
return True
except ValueError:
return False
def _is_word_number(text: str) -> bool:
try:
word_to_num(text)
return True
except ValueError:
return False
def _normalize_number(text: str) -> str:
if _is_number(text):
return str(float(text))
#TODO: this is not included in the original drop evaluation script, we need to have our own in the end anyways.
elif _is_word_number(text):
return str(float(word_to_num(text)))
else:
return text
def _answer_to_bags(
answer: Union[str, List[str], Tuple[str, ...]]
) -> Tuple[List[str], List[Set[str]]]:
if isinstance(answer, (list, tuple)):
raw_spans = answer
else:
raw_spans = [answer]
normalized_spans: List[str] = []
token_bags = []
for raw_span in raw_spans:
normalized_span = _normalize_answer(raw_span)
normalized_spans.append(normalized_span)
token_bags.append(set(normalized_span.split()))
return normalized_spans, token_bags
def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
"""
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
between them and gets maximum metric values over all the answers.
"""
scores = np.zeros([len(gold), len(predicted)])
for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted):
if _match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item)
row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold), len(predicted))])
for row, column in zip(row_ind, col_ind):
max_scores[row] = max(max_scores[row], scores[row, column])
return max_scores
def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
intersection = len(gold_bag.intersection(predicted_bag))
if not predicted_bag:
precision = 1.0
else:
precision = intersection / float(len(predicted_bag))
if not gold_bag:
recall = 1.0
else:
recall = intersection / float(len(gold_bag))
f1 = (
(2 * precision * recall) / (precision + recall)
if not (precision == 0.0 and recall == 0.0)
else 0.0
)
return f1
def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool:
gold_numbers = set()
predicted_numbers = set()
for word in gold_bag:
if _is_number(word):
gold_numbers.add(word)
for word in predicted_bag:
if _is_number(word):
predicted_numbers.add(word)
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
return True
return False
def list_em(predicted, gold):
predicted_bags = _answer_to_bags(predicted)
gold_bags = _answer_to_bags(gold)
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
return 1.0
else:
return 0.0
def list_f1(predicted, gold):
predicted_bags = _answer_to_bags(predicted)
gold_bags = _answer_to_bags(gold)
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
f1 = np.mean(f1_per_bag)
f1 = round(f1, 2)
return f1
def metric_max_over_ground_truths(metric_fn, prediction, gold_answers):
scores_for_ground_truths = []
for gold_answer in gold_answers:
score = metric_fn(prediction, gold_answer)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def evaluate_predictions(predictions, gold_answers, example_types=None):
"""To support multiple gold annotations, `gold_answers` should be a list,
with each item (either a string or a list) corresponding to one valid reference answer."""
instance_eval_results = {}
instance_eval_results_by_types = {}
eval_funcs = {
"list_em": list_em,
"list_f1": list_f1
}
for qas_id in gold_answers:
ref_answers = gold_answers[qas_id]
if qas_id not in predictions:
print(f"Missing prediction for question {qas_id}, and all scores for this question are set to zero")
instance_eval_results[qas_id] = {
metric: 0.0 for metric in eval_funcs.keys()
}
else:
pred_answer = predictions[qas_id]
instance_eval_results[qas_id] = {
metric: metric_max_over_ground_truths(
func, pred_answer, ref_answers
) for metric, func in eval_funcs.items()
}
if example_types is not None:
example_type = example_types[qas_id]
if example_type not in instance_eval_results_by_types:
instance_eval_results_by_types[example_type] = {}
instance_eval_results_by_types[example_type][qas_id] = instance_eval_results[qas_id]
eval_scores = {metric: np.mean([result[metric] for result in instance_eval_results.values()]) * 100
for metric in eval_funcs.keys()}
if example_types is not None:
eval_scores_by_types = {}
for example_type, type_instance_eval_results in instance_eval_results_by_types.items():
eval_scores_by_types[example_type] = {
metric: np.mean([result[metric] for result in type_instance_eval_results.values()]) * 100 for metric in eval_funcs.keys()
}
return eval_scores, instance_eval_results, eval_scores_by_types
else:
return eval_scores, instance_eval_results
def eval_retrieval(qid2retrieval_results, gold_examples, recall_levels=[1, 2, 4, 5, 10]):
"""
Calculate document recall metrics per query, including recall@k for various k values,
and compute the average recall.
Args:
qid2retrieval_results (dict): Dictionary mapping query IDs (qid) to retrieval results.
The retrieval results are a list of lists, where each sublist contains
[doc_id, page, score], e.g., [['53e1d168375850acbaf134da744ffae0', 3, 2.3689956665039062]].
gold_examples (list): List of ground truth examples with 'qid' and 'supporting_context'.
Each 'supporting_context' is a list of dictionaries,
e.g., [{'doc_id': '8513db80c11ea439ab11eba406ec00d9', 'doc_part': 'table'}].
recall_levels (list): List of k values for recall@k, e.g., [1, 5, 10].
Returns:
dict: Dictionary containing recall per query at different recall levels,
and the average recall for each level.
"""
# Create a mapping from query ID to the supporting context (gold standard)
qid2supporting_context = {
example["qid"]: example['supporting_context'] for example in gold_examples
}
# Initialize dictionaries to store recall at different levels
qid2recall_at_k = {k: {} for k in recall_levels}
avg_recall_at_k = {k: 0.0 for k in recall_levels}
# Loop through each query and its supporting context
for qid, supporting_context in qid2supporting_context.items():
# Get the retrieved documents for the current query
retrieved_results = qid2retrieval_results.get(qid, [])
# Extract the relevant doc_ids from supporting_context
relevant_doc_ids = {ctx['doc_id'] for ctx in supporting_context}
# Count total relevant documents for the query
n_relevant = len(relevant_doc_ids)
# For each recall level, calculate how many relevant documents were retrieved
for k in recall_levels:
# Get the top-k results (or fewer if there are fewer results)
top_k_results = retrieved_results[:k]
# Ensure doc_id uniqueness in top_k_results
top_k_doc_ids = set(result[0] for result in top_k_results)
# Count how many relevant documents are in the top-k results
n_relevant_retrieved_at_k = len(top_k_doc_ids.intersection(relevant_doc_ids))
# Calculate recall@k for the current query
recall_at_k = n_relevant_retrieved_at_k / n_relevant if n_relevant > 0 else 0.0
# Ensure recall is between 0 and 1
recall_at_k = min(recall_at_k, 1.0)
# Store recall@k for the current query
qid2recall_at_k[k][qid] = recall_at_k
# Calculate average recall@k across all queries
for k in recall_levels:
avg_recall_at_k[k] = sum(qid2recall_at_k[k].values()) / len(qid2recall_at_k[k]) if qid2recall_at_k[k] else 0.0
# Return recall@k for each query and average recall@k
return {
"recall_per_qid_at_k": qid2recall_at_k,
"average_recall_at_k": avg_recall_at_k
}
def evaluate_prediction_file(prediction_path, gold_path='/job/datasets/m3-docvqa/MMQA_dev.jsonl'):
# predicted_answers = json.load(open(prediction_path, encoding="utf-8"))
if isinstance(prediction_path, dict):
raw_prediction_json = prediction_path
else:
raw_prediction_json = json.load(open(prediction_path, encoding="utf-8"))
examples = read_jsonl(gold_path)
gold_answers, answer_modalities, hop_types, question_types = {}, {}, {}, {}
qid2predicted_answers = {}
qid2retrieval_results = {}
for qid, data in raw_prediction_json.items():
qid2predicted_answers[qid] = data['pred_answer'].strip()
qid2retrieval_results[qid] = data['page_retrieval_results']
eval_retrieval_results = eval_retrieval(qid2retrieval_results, examples)
print('Average recall at K')
print(eval_retrieval_results['average_recall_at_k'])
predicted_answers = qid2predicted_answers
all_scores = {}
all_scores['overall'] = {}
all_scores['modalities'] = {}
all_scores['hop_types'] = {}
all_scores['q_types'] = {}
all_scores['average_recall_at_k'] = eval_retrieval_results['average_recall_at_k']
for example in examples:
qid = example["qid"]
# Currently we only have one ground truth answer.
# Even if there are multiple entries in example["answers"], the whole list should be regarded as one ref answer.
# However, our script supports evaluation with multiple ref answers.
# So, we will use an outer bracket here to pretend we have a list of ref answers.
gold_answer = [str(item["answer"]) for item in example["answers"]]
gold_answers[qid] = [gold_answer]
answer_modality = set([item["modality"] for item in example["answers"]])
assert len(answer_modality) == 1
answer_modalities[qid] = answer_modality.pop()
question_types[qid] = example["metadata"]["type"]
hop_types[qid] = "Multi-hop" if example["metadata"]["type"] in MULTI_HOP_QUESTION_TYPES else "Single-hop"
eval_scores, instance_eval_results = evaluate_predictions(predicted_answers, gold_answers)
print("\n\nOverall result with different metrics: ")
for metric, value in eval_scores.items():
print(f"{metric}: {value}")
all_scores['overall'][metric] = value
modality_counts = Counter(answer_modalities.values())
_, _, eval_scores_by_modalities = \
evaluate_predictions(predicted_answers, gold_answers, answer_modalities)
print("\n\nEval results for different modalities:")
for answer_modality in sorted(eval_scores_by_modalities.keys()):
result = eval_scores_by_modalities[answer_modality]
print(f"{answer_modality}")
print(f"# of examples: {modality_counts[answer_modality]}")
all_scores['modalities'][answer_modality] = {}
for metric, value in result.items():
print(f"{metric}: {value}")
all_scores['modalities'][answer_modality][metric] = value
hop_type_counts = Counter(hop_types.values())
_, _, eval_scores_by_hop_types = evaluate_predictions(predicted_answers, gold_answers, hop_types)
print("\n\nType\tCount\tEM\tF1")
for hop_type in sorted(eval_scores_by_hop_types.keys()):
result = eval_scores_by_hop_types[hop_type]
print(f"{hop_type}\t{hop_type_counts[hop_type]}\t{result['list_em']}\t{result['list_f1']}")
all_scores['hop_types'][hop_type] = {}
all_scores['hop_types'][hop_type][metric] = value
question_type_counts = Counter(question_types.values())
_, _, eval_scores_by_qtypes = evaluate_predictions(predicted_answers, gold_answers, question_types)
print("\n\nType\tCount\tEM\tF1")
for question_type in sorted(eval_scores_by_qtypes.keys()):
result = eval_scores_by_qtypes[question_type]
print(f"{question_type}\t{question_type_counts[question_type]}\t{result['list_em']}\t{result['list_f1']}")
all_scores['q_types'][question_type] = {}
all_scores['q_types'][question_type][metric] = value
# return eval_scores
return all_scores
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="evaluate on drop dataset")
parser.add_argument(
"--prediction_path",
type=str,
default="predictions.json",
help="location of the prediction file",
)
parser.add_argument(
"--gold_path",
type=str,
default="dev.json",
help="location of the gold file",
)
args = parser.parse_args()
evaluate_prediction_file(args.prediction_path, args.gold_path)

View File

@ -0,0 +1,20 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from .multimodal import MultimodalRAGModel
from .utils import reduce_embeddings

177
src/m3docrag/rag/base.py Normal file
View File

@ -0,0 +1,177 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from typing import List, Dict, Tuple
from tqdm.auto import tqdm
import torch
import numpy as np
from .utils import get_top_k_pages, get_top_k_pages_single_page_from_each_doc
class RAGModelBase:
def __init__(
self,
retrieval_model=None,
qa_model=None,
vqa_model=None,
):
"""Base class for RAG pipeline
- retrieval_model: arbitrary retrieval model (e.g., ColPali / ColBERT)
- qa_model: arbitrary text-only QA model (e.g., LLama3)
- vqa_model: arbitrary VQA model (e.g., InternVL2, GPT4-o)
"""
self.retrieval_model = retrieval_model
self.qa_model = qa_model
self.vqa_model = vqa_model
if self.retrieval_model is not None:
self.retrieval_model.model.eval()
if self.vqa_model is not None:
self.vqa_model.model.eval()
if self.qa_model is not None:
self.qa_model.model.eval()
def retrieve_pages_from_docs(
self,
query: str,
docid2embs: Dict[str, torch.Tensor],
docid2lens: Dict[str, torch.Tensor] = None,
index = None,
token2pageuid = None,
all_token_embeddings = None,
n_return_pages: int = 1,
single_page_from_each_doc: bool = False,
show_progress=False,
) -> List[Tuple]:
"""
Given text query and pre-extracted document embedding,
calculate similarity scores and return top-n pages
Args:
- query (str): a text query to call retrieval model
- docid2embs (Dict[str, tensor]): collection of document embeddings
key: document_id
value: torch.tensor of size (n_tokens, emb_dim)
- index: faiss index
- n_return_pages (int): number of pages to return
- single_page_from_each_doc (bool): if true, only single page is retrieved from each PDF document.
Return:
retrieval_results
[(doc_id, page_idx, scores)...]
"""
if index is not None:
# [n_query_tokens, dim]
query_emb = self.retrieval_model.encode_queries([query])[0]
query_emb = query_emb.cpu().float().numpy().astype(np.float32)
# NN search
k = n_return_pages
D, I = index.search(query_emb, k)
# Sum the MaxSim scores across all query tokens for each document
final_page2scores = {}
# Iterate over query tokens
for q_idx, query_emb in enumerate(query_emb):
# Initialize a dictionary to hold document relevance scores
curent_q_page2scores = {}
for nn_idx in range(k):
found_nearest_doc_token_idx = I[q_idx, nn_idx]
page_uid = token2pageuid[found_nearest_doc_token_idx] # Get the document ID for this token
# reconstruct the original score
doc_token_emb = all_token_embeddings[found_nearest_doc_token_idx]
score = (query_emb * doc_token_emb).sum()
# MaxSim: aggregate the highest similarity score for each query token per document
if page_uid not in curent_q_page2scores:
curent_q_page2scores[page_uid] = score
else:
curent_q_page2scores[page_uid] = max(curent_q_page2scores[page_uid], score)
for page_uid, score in curent_q_page2scores.items():
if page_uid in final_page2scores:
final_page2scores[page_uid] += score
else:
final_page2scores[page_uid] = score
# Sort documents by their final relevance score
sorted_pages = sorted(final_page2scores.items(), key=lambda x: x[1], reverse=True)
# Get the top-k document candidates
top_k_pages = sorted_pages[:k]
# [(doc_id, page_idx, scores)...]
sorted_results = []
for page_uid, score in top_k_pages:
# logger.info(f"{page_uid} with score {score}")
# page_uid = f"{doc_id}_page{page_id}"
doc_id = page_uid.split('_page')[0]
page_idx = int(page_uid.split('_page')[-1])
sorted_results.append((doc_id, page_idx, score.item()))
return sorted_results
docid2scores = {}
for doc_id, doc_embs in tqdm(
docid2embs.items(),
total=len(docid2embs),
disable = not show_progress,
desc=f"Calculating similarity score over documents"
):
doc_lens = None
if docid2lens is not None:
doc_lens = docid2lens[doc_id]
scores = self.retrieval_model.retrieve(
query=query,
doc_embeds=doc_embs,
doc_lens=doc_lens,
to_cpu=True,
return_top_1=False
)
scores = scores.flatten().tolist()
docid2scores[doc_id] = scores
# find the pages with top scores
if single_page_from_each_doc:
return get_top_k_pages_single_page_from_each_doc(docid2scores=docid2scores, k=n_return_pages)
else:
return get_top_k_pages(docid2scores=docid2scores, k=n_return_pages)
def run_qa(self):
raise NotImplementedError
def run_vqa(self):
raise NotImplementedError

View File

@ -0,0 +1,51 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from .base import RAGModelBase
import torch
from m3docrag.vqa import VQAModel
from m3docrag.retrieval import ColPaliRetrievalModel
class MultimodalRAGModel(RAGModelBase):
def __init__(
self,
retrieval_model: ColPaliRetrievalModel,
vqa_model: VQAModel = None
):
self.retrieval_model = retrieval_model
self.vqa_model = vqa_model
self.retrieval_model.model.eval()
if self.vqa_model is not None and isinstance(self.vqa_model.model, torch.nn.Module):
self.vqa_model.model.eval()
def run_vqa(
self,
images,
question,
) -> str:
response = self.vqa_model.generate(images=images, question=question)
assert isinstance(response, str), type(response)
return response

128
src/m3docrag/rag/utils.py Normal file
View File

@ -0,0 +1,128 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from tqdm.auto import tqdm
def reduce_embeddings(docid2embs, dim='page', show_progress=True):
"""Summarize document embedding by reducing (averaging) specific dimensions
Input embedding:
[n_pages, n_tokens, emb_dim]
Output embedding:
- reduction_dim == 'page'
[1, n_tokens, emb_dim]
- reduction_dim == 'token'
[1, n_pages, emb_dim]
- reduction_dim == 'page_token'
[1, 1, emb_dim]
"""
assert dim in ['page', 'token', 'page_token'], f"{dim}"
new_docid2embs = {}
for doc_id in tqdm(
list(docid2embs.keys()),
disable=not show_progress
):
# [n_pages, n_tokens, dim]
embs = docid2embs[doc_id]
emb_dim = embs.size(-1)
if dim == 'page':
# [n_tokens, dim]
new_embs = embs.mean(dim=0)
elif dim == 'token':
# [n_pages, dim]
new_embs = embs.mean(dim=1)
elif dim == 'page_token':
# [1, dim]
new_embs = embs.mean(dim=0).mean(dim=0)
new_docid2embs[doc_id] = new_embs.view(1, -1, emb_dim)
return new_docid2embs
def get_top_k_pages(docid2scores: dict, k: int):
"""
# Example usage:
docid2scores = {
"doc1": [10, 50, 30],
"doc2": [40, 20, 60],
"doc3": [70, 90]
}
k = 3
top_k_pages = get_top_k_pages(docid2scores, k)
print(top_k_pages)
-> [('doc3', 1, 90), ('doc3', 0, 70), ('doc2', 2, 60)]
"""
# Flatten the dictionary into a list of tuples (doc_id, page_index, score)
flattened_scores = [
(doc_id, page_index, score)
for doc_id, scores in docid2scores.items()
for page_index, score in enumerate(scores)
]
# Sort by score in descending order
flattened_scores.sort(key=lambda x: x[2], reverse=True)
# Get the top-k entries
top_k_pages = flattened_scores[:k]
return top_k_pages
def get_top_k_pages_single_page_from_each_doc(docid2scores: dict, k: int):
"""
# Example usage:
docid2scores = {
"doc1": [10, 50, 30],
"doc2": [40, 20, 60],
"doc3": [70, 90]
}
k = 2
top_k_pages = get_top_k_pages_single_page_from_each_doc(docid2scores, k)
print(top_k_pages)
-> [('doc3', 1, 90), ('doc2', 2, 60)]
"""
# First, get the highest scoring page for each document
highest_per_doc = [
(doc_id, max(enumerate(scores), key=lambda x: x[1])) # (doc_id, (page_index, score))
for doc_id, scores in docid2scores.items()
]
# Flatten the structure to (doc_id, page_index, score)
highest_per_doc_flat = [(doc_id, page_index, score) for doc_id, (page_index, score) in highest_per_doc]
# Sort by score in descending order
highest_per_doc_flat.sort(key=lambda x: x[2], reverse=True)
# Get the top-k entries
top_k_pages = highest_per_doc_flat[:k]
return top_k_pages

View File

@ -0,0 +1,18 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from .colpali import ColPaliRetrievalModel

View File

@ -0,0 +1,303 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
# from transformers import AutoProcessor
from PIL import Image
from typing import List
from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.models import ColQwen2, ColQwen2Processor
def init(
backbone_name_or_path="/job/model/colpaligemma-3b-pt-448-base",
adapter_name_or_path= "/job/model/colpali-v1.2",
dtype=torch.bfloat16,
):
"""
Load ColPali Model and Processor from (locally downloaded) HF checkpoint.
Args:
- backbone_model_name_or_path: downloaded from https://huggingface.co/vidore/colpaligemma-3b-pt-448-base
- adapter_name_or_path: downloaded from https://huggingface.co/vidore/colpali-v1.2
Return:
- model
- processor
"""
kwargs = {}
model_class = ColPali
processor_class = ColPaliProcessor
if 'colqwen' in str(adapter_name_or_path):
model_class = ColQwen2
processor_class = ColQwen2Processor
kwargs['attn_implementation'] = "flash_attention_2"
model = model_class.from_pretrained(
backbone_name_or_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
**kwargs
).eval()
model.load_adapter(adapter_name_or_path)
processor = processor_class.from_pretrained(adapter_name_or_path)
return model, processor
def encode_images(
model,
processor,
images: List[Image.Image],
batch_size: int = 4,
to_cpu: bool = False,
use_tqdm: bool = False,
collate_fn=None,
return_doclens: bool = False
):
"""Create document embeddings with ColPali
Args:
model
processor
images (List[Image.Image])
(n_pages)
batch_size (int, optional):
batch size. Defaults to 4.
to_cpu (bool, optional):
whether to save embeddings in cpu tensors. Defaults to False.
use_tqdm (bool, optional):
whether to show tqdm progress bar. Defaults to False.
collate_fn (_type_, optional):
custom collate_fn for document dataloader. Defaults to None.
return_doclens (bool, optional):
whether to output the number of pages. Defaults to False.
Returns:
doc_embs: List[torch.tensor]
visual embedding of documents (n_pages, n_tokens, n_dimension)
(optional) doclens
number of pages
"""
if collate_fn is None:
collate_fn = processor.process_images
# run inference - docs
dataloader = DataLoader(
images,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn,
)
doc_embs = []
if return_doclens:
doclens = []
if use_tqdm:
dataloader = tqdm(dataloader)
for batch_doc in dataloader:
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
if to_cpu:
embeddings_doc = embeddings_doc.to("cpu")
if return_doclens:
_doclens = batch_doc.attention_mask.squeeze(-1).sum(-1).tolist()
doclens.extend(_doclens)
doc_embs.extend(list(torch.unbind(embeddings_doc)))
if return_doclens:
doc_embs, doclens
else:
return doc_embs
def encode_queries(
model,
processor,
queries: List[str],
batch_size: int = 4,
to_cpu: bool = False,
use_tqdm: bool = False,
collate_fn=None,
):
"""Create query embeddings with ColPali
Args:
model
processor
queries (List[str]):
text queries (n_queries,)
batch_size (int, optional):
batch size. Defaults to 4.
to_cpu (bool, optional):
whether to save embeddings in cpu tensors. Defaults to False.
use_tqdm (bool, optional):
whether to show tqdm progress bar. Defaults to False.
collate_fn (_type_, optional):
custom collate_fn for document dataloader. Defaults to None.
Returns:
query_embs: List[torch.tensor]
embedding of queries (n_queries, n_tokens, n_dimension)
"""
if collate_fn is None:
collate_fn = processor.process_queries
# run inference - queries
dataloader = DataLoader(
queries,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
query_embs = []
if use_tqdm:
dataloader = tqdm(dataloader)
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
if to_cpu:
embeddings_query = embeddings_query.to("cpu")
query_embs.extend(list(torch.unbind(embeddings_query)))
return query_embs
def retrieve(
model,
processor,
docs=None,
query=None,
doc_embeds=None,
query_embeds=None,
to_cpu=False,
batch_size=1,
use_tqdm=False,
return_top_1=True
):
"""Find the right document image with colpali
"""
if doc_embeds is None:
doc_embeds = encode_images(
model, processor,
images=docs,
batch_size=batch_size,
use_tqdm=use_tqdm,
to_cpu=to_cpu,
)
if query_embeds is None:
query_embeds = encode_queries(
model, processor,
queries=[query],
batch_size=1,
use_tqdm=use_tqdm,
to_cpu=to_cpu,
)
qs = query_embeds
ds = doc_embeds
qs = [q.to(ds[0].dtype) for q in qs]
scores = processor.score_multi_vector(qs, ds)
if return_top_1:
return scores.argmax(axis=1)
else:
return scores
class ColPaliRetrievalModel:
def __init__(self,
backbone_name_or_path="/job/model/colpaligemma-3b-pt-448-base",
adapter_name_or_path= "/job/model/colpali-v1.2",
dtype=torch.bfloat16,
):
model, processor = init(backbone_name_or_path=backbone_name_or_path,
adapter_name_or_path=adapter_name_or_path,
dtype=dtype,
)
self.model = model.eval()
self.processor = processor
def encode_queries(self,
queries: List[str],
batch_size: int = 4,
to_cpu: bool = False,
use_tqdm: bool = False,
collate_fn=None
):
return encode_queries(
model=self.model,
processor=self.processor,
queries=queries,
batch_size=batch_size,
to_cpu=to_cpu,
use_tqdm=use_tqdm,
collate_fn=collate_fn)
def encode_images(self,
images: List[Image.Image],
batch_size: int = 4,
to_cpu: bool = False,
use_tqdm: bool = False,
collate_fn=None,
return_doclens: bool = False
):
return encode_images(
model=self.model,
processor=self.processor,
images=images,
batch_size=batch_size,
to_cpu=to_cpu,
use_tqdm=use_tqdm,
collate_fn=collate_fn,
return_doclens=return_doclens,
)
def retrieve(self,
docs=None,
query=None,
doc_embeds=None,
doc_lens=None,
query_embeds=None,
to_cpu=False,
batch_size=1,
use_tqdm=False,
return_top_1=True
):
return retrieve(
model=self.model,
processor=self.processor,
docs=docs,
query=query,
doc_embeds=doc_embeds,
query_embeds=query_embeds,
to_cpu=to_cpu,
batch_size=batch_size,
use_tqdm=use_tqdm,
return_top_1=return_top_1
)

View File

@ -0,0 +1,75 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List
import transformers
from icecream import ic
@dataclass
class TrainingArguments(transformers.TrainingArguments):
# Data settings
split: str = 'train'
data_name: str = field(default='m3-docvqa', metadata={"help": "Local name to be stored at LOCAL_DATA_DIR"})
data_len: int = field(default=None, metadata={"help": "number of examples to subsample from dataset"})
use_dummy_images: bool = field(default=False, metadata={"help": "if true, skip downloading images"})
load_embedding: bool = False
embedding_name: str = "colpali-v1.2_m3-docvqa_dev"
max_pages: int = 20
do_page_padding: bool = False
# Retrieval settings
retrieval_model_type: str = field(default='colpali', metadata={"choices": ['colpali', 'colbert']})
use_retrieval: bool = True
retrieval_only: bool = field(default=False, metadata={"help": "not running stage 2 (VQA)"})
page_retrieval_type: str = 'logits'
loop_unique_doc_ids: bool = field(default=False, metadata={"help": "if true, apply retrieval only on unique doc ids"})
n_retrieval_pages: int = 1
# Embedding indexing settings
faiss_index_type: str = field(default='ivfflat', metadata={"choices": ['flatip', 'ivfflat', 'ivfpq']})
# Local paths
model_name_or_path: Optional[str] = field(default="Qwen2-VL-7B-Instruct")
retrieval_model_name_or_path: Optional[str] = field(default="colpaligemma-3b-pt-448-base")
retrieval_adapter_model_name_or_path: Optional[str] = field(default="colpali-v1.2")
# Model settings
bits: int = field(default=16, metadata={"help": "Floating point precision. Use '4' for 4-bit quantization to save memory"})
# idefics2 settings
do_image_splitting: bool = False
_example_arg_str = """
--output_dir=/job/outputs
--data_name=m3-docvqa
--use_retrieval=True
"""
_example_args = _example_arg_str.strip().split('\n')
def parse_args(args=None):
parser = transformers.HfArgumentParser(TrainingArguments)
parsed_args, remaining_args = parser.parse_args_into_dataclasses(args, return_remaining_strings=True)
ic(remaining_args)
return parsed_args

View File

@ -0,0 +1,209 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""
Utility functions to manage multi-device training
"""
import inspect
import logging
import os
import platform
import subprocess
import sys
# import pkg_resources
import importlib.metadata
import torch.cuda
import torch.distributed
from loguru import logger
def world_size() -> int:
"""Returns the total number of processes in a distributed job (num_nodes x gpus_per_node).
Returns 1 in a non-distributed job.
"""
return int(os.environ.get("WORLD_SIZE", "1"))
def is_distributed() -> bool:
"""Returns True iff this is a distributed job (more than one process)."""
return world_size() > 1
def local_rank() -> int:
"""Returns the local rank of the current process in a distributed job.
Returns 0 (local primary) for non-distributed jobs.
"""
return int(os.environ.get("LOCAL_RANK", "0"))
def global_rank() -> int:
"""Returns the global rank of the current process in a distributed job.
Returns 0 (global primary) for non-distributed jobs.
"""
return int(os.environ.get("RANK", local_rank()))
def barrier():
"""Synchronizes all processes. Set GPU with local_rank to perform barrier used by this process."""
torch.distributed.barrier(device_ids=[local_rank()])
def patch_module_loggers(module_namespace):
modules = inspect.getmembers(module_namespace, predicate=inspect.ismodule)
# check toplevel module
if hasattr(module_namespace, "logger"):
module_namespace.logger = logger
logger.info(f"Patching logger: {module_namespace.__name__}")
for _, mod in modules:
if hasattr(mod, "logger"):
mod.logger = logger
logger.info(f"Patching logger: {mod.__name__}")
class InterceptLogHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
# Get corresponding Loguru level if it exists.
level: str | int
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
# Find caller from where originated the logged message.
frame, depth = inspect.currentframe(), 0
while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__):
frame = frame.f_back
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)
def configure_distributed_logging(
rank_zero_level="INFO",
rank_non_zero_level="WARNING",
):
"""This method configures logger to reduce noise in multi-node, multi-process evaluations (e.g. DeepSpeed)_summary_
Args:
rank_zero_level (str, optional): Log level on zero rank process. Defaults to "INFO".
rank_non_zero_level (str, optional): Log level on non-zero rank processes. Defaults to "WARNING".
"""
logger.remove()
rank = local_rank()
format = "<g>{time:YYYY-MM-DD HH:mm:ss.SSS}</g> | <lvl>{level}</lvl> | rank={extra[rank]} | <c>{name}</c>:<c>{function}</c>:<c>{line}</c> - <lvl>{message}</lvl>\n{exception}"
if rank != 0:
logger.configure(
extra={"rank": f"{global_rank()}:{rank}"},
handlers=[
{"sink": sys.stderr, "format": format, "level": rank_non_zero_level}
],
)
else:
logger.configure(
extra={"rank": f"{global_rank()}:{rank}"},
handlers=[{"sink": sys.stdout, "format": format, "level": rank_zero_level}],
)
# Attempt to intercept normal logging in libs
logging.basicConfig(handlers=[InterceptLogHandler()], level=0, force=True)
def get_cuda_version():
"""Get the installed CUDA version."""
try:
output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
version_line = output.strip().split("\n")[-1]
version = version_line.split(" ")[-1]
return version
except Exception as e:
logger.info(f"Cannot detect CUDA version. Exception occured: {e}")
return "N/A"
def log_runtime_info():
# Get Python runtime information
python_version = sys.version
python_implementation = platform.python_implementation()
python_build = platform.python_build()
# Get environment variables
env_vars = os.environ
# Get installed package versions
installed_packages = [
# (d.project_name, d.version) for d in pkg_resources.working_set
(d.metadata["Name"], d.version) for d in importlib.metadata.distributions()
]
# logger.info diagnostics
logger.info("Python Version: {}".format(python_version))
logger.info("Python Implementation: {}".format(python_implementation))
logger.info("Python Build: {}".format(python_build))
logger.info(f"Environment Variables: {env_vars}")
logger.info(f"Installed Packages: {installed_packages}")
logger.info(f"CUDA version: {get_cuda_version()}")
logger.info(f"Is CUDA available for Torch?: {torch.cuda.is_available()}")
logger.info(f"World size: {world_size()}")
def local_rank_zero(func):
"""
Decorator to execute function only in local zero rank. Can be useful for logging statistics.
"""
def wrapper(*args, **kwargs):
if local_rank() == 0:
func(*args, **kwargs)
return wrapper
def global_rank_zero(func):
"""
Decorator to execute function only in global zero rank. Can be useful for logging statistics.
"""
def wrapper(*args, **kwargs):
if global_rank() == 0:
func(*args, **kwargs)
return wrapper
import gpustat
def print_gpu_stats():
gpustat.cli.main([])
def supports_flash_attention(device_id=0):
"""Check if a GPU supports FlashAttention."""
major, minor = torch.cuda.get_device_capability(device_id)
# Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
return is_sm8x or is_sm90

View File

@ -0,0 +1,24 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import os
from dotenv import load_dotenv
load_dotenv()
LOCAL_DATA_DIR = os.getenv("LOCAL_DATA_DIR", "/job/datasets")
LOCAL_EMBEDDINGS_DIR = os.getenv("LOCAL_EMBEDDINGS_DIR", "/job/embeddings")
LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/job/model")
LOCAL_OUTPUT_DIR = os.getenv("LOCAL_OUTPUT_DIR", "/job/output")

View File

@ -0,0 +1,75 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from loguru import logger
from pdf2image import convert_from_path
from collections import Counter
def get_images_from_pdf(pdf_path, max_pages=None, dpi_resolution=144, save_dir='/tmp/', save_image=False, save_type='png', verbose=False):
pdf_path = Path(pdf_path)
assert pdf_path.exists(), f"{pdf_path} does not exist"
pdf_fname = pdf_path.name
images = convert_from_path(pdf_path, dpi=dpi_resolution)
# PIL.PpmImagePlugin.PpmImageFile -> PIL.Image.Image
images = [img.convert('RGB') for img in images]
# resizing to the most common image size so that we can stack in pytorch tensor
# PDFs (e.g., MMLongBench-Doc) have different image sizes
# width=1,224 and height=1,584
# 1440, 810
# 1191, 1684
# 1440, 1080
# 1536, 1152
# 1) find the most common image size
img_size_counter = Counter()
for img in images:
img_size_counter[img.size] += 1
common_img_size, common_img_size_count = img_size_counter.most_common(1)[0]
# 2) if pages have different sizes -> resize all pages to that image size
if len(images) != common_img_size_count:
logger.info(f"total: {len(images)} pages")
logger.info(f"resizing to the most common image size: {common_img_size} with count: {common_img_size_count}")
images = [img.resize(common_img_size) for img in images]
if save_image:
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
for page_index, page_image in enumerate(images):
save_page_path = save_dir / f"{pdf_fname}_{page_index+1}.{save_type}"
if not save_page_path.exists():
page_image.save(save_page_path)
if verbose:
logger.info(f"Page {page_index} saved at {save_page_path}")
return images
if __name__ == '__main__':
get_images_from_pdf(
pdf_path="./multimodalqa_screenshots_pdfs/0df5cc80bcd2a27b91224d658ad3a7b5.pdf",
save_dir='./tmp/',
save_image=True
)

View File

@ -0,0 +1,70 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import string
binary_page_retrieval_template = """
question: $question
Does this page have the answer to the question?
Answer only with yes or no.
""".strip()
concat_page_retrieval_template = """
question: $question
Which page is most relevant to the question?
""".strip()
concat_page_retrieval_with_answer_template = """
Find the page most relevant to the question and answer: "$question"
""".strip()
concate_page_answer_template = """
Find the answer to the question: "$question"
""".strip()
short_answer_template = """
question: $question
output only answer.
""".strip()
long_answer_template = """
question: $question
Answer the question with detailed explanation.
""".strip()
text_rag_template = """
DOCUMENTS:
$documents
QUESTION:
$question
INSTRUCTIONS:
Answer the QUESTION using the DOCUMENTS text above. Simply output the answer only.
Answer:
"""
binary_page_retrieval_template = string.Template(binary_page_retrieval_template)
concat_page_retrieval_template = string.Template(concat_page_retrieval_template)
concat_page_retrieval_with_answer_template = string.Template(concat_page_retrieval_with_answer_template)
concate_page_answer_template = string.Template(concate_page_answer_template)
short_answer_template = string.Template(short_answer_template)
long_answer_template = string.Template(long_answer_template)
text_rag_template = string.Template(text_rag_template)

32
src/m3docrag/utils/tar.py Normal file
View File

@ -0,0 +1,32 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import os
import tarfile
from loguru import logger
def make_tarfile(source_dir, output_filename):
logger.info(f"Compressing {source_dir} to {output_filename} ...")
with tarfile.open(output_filename, "w:gz") as tar:
tar.add(source_dir, arcname='.')
logger.info(f"Compression done!")
def extract_tarfile(input_filename, target_dir):
logger.info(f"Extracting {input_filename} to {target_dir} ...")
with tarfile.open(input_filename) as f:
f.extractall(target_dir)
logger.info(f"Extraction done!")

View File

@ -0,0 +1,144 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from typing import Union, List
from pathlib import Path
import torch
from m3docrag.vqa import internvl2
from m3docrag.vqa import idefics2
from m3docrag.vqa import idefics3
from m3docrag.vqa import florence2
from m3docrag.vqa import qwen2
ALL_VQA_MODEL_TYPES = ['florence2', 'idefics2', 'internvl2', 'idefics3', 'qwen2']
def init(
model_name_or_path: Union[str, Path],
model_type: str,
**kwargs
):
if 'internvl2' == model_type.lower():
return internvl2.init(
model_name_or_path=model_name_or_path,
**kwargs
)
elif 'idefics2' == model_type.lower():
return idefics2.init(
model_name_or_path=model_name_or_path,
**kwargs
)
elif 'idefics3' == model_type.lower():
return idefics3.init(
model_name_or_path=model_name_or_path,
**kwargs
)
elif 'florence2' == model_type.lower():
return florence2.init(
model_name_or_path=model_name_or_path,
**kwargs
)
elif 'qwen2' == model_type.lower():
return qwen2.init(
model_name_or_path=model_name_or_path,
**kwargs
)
else:
raise NotImplementedError(f"{model_type} is unsupported. Supported: {ALL_VQA_MODEL_TYPES}")
def generate(
model_type: str,
model,
processor,
**kwargs
) -> List[str]:
if 'internvl2' == model_type.lower():
return internvl2.generate(
model=model,
processor=processor,
**kwargs
)
elif 'idefics2' == model_type.lower():
return idefics2.generate(
model=model,
processor=processor,
**kwargs
)
elif 'idefics3' == model_type.lower():
return idefics3.generate(
model=model,
processor=processor,
**kwargs
)
elif 'florence2' == model_type.lower():
return florence2.generate(
model=model,
processor=processor,
**kwargs
)
elif 'qwen2' == model_type.lower():
return qwen2.generate(
model=model,
processor=processor,
**kwargs
)
else:
raise NotImplementedError(f"{model_type} is unsupported. Supported: {ALL_VQA_MODEL_TYPES}")
class VQAModel:
def __init__(self, model_name_or_path: Union[str, Path], model_type: str, **kwargs):
model_loaded = init(model_name_or_path=model_name_or_path, model_type=model_type, **kwargs)
model = model_loaded['model']
if 'tokenizer' in model_loaded:
processor = model_loaded['tokenizer']
else:
processor = model_loaded['processor']
if isinstance(model, torch.nn.Module):
model = model.eval()
# greedy decoding
if hasattr(model, 'generation_config'):
model.generation_config.temperature=None
model.generation_config.top_p=None
model.generation_config.top_k=None
self.model = model
self.processor = processor
self.model_type = model_type
def generate(self, images, question) -> str:
responses = generate(
model_type=self.model_type,
model=self.model,
processor=self.processor,
images=images,
question=question,
)
assert isinstance(responses, list), responses
out_text = responses[0]
out_text = out_text.strip()
return out_text

View File

@ -0,0 +1,169 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
from typing import Union, List
def init(
model_name_or_path,
dtype=torch.bfloat16,
# model_max_length=None,
**kwargs,
):
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
trust_remote_code=True
)
model.eval()
processor = AutoProcessor.from_pretrained(
model_name_or_path,
trust_remote_code=True,
# model_max_length=model_max_length
)
return {
'model': model,
'processor': processor
}
def generate(
model,
processor,
question,
images
) -> List[str]:
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
module = model.module
else:
module = model
answer = generate_caption(module, processor=processor, images=images, text_input=question)
return answer
@torch.no_grad()
def generate_caption(
model,
processor,
images=None,
# task_prompt="<MORE_DETAILED_CAPTION>",
text_input=None,
input_ids=None,
pixel_values=None,
max_new_tokens=77,
num_beams=1,
do_sample=False,
decode_text=True,
**generate_kwargs
):
if input_ids is None and pixel_values is None:
if isinstance(images, Image.Image):
images = [images]
B = len(images)
if isinstance(text_input, str):
text_input = [text_input] * B
inputs = processor(
text=text_input,
images=images,
return_tensors="pt")
input_ids = inputs["input_ids"]
pixel_values = inputs["pixel_values"]
p = next(iter(model.parameters()))
device = p.device
dtype = p.dtype
generated_ids = model.generate(
input_ids=input_ids.to(device),
pixel_values=pixel_values.to(device, dtype),
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
**generate_kwargs
)
if decode_text:
out_text = decode_predictions(processor, generated_ids)
return out_text
else:
return generated_ids
def decode_predictions(processor, generated_ids):
B = len(generated_ids)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
out_text = []
for i in range(B):
out_text.append(generated_text[i].replace('<s>', '').replace('</s>', '').replace('<pad>', '').strip())
return out_text
# def load_model_from_ckpt(
# ckpt_dir,
# load_abstractor=False,
# num_input_tokens=576,
# num_query_tokens=64,
# proj_type='c-abs',
# projection_dim=1024,
# strict=False
# ):
# florence_config = Florence2Config.from_pretrained(ckpt_dir)
# if load_abstractor:
# abstractor_config = AbstractorConfig(
# num_input_tokens=num_input_tokens,
# num_query_tokens=num_query_tokens,
# proj_type=proj_type,
# projection_dim=projection_dim,
# )
# florence_config.abstractor_config = abstractor_config
# florence_config.vision_config.model_type = 'davit'
# model = Florence2ForConditionalGeneration(config=florence_config)
# ckpt_path = Path(ckpt_dir) / 'model.safetensors'
# if ckpt_path.exists():
# logger.info(f"loading checkpoints from {ckpt_path}")
# state_dict = safetensors.torch.load_file(ckpt_path, device="cpu")
# else:
# ckpt_path = Path(ckpt_dir) / 'pytorch_model.bin'
# logger.info(f"loading checkpoints from {ckpt_path}")
# state_dict = torch.load(
# ckpt_path,
# map_location="cpu",
# )
# load_result = model.load_state_dict(state_dict, strict=strict)
# logger.info(load_result)
# return model

View File

@ -0,0 +1,137 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import torch
from transformers import Idefics2ForConditionalGeneration
from transformers import Idefics2Processor
from transformers import BitsAndBytesConfig
from typing import Union, List
def init(
model_name_or_path,
do_image_splitting=True,
bits=4,
dtype=torch.bfloat16,
**kwargs,
):
if bits == 4:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=dtype
)
else:
bnb_config = None
model = Idefics2ForConditionalGeneration.from_pretrained(
model_name_or_path,
torch_dtype=dtype,
quantization_config=bnb_config,
low_cpu_mem_usage=True
)
model.eval()
processor = Idefics2Processor.from_pretrained(
model_name_or_path,
do_image_splitting=do_image_splitting,
# size= {"longest_edge": image_size, "shortest_edge": 378}
)
return {
'model': model,
'processor': processor
}
def generate(
model,
processor,
question,
images
) -> List[str]:
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
module = model.module
else:
module = model
messages = idefics2_create_message(images=images, question=question)
examples = [{'images': images, 'messages': messages}]
batch = idefics2_collate_fn(examples, processor)
for k in batch:
batch[k] = batch[k].to(module.device)
generated_ids = module.generate(**batch, max_new_tokens=50, do_sample=False)
answer = processor.batch_decode(
generated_ids[:, batch.input_ids.size(1):],
skip_special_tokens=True)
return answer
def idefics2_collate_fn(examples,
processor,
model_max_length=3600,
is_train=False,
image_token_id=None,
):
if image_token_id is None:
image_token_id = processor.tokenizer.additional_special_tokens_ids[processor.tokenizer.additional_special_tokens.index("<image>")]
texts = []
images = []
for example in examples:
prompt = processor.apply_chat_template(example['messages'], add_generation_prompt=not is_train)
texts.append(prompt)
images.append(example['images'])
if is_train:
batch = processor(text=texts, images=images, return_tensors="pt",
padding=True, truncation=True,
max_length=model_max_length)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
batch["labels"] = labels
else:
batch = processor(text=texts, images=images, return_tensors="pt",
padding=True,
truncation=True,
max_length=model_max_length
)
return batch
def idefics2_create_message(images, question, is_train=False, target_text=None):
content = []
for page_i in range(len(images)):
# content += [{"type": "text", "text": f"page {page_i}: "}]
content += [{"type": "image"}]
content += [{"type": "text", "text": question}]
messages = [{"role": "user", "content": content}]
if is_train:
messages += [{"role": "assistant", "content": [{"type": "text", "text": target_text}]}]
return messages

View File

@ -0,0 +1,129 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers import BitsAndBytesConfig
from typing import Union, List
def init(
model_name_or_path,
bits=4,
dtype=torch.bfloat16,
**kwargs,
):
if bits == 4:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=dtype
)
else:
bnb_config = None
model = AutoModelForVision2Seq.from_pretrained(
model_name_or_path,
torch_dtype=dtype,
quantization_config=bnb_config,
low_cpu_mem_usage=True
)
model.eval()
processor = AutoProcessor.from_pretrained(
model_name_or_path,
)
return {
'model': model,
'processor': processor
}
def generate(
model,
processor,
question,
images
) -> List[str]:
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
module = model.module
else:
module = model
messages = idefics3_create_message(images=images, question=question)
examples = [{'images': images, 'messages': messages}]
batch = idefics3_collate_fn(examples, processor)
for k in batch:
batch[k] = batch[k].to(module.device)
generated_ids = module.generate(**batch, max_new_tokens=50, do_sample=False)
answer = processor.batch_decode(
generated_ids[:, batch.input_ids.size(1):],
skip_special_tokens=True)
return answer
def idefics3_collate_fn(examples,
processor,
is_train=False,
image_token_id=None,
):
texts = []
images = []
for example in examples:
prompt = processor.apply_chat_template(example['messages'], add_generation_prompt=not is_train)
texts.append(prompt)
images.append(example['images'])
if is_train:
batch = processor(text=texts, images=images, return_tensors="pt",
padding=True, truncation=True,
# max_length=model_max_length
)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
batch["labels"] = labels
else:
batch = processor(text=texts, images=images, return_tensors="pt",
padding=True,
truncation=True,
# max_length=model_max_length
)
return batch
def idefics3_create_message(images, question, is_train=False, target_text=None):
content = []
for page_i in range(len(images)):
# content += [{"type": "text", "text": f"page {page_i}: "}]
content += [{"type": "image"}]
content += [{"type": "text", "text": question}]
messages = [{"role": "user", "content": content}]
if is_train:
messages += [{"role": "assistant", "content": [{"type": "text", "text": target_text}]}]
return messages

View File

@ -0,0 +1,288 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import torch
import torchvision.transforms as T
from pathlib import Path
# from decord import VideoReader, cpu
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
from typing import Union, List
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=12):
if isinstance(image_file, (str, Path)):
image = Image.open(image_file).convert('RGB')
elif isinstance(image_file, Image.Image):
image = image_file
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
# def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
# if bound:
# start, end = bound[0], bound[1]
# else:
# start, end = -100000, 100000
# start_idx = max(first_idx, round(start * fps))
# end_idx = min(round(end * fps), max_frame)
# seg_size = float(end_idx - start_idx) / num_segments
# frame_indices = np.array([
# int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
# for idx in range(num_segments)
# ])
# return frame_indices
# def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
# vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
# max_frame = len(vr) - 1
# fps = float(vr.get_avg_fps())
# pixel_values_list, num_patches_list = [], []
# transform = build_transform(input_size=input_size)
# frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
# for frame_index in frame_indices:
# img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
# img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
# pixel_values = [transform(tile) for tile in img]
# pixel_values = torch.stack(pixel_values)
# num_patches_list.append(pixel_values.shape[0])
# pixel_values_list.append(pixel_values)
# pixel_values = torch.cat(pixel_values_list)
# return pixel_values, num_patches_list
def init(
model_name_or_path,
dtype=torch.bfloat16,
use_flash_attn=True,
**kwargs,
):
model = AutoModel.from_pretrained(
model_name_or_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
use_flash_attn=use_flash_attn,
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=False)
return {
'model': model,
# 'tokenizer': tokenizer
'processor': tokenizer
}
def generate(
model,
# tokenizer,
processor,
question,
images
) -> List[str]:
tokenizer = processor
if len(images) == 1:
pixel_values = load_image(
images[0],
max_num=12).to(torch.bfloat16).to(model.device)
else:
raise NotImplementedError
generation_config = dict(max_new_tokens=1024, do_sample=False)
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
module = model.module
else:
module = model
answer = module.chat(tokenizer, pixel_values, question, generation_config)
return [answer]
if __name__ == '__main__':
# If you want to load a model using multiple GPUs, please refer to the `Multiple GPUs` section.
path = 'OpenGVLab/InternVL2-8B'
model = AutoModel.from_pretrained(
path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
# set the max number of tiles in `max_num`
pixel_values = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
generation_config = dict(max_new_tokens=1024, do_sample=False)
# pure-text conversation (纯文本对话)
question = 'Hello, who are you?'
response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
print(f'User: {question}\nAssistant: {response}')
question = 'Can you tell me a story?'
response, history = model.chat(tokenizer, None, question, generation_config, history=history, return_history=True)
print(f'User: {question}\nAssistant: {response}')
# single-image single-round conversation (单图单轮对话)
question = '<image>\nPlease describe the image shortly.'
response = model.chat(tokenizer, pixel_values, question, generation_config)
print(f'User: {question}\nAssistant: {response}')
# single-image multi-round conversation (单图多轮对话)
question = '<image>\nPlease describe the image in detail.'
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
print(f'User: {question}\nAssistant: {response}')
question = 'Please write a poem according to the image.'
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=history, return_history=True)
print(f'User: {question}\nAssistant: {response}')
# multi-image multi-round conversation, combined images (多图多轮对话,拼接图像)
pixel_values1 = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
pixel_values2 = load_image('./examples/image2.jpg', max_num=12).to(torch.bfloat16).cuda()
pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)
question = '<image>\nDescribe the two images in detail.'
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
history=None, return_history=True)
print(f'User: {question}\nAssistant: {response}')
question = 'What are the similarities and differences between these two images.'
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
history=history, return_history=True)
print(f'User: {question}\nAssistant: {response}')
# multi-image multi-round conversation, separate images (多图多轮对话,独立图像)
pixel_values1 = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
pixel_values2 = load_image('./examples/image2.jpg', max_num=12).to(torch.bfloat16).cuda()
pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)
num_patches_list = [pixel_values1.size(0), pixel_values2.size(0)]
question = 'Image-1: <image>\nImage-2: <image>\nDescribe the two images in detail.'
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
num_patches_list=num_patches_list,
history=None, return_history=True)
print(f'User: {question}\nAssistant: {response}')
question = 'What are the similarities and differences between these two images.'
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
num_patches_list=num_patches_list,
history=history, return_history=True)
print(f'User: {question}\nAssistant: {response}')
# batch inference, single image per sample (单图批处理)
pixel_values1 = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
pixel_values2 = load_image('./examples/image2.jpg', max_num=12).to(torch.bfloat16).cuda()
num_patches_list = [pixel_values1.size(0), pixel_values2.size(0)]
pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)
questions = ['<image>\nDescribe the image in detail.'] * len(num_patches_list)
responses = model.batch_chat(tokenizer, pixel_values,
num_patches_list=num_patches_list,
questions=questions,
generation_config=generation_config)
for question, response in zip(questions, responses):
print(f'User: {question}\nAssistant: {response}')
# video multi-round conversation (视频多轮对话)
video_path = './examples/red-panda.mp4'
pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1)
pixel_values = pixel_values.to(torch.bfloat16).cuda()
video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(num_patches_list))])
question = video_prefix + 'What is the red panda doing?'
# Frame1: <image>\nFrame2: <image>\n...\nFrame8: <image>\n{question}
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
num_patches_list=num_patches_list, history=None, return_history=True)
print(f'User: {question}\nAssistant: {response}')
question = 'Describe this video in detail. Don\'t repeat.'
response, history = model.chat(tokenizer, pixel_values, question, generation_config,
num_patches_list=num_patches_list, history=history, return_history=True)
print(f'User: {question}\nAssistant: {response}')

100
src/m3docrag/vqa/qwen2.py Normal file
View File

@ -0,0 +1,100 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from PIL import Image
import requests
import torch
from torchvision import io
from typing import Dict, List
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info
def init(
model_name_or_path,
dtype=torch.bfloat16,
bits=16,
attn_implementation="flash_attention_2",
**kwargs,
):
if bits == 4:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=dtype
)
else:
bnb_config = None
# Load the model in half-precision on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name_or_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
attn_implementation=attn_implementation,
quantization_config=bnb_config,
vision_config={"torch_dtype": dtype}
)
model.eval()
processor = AutoProcessor.from_pretrained(model_name_or_path)
return {
'model': model,
'processor': processor
}
def generate(
model,
processor,
question,
images
) -> List[str]:
image_content = [{"type": "image", "image": "dummy_content"}] * len(images)
messages = [
{
"role": "user",
"content": image_content + [{"type": "text", "text": question}]
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=images,
# videos=video_inputs,
padding=True,
return_tensors="pt",
)
p = next(iter(model.parameters()))
inputs = inputs.to(p.device)
# Inference
generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
assert isinstance(output_text, list), output_text
return output_text