Files
M3DocRAG/examples/run_page_embedding.py
j-min 27aac8d521 Release commit
Signed-off-by: Stephen Augustus <saugustus2@bloomberg.net>
2025-02-15 09:52:51 -05:00

181 lines
5.3 KiB
Python

# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
import accelerate
import safetensors
import torch
import transformers
from accelerate import Accelerator
from loguru import logger
from tqdm import tqdm
from m3docrag.datasets.m3_docvqa import M3DocVQADataset
from m3docrag.retrieval import ColPaliRetrievalModel
from m3docrag.utils.args import parse_args
from m3docrag.utils.distributed import (
barrier,
global_rank,
is_distributed,
local_rank,
log_runtime_info,
print_gpu_stats,
)
from m3docrag.utils.paths import (
LOCAL_DATA_DIR,
LOCAL_MODEL_DIR,
)
logger.info(torch.__version__)
logger.info(transformers.__version__)
logger.info(accelerate.__version__)
def main():
args = parse_args()
log_runtime_info()
print_gpu_stats()
accelerator = Accelerator()
if not is_distributed() or global_rank() == 0:
logger.info(f"Process {global_rank()}:{local_rank()} - args {args}")
if is_distributed():
barrier()
local_data_dir = Path(LOCAL_DATA_DIR) / args.data_name
local_retrieval_model_dir = (
Path(LOCAL_MODEL_DIR) / args.retrieval_model_name_or_path
)
local_retrieval_adapter_model_dir = (
Path(LOCAL_MODEL_DIR) / args.retrieval_adapter_model_name_or_path
)
# Download datasets / model checkpoints
if not is_distributed() or global_rank() == 0:
if not local_data_dir.exists():
raise ValueError(f"Data directory {local_data_dir} does not exist")
assert args.use_retrieval, args.use_retrieval
if not local_retrieval_model_dir.exists():
raise ValueError(
f"Retrieval model directory {local_retrieval_model_dir} does not exist"
)
if args.retrieval_model_type == "colpali":
if not local_retrieval_adapter_model_dir.exists():
raise ValueError(
f"Retrieval adapter model directory {local_retrieval_adapter_model_dir} does not exist"
)
if is_distributed():
barrier()
if args.retrieval_model_type == "colpali":
colpali_model = ColPaliRetrievalModel(
backbone_name_or_path=local_retrieval_model_dir,
adapter_name_or_path=local_retrieval_adapter_model_dir,
)
retrieval_model = colpali_model
if args.data_name == "m3-docvqa":
dataset = M3DocVQADataset(args=args)
def collate_fn(examples):
out = {}
if args.retrieval_model_type == "colpali":
for k in ["doc_id", "images"]:
out[k] = [ex[k] for ex in examples]
return out
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
collate_fn=collate_fn,
batch_size=1,
shuffle=False,
batch_sampler=None,
sampler=None,
drop_last=False,
num_workers=args.dataloader_num_workers,
)
retrieval_model.model, data_loader = accelerator.prepare(
retrieval_model.model, data_loader
)
all_results = []
save_dir = Path(args.output_dir)
save_dir.mkdir(exist_ok=True, parents=True)
logger.info(f"Results will be saved at: {save_dir}")
for i, datum in enumerate(tqdm(data_loader)):
print(f"{i} / {len(data_loader)}")
if args.data_name == "mp-docvqa":
page_name = datum["page_name"][0]
logger.info(page_name)
else:
doc_id = datum["doc_id"][0]
logger.info(doc_id)
if args.retrieval_model_type == "colpali":
images = datum["images"][0]
doc_embs = colpali_model.encode_images(
images=images,
batch_size=args.per_device_eval_batch_size,
to_cpu=True,
use_tqdm=False,
)
# [n_pages, n_tokens, emb_dim]
doc_embs = torch.stack(doc_embs, dim=0)
# Store embedding as BF16 by default
doc_embs = doc_embs.to(torch.bfloat16)
logger.info(doc_embs.shape)
if args.retrieval_model_type == "colpali":
logger.info(doc_embs[0, 0, :5])
# Save the embedding
if args.data_name == "mp-docvqa":
local_save_fname = f"{page_name}.safetensors"
else:
local_save_fname = f"{doc_id}.safetensors"
local_save_path = save_dir / local_save_fname
if args.retrieval_model_type == "colpali":
safetensors.torch.save_file({"embeddings": doc_embs}, local_save_path)
all_results.append({"save_path": local_save_path})
logger.info(
f"Process {global_rank()}:{local_rank()} Results correctly saved at {save_dir}"
)
if is_distributed():
barrier()
if __name__ == "__main__":
main()