Release commit
Signed-off-by: Stephen Augustus <saugustus2@bloomberg.net>
This commit is contained in:
167
m3docvqa/.gitignore
vendored
Normal file
167
m3docvqa/.gitignore
vendored
Normal file
@ -0,0 +1,167 @@
|
||||
.vscode/
|
||||
notebooks/node_modules/
|
||||
notebooks/package-lock.json
|
||||
dataset/*.jsonl
|
||||
dataset/img
|
||||
dataset/cache
|
||||
dataset/img_features
|
||||
baselines/data/
|
||||
baselines/output/
|
||||
baselines/image_qa/training_stats/
|
||||
baselines/image_qa/checkpoints/
|
||||
baselines/image_qa/analysis/
|
||||
deps/vilbert-multi-task/data/
|
||||
deps/vilbert-multi-task/save/
|
||||
deps/vilbert-multi-task/multi_task_model*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
|
||||
# DATA
|
||||
data/
|
||||
|
||||
# for MacOS
|
||||
*.DS_Store
|
||||
|
||||
*.png
|
||||
dataset
|
||||
multimodalqa_screenshots
|
||||
*.json
|
||||
*.pdf
|
||||
*.jsonl*
|
||||
|
||||
setup_BCOS_README.md
|
||||
*.ipynb
|
||||
|
||||
.env
|
||||
READDME copy.md
|
200
m3docvqa/README.md
Normal file
200
m3docvqa/README.md
Normal file
@ -0,0 +1,200 @@
|
||||
# M3DocVQA
|
||||
|
||||
Dataset generation package for [M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding.](https://m3docrag.github.io/)
|
||||
|
||||
## Summary
|
||||
|
||||
M3DocVQA (Multi-modal Multi-page Multi-Document Visual Question Answering) is a new benchmark for evaluating open-domain DocVQA over 3,000+ PDF documents with 40,000+ pages. M3DocVQA significantly raises the challenge of DocVQA to answering questions from a large document corpus (Sec. 3). By extending the MultimodalQA dataset’s closed-domain context to an open-domain setting, M3DocVQA introduces 2,441 multi-hop questions spanning 3,368 PDF documents, which collectively contain over 41,005 pages of diverse multi-modal content, including text, images, and tables. The dataset generated presents real-world challenges by requiring models to navigate complex reasoning paths across pages and within various types of document elements, better reflecting the intricacies of document understanding.
|
||||
|
||||
<img src='../assets/dataset.png'>
|
||||
|
||||
Comparison of existing DocVQA datasets (left: e.g., DocVQA) and the generated `M3DocVQA` dataset (right). In contrast to previous DocVQA datasets that have questions that are specific to a single provided PDF (e.g., `What was the gross profit in the year 2009?`), M3DocVQA contains information-seeking questions that benchmark open-domain question answering capabilities across more than `3,000 PDF documents` (i.e., `40,000+ pages`).
|
||||
|
||||
<img src='../assets/data_collection.png'>
|
||||
|
||||
We extend the question-answer pairs from a short-context VQA dataset to a more complex setting that includes:
|
||||
1. PDF documents.
|
||||
2. Open-domain contexts.
|
||||
|
||||
We first collect the URLs of all supporting contexts (Wikipedia documents) of individual questions of [MultimodalQA](https://github.com/allenai/multimodalqa). This tool then creates PDF versions from their URLs by rendering them in a Chromium web browser.
|
||||
|
||||
## M3DocVQA Dataset Creation Pipeline
|
||||
|
||||
This part of the repository provides scripts to create the `M3DocVQA` dataset, including functionalities to download Wikipedia pages as PDFs, check and clean corrupted PDFs, extract images, and organize files into directories for training and evaluation.
|
||||
|
||||
### Overview
|
||||
|
||||
The scripts allows users to:
|
||||
- Download Wikipedia pages in either PDF or PNG format.
|
||||
- Verify and clean downloaded PDFs.
|
||||
- Extract images from PDFs.
|
||||
- Organize files into directories based on split information for training/evaluation.
|
||||
|
||||
## Installation
|
||||
|
||||
```
|
||||
git clone <url-tbd>
|
||||
cd <repo-name-tbd>/m3docvqa
|
||||
```
|
||||
|
||||
### Install Python Package
|
||||
|
||||
We used Python 3.10.
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Setup Playwright
|
||||
|
||||
```bash
|
||||
# e.g., download browsers, ffmpeg, etc.
|
||||
playwright install
|
||||
playwright install-deps
|
||||
```
|
||||
|
||||
### Test the Package
|
||||
```bash
|
||||
pytest tests
|
||||
```
|
||||
|
||||
**Note**: The tests might fail if `poppler-utils` is not installed on your system. You need to make sure you have `poppler-utils` installed for `pdf2image`. Please refer to these [detailed instructions](https://pdf2image.readthedocs.io/en/latest/installation.html).
|
||||
|
||||
### Additional Setup
|
||||
Ensure the required directories and metadata files are available before running the scripts. Continue as directed to get the required data.
|
||||
|
||||
## Usage
|
||||
|
||||
The main script (`main.py`) supports several actions, each of which targets a specific portion of the dataset creation process.
|
||||
|
||||
### Command Structure
|
||||
```bash
|
||||
python main.py <action> [options]
|
||||
```
|
||||
|
||||
### Available Actions
|
||||
- `download_pdfs`: Download PDFs from URLs provided in the metadata.
|
||||
- `check_pdfs`: Verify if the downloaded PDFs are valid.
|
||||
- `extract_images`: Extract images from the pages of the downloaded PDFs.
|
||||
- `organize_files`: Organize downloaded PDFs into specified directory splits.
|
||||
- `download_mmqa`: Download and decompress the MMQA dataset.
|
||||
- `generate_wiki_mapping`: Generate a mapping of 'id' to 'url' from multiple JSONL files.
|
||||
|
||||
## Steps for Generating the M3DocVQA Dataset
|
||||
|
||||
### Step 1: Download the MultiModalQA Dataset
|
||||
Use the `download_mmqa` action to download and decompress the MultiModalQA dataset files.
|
||||
|
||||
```bash
|
||||
python main.py download_mmqa --output_dir=./multimodalqa
|
||||
```
|
||||
|
||||
Output:
|
||||
Decompressed JSONL files
|
||||
```bash
|
||||
MMQA_train.jsonl
|
||||
MMQA_dev.jsonl
|
||||
MMQA_texts.jsonl
|
||||
MMQA_images.jsonl
|
||||
MMQA_tables.jsonl
|
||||
```
|
||||
|
||||
These files will be stored in the `./multimodalqa/` directory.
|
||||
|
||||
### Step 2: Generate Wiki Mapping
|
||||
Use the `generate_wiki_mapping` action to create a mapping of `id` to `url` from the downloaded JSONL files.
|
||||
|
||||
```bash
|
||||
python main.py generate_wiki_mapping --text=./multimodalqa/MMQA_texts.jsonl --image=./multimodalqa/MMQA_images.jsonl --table=./multimodalqa/MMQA_tables.jsonl --output=./id_url_mapping.jsonl
|
||||
```
|
||||
Output:
|
||||
|
||||
A JSONL file `id_url_mapping.jsonl` containing the ID and corresponding URL mappings.
|
||||
|
||||
### Step 3: Download Wikipedia Articles as PDFs
|
||||
Use the `download_pdfs` action to download Wikipedia articles in a PDF format based on the generated mapping.
|
||||
|
||||
```bash
|
||||
python main.py download_pdfs --metadata_path=./id_url_mapping.jsonl --pdf_dir=./pdfs --result_log_path=./download_results.jsonl --first_n=10 --supporting_doc_ids_per_split=./supporting_doc_ids_per_split.json --split=dev
|
||||
```
|
||||
|
||||
Options:
|
||||
- `--metadata_path`: Path to the id_url_mapping.jsonl file.
|
||||
- `--pdf_dir`: Directory to save the downloaded PDFs.
|
||||
- `--result_log_path`: Path to log the download results.
|
||||
- `--first_n`: Downloads the first N PDFs for testing. **Do not use this option for downloading all the PDFs.**
|
||||
- `--supporting_doc_ids_per_split`: Path to JSON file containing document IDs for each split. `dev` is the default split, as all of the experimental results in the `M3DocRAG` paper were reported on the `dev` split. Anyone interested in downloading the PDFs in the `train` split can provide `--supporting_doc_ids_per_split=train` as the option. In case anyone is interested in downloading all the PDFs, one can also provide `--supporting_doc_ids_per_split=all` as an option.
|
||||
|
||||
Output:
|
||||
|
||||
- PDF files for Wikipedia articles, saved in the `./pdfs/` directory.
|
||||
- A `download_results.jsonl` file logging the status of each download.
|
||||
|
||||
### Step 4: Check PDF Integrity
|
||||
Use the `check_pdfs` action to verify the integrity of the downloaded PDFs.
|
||||
|
||||
```bash
|
||||
python main.py check_pdfs --pdf_dir=./pdfs
|
||||
```
|
||||
Output:
|
||||
|
||||
Identifies and logs corrupted or unreadable PDFs.
|
||||
|
||||
### Step 5: Organize Files into Splits
|
||||
Use the `organize_files` action to organize the downloaded PDFs into specific splits (e.g., `train`, `dev`) based on a split information file.
|
||||
|
||||
```bash
|
||||
python main.py organize_files --all_pdf_dir=./pdfs --target_dir_base=./splits --split=dev --split_metadata_file=./multimodalqa/MMQA_dev.jsonl
|
||||
```
|
||||
|
||||
If train split is needed:
|
||||
|
||||
```bash
|
||||
python main.py organize_files --all_pdf_dir=./pdfs --target_dir_base=./splits --split=train --split_metadata_file=./multimodalqa/MMQA_train.jsonl
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
- Organized PDFs into directories in `./splits/pdfs_train/` and `./splits/pdfs_dev/`.
|
||||
- Files that store document IDs of each split `./train_doc_ids.json` and `./dev_doc_ids.json`.
|
||||
|
||||
**Note** - In the [M3DocRAG](https://arxiv.org/abs/2411.04952) paper, we only use the `dev` split for our experiments.
|
||||
|
||||
### Step 6: Extract Images from PDFs
|
||||
Use the `extract_images` action to extract images from the downloaded PDFs. A PNG image of each page of the PDFs is extracted. These images are used for both `retrieval` using `ColPali/ColQwen`, as well as `question answering` using the LLMs mentioned in the [M3DocRAG](https://arxiv.org/abs/2411.04952) paper.
|
||||
|
||||
```bash
|
||||
python main.py extract_images --pdf_dir=./splits/pdfs_dev/ --image_dir=./images/images_dev
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
Extracted images from the PDFs in the dev split are saved in the `./images/images_dev` directory.
|
||||
|
||||
After following these steps, your dataset directory structure will look like this:
|
||||
|
||||
```
|
||||
./
|
||||
|-- multimodalqa/
|
||||
| |-- MMQA_train.jsonl
|
||||
| |-- MMQA_dev.jsonl
|
||||
| |-- MMQA_texts.jsonl
|
||||
| |-- MMQA_images.jsonl
|
||||
| |-- MMQA_tables.jsonl
|
||||
|-- id_url_mapping.jsonl
|
||||
|-- dev_doc_ids.json
|
||||
|-- train_doc_ids.json
|
||||
|-- supporting_doc_ids_per_split.json
|
||||
|-- download_results.jsonl
|
||||
|-- pdfs/
|
||||
| |-- <article_1>.pdf
|
||||
| |-- <article_2>.pdf
|
||||
|-- images/
|
||||
|-- |--images_dev/
|
||||
| | |-- <doc_id_1_page_1>.png
|
||||
| | |-- <doc_id_2_page_2>.png
|
||||
|-- splits/
|
||||
| |-- pdfs_dev/
|
||||
| | |-- <doc_id_1>.pdf
|
||||
| | |-- <doc_id_2>.pdf
|
||||
```
|
201
m3docvqa/main.py
Normal file
201
m3docvqa/main.py
Normal file
@ -0,0 +1,201 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""Main Script for M3DocVQA Dataset Creation Pipeline.
|
||||
|
||||
This script orchestrates downloading PDFs or PNGs, checking for corrupted PDFs, extracting images,
|
||||
organizing them into directories, downloading/decompressing MMQA data, and creating wiki links mapping.
|
||||
|
||||
Usage:
|
||||
python main.py <action> [other options]
|
||||
|
||||
Actions:
|
||||
- download_pdfs: Download PDFs from URLs provided in metadata.
|
||||
- check_pdfs: Verify if the downloaded PDFs are valid.
|
||||
- extract_images: Extract images from the pages of downloaded PDFs.
|
||||
- organize_files: Organize downloaded PDFs into specified directory splits.
|
||||
- download_mmqa: Download and decompress the MMQA dataset.
|
||||
- generate_wiki_mapping: Generate a mapping of 'id' to 'url' from multiple JSONL files.
|
||||
|
||||
Example:
|
||||
python main.py generate_wiki_mapping --text=MMQA_texts.jsonl --image=MMQA_images.jsonl --table=MMQA_tables.jsonl --output=id_url_mapping.jsonl
|
||||
"""
|
||||
|
||||
import fire
|
||||
import json
|
||||
import jsonlines
|
||||
from pathlib import Path
|
||||
from m3docvqa.downloader import download_wiki_page
|
||||
from m3docvqa.pdf_utils import is_pdf_downloaded, is_pdf_clean, get_images_from_pdf
|
||||
from m3docvqa.split_utils import create_split_dirs
|
||||
from m3docvqa.mmqa_downloader import download_and_decompress_mmqa
|
||||
from m3docvqa.wiki_mapper import generate_wiki_links_mapping
|
||||
from loguru import logger
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def _prepare_download(
|
||||
metadata_path: Path | str,
|
||||
output_dir: Path | str,
|
||||
first_n: int,
|
||||
doc_ids: set,
|
||||
) -> tuple[list[str], list[Path]]:
|
||||
"""Prepare URLs and save paths for downloading.
|
||||
|
||||
Args:
|
||||
metadata_path (Path): Path to the metadata JSONL file.
|
||||
output_dir (str): Directory where files will be saved.
|
||||
first_n (int): Maximum number of entries to process.
|
||||
doc_ids (set): Set of doc ids to filter for downloading.
|
||||
|
||||
Returns:
|
||||
tuple[list[str], list[Path]]: URLs and save paths for downloading.
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
urls, save_paths = [], []
|
||||
with jsonlines.open(metadata_path) as reader:
|
||||
for line in reader:
|
||||
if len(urls) == first_n:
|
||||
break
|
||||
|
||||
doc_id = line.get("id")
|
||||
if doc_ids and doc_id not in doc_ids:
|
||||
continue
|
||||
|
||||
url = line.get("url")
|
||||
save_path = output_dir / f"{doc_id}.pdf"
|
||||
if not is_pdf_downloaded(save_path):
|
||||
urls.append(url)
|
||||
save_paths.append(save_path)
|
||||
|
||||
return urls, save_paths
|
||||
|
||||
|
||||
def download_pdfs(
|
||||
metadata_path: Path | str,
|
||||
pdf_dir: Path | str,
|
||||
result_log_path: Path | str,
|
||||
supporting_doc_ids_per_split: Path | str,
|
||||
first_n: int = -1,
|
||||
proc_id: int = 0,
|
||||
n_proc: int = 1,
|
||||
split: str = 'dev',
|
||||
):
|
||||
"""Download Wikipedia pages as PDFs."""
|
||||
# Load document ids for the specified split
|
||||
if supporting_doc_ids_per_split:
|
||||
with open(supporting_doc_ids_per_split, "r") as f:
|
||||
doc_ids_per_split = json.load(f)
|
||||
split_doc_ids = {
|
||||
"train": set(doc_ids_per_split.get("train", [])),
|
||||
"dev": set(doc_ids_per_split.get("dev", [])),
|
||||
"all": set(doc_ids_per_split.get("train", []) + doc_ids_per_split.get("dev", []))
|
||||
}
|
||||
if split not in split_doc_ids:
|
||||
raise ValueError(f"Invalid or missing split. Expected one of {split_doc_ids.keys()}")
|
||||
doc_ids = split_doc_ids.get(split, split_doc_ids.get("all"))
|
||||
logger.info(f"Downloading documents for split: {split} with {len(doc_ids)} document IDs.")
|
||||
|
||||
urls, save_paths = _prepare_download(metadata_path, pdf_dir, first_n, doc_ids)
|
||||
logger.info(f"Starting download of {len(urls)} PDFs to {pdf_dir}")
|
||||
download_results = download_wiki_page(urls, save_paths, "pdf", result_log_path, proc_id, n_proc)
|
||||
logger.info(f"Download completed with {sum(download_results)} successful downloads out of {len(urls)}")
|
||||
|
||||
|
||||
def check_pdfs(pdf_dir: str, proc_id: int = 0, n_proc: int = 1):
|
||||
"""Verifies the integrity of downloaded PDFs."""
|
||||
corrupted_paths = []
|
||||
total_checked, corrupted_count = 0, 0
|
||||
|
||||
pdf_files = list(Path(pdf_dir).glob("*.pdf"))
|
||||
for pdf_path in tqdm(pdf_files, disable=(proc_id != 0), desc="Checking PDFs"):
|
||||
total_checked += 1
|
||||
if not is_pdf_downloaded(pdf_path) or not is_pdf_clean(pdf_path):
|
||||
corrupted_paths.append(pdf_path)
|
||||
corrupted_count += 1
|
||||
|
||||
logger.info(f"Checked {total_checked} PDFs: {corrupted_count} corrupted files.")
|
||||
if corrupted_paths:
|
||||
logger.warning(f"Corrupted PDFs: {corrupted_paths}")
|
||||
|
||||
|
||||
def extract_images(pdf_dir: str, image_dir: str, save_type='png'):
|
||||
"""Extracts images from downloaded PDFs."""
|
||||
Path(image_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pdf_files = list(Path(pdf_dir).glob("*.pdf"))
|
||||
if not pdf_files:
|
||||
logger.warning(f"No PDFs found in {pdf_dir} for image extraction.")
|
||||
return
|
||||
|
||||
logger.info(f"Starting image extraction from {len(pdf_files)} PDFs in {pdf_dir}.")
|
||||
|
||||
for pdf_path in tqdm(pdf_files, desc="Extracting images", unit="PDF"):
|
||||
get_images_from_pdf(pdf_path, save_dir=image_dir, save_type=save_type)
|
||||
logger.info(f"Images extracted from PDFs in {pdf_dir}")
|
||||
|
||||
|
||||
def organize_files(all_pdf_dir: Path | str, target_dir_base: Path | str, split_metadata_file: str | Path, split: str):
|
||||
"""Organizes PDFs into directory splits based on split information file."""
|
||||
create_split_dirs(
|
||||
all_pdf_dir=all_pdf_dir,
|
||||
target_dir_base=target_dir_base,
|
||||
split_metadata_file=split_metadata_file,
|
||||
split=split,
|
||||
)
|
||||
logger.info(f"Files organized for {split} split: in {target_dir_base}")
|
||||
|
||||
|
||||
def download_mmqa(output_dir: str):
|
||||
"""Downloads and decompresses the MMQA dataset.
|
||||
|
||||
Args:
|
||||
output_dir (str): Directory where the MMQA files will be downloaded and decompressed.
|
||||
"""
|
||||
logger.info(f"Starting MMQA dataset download to {output_dir}")
|
||||
download_and_decompress_mmqa(output_dir)
|
||||
logger.info(f"MMQA dataset downloaded and decompressed successfully in {output_dir}")
|
||||
|
||||
|
||||
def generate_wiki_mapping(text: str, image: str, table: str, output: str = "id_url_mapping.jsonl"):
|
||||
"""Generates a mapping of 'id' to 'url' from multiple JSONL files.
|
||||
|
||||
Args:
|
||||
text (str): Path to the JSONL file containing text data from multimodalqa dataset with 'id' and 'url' fields.
|
||||
image (str): Path to the JSONL file containing image data from multimodalqa dataset with 'id' and 'url' fields.
|
||||
table (str): Path to the JSONL file containing table data from multimodalqa dataset with 'id' and 'url' fields.
|
||||
output (str): Path to save the output JSONL file. Defaults to 'id_url_mapping.jsonl'.
|
||||
"""
|
||||
logger.info("Starting wiki mapping generation...")
|
||||
generate_wiki_links_mapping(text_file=text, image_file=image, table_file=table, output_file=output)
|
||||
logger.info(f"Wiki mapping successfully saved to {output}")
|
||||
|
||||
|
||||
def main():
|
||||
fire.Fire({
|
||||
"download_mmqa": download_mmqa,
|
||||
"generate_wiki_mapping": generate_wiki_mapping,
|
||||
"download_pdfs": download_pdfs,
|
||||
"check_pdfs": check_pdfs,
|
||||
"extract_images": extract_images,
|
||||
"organize_files": organize_files,
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
36
m3docvqa/pyproject.toml
Normal file
36
m3docvqa/pyproject.toml
Normal file
@ -0,0 +1,36 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=69.5"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[project]
|
||||
name = "m3docvqa"
|
||||
version = "0.0.1"
|
||||
description = "M3DocVQA - Dataset package for M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
classifiers = ["Programming Language :: Python :: 3"]
|
||||
dependencies = [
|
||||
"loguru",
|
||||
"jsonlines",
|
||||
"fire",
|
||||
"pytest-playwright",
|
||||
"figure",
|
||||
"pdf2image",
|
||||
"pillow",
|
||||
"numpy<2.0.0",
|
||||
"pdfrw",
|
||||
"tqdm",
|
||||
"reportlab", # only used in the test cases
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"__init__.py" = ["F401"]
|
126
m3docvqa/src/m3docvqa/downloader.py
Normal file
126
m3docvqa/src/m3docvqa/downloader.py
Normal file
@ -0,0 +1,126 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
Downloader Module for M3DocVQA
|
||||
|
||||
This module provides functions to download Wikipedia pages in either PDF or PNG format
|
||||
for the M3DocVQA dataset. It uses Playwright to load and capture the pages in a headless
|
||||
browser environment and saves each page in the specified format.
|
||||
|
||||
Functions:
|
||||
- _download_wiki_page: Downloads a single Wikipedia page as a PDF or PNG.
|
||||
- download_wiki_page: Manages the downloading of multiple Wikipedia pages.
|
||||
"""
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
import jsonlines
|
||||
from tqdm.auto import tqdm
|
||||
from m3docvqa.pdf_utils import is_pdf_downloaded
|
||||
|
||||
|
||||
def _download_wiki_page(args: tuple[int, int, str, str, str, int]) -> tuple[bool, Exception | None]:
|
||||
"""Download a single Wikipedia page as a PDF or PNG using Playwright.
|
||||
|
||||
Args:
|
||||
args (Tuple[int, int, str, str, str, int]): Contains order in batch, total count, URL, save path,
|
||||
save type ('pdf' or 'png'), and process ID.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[Exception]]: A tuple where the first element is a boolean indicating success,
|
||||
and the second element is an exception if an error occurred, or None otherwise.
|
||||
"""
|
||||
order_i, total, url, save_path, save_type, proc_id = args
|
||||
|
||||
if is_pdf_downloaded(save_path):
|
||||
if proc_id == 0:
|
||||
logger.info(f"{order_i} / {total} - {save_path} already downloaded")
|
||||
return True, None
|
||||
|
||||
try:
|
||||
with sync_playwright() as p:
|
||||
browser = p.chromium.launch(headless=True)
|
||||
context = browser.new_context(ignore_https_errors=True)
|
||||
page = context.new_page()
|
||||
page.set_default_timeout(30000) # 30 seconds timeout
|
||||
|
||||
page.goto(url)
|
||||
if save_type == 'png':
|
||||
page.screenshot(path=save_path, full_page=True)
|
||||
elif save_type == 'pdf':
|
||||
page.emulate_media(media="screen")
|
||||
page.pdf(path=save_path)
|
||||
|
||||
browser.close()
|
||||
|
||||
return True, None
|
||||
except Exception as error:
|
||||
logger.warning(f"Failed to download {url} as {save_type}. Error: {error}")
|
||||
return False, error
|
||||
|
||||
|
||||
def download_wiki_page(
|
||||
urls: list[str],
|
||||
save_paths: list[str],
|
||||
save_type: str,
|
||||
result_jsonl_path: str,
|
||||
proc_id: int = 0,
|
||||
n_proc: int = 1
|
||||
) -> list[bool]:
|
||||
"""Download multiple Wikipedia pages and log progress.
|
||||
|
||||
Args:
|
||||
urls (List[str]): List of Wikipedia URLs to download.
|
||||
save_paths (List[str]): List of paths where each downloaded file will be saved.
|
||||
save_type (str): File type to save each page as ('pdf' or 'png').
|
||||
result_jsonl_path (str): Path to the JSONL file where download results will be logged.
|
||||
proc_id (int, optional): Process ID for parallel processing. Defaults to 0.
|
||||
n_proc (int, optional): Total number of processes running in parallel. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
List[bool]: A list of booleans indicating whether each download was successful.
|
||||
"""
|
||||
total = len(urls)
|
||||
all_args = [(i, total, url, str(save_path), save_type, proc_id)
|
||||
for i, (url, save_path) in enumerate(zip(urls, save_paths))]
|
||||
|
||||
pbar = tqdm(total=len(all_args), ncols=100, disable=not (proc_id == 0))
|
||||
|
||||
results = []
|
||||
n_downloaded = 0
|
||||
|
||||
# Log results to a JSONL file
|
||||
with jsonlines.open(result_jsonl_path, 'w') as writer:
|
||||
for args in all_args:
|
||||
downloaded, error = _download_wiki_page(args)
|
||||
|
||||
if downloaded:
|
||||
n_downloaded += 1
|
||||
|
||||
pbar.set_description(f"Process: {proc_id}/{n_proc} - Downloaded: {n_downloaded}/{total}")
|
||||
pbar.update(1)
|
||||
|
||||
results.append(downloaded)
|
||||
writer.write({
|
||||
'downloaded': downloaded,
|
||||
'args': [arg if not isinstance(arg, Path) else str(arg) for arg in args],
|
||||
'error': str(error) if error else None
|
||||
})
|
||||
|
||||
pbar.close()
|
||||
return results
|
119
m3docvqa/src/m3docvqa/mmqa_downloader.py
Normal file
119
m3docvqa/src/m3docvqa/mmqa_downloader.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""Downloads the portion of the multimodalqa dataset from https://github.com/allenai/multimodalqa/tree/master/dataset
|
||||
that is useful for creating the m3docvqa dataset.
|
||||
"""
|
||||
import gzip
|
||||
import requests
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def download_file(url: str, output_path: str) -> None:
|
||||
"""Downloads a file from a given URL and saves it to the specified output path.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the file to download.
|
||||
output_path (str): The path where the downloaded file will be saved.
|
||||
|
||||
Raises:
|
||||
requests.exceptions.RequestException: If the file could not be downloaded.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
with open(output_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
logger.info(f"File downloaded successfully: {output_path}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to download file from {url}: {e}")
|
||||
raise
|
||||
|
||||
def decompress_gz_file(input_path: str | Path, output_path: str | Path) -> None:
|
||||
"""
|
||||
Decompresses a `.gz` file into its original format.
|
||||
|
||||
Args:
|
||||
input_path (str | Path): Path to the `.gz` file.
|
||||
output_path (str | Path): Path where the decompressed file will be written.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input path does not exist or is not a file.
|
||||
"""
|
||||
input_path = Path(input_path)
|
||||
output_path = Path(output_path)
|
||||
|
||||
if not input_path.is_file():
|
||||
raise ValueError(f"The input file {input_path} does not exist or is not a file.")
|
||||
|
||||
with gzip.open(input_path, "rb") as f_in, open(output_path, "wb") as f_out:
|
||||
f_out.write(f_in.read())
|
||||
logger.info(f"Decompressed {input_path} to {output_path}")
|
||||
|
||||
def download_and_decompress_mmqa(output_directory: str | Path) -> None:
|
||||
"""
|
||||
Downloads and decompresses the MultiModalQA dataset files into the specified directory.
|
||||
|
||||
Args:
|
||||
output_directory (str | Path): The directory where the files will be stored.
|
||||
|
||||
Steps:
|
||||
1. Creates the output directory if it doesn't exist.
|
||||
2. Downloads the `.jsonl.gz` files.
|
||||
3. Decompresses each `.gz` file into its `.jsonl` format.
|
||||
4. Removes the `.gz` files after decompression.
|
||||
|
||||
Raises:
|
||||
requests.exceptions.RequestException: If any of the files could not be downloaded.
|
||||
"""
|
||||
# Define base URL and file names
|
||||
base_url = "https://github.com/allenai/multimodalqa/raw/refs/heads/master/dataset/"
|
||||
files = [
|
||||
"MMQA_texts.jsonl.gz",
|
||||
"MMQA_tables.jsonl.gz",
|
||||
"MMQA_images.jsonl.gz",
|
||||
"MMQA_dev.jsonl.gz",
|
||||
"MMQA_train.jsonl.gz",
|
||||
]
|
||||
|
||||
output_directory = Path(output_directory)
|
||||
|
||||
# Ensure the output directory exists
|
||||
if not output_directory.exists():
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created output directory: {output_directory}")
|
||||
|
||||
for file_name in files:
|
||||
compressed_path = output_directory / file_name
|
||||
decompressed_path = output_directory / file_name.replace(".gz", "")
|
||||
|
||||
try:
|
||||
# Step 1: Download the file
|
||||
logger.info(f"Downloading {file_name}...")
|
||||
download_file(base_url + file_name, compressed_path)
|
||||
|
||||
# Step 2: Decompress the file
|
||||
logger.info(f"Decompressing {file_name}...")
|
||||
decompress_gz_file(compressed_path, decompressed_path)
|
||||
|
||||
# Step 3: Remove the compressed file
|
||||
compressed_path.unlink()
|
||||
logger.info(f"Removed compressed file: {compressed_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_name}: {e}")
|
||||
raise
|
126
m3docvqa/src/m3docvqa/pdf_utils.py
Normal file
126
m3docvqa/src/m3docvqa/pdf_utils.py
Normal file
@ -0,0 +1,126 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""PDF Utilities Module for M3DocVQA.
|
||||
|
||||
This module provides utility functions for managing and processing PDF files in the M3DocVQA dataset.
|
||||
It includes functions for checking if a PDF has been downloaded, verifying if a PDF is clean (not corrupted),
|
||||
and extracting images from PDF pages.
|
||||
|
||||
Functions:
|
||||
- is_pdf_downloaded: Checks if a given PDF file exists and can be opened without errors.
|
||||
- is_pdf_clean: Checks if a PDF file is clean (not corrupted) and can be read without issues.
|
||||
- get_images_from_pdf: Extracts images from each page of a PDF and optionally saves them in a specified directory.
|
||||
"""
|
||||
|
||||
from pdf2image import convert_from_path
|
||||
from PIL import Image
|
||||
from pdfrw import PdfReader
|
||||
from pathlib import Path
|
||||
from io import BytesIO
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def is_pdf_downloaded(pdf_path: str) -> bool:
|
||||
"""Check if the PDF file exists and can be opened without errors.
|
||||
|
||||
Args:
|
||||
pdf_path (str): Path to the PDF file.
|
||||
|
||||
Returns:
|
||||
bool: True if the PDF file is downloaded and accessible; False otherwise.
|
||||
"""
|
||||
try:
|
||||
with open(pdf_path, "rb") as f:
|
||||
f.read(1) # Attempt to read a byte to verify file exists and is accessible
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.trace(f"Failed to open PDF at {pdf_path}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def is_pdf_clean(pdf_path: str) -> bool:
|
||||
"""Verify if a PDF file is clean (not corrupted) and can be read without errors.
|
||||
|
||||
Args:
|
||||
pdf_path (str): Path to the PDF file.
|
||||
|
||||
Returns:
|
||||
bool: True if the PDF file is clean and readable; False otherwise.
|
||||
"""
|
||||
try:
|
||||
with open(pdf_path, "rb") as f:
|
||||
idata = f.read()
|
||||
ibuffer = BytesIO(idata)
|
||||
PdfReader(ibuffer) # Attempt to read the PDF structure for validity
|
||||
return True
|
||||
except Exception as error:
|
||||
logger.warning(f"PDF at {pdf_path} is corrupted or unreadable: {error}")
|
||||
return False
|
||||
|
||||
|
||||
def get_images_from_pdf(
|
||||
pdf_path: str,
|
||||
save_dir: str = None,
|
||||
max_pages: int = None,
|
||||
dpi_resolution: int = 144,
|
||||
save_type: str = 'png'
|
||||
) -> list[Image.Image]:
|
||||
"""Extract images from each page of a PDF and optionally save them to a directory.
|
||||
|
||||
Args:
|
||||
pdf_path (str): Path to the PDF file.
|
||||
save_dir (str, optional): Directory where images will be saved. If None, images are not saved. Defaults to None.
|
||||
max_pages (int, optional): Maximum number of pages to process. If None, all pages are processed. Defaults to None.
|
||||
dpi_resolution (int, optional): Resolution for image extraction. Defaults to 144.
|
||||
save_type (str, optional): Image file type to save as ('png', 'jpg', etc.). Defaults to 'png'.
|
||||
|
||||
Returns:
|
||||
list[Image.Image]: A list of images extracted from each page of the PDF.
|
||||
"""
|
||||
pdf_path_obj = Path(pdf_path)
|
||||
assert pdf_path_obj.exists(), f"PDF file {pdf_path} does not exist."
|
||||
|
||||
out_images = []
|
||||
|
||||
# Create save directory if saving images is enabled
|
||||
if save_dir:
|
||||
save_dir_path = Path(save_dir)
|
||||
save_dir_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
try:
|
||||
# Convert PDF to images using pdf2image
|
||||
images = convert_from_path(pdf_path, dpi=dpi_resolution)
|
||||
logger.info(f"PDF {pdf_path} has {len(images)} pages.")
|
||||
|
||||
# Limit the number of pages processed if max_pages is set
|
||||
if max_pages:
|
||||
images = images[:max_pages]
|
||||
|
||||
for page_index, image in enumerate(images):
|
||||
out_images.append(image)
|
||||
|
||||
# Save image if save directory is specified
|
||||
if save_dir:
|
||||
save_page_path = save_dir_path / f"{pdf_path_obj.stem}_{page_index + 1}.{save_type}"
|
||||
if not save_page_path.exists():
|
||||
image.save(save_page_path)
|
||||
logger.info(f"Saved page {page_index + 1} as image at {save_page_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting images from PDF {pdf_path}: {e}")
|
||||
|
||||
return out_images
|
93
m3docvqa/src/m3docvqa/split_utils.py
Normal file
93
m3docvqa/src/m3docvqa/split_utils.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""Split Utilities Module for M3DocVQA.
|
||||
|
||||
This module provides utilities for organizing PDF files into split directories (e.g., train, dev)
|
||||
and compressing these directories using functions from compression_utils.
|
||||
|
||||
Functions:
|
||||
- create_split_dirs: Copies specified PDF files into separate split directories (e.g., train, dev).
|
||||
- compress_split_directory: Compresses the split directory into a `.tar.gz` archive.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import json
|
||||
import jsonlines
|
||||
from loguru import logger
|
||||
|
||||
def create_split_dirs(
|
||||
all_pdf_dir: str | Path,
|
||||
target_dir_base: str | Path,
|
||||
split_metadata_file: str | Path,
|
||||
split: str
|
||||
) -> None:
|
||||
"""Copies specified PDF files into a target directory based on a given split.
|
||||
|
||||
Args:
|
||||
all_pdf_dir (Union[str, Path]): Path to the directory containing all downloaded PDF files.
|
||||
target_dir_base (Union[str, Path]): Base directory where the split-specific directory will be created.
|
||||
split_metadata_file (Union[str, Path]): Path to the metadata JSONL file for the split.
|
||||
split (str): Split type ('train' or 'dev').
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the JSONL metadata file does not exist.
|
||||
ValueError: If the split is not 'train' or 'dev'.
|
||||
"""
|
||||
# Validate split type
|
||||
if split not in {"train", "dev"}:
|
||||
raise ValueError(f"Invalid split: {split}. Expected 'train' or 'dev'.")
|
||||
|
||||
all_pdf_dir = Path(all_pdf_dir)
|
||||
target_dir = Path(target_dir_base) / f'pdfs_{split}'
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Validate metadata file
|
||||
split_metadata_file = Path(split_metadata_file)
|
||||
if not split_metadata_file.exists():
|
||||
raise FileNotFoundError(f"Metadata file for split '{split}' not found: {split_metadata_file}")
|
||||
|
||||
# Load all doc IDs for the split
|
||||
split_doc_ids = []
|
||||
with jsonlines.open(split_metadata_file) as reader:
|
||||
for obj in reader:
|
||||
split_doc_ids.extend(doc['doc_id'] for doc in obj['supporting_context'])
|
||||
|
||||
# Remove duplicates and log the count
|
||||
split_doc_ids = sorted(set(split_doc_ids))
|
||||
logger.info(f"Split {split} -> # supporting context: {len(split_doc_ids)}")
|
||||
|
||||
# Save the split-specific IDs to a JSON file
|
||||
split_doc_ids_output_path = Path(f'./{split}_doc_ids.json')
|
||||
with open(split_doc_ids_output_path, 'w') as f:
|
||||
json.dump(split_doc_ids, f, indent=4)
|
||||
logger.info(f"Split {split} -> saved doc IDs at {split_doc_ids_output_path}")
|
||||
|
||||
# Copy PDF files to the target directory
|
||||
missing_files = []
|
||||
for doc_id in split_doc_ids:
|
||||
pdf_file = all_pdf_dir / f"{doc_id}.pdf"
|
||||
if pdf_file.exists():
|
||||
shutil.copy(pdf_file, target_dir / pdf_file.name)
|
||||
else:
|
||||
missing_files.append(pdf_file)
|
||||
|
||||
if missing_files:
|
||||
logger.warning(f"Warning: {len(missing_files)} files are missing and will be skipped.")
|
||||
for missing_file in missing_files:
|
||||
logger.warning(f" Missing: {missing_file}")
|
||||
|
140
m3docvqa/src/m3docvqa/wiki_mapper.py
Normal file
140
m3docvqa/src/m3docvqa/wiki_mapper.py
Normal file
@ -0,0 +1,140 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""Wiki Mapper.
|
||||
|
||||
This module provides functionality to parse multimodalqa JSONL files that has been already downloaded that contains 'id' and 'url' mappings,
|
||||
merge them into a single mapping, and save the result to a JSONL file.
|
||||
|
||||
Each JSONL file should contain one JSON object per line with the following structure:
|
||||
{
|
||||
"title": "Article Title",
|
||||
"url": "https://en.wikipedia.org/wiki/Article_Title",
|
||||
"id": "unique_id",
|
||||
"text": "Text description of the article."
|
||||
}
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def parse_jsonl(file_path: str | Path) -> dict[str, str]:
|
||||
"""Parses a JSONL file from the multimodalqa dataset to extract a mapping of 'id' to 'url'.
|
||||
|
||||
Args:
|
||||
file_path (str | Path): Path to the JSONL file.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: A dictionary mapping each 'id' to its corresponding 'url'.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the JSONL file does not exist.
|
||||
ValueError: If the file contains invalid JSON lines.
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
if not file_path.is_file():
|
||||
logger.error(f"File not found: {file_path}")
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
id_url_mapping = {}
|
||||
try:
|
||||
with file_path.open("r") as file:
|
||||
for line in file:
|
||||
data = json.loads(line.strip())
|
||||
entry_id = data.get("id")
|
||||
url = data.get("url")
|
||||
if entry_id and url:
|
||||
id_url_mapping[entry_id] = url
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error decoding JSON in file {file_path}: {e}")
|
||||
raise ValueError(f"Invalid JSON in file {file_path}: {e}")
|
||||
|
||||
logger.info(f"Parsed {len(id_url_mapping)} entries from {file_path}")
|
||||
return id_url_mapping
|
||||
|
||||
|
||||
def merge_mappings(mappings: list[dict[str, str]]) -> dict[str, str]:
|
||||
"""Merges multiple mappings into a single dictionary.
|
||||
|
||||
Args:
|
||||
mappings (list[dict[str, str]]): A list of dictionaries containing 'id' to 'url' mappings.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: A merged dictionary containing all 'id' to 'url' mappings.
|
||||
"""
|
||||
merged_mapping = {}
|
||||
for mapping in mappings:
|
||||
merged_mapping.update(mapping)
|
||||
logger.info(f"Merged {len(mappings)} mappings with a total of {len(merged_mapping)} entries.")
|
||||
return merged_mapping
|
||||
|
||||
|
||||
def save_mapping_to_jsonl(mapping: dict[str, str], output_file: str | Path) -> None:
|
||||
"""Saves the 'id'-to-'url' mapping to a JSONL file.
|
||||
|
||||
Args:
|
||||
mapping (dict[str, str]): The dictionary containing 'id' to 'url' mappings.
|
||||
output_file (str | Path): The path to the output JSONL file.
|
||||
|
||||
Raises:
|
||||
IOError: If the file cannot be written.
|
||||
"""
|
||||
output_file = Path(output_file)
|
||||
try:
|
||||
with output_file.open("w") as file:
|
||||
for entry_id, url in mapping.items():
|
||||
json.dump({"id": entry_id, "url": url}, file)
|
||||
file.write("\n")
|
||||
logger.info(f"Mapping saved to {output_file}")
|
||||
except IOError as e:
|
||||
logger.error(f"Error writing to file {output_file}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def generate_wiki_links_mapping(
|
||||
text_file: str | Path, image_file: str | Path, table_file: str | Path, output_file: str | Path = "id_url_mapping.jsonl"
|
||||
) -> None:
|
||||
"""Orchestrates the process of parsing input files, merging mappings, and saving the result to JSONL.
|
||||
|
||||
Args:
|
||||
text_file (str | Path): Path to the JSONL file containing text data with 'id' and 'url' fields.
|
||||
image_file (str | Path): Path to the JSONL file containing image data with 'id' and 'url' fields.
|
||||
table_file (str | Path): Path to the JSONL file containing table data with 'id' and 'url' fields.
|
||||
output_file (str | Path): Path to save the output JSONL file. Defaults to 'id_url_mapping.jsonl'.
|
||||
|
||||
Raises:
|
||||
Exception: If any part of the pipeline fails.
|
||||
"""
|
||||
try:
|
||||
# Parse input files
|
||||
logger.info("Parsing JSONL files...")
|
||||
text_mapping = parse_jsonl(text_file)
|
||||
image_mapping = parse_jsonl(image_file)
|
||||
table_mapping = parse_jsonl(table_file)
|
||||
|
||||
# Merge mappings
|
||||
logger.info("Merging mappings...")
|
||||
merged_mapping = merge_mappings([text_mapping, image_mapping, table_mapping])
|
||||
|
||||
# Save the merged mapping
|
||||
logger.info("Saving merged mapping to output file...")
|
||||
save_mapping_to_jsonl(merged_mapping, output_file)
|
||||
logger.info(f"Mapping successfully generated and saved to {output_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating wiki links mapping: {e}")
|
||||
raise
|
105
m3docvqa/tests/test_downloader.py
Normal file
105
m3docvqa/tests/test_downloader.py
Normal file
@ -0,0 +1,105 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from pathlib import Path
|
||||
import jsonlines
|
||||
from m3docvqa.downloader import _download_wiki_page, download_wiki_page
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_urls_and_paths(tmp_path):
|
||||
"""Fixture to provide sample URLs and save paths for testing."""
|
||||
urls = ["https://en.wikipedia.org/wiki/SamplePage1", "https://en.wikipedia.org/wiki/SamplePage2"]
|
||||
save_paths = [str(tmp_path / "sample1.pdf"), str(tmp_path / "sample2.pdf")]
|
||||
return urls, save_paths
|
||||
|
||||
|
||||
@patch("m3docvqa.downloader.sync_playwright")
|
||||
def test__download_wiki_page_pdf(mock_playwright, tmp_path):
|
||||
"""Test downloading a single page as a PDF."""
|
||||
url = "https://en.wikipedia.org/wiki/SamplePage"
|
||||
save_path = tmp_path / "sample.pdf"
|
||||
args = (0, 1, url, str(save_path), 'pdf', 0)
|
||||
|
||||
# Mock Playwright behavior
|
||||
mock_browser = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
mock_playwright.return_value.__enter__.return_value.chromium.launch.return_value = mock_browser
|
||||
mock_browser.new_context.return_value = mock_context
|
||||
mock_context.new_page.return_value = mock_page
|
||||
|
||||
# Call the function
|
||||
downloaded, error = _download_wiki_page(args)
|
||||
|
||||
# Assertions
|
||||
assert downloaded is True
|
||||
assert error is None
|
||||
mock_page.goto.assert_called_once_with(url)
|
||||
mock_page.pdf.assert_called_once_with(path=str(save_path))
|
||||
|
||||
|
||||
@patch("m3docvqa.downloader.sync_playwright")
|
||||
def test__download_wiki_page_png(mock_playwright, tmp_path):
|
||||
"""Test downloading a single page as a PNG."""
|
||||
url = "https://en.wikipedia.org/wiki/SamplePage"
|
||||
save_path = tmp_path / "sample.png"
|
||||
args = (0, 1, url, str(save_path), 'png', 0)
|
||||
|
||||
# Mock Playwright behavior
|
||||
mock_browser = MagicMock()
|
||||
mock_context = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
mock_playwright.return_value.__enter__.return_value.chromium.launch.return_value = mock_browser
|
||||
mock_browser.new_context.return_value = mock_context
|
||||
mock_context.new_page.return_value = mock_page
|
||||
|
||||
# Call the function
|
||||
downloaded, error = _download_wiki_page(args)
|
||||
|
||||
# Assertions
|
||||
assert downloaded is True
|
||||
assert error is None
|
||||
mock_page.goto.assert_called_once_with(url)
|
||||
mock_page.screenshot.assert_called_once_with(path=str(save_path), full_page=True)
|
||||
|
||||
|
||||
@patch("m3docvqa.downloader._download_wiki_page")
|
||||
def test_download_wiki_page_batch(mock_download_wiki_page, tmp_path, test_urls_and_paths):
|
||||
"""Test batch downloading multiple Wikipedia pages."""
|
||||
urls, save_paths = test_urls_and_paths
|
||||
result_jsonl_path = tmp_path / "download_results.jsonl"
|
||||
|
||||
# Mock individual downloads to always succeed
|
||||
mock_download_wiki_page.side_effect = [(True, None), (True, None)]
|
||||
|
||||
# Call the function
|
||||
results = download_wiki_page(urls, save_paths, 'pdf', str(result_jsonl_path), proc_id=0, n_proc=1)
|
||||
|
||||
# Assertions
|
||||
assert results == [True, True]
|
||||
assert result_jsonl_path.exists()
|
||||
|
||||
# Check JSONL log entries
|
||||
with jsonlines.open(result_jsonl_path, 'r') as reader:
|
||||
log_entries = list(reader)
|
||||
assert len(log_entries) == 2
|
||||
assert log_entries[0]['downloaded'] is True
|
||||
assert log_entries[0]['error'] is None
|
||||
assert log_entries[1]['downloaded'] is True
|
||||
assert log_entries[1]['error'] is None
|
86
m3docvqa/tests/test_pdf_utils.py
Normal file
86
m3docvqa/tests/test_pdf_utils.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from m3docvqa.pdf_utils import is_pdf_downloaded, is_pdf_clean, get_images_from_pdf
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from reportlab.pdfgen import canvas # For creating sample PDFs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pdf(tmp_path) -> Path:
|
||||
"""Create a temporary sample PDF file for testing."""
|
||||
pdf_path = tmp_path / "sample.pdf"
|
||||
c = canvas.Canvas(str(pdf_path))
|
||||
c.drawString(100, 100, "Sample PDF text for testing.") # Add sample text to the PDF
|
||||
c.save()
|
||||
return pdf_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def corrupted_pdf(tmp_path) -> Path:
|
||||
"""Create a temporary, corrupted PDF file for testing."""
|
||||
pdf_path = tmp_path / "corrupted.pdf"
|
||||
pdf_path.write_bytes(b"%PDF-1.4 corrupted content") # Write incomplete/corrupted PDF content
|
||||
return pdf_path
|
||||
|
||||
|
||||
def test_is_pdf_downloaded_existing_pdf(sample_pdf):
|
||||
"""Test is_pdf_downloaded on a valid, existing PDF."""
|
||||
assert is_pdf_downloaded(str(sample_pdf)) is True, "Expected PDF to be recognized as downloaded."
|
||||
|
||||
|
||||
def test_is_pdf_downloaded_nonexistent_pdf(tmp_path):
|
||||
"""Test is_pdf_downloaded on a non-existent PDF file."""
|
||||
non_existent_pdf = tmp_path / "non_existent.pdf"
|
||||
assert is_pdf_downloaded(str(non_existent_pdf)) is False, "Expected non-existent PDF to be marked as not downloaded."
|
||||
|
||||
|
||||
def test_is_pdf_clean_valid_pdf(sample_pdf):
|
||||
"""Test is_pdf_clean on a valid, clean PDF."""
|
||||
assert is_pdf_clean(str(sample_pdf)) is True, "Expected PDF to be recognized as clean."
|
||||
|
||||
|
||||
def test_is_pdf_clean_corrupted_pdf(corrupted_pdf):
|
||||
"""Test is_pdf_clean on a corrupted PDF."""
|
||||
assert is_pdf_clean(str(corrupted_pdf)) is False, "Expected corrupted PDF to be marked as not clean."
|
||||
|
||||
|
||||
def test_get_images_from_pdf_extract_images(sample_pdf, tmp_path):
|
||||
"""Test get_images_from_pdf to ensure it extracts images correctly."""
|
||||
image_dir = tmp_path / "images"
|
||||
images = get_images_from_pdf(str(sample_pdf), save_dir=str(image_dir), dpi_resolution=72, save_type='png')
|
||||
|
||||
# Verify that at least one image was extracted
|
||||
assert len(images) > 0, "Expected at least one image to be extracted from the PDF."
|
||||
|
||||
# Verify that images were saved to the directory
|
||||
saved_images = list(image_dir.glob("*.png"))
|
||||
assert len(saved_images) == len(images), "Expected number of saved images to match the number of extracted images."
|
||||
|
||||
# Verify that the saved image files exist and are valid
|
||||
for image_path in saved_images:
|
||||
with Image.open(image_path) as img:
|
||||
assert img.format == "PNG", "Expected saved image to be in PNG format."
|
||||
|
||||
|
||||
def test_get_images_from_pdf_no_save_dir(sample_pdf):
|
||||
"""Test get_images_from_pdf without saving images, only returning them as a list."""
|
||||
images = get_images_from_pdf(str(sample_pdf), save_dir=None, dpi_resolution=72)
|
||||
assert len(images) > 0, "Expected at least one image to be returned without saving."
|
||||
assert all(isinstance(image, Image.Image) for image in images), "Expected all returned items to be PIL Image objects."
|
||||
|
107
m3docvqa/tests/test_split_utils.py
Normal file
107
m3docvqa/tests/test_split_utils.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Copyright 2024 Bloomberg Finance L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import json
|
||||
import jsonlines
|
||||
from unittest.mock import MagicMock, patch
|
||||
from m3docvqa.split_utils import create_split_dirs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pdf_directory(tmp_path):
|
||||
# Create a temporary directory for PDFs
|
||||
pdf_dir = tmp_path / "pdfs"
|
||||
pdf_dir.mkdir()
|
||||
# Add some mock PDF files
|
||||
(pdf_dir / "doc1.pdf").write_text("PDF content for doc1")
|
||||
(pdf_dir / "doc2.pdf").write_text("PDF content for doc2")
|
||||
return pdf_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metadata_file(tmp_path):
|
||||
# Create a temporary metadata file in JSONL format
|
||||
metadata_file = tmp_path / "MMQA_train.jsonl"
|
||||
data = [
|
||||
{"supporting_context": [{"doc_id": "doc1"}]},
|
||||
{"supporting_context": [{"doc_id": "doc2"}]}
|
||||
]
|
||||
with jsonlines.open(metadata_file, mode='w') as writer:
|
||||
writer.write_all(data)
|
||||
return metadata_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_target_directory(tmp_path):
|
||||
return tmp_path / "target"
|
||||
|
||||
|
||||
def test_create_split_dirs(mock_pdf_directory, mock_metadata_file, mock_target_directory):
|
||||
"""Test the create_split_dirs function."""
|
||||
# Prepare the split directory
|
||||
split = "train"
|
||||
|
||||
# Call the function to create split directories
|
||||
create_split_dirs(
|
||||
all_pdf_dir=mock_pdf_directory,
|
||||
target_dir_base=mock_target_directory,
|
||||
split_metadata_file=mock_metadata_file,
|
||||
split=split
|
||||
)
|
||||
|
||||
# Assert that the target directory exists and contains the expected PDF files
|
||||
target_dir = mock_target_directory / f"pdfs_{split}"
|
||||
assert target_dir.exists(), f"Directory {target_dir} was not created"
|
||||
assert (target_dir / "doc1.pdf").exists(), "doc1.pdf was not copied"
|
||||
assert (target_dir / "doc2.pdf").exists(), "doc2.pdf was not copied"
|
||||
|
||||
|
||||
def test_create_split_dirs_missing_pdf(mock_metadata_file, mock_target_directory):
|
||||
"""Test create_split_dirs when PDF files are missing."""
|
||||
# Prepare the split directory
|
||||
split = "train"
|
||||
all_pdf_dir = Path("non_existing_pdf_dir")
|
||||
|
||||
# Call the function and verify that the missing PDFs are handled correctly
|
||||
create_split_dirs(
|
||||
all_pdf_dir=all_pdf_dir,
|
||||
target_dir_base=mock_target_directory,
|
||||
split_metadata_file=mock_metadata_file,
|
||||
split=split
|
||||
)
|
||||
|
||||
target_dir = mock_target_directory / f"pdfs_{split}"
|
||||
assert target_dir.exists(), f"Directory {target_dir} was not created"
|
||||
assert not (target_dir / "doc1.pdf").exists(), "doc1.pdf should not exist"
|
||||
assert not (target_dir / "doc2.pdf").exists(), "doc2.pdf should not exist"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("split, expected_error", [
|
||||
("test_split", ValueError), # Invalid split type
|
||||
(None, ValueError), # Missing split
|
||||
])
|
||||
def test_create_split_dirs_invalid_split_type(mock_pdf_directory, mock_metadata_file, mock_target_directory, split, expected_error):
|
||||
"""Test invalid split types in create_split_dirs."""
|
||||
with pytest.raises(expected_error):
|
||||
create_split_dirs(
|
||||
all_pdf_dir=mock_pdf_directory,
|
||||
target_dir_base=mock_target_directory,
|
||||
split_metadata_file=mock_metadata_file,
|
||||
split=split
|
||||
)
|
Reference in New Issue
Block a user