Release commit
Signed-off-by: Stephen Augustus <saugustus2@bloomberg.net>
This commit is contained in:
180
examples/run_page_embedding.py
Normal file
180
examples/run_page_embedding.py
Normal file
@ -0,0 +1,180 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import safetensors
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
from m3docrag.datasets.m3_docvqa import M3DocVQADataset
|
||||
from m3docrag.retrieval import ColPaliRetrievalModel
|
||||
from m3docrag.utils.args import parse_args
|
||||
from m3docrag.utils.distributed import (
|
||||
barrier,
|
||||
global_rank,
|
||||
is_distributed,
|
||||
local_rank,
|
||||
log_runtime_info,
|
||||
print_gpu_stats,
|
||||
)
|
||||
from m3docrag.utils.paths import (
|
||||
LOCAL_DATA_DIR,
|
||||
LOCAL_MODEL_DIR,
|
||||
)
|
||||
|
||||
logger.info(torch.__version__)
|
||||
logger.info(transformers.__version__)
|
||||
logger.info(accelerate.__version__)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
log_runtime_info()
|
||||
print_gpu_stats()
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
if not is_distributed() or global_rank() == 0:
|
||||
logger.info(f"Process {global_rank()}:{local_rank()} - args {args}")
|
||||
|
||||
if is_distributed():
|
||||
barrier()
|
||||
|
||||
local_data_dir = Path(LOCAL_DATA_DIR) / args.data_name
|
||||
local_retrieval_model_dir = (
|
||||
Path(LOCAL_MODEL_DIR) / args.retrieval_model_name_or_path
|
||||
)
|
||||
local_retrieval_adapter_model_dir = (
|
||||
Path(LOCAL_MODEL_DIR) / args.retrieval_adapter_model_name_or_path
|
||||
)
|
||||
|
||||
# Download datasets / model checkpoints
|
||||
if not is_distributed() or global_rank() == 0:
|
||||
if not local_data_dir.exists():
|
||||
raise ValueError(f"Data directory {local_data_dir} does not exist")
|
||||
|
||||
assert args.use_retrieval, args.use_retrieval
|
||||
|
||||
if not local_retrieval_model_dir.exists():
|
||||
raise ValueError(
|
||||
f"Retrieval model directory {local_retrieval_model_dir} does not exist"
|
||||
)
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
if not local_retrieval_adapter_model_dir.exists():
|
||||
raise ValueError(
|
||||
f"Retrieval adapter model directory {local_retrieval_adapter_model_dir} does not exist"
|
||||
)
|
||||
|
||||
if is_distributed():
|
||||
barrier()
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
colpali_model = ColPaliRetrievalModel(
|
||||
backbone_name_or_path=local_retrieval_model_dir,
|
||||
adapter_name_or_path=local_retrieval_adapter_model_dir,
|
||||
)
|
||||
retrieval_model = colpali_model
|
||||
|
||||
if args.data_name == "m3-docvqa":
|
||||
dataset = M3DocVQADataset(args=args)
|
||||
|
||||
def collate_fn(examples):
|
||||
out = {}
|
||||
if args.retrieval_model_type == "colpali":
|
||||
for k in ["doc_id", "images"]:
|
||||
out[k] = [ex[k] for ex in examples]
|
||||
return out
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
batch_sampler=None,
|
||||
sampler=None,
|
||||
drop_last=False,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
retrieval_model.model, data_loader = accelerator.prepare(
|
||||
retrieval_model.model, data_loader
|
||||
)
|
||||
|
||||
all_results = []
|
||||
|
||||
save_dir = Path(args.output_dir)
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
logger.info(f"Results will be saved at: {save_dir}")
|
||||
|
||||
for i, datum in enumerate(tqdm(data_loader)):
|
||||
print(f"{i} / {len(data_loader)}")
|
||||
|
||||
if args.data_name == "mp-docvqa":
|
||||
page_name = datum["page_name"][0]
|
||||
logger.info(page_name)
|
||||
else:
|
||||
doc_id = datum["doc_id"][0]
|
||||
logger.info(doc_id)
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
images = datum["images"][0]
|
||||
|
||||
doc_embs = colpali_model.encode_images(
|
||||
images=images,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
to_cpu=True,
|
||||
use_tqdm=False,
|
||||
)
|
||||
|
||||
# [n_pages, n_tokens, emb_dim]
|
||||
doc_embs = torch.stack(doc_embs, dim=0)
|
||||
|
||||
# Store embedding as BF16 by default
|
||||
doc_embs = doc_embs.to(torch.bfloat16)
|
||||
|
||||
logger.info(doc_embs.shape)
|
||||
if args.retrieval_model_type == "colpali":
|
||||
logger.info(doc_embs[0, 0, :5])
|
||||
|
||||
# Save the embedding
|
||||
if args.data_name == "mp-docvqa":
|
||||
local_save_fname = f"{page_name}.safetensors"
|
||||
else:
|
||||
local_save_fname = f"{doc_id}.safetensors"
|
||||
local_save_path = save_dir / local_save_fname
|
||||
|
||||
if args.retrieval_model_type == "colpali":
|
||||
safetensors.torch.save_file({"embeddings": doc_embs}, local_save_path)
|
||||
|
||||
all_results.append({"save_path": local_save_path})
|
||||
|
||||
logger.info(
|
||||
f"Process {global_rank()}:{local_rank()} Results correctly saved at {save_dir}"
|
||||
)
|
||||
|
||||
if is_distributed():
|
||||
barrier()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user