Release commit

Signed-off-by: Stephen Augustus <saugustus2@bloomberg.net>
This commit is contained in:
j-min
2025-01-30 17:04:56 -05:00
committed by oir
parent e04aeadfb0
commit 27aac8d521
50 changed files with 5692 additions and 0 deletions

167
m3docvqa/.gitignore vendored Normal file
View File

@ -0,0 +1,167 @@
.vscode/
notebooks/node_modules/
notebooks/package-lock.json
dataset/*.jsonl
dataset/img
dataset/cache
dataset/img_features
baselines/data/
baselines/output/
baselines/image_qa/training_stats/
baselines/image_qa/checkpoints/
baselines/image_qa/analysis/
deps/vilbert-multi-task/data/
deps/vilbert-multi-task/save/
deps/vilbert-multi-task/multi_task_model*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# IDE
.idea/
# DATA
data/
# for MacOS
*.DS_Store
*.png
dataset
multimodalqa_screenshots
*.json
*.pdf
*.jsonl*
setup_BCOS_README.md
*.ipynb
.env
READDME copy.md

200
m3docvqa/README.md Normal file
View File

@ -0,0 +1,200 @@
# M3DocVQA
Dataset generation package for [M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding.](https://m3docrag.github.io/)
## Summary
M3DocVQA (Multi-modal Multi-page Multi-Document Visual Question Answering) is a new benchmark for evaluating open-domain DocVQA over 3,000+ PDF documents with 40,000+ pages. M3DocVQA significantly raises the challenge of DocVQA to answering questions from a large document corpus (Sec. 3). By extending the MultimodalQA datasets closed-domain context to an open-domain setting, M3DocVQA introduces 2,441 multi-hop questions spanning 3,368 PDF documents, which collectively contain over 41,005 pages of diverse multi-modal content, including text, images, and tables. The dataset generated presents real-world challenges by requiring models to navigate complex reasoning paths across pages and within various types of document elements, better reflecting the intricacies of document understanding.
<img src='../assets/dataset.png'>
Comparison of existing DocVQA datasets (left: e.g., DocVQA) and the generated `M3DocVQA` dataset (right). In contrast to previous DocVQA datasets that have questions that are specific to a single provided PDF (e.g., `What was the gross profit in the year 2009?`), M3DocVQA contains information-seeking questions that benchmark open-domain question answering capabilities across more than `3,000 PDF documents` (i.e., `40,000+ pages`).
<img src='../assets/data_collection.png'>
We extend the question-answer pairs from a short-context VQA dataset to a more complex setting that includes:
1. PDF documents.
2. Open-domain contexts.
We first collect the URLs of all supporting contexts (Wikipedia documents) of individual questions of [MultimodalQA](https://github.com/allenai/multimodalqa). This tool then creates PDF versions from their URLs by rendering them in a Chromium web browser.
## M3DocVQA Dataset Creation Pipeline
This part of the repository provides scripts to create the `M3DocVQA` dataset, including functionalities to download Wikipedia pages as PDFs, check and clean corrupted PDFs, extract images, and organize files into directories for training and evaluation.
### Overview
The scripts allows users to:
- Download Wikipedia pages in either PDF or PNG format.
- Verify and clean downloaded PDFs.
- Extract images from PDFs.
- Organize files into directories based on split information for training/evaluation.
## Installation
```
git clone <url-tbd>
cd <repo-name-tbd>/m3docvqa
```
### Install Python Package
We used Python 3.10.
```bash
pip install -e .
```
### Setup Playwright
```bash
# e.g., download browsers, ffmpeg, etc.
playwright install
playwright install-deps
```
### Test the Package
```bash
pytest tests
```
**Note**: The tests might fail if `poppler-utils` is not installed on your system. You need to make sure you have `poppler-utils` installed for `pdf2image`. Please refer to these [detailed instructions](https://pdf2image.readthedocs.io/en/latest/installation.html).
### Additional Setup
Ensure the required directories and metadata files are available before running the scripts. Continue as directed to get the required data.
## Usage
The main script (`main.py`) supports several actions, each of which targets a specific portion of the dataset creation process.
### Command Structure
```bash
python main.py <action> [options]
```
### Available Actions
- `download_pdfs`: Download PDFs from URLs provided in the metadata.
- `check_pdfs`: Verify if the downloaded PDFs are valid.
- `extract_images`: Extract images from the pages of the downloaded PDFs.
- `organize_files`: Organize downloaded PDFs into specified directory splits.
- `download_mmqa`: Download and decompress the MMQA dataset.
- `generate_wiki_mapping`: Generate a mapping of 'id' to 'url' from multiple JSONL files.
## Steps for Generating the M3DocVQA Dataset
### Step 1: Download the MultiModalQA Dataset
Use the `download_mmqa` action to download and decompress the MultiModalQA dataset files.
```bash
python main.py download_mmqa --output_dir=./multimodalqa
```
Output:
Decompressed JSONL files
```bash
MMQA_train.jsonl
MMQA_dev.jsonl
MMQA_texts.jsonl
MMQA_images.jsonl
MMQA_tables.jsonl
```
These files will be stored in the `./multimodalqa/` directory.
### Step 2: Generate Wiki Mapping
Use the `generate_wiki_mapping` action to create a mapping of `id` to `url` from the downloaded JSONL files.
```bash
python main.py generate_wiki_mapping --text=./multimodalqa/MMQA_texts.jsonl --image=./multimodalqa/MMQA_images.jsonl --table=./multimodalqa/MMQA_tables.jsonl --output=./id_url_mapping.jsonl
```
Output:
A JSONL file `id_url_mapping.jsonl` containing the ID and corresponding URL mappings.
### Step 3: Download Wikipedia Articles as PDFs
Use the `download_pdfs` action to download Wikipedia articles in a PDF format based on the generated mapping.
```bash
python main.py download_pdfs --metadata_path=./id_url_mapping.jsonl --pdf_dir=./pdfs --result_log_path=./download_results.jsonl --first_n=10 --supporting_doc_ids_per_split=./supporting_doc_ids_per_split.json --split=dev
```
Options:
- `--metadata_path`: Path to the id_url_mapping.jsonl file.
- `--pdf_dir`: Directory to save the downloaded PDFs.
- `--result_log_path`: Path to log the download results.
- `--first_n`: Downloads the first N PDFs for testing. **Do not use this option for downloading all the PDFs.**
- `--supporting_doc_ids_per_split`: Path to JSON file containing document IDs for each split. `dev` is the default split, as all of the experimental results in the `M3DocRAG` paper were reported on the `dev` split. Anyone interested in downloading the PDFs in the `train` split can provide `--supporting_doc_ids_per_split=train` as the option. In case anyone is interested in downloading all the PDFs, one can also provide `--supporting_doc_ids_per_split=all` as an option.
Output:
- PDF files for Wikipedia articles, saved in the `./pdfs/` directory.
- A `download_results.jsonl` file logging the status of each download.
### Step 4: Check PDF Integrity
Use the `check_pdfs` action to verify the integrity of the downloaded PDFs.
```bash
python main.py check_pdfs --pdf_dir=./pdfs
```
Output:
Identifies and logs corrupted or unreadable PDFs.
### Step 5: Organize Files into Splits
Use the `organize_files` action to organize the downloaded PDFs into specific splits (e.g., `train`, `dev`) based on a split information file.
```bash
python main.py organize_files --all_pdf_dir=./pdfs --target_dir_base=./splits --split=dev --split_metadata_file=./multimodalqa/MMQA_dev.jsonl
```
If train split is needed:
```bash
python main.py organize_files --all_pdf_dir=./pdfs --target_dir_base=./splits --split=train --split_metadata_file=./multimodalqa/MMQA_train.jsonl
```
Output:
- Organized PDFs into directories in `./splits/pdfs_train/` and `./splits/pdfs_dev/`.
- Files that store document IDs of each split `./train_doc_ids.json` and `./dev_doc_ids.json`.
**Note** - In the [M3DocRAG](https://arxiv.org/abs/2411.04952) paper, we only use the `dev` split for our experiments.
### Step 6: Extract Images from PDFs
Use the `extract_images` action to extract images from the downloaded PDFs. A PNG image of each page of the PDFs is extracted. These images are used for both `retrieval` using `ColPali/ColQwen`, as well as `question answering` using the LLMs mentioned in the [M3DocRAG](https://arxiv.org/abs/2411.04952) paper.
```bash
python main.py extract_images --pdf_dir=./splits/pdfs_dev/ --image_dir=./images/images_dev
```
Output:
Extracted images from the PDFs in the dev split are saved in the `./images/images_dev` directory.
After following these steps, your dataset directory structure will look like this:
```
./
|-- multimodalqa/
|   |-- MMQA_train.jsonl
|   |-- MMQA_dev.jsonl
|   |-- MMQA_texts.jsonl
|   |-- MMQA_images.jsonl
|   |-- MMQA_tables.jsonl
|-- id_url_mapping.jsonl
|-- dev_doc_ids.json
|-- train_doc_ids.json
|-- supporting_doc_ids_per_split.json
|-- download_results.jsonl
|-- pdfs/
|   |-- <article_1>.pdf
|   |-- <article_2>.pdf
|-- images/
|-- |--images_dev/
|   | |-- <doc_id_1_page_1>.png
| | |-- <doc_id_2_page_2>.png
|-- splits/
|   |-- pdfs_dev/
|   |   |-- <doc_id_1>.pdf
|   |   |-- <doc_id_2>.pdf
```

201
m3docvqa/main.py Normal file
View File

@ -0,0 +1,201 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Main Script for M3DocVQA Dataset Creation Pipeline.
This script orchestrates downloading PDFs or PNGs, checking for corrupted PDFs, extracting images,
organizing them into directories, downloading/decompressing MMQA data, and creating wiki links mapping.
Usage:
python main.py <action> [other options]
Actions:
- download_pdfs: Download PDFs from URLs provided in metadata.
- check_pdfs: Verify if the downloaded PDFs are valid.
- extract_images: Extract images from the pages of downloaded PDFs.
- organize_files: Organize downloaded PDFs into specified directory splits.
- download_mmqa: Download and decompress the MMQA dataset.
- generate_wiki_mapping: Generate a mapping of 'id' to 'url' from multiple JSONL files.
Example:
python main.py generate_wiki_mapping --text=MMQA_texts.jsonl --image=MMQA_images.jsonl --table=MMQA_tables.jsonl --output=id_url_mapping.jsonl
"""
import fire
import json
import jsonlines
from pathlib import Path
from m3docvqa.downloader import download_wiki_page
from m3docvqa.pdf_utils import is_pdf_downloaded, is_pdf_clean, get_images_from_pdf
from m3docvqa.split_utils import create_split_dirs
from m3docvqa.mmqa_downloader import download_and_decompress_mmqa
from m3docvqa.wiki_mapper import generate_wiki_links_mapping
from loguru import logger
from tqdm.auto import tqdm
def _prepare_download(
metadata_path: Path | str,
output_dir: Path | str,
first_n: int,
doc_ids: set,
) -> tuple[list[str], list[Path]]:
"""Prepare URLs and save paths for downloading.
Args:
metadata_path (Path): Path to the metadata JSONL file.
output_dir (str): Directory where files will be saved.
first_n (int): Maximum number of entries to process.
doc_ids (set): Set of doc ids to filter for downloading.
Returns:
tuple[list[str], list[Path]]: URLs and save paths for downloading.
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
urls, save_paths = [], []
with jsonlines.open(metadata_path) as reader:
for line in reader:
if len(urls) == first_n:
break
doc_id = line.get("id")
if doc_ids and doc_id not in doc_ids:
continue
url = line.get("url")
save_path = output_dir / f"{doc_id}.pdf"
if not is_pdf_downloaded(save_path):
urls.append(url)
save_paths.append(save_path)
return urls, save_paths
def download_pdfs(
metadata_path: Path | str,
pdf_dir: Path | str,
result_log_path: Path | str,
supporting_doc_ids_per_split: Path | str,
first_n: int = -1,
proc_id: int = 0,
n_proc: int = 1,
split: str = 'dev',
):
"""Download Wikipedia pages as PDFs."""
# Load document ids for the specified split
if supporting_doc_ids_per_split:
with open(supporting_doc_ids_per_split, "r") as f:
doc_ids_per_split = json.load(f)
split_doc_ids = {
"train": set(doc_ids_per_split.get("train", [])),
"dev": set(doc_ids_per_split.get("dev", [])),
"all": set(doc_ids_per_split.get("train", []) + doc_ids_per_split.get("dev", []))
}
if split not in split_doc_ids:
raise ValueError(f"Invalid or missing split. Expected one of {split_doc_ids.keys()}")
doc_ids = split_doc_ids.get(split, split_doc_ids.get("all"))
logger.info(f"Downloading documents for split: {split} with {len(doc_ids)} document IDs.")
urls, save_paths = _prepare_download(metadata_path, pdf_dir, first_n, doc_ids)
logger.info(f"Starting download of {len(urls)} PDFs to {pdf_dir}")
download_results = download_wiki_page(urls, save_paths, "pdf", result_log_path, proc_id, n_proc)
logger.info(f"Download completed with {sum(download_results)} successful downloads out of {len(urls)}")
def check_pdfs(pdf_dir: str, proc_id: int = 0, n_proc: int = 1):
"""Verifies the integrity of downloaded PDFs."""
corrupted_paths = []
total_checked, corrupted_count = 0, 0
pdf_files = list(Path(pdf_dir).glob("*.pdf"))
for pdf_path in tqdm(pdf_files, disable=(proc_id != 0), desc="Checking PDFs"):
total_checked += 1
if not is_pdf_downloaded(pdf_path) or not is_pdf_clean(pdf_path):
corrupted_paths.append(pdf_path)
corrupted_count += 1
logger.info(f"Checked {total_checked} PDFs: {corrupted_count} corrupted files.")
if corrupted_paths:
logger.warning(f"Corrupted PDFs: {corrupted_paths}")
def extract_images(pdf_dir: str, image_dir: str, save_type='png'):
"""Extracts images from downloaded PDFs."""
Path(image_dir).mkdir(parents=True, exist_ok=True)
pdf_files = list(Path(pdf_dir).glob("*.pdf"))
if not pdf_files:
logger.warning(f"No PDFs found in {pdf_dir} for image extraction.")
return
logger.info(f"Starting image extraction from {len(pdf_files)} PDFs in {pdf_dir}.")
for pdf_path in tqdm(pdf_files, desc="Extracting images", unit="PDF"):
get_images_from_pdf(pdf_path, save_dir=image_dir, save_type=save_type)
logger.info(f"Images extracted from PDFs in {pdf_dir}")
def organize_files(all_pdf_dir: Path | str, target_dir_base: Path | str, split_metadata_file: str | Path, split: str):
"""Organizes PDFs into directory splits based on split information file."""
create_split_dirs(
all_pdf_dir=all_pdf_dir,
target_dir_base=target_dir_base,
split_metadata_file=split_metadata_file,
split=split,
)
logger.info(f"Files organized for {split} split: in {target_dir_base}")
def download_mmqa(output_dir: str):
"""Downloads and decompresses the MMQA dataset.
Args:
output_dir (str): Directory where the MMQA files will be downloaded and decompressed.
"""
logger.info(f"Starting MMQA dataset download to {output_dir}")
download_and_decompress_mmqa(output_dir)
logger.info(f"MMQA dataset downloaded and decompressed successfully in {output_dir}")
def generate_wiki_mapping(text: str, image: str, table: str, output: str = "id_url_mapping.jsonl"):
"""Generates a mapping of 'id' to 'url' from multiple JSONL files.
Args:
text (str): Path to the JSONL file containing text data from multimodalqa dataset with 'id' and 'url' fields.
image (str): Path to the JSONL file containing image data from multimodalqa dataset with 'id' and 'url' fields.
table (str): Path to the JSONL file containing table data from multimodalqa dataset with 'id' and 'url' fields.
output (str): Path to save the output JSONL file. Defaults to 'id_url_mapping.jsonl'.
"""
logger.info("Starting wiki mapping generation...")
generate_wiki_links_mapping(text_file=text, image_file=image, table_file=table, output_file=output)
logger.info(f"Wiki mapping successfully saved to {output}")
def main():
fire.Fire({
"download_mmqa": download_mmqa,
"generate_wiki_mapping": generate_wiki_mapping,
"download_pdfs": download_pdfs,
"check_pdfs": check_pdfs,
"extract_images": extract_images,
"organize_files": organize_files,
})
if __name__ == "__main__":
main()

36
m3docvqa/pyproject.toml Normal file
View File

@ -0,0 +1,36 @@
[build-system]
requires = ["setuptools>=69.5"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
[project]
name = "m3docvqa"
version = "0.0.1"
description = "M3DocVQA - Dataset package for M3DocRAG: Multi-modal Retrieval is What You Need for Multi-page Multi-document Understanding."
readme = "README.md"
requires-python = ">=3.10"
classifiers = ["Programming Language :: Python :: 3"]
dependencies = [
"loguru",
"jsonlines",
"fire",
"pytest-playwright",
"figure",
"pdf2image",
"pillow",
"numpy<2.0.0",
"pdfrw",
"tqdm",
"reportlab", # only used in the test cases
]
[tool.ruff]
target-version = "py310"
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I"]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]

View File

@ -0,0 +1,126 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""
Downloader Module for M3DocVQA
This module provides functions to download Wikipedia pages in either PDF or PNG format
for the M3DocVQA dataset. It uses Playwright to load and capture the pages in a headless
browser environment and saves each page in the specified format.
Functions:
- _download_wiki_page: Downloads a single Wikipedia page as a PDF or PNG.
- download_wiki_page: Manages the downloading of multiple Wikipedia pages.
"""
from playwright.sync_api import sync_playwright
from loguru import logger
from pathlib import Path
import jsonlines
from tqdm.auto import tqdm
from m3docvqa.pdf_utils import is_pdf_downloaded
def _download_wiki_page(args: tuple[int, int, str, str, str, int]) -> tuple[bool, Exception | None]:
"""Download a single Wikipedia page as a PDF or PNG using Playwright.
Args:
args (Tuple[int, int, str, str, str, int]): Contains order in batch, total count, URL, save path,
save type ('pdf' or 'png'), and process ID.
Returns:
Tuple[bool, Optional[Exception]]: A tuple where the first element is a boolean indicating success,
and the second element is an exception if an error occurred, or None otherwise.
"""
order_i, total, url, save_path, save_type, proc_id = args
if is_pdf_downloaded(save_path):
if proc_id == 0:
logger.info(f"{order_i} / {total} - {save_path} already downloaded")
return True, None
try:
with sync_playwright() as p:
browser = p.chromium.launch(headless=True)
context = browser.new_context(ignore_https_errors=True)
page = context.new_page()
page.set_default_timeout(30000) # 30 seconds timeout
page.goto(url)
if save_type == 'png':
page.screenshot(path=save_path, full_page=True)
elif save_type == 'pdf':
page.emulate_media(media="screen")
page.pdf(path=save_path)
browser.close()
return True, None
except Exception as error:
logger.warning(f"Failed to download {url} as {save_type}. Error: {error}")
return False, error
def download_wiki_page(
urls: list[str],
save_paths: list[str],
save_type: str,
result_jsonl_path: str,
proc_id: int = 0,
n_proc: int = 1
) -> list[bool]:
"""Download multiple Wikipedia pages and log progress.
Args:
urls (List[str]): List of Wikipedia URLs to download.
save_paths (List[str]): List of paths where each downloaded file will be saved.
save_type (str): File type to save each page as ('pdf' or 'png').
result_jsonl_path (str): Path to the JSONL file where download results will be logged.
proc_id (int, optional): Process ID for parallel processing. Defaults to 0.
n_proc (int, optional): Total number of processes running in parallel. Defaults to 1.
Returns:
List[bool]: A list of booleans indicating whether each download was successful.
"""
total = len(urls)
all_args = [(i, total, url, str(save_path), save_type, proc_id)
for i, (url, save_path) in enumerate(zip(urls, save_paths))]
pbar = tqdm(total=len(all_args), ncols=100, disable=not (proc_id == 0))
results = []
n_downloaded = 0
# Log results to a JSONL file
with jsonlines.open(result_jsonl_path, 'w') as writer:
for args in all_args:
downloaded, error = _download_wiki_page(args)
if downloaded:
n_downloaded += 1
pbar.set_description(f"Process: {proc_id}/{n_proc} - Downloaded: {n_downloaded}/{total}")
pbar.update(1)
results.append(downloaded)
writer.write({
'downloaded': downloaded,
'args': [arg if not isinstance(arg, Path) else str(arg) for arg in args],
'error': str(error) if error else None
})
pbar.close()
return results

View File

@ -0,0 +1,119 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Downloads the portion of the multimodalqa dataset from https://github.com/allenai/multimodalqa/tree/master/dataset
that is useful for creating the m3docvqa dataset.
"""
import gzip
import requests
from loguru import logger
from pathlib import Path
def download_file(url: str, output_path: str) -> None:
"""Downloads a file from a given URL and saves it to the specified output path.
Args:
url (str): The URL of the file to download.
output_path (str): The path where the downloaded file will be saved.
Raises:
requests.exceptions.RequestException: If the file could not be downloaded.
"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
with open(output_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"File downloaded successfully: {output_path}")
except requests.exceptions.RequestException as e:
logger.error(f"Failed to download file from {url}: {e}")
raise
def decompress_gz_file(input_path: str | Path, output_path: str | Path) -> None:
"""
Decompresses a `.gz` file into its original format.
Args:
input_path (str | Path): Path to the `.gz` file.
output_path (str | Path): Path where the decompressed file will be written.
Raises:
ValueError: If the input path does not exist or is not a file.
"""
input_path = Path(input_path)
output_path = Path(output_path)
if not input_path.is_file():
raise ValueError(f"The input file {input_path} does not exist or is not a file.")
with gzip.open(input_path, "rb") as f_in, open(output_path, "wb") as f_out:
f_out.write(f_in.read())
logger.info(f"Decompressed {input_path} to {output_path}")
def download_and_decompress_mmqa(output_directory: str | Path) -> None:
"""
Downloads and decompresses the MultiModalQA dataset files into the specified directory.
Args:
output_directory (str | Path): The directory where the files will be stored.
Steps:
1. Creates the output directory if it doesn't exist.
2. Downloads the `.jsonl.gz` files.
3. Decompresses each `.gz` file into its `.jsonl` format.
4. Removes the `.gz` files after decompression.
Raises:
requests.exceptions.RequestException: If any of the files could not be downloaded.
"""
# Define base URL and file names
base_url = "https://github.com/allenai/multimodalqa/raw/refs/heads/master/dataset/"
files = [
"MMQA_texts.jsonl.gz",
"MMQA_tables.jsonl.gz",
"MMQA_images.jsonl.gz",
"MMQA_dev.jsonl.gz",
"MMQA_train.jsonl.gz",
]
output_directory = Path(output_directory)
# Ensure the output directory exists
if not output_directory.exists():
output_directory.mkdir(parents=True, exist_ok=True)
logger.info(f"Created output directory: {output_directory}")
for file_name in files:
compressed_path = output_directory / file_name
decompressed_path = output_directory / file_name.replace(".gz", "")
try:
# Step 1: Download the file
logger.info(f"Downloading {file_name}...")
download_file(base_url + file_name, compressed_path)
# Step 2: Decompress the file
logger.info(f"Decompressing {file_name}...")
decompress_gz_file(compressed_path, decompressed_path)
# Step 3: Remove the compressed file
compressed_path.unlink()
logger.info(f"Removed compressed file: {compressed_path}")
except Exception as e:
logger.error(f"Error processing {file_name}: {e}")
raise

View File

@ -0,0 +1,126 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""PDF Utilities Module for M3DocVQA.
This module provides utility functions for managing and processing PDF files in the M3DocVQA dataset.
It includes functions for checking if a PDF has been downloaded, verifying if a PDF is clean (not corrupted),
and extracting images from PDF pages.
Functions:
- is_pdf_downloaded: Checks if a given PDF file exists and can be opened without errors.
- is_pdf_clean: Checks if a PDF file is clean (not corrupted) and can be read without issues.
- get_images_from_pdf: Extracts images from each page of a PDF and optionally saves them in a specified directory.
"""
from pdf2image import convert_from_path
from PIL import Image
from pdfrw import PdfReader
from pathlib import Path
from io import BytesIO
from loguru import logger
def is_pdf_downloaded(pdf_path: str) -> bool:
"""Check if the PDF file exists and can be opened without errors.
Args:
pdf_path (str): Path to the PDF file.
Returns:
bool: True if the PDF file is downloaded and accessible; False otherwise.
"""
try:
with open(pdf_path, "rb") as f:
f.read(1) # Attempt to read a byte to verify file exists and is accessible
return True
except Exception as e:
logger.trace(f"Failed to open PDF at {pdf_path}: {e}")
return False
def is_pdf_clean(pdf_path: str) -> bool:
"""Verify if a PDF file is clean (not corrupted) and can be read without errors.
Args:
pdf_path (str): Path to the PDF file.
Returns:
bool: True if the PDF file is clean and readable; False otherwise.
"""
try:
with open(pdf_path, "rb") as f:
idata = f.read()
ibuffer = BytesIO(idata)
PdfReader(ibuffer) # Attempt to read the PDF structure for validity
return True
except Exception as error:
logger.warning(f"PDF at {pdf_path} is corrupted or unreadable: {error}")
return False
def get_images_from_pdf(
pdf_path: str,
save_dir: str = None,
max_pages: int = None,
dpi_resolution: int = 144,
save_type: str = 'png'
) -> list[Image.Image]:
"""Extract images from each page of a PDF and optionally save them to a directory.
Args:
pdf_path (str): Path to the PDF file.
save_dir (str, optional): Directory where images will be saved. If None, images are not saved. Defaults to None.
max_pages (int, optional): Maximum number of pages to process. If None, all pages are processed. Defaults to None.
dpi_resolution (int, optional): Resolution for image extraction. Defaults to 144.
save_type (str, optional): Image file type to save as ('png', 'jpg', etc.). Defaults to 'png'.
Returns:
list[Image.Image]: A list of images extracted from each page of the PDF.
"""
pdf_path_obj = Path(pdf_path)
assert pdf_path_obj.exists(), f"PDF file {pdf_path} does not exist."
out_images = []
# Create save directory if saving images is enabled
if save_dir:
save_dir_path = Path(save_dir)
save_dir_path.mkdir(exist_ok=True, parents=True)
try:
# Convert PDF to images using pdf2image
images = convert_from_path(pdf_path, dpi=dpi_resolution)
logger.info(f"PDF {pdf_path} has {len(images)} pages.")
# Limit the number of pages processed if max_pages is set
if max_pages:
images = images[:max_pages]
for page_index, image in enumerate(images):
out_images.append(image)
# Save image if save directory is specified
if save_dir:
save_page_path = save_dir_path / f"{pdf_path_obj.stem}_{page_index + 1}.{save_type}"
if not save_page_path.exists():
image.save(save_page_path)
logger.info(f"Saved page {page_index + 1} as image at {save_page_path}")
except Exception as e:
logger.error(f"Error extracting images from PDF {pdf_path}: {e}")
return out_images

View File

@ -0,0 +1,93 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Split Utilities Module for M3DocVQA.
This module provides utilities for organizing PDF files into split directories (e.g., train, dev)
and compressing these directories using functions from compression_utils.
Functions:
- create_split_dirs: Copies specified PDF files into separate split directories (e.g., train, dev).
- compress_split_directory: Compresses the split directory into a `.tar.gz` archive.
"""
from pathlib import Path
import shutil
import json
import jsonlines
from loguru import logger
def create_split_dirs(
all_pdf_dir: str | Path,
target_dir_base: str | Path,
split_metadata_file: str | Path,
split: str
) -> None:
"""Copies specified PDF files into a target directory based on a given split.
Args:
all_pdf_dir (Union[str, Path]): Path to the directory containing all downloaded PDF files.
target_dir_base (Union[str, Path]): Base directory where the split-specific directory will be created.
split_metadata_file (Union[str, Path]): Path to the metadata JSONL file for the split.
split (str): Split type ('train' or 'dev').
Raises:
FileNotFoundError: If the JSONL metadata file does not exist.
ValueError: If the split is not 'train' or 'dev'.
"""
# Validate split type
if split not in {"train", "dev"}:
raise ValueError(f"Invalid split: {split}. Expected 'train' or 'dev'.")
all_pdf_dir = Path(all_pdf_dir)
target_dir = Path(target_dir_base) / f'pdfs_{split}'
target_dir.mkdir(parents=True, exist_ok=True)
# Validate metadata file
split_metadata_file = Path(split_metadata_file)
if not split_metadata_file.exists():
raise FileNotFoundError(f"Metadata file for split '{split}' not found: {split_metadata_file}")
# Load all doc IDs for the split
split_doc_ids = []
with jsonlines.open(split_metadata_file) as reader:
for obj in reader:
split_doc_ids.extend(doc['doc_id'] for doc in obj['supporting_context'])
# Remove duplicates and log the count
split_doc_ids = sorted(set(split_doc_ids))
logger.info(f"Split {split} -> # supporting context: {len(split_doc_ids)}")
# Save the split-specific IDs to a JSON file
split_doc_ids_output_path = Path(f'./{split}_doc_ids.json')
with open(split_doc_ids_output_path, 'w') as f:
json.dump(split_doc_ids, f, indent=4)
logger.info(f"Split {split} -> saved doc IDs at {split_doc_ids_output_path}")
# Copy PDF files to the target directory
missing_files = []
for doc_id in split_doc_ids:
pdf_file = all_pdf_dir / f"{doc_id}.pdf"
if pdf_file.exists():
shutil.copy(pdf_file, target_dir / pdf_file.name)
else:
missing_files.append(pdf_file)
if missing_files:
logger.warning(f"Warning: {len(missing_files)} files are missing and will be skipped.")
for missing_file in missing_files:
logger.warning(f" Missing: {missing_file}")

View File

@ -0,0 +1,140 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Wiki Mapper.
This module provides functionality to parse multimodalqa JSONL files that has been already downloaded that contains 'id' and 'url' mappings,
merge them into a single mapping, and save the result to a JSONL file.
Each JSONL file should contain one JSON object per line with the following structure:
{
"title": "Article Title",
"url": "https://en.wikipedia.org/wiki/Article_Title",
"id": "unique_id",
"text": "Text description of the article."
}
"""
import json
from pathlib import Path
from loguru import logger
def parse_jsonl(file_path: str | Path) -> dict[str, str]:
"""Parses a JSONL file from the multimodalqa dataset to extract a mapping of 'id' to 'url'.
Args:
file_path (str | Path): Path to the JSONL file.
Returns:
dict[str, str]: A dictionary mapping each 'id' to its corresponding 'url'.
Raises:
FileNotFoundError: If the JSONL file does not exist.
ValueError: If the file contains invalid JSON lines.
"""
file_path = Path(file_path)
if not file_path.is_file():
logger.error(f"File not found: {file_path}")
raise FileNotFoundError(f"File not found: {file_path}")
id_url_mapping = {}
try:
with file_path.open("r") as file:
for line in file:
data = json.loads(line.strip())
entry_id = data.get("id")
url = data.get("url")
if entry_id and url:
id_url_mapping[entry_id] = url
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON in file {file_path}: {e}")
raise ValueError(f"Invalid JSON in file {file_path}: {e}")
logger.info(f"Parsed {len(id_url_mapping)} entries from {file_path}")
return id_url_mapping
def merge_mappings(mappings: list[dict[str, str]]) -> dict[str, str]:
"""Merges multiple mappings into a single dictionary.
Args:
mappings (list[dict[str, str]]): A list of dictionaries containing 'id' to 'url' mappings.
Returns:
dict[str, str]: A merged dictionary containing all 'id' to 'url' mappings.
"""
merged_mapping = {}
for mapping in mappings:
merged_mapping.update(mapping)
logger.info(f"Merged {len(mappings)} mappings with a total of {len(merged_mapping)} entries.")
return merged_mapping
def save_mapping_to_jsonl(mapping: dict[str, str], output_file: str | Path) -> None:
"""Saves the 'id'-to-'url' mapping to a JSONL file.
Args:
mapping (dict[str, str]): The dictionary containing 'id' to 'url' mappings.
output_file (str | Path): The path to the output JSONL file.
Raises:
IOError: If the file cannot be written.
"""
output_file = Path(output_file)
try:
with output_file.open("w") as file:
for entry_id, url in mapping.items():
json.dump({"id": entry_id, "url": url}, file)
file.write("\n")
logger.info(f"Mapping saved to {output_file}")
except IOError as e:
logger.error(f"Error writing to file {output_file}: {e}")
raise
def generate_wiki_links_mapping(
text_file: str | Path, image_file: str | Path, table_file: str | Path, output_file: str | Path = "id_url_mapping.jsonl"
) -> None:
"""Orchestrates the process of parsing input files, merging mappings, and saving the result to JSONL.
Args:
text_file (str | Path): Path to the JSONL file containing text data with 'id' and 'url' fields.
image_file (str | Path): Path to the JSONL file containing image data with 'id' and 'url' fields.
table_file (str | Path): Path to the JSONL file containing table data with 'id' and 'url' fields.
output_file (str | Path): Path to save the output JSONL file. Defaults to 'id_url_mapping.jsonl'.
Raises:
Exception: If any part of the pipeline fails.
"""
try:
# Parse input files
logger.info("Parsing JSONL files...")
text_mapping = parse_jsonl(text_file)
image_mapping = parse_jsonl(image_file)
table_mapping = parse_jsonl(table_file)
# Merge mappings
logger.info("Merging mappings...")
merged_mapping = merge_mappings([text_mapping, image_mapping, table_mapping])
# Save the merged mapping
logger.info("Saving merged mapping to output file...")
save_mapping_to_jsonl(merged_mapping, output_file)
logger.info(f"Mapping successfully generated and saved to {output_file}")
except Exception as e:
logger.error(f"Error generating wiki links mapping: {e}")
raise

View File

@ -0,0 +1,105 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from unittest.mock import patch, MagicMock
from pathlib import Path
import jsonlines
from m3docvqa.downloader import _download_wiki_page, download_wiki_page
@pytest.fixture
def test_urls_and_paths(tmp_path):
"""Fixture to provide sample URLs and save paths for testing."""
urls = ["https://en.wikipedia.org/wiki/SamplePage1", "https://en.wikipedia.org/wiki/SamplePage2"]
save_paths = [str(tmp_path / "sample1.pdf"), str(tmp_path / "sample2.pdf")]
return urls, save_paths
@patch("m3docvqa.downloader.sync_playwright")
def test__download_wiki_page_pdf(mock_playwright, tmp_path):
"""Test downloading a single page as a PDF."""
url = "https://en.wikipedia.org/wiki/SamplePage"
save_path = tmp_path / "sample.pdf"
args = (0, 1, url, str(save_path), 'pdf', 0)
# Mock Playwright behavior
mock_browser = MagicMock()
mock_context = MagicMock()
mock_page = MagicMock()
mock_playwright.return_value.__enter__.return_value.chromium.launch.return_value = mock_browser
mock_browser.new_context.return_value = mock_context
mock_context.new_page.return_value = mock_page
# Call the function
downloaded, error = _download_wiki_page(args)
# Assertions
assert downloaded is True
assert error is None
mock_page.goto.assert_called_once_with(url)
mock_page.pdf.assert_called_once_with(path=str(save_path))
@patch("m3docvqa.downloader.sync_playwright")
def test__download_wiki_page_png(mock_playwright, tmp_path):
"""Test downloading a single page as a PNG."""
url = "https://en.wikipedia.org/wiki/SamplePage"
save_path = tmp_path / "sample.png"
args = (0, 1, url, str(save_path), 'png', 0)
# Mock Playwright behavior
mock_browser = MagicMock()
mock_context = MagicMock()
mock_page = MagicMock()
mock_playwright.return_value.__enter__.return_value.chromium.launch.return_value = mock_browser
mock_browser.new_context.return_value = mock_context
mock_context.new_page.return_value = mock_page
# Call the function
downloaded, error = _download_wiki_page(args)
# Assertions
assert downloaded is True
assert error is None
mock_page.goto.assert_called_once_with(url)
mock_page.screenshot.assert_called_once_with(path=str(save_path), full_page=True)
@patch("m3docvqa.downloader._download_wiki_page")
def test_download_wiki_page_batch(mock_download_wiki_page, tmp_path, test_urls_and_paths):
"""Test batch downloading multiple Wikipedia pages."""
urls, save_paths = test_urls_and_paths
result_jsonl_path = tmp_path / "download_results.jsonl"
# Mock individual downloads to always succeed
mock_download_wiki_page.side_effect = [(True, None), (True, None)]
# Call the function
results = download_wiki_page(urls, save_paths, 'pdf', str(result_jsonl_path), proc_id=0, n_proc=1)
# Assertions
assert results == [True, True]
assert result_jsonl_path.exists()
# Check JSONL log entries
with jsonlines.open(result_jsonl_path, 'r') as reader:
log_entries = list(reader)
assert len(log_entries) == 2
assert log_entries[0]['downloaded'] is True
assert log_entries[0]['error'] is None
assert log_entries[1]['downloaded'] is True
assert log_entries[1]['error'] is None

View File

@ -0,0 +1,86 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from m3docvqa.pdf_utils import is_pdf_downloaded, is_pdf_clean, get_images_from_pdf
from pathlib import Path
from PIL import Image
from reportlab.pdfgen import canvas # For creating sample PDFs
@pytest.fixture
def sample_pdf(tmp_path) -> Path:
"""Create a temporary sample PDF file for testing."""
pdf_path = tmp_path / "sample.pdf"
c = canvas.Canvas(str(pdf_path))
c.drawString(100, 100, "Sample PDF text for testing.") # Add sample text to the PDF
c.save()
return pdf_path
@pytest.fixture
def corrupted_pdf(tmp_path) -> Path:
"""Create a temporary, corrupted PDF file for testing."""
pdf_path = tmp_path / "corrupted.pdf"
pdf_path.write_bytes(b"%PDF-1.4 corrupted content") # Write incomplete/corrupted PDF content
return pdf_path
def test_is_pdf_downloaded_existing_pdf(sample_pdf):
"""Test is_pdf_downloaded on a valid, existing PDF."""
assert is_pdf_downloaded(str(sample_pdf)) is True, "Expected PDF to be recognized as downloaded."
def test_is_pdf_downloaded_nonexistent_pdf(tmp_path):
"""Test is_pdf_downloaded on a non-existent PDF file."""
non_existent_pdf = tmp_path / "non_existent.pdf"
assert is_pdf_downloaded(str(non_existent_pdf)) is False, "Expected non-existent PDF to be marked as not downloaded."
def test_is_pdf_clean_valid_pdf(sample_pdf):
"""Test is_pdf_clean on a valid, clean PDF."""
assert is_pdf_clean(str(sample_pdf)) is True, "Expected PDF to be recognized as clean."
def test_is_pdf_clean_corrupted_pdf(corrupted_pdf):
"""Test is_pdf_clean on a corrupted PDF."""
assert is_pdf_clean(str(corrupted_pdf)) is False, "Expected corrupted PDF to be marked as not clean."
def test_get_images_from_pdf_extract_images(sample_pdf, tmp_path):
"""Test get_images_from_pdf to ensure it extracts images correctly."""
image_dir = tmp_path / "images"
images = get_images_from_pdf(str(sample_pdf), save_dir=str(image_dir), dpi_resolution=72, save_type='png')
# Verify that at least one image was extracted
assert len(images) > 0, "Expected at least one image to be extracted from the PDF."
# Verify that images were saved to the directory
saved_images = list(image_dir.glob("*.png"))
assert len(saved_images) == len(images), "Expected number of saved images to match the number of extracted images."
# Verify that the saved image files exist and are valid
for image_path in saved_images:
with Image.open(image_path) as img:
assert img.format == "PNG", "Expected saved image to be in PNG format."
def test_get_images_from_pdf_no_save_dir(sample_pdf):
"""Test get_images_from_pdf without saving images, only returning them as a list."""
images = get_images_from_pdf(str(sample_pdf), save_dir=None, dpi_resolution=72)
assert len(images) > 0, "Expected at least one image to be returned without saving."
assert all(isinstance(image, Image.Image) for image in images), "Expected all returned items to be PIL Image objects."

View File

@ -0,0 +1,107 @@
# Copyright 2024 Bloomberg Finance L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from pathlib import Path
import shutil
import json
import jsonlines
from unittest.mock import MagicMock, patch
from m3docvqa.split_utils import create_split_dirs
@pytest.fixture
def mock_pdf_directory(tmp_path):
# Create a temporary directory for PDFs
pdf_dir = tmp_path / "pdfs"
pdf_dir.mkdir()
# Add some mock PDF files
(pdf_dir / "doc1.pdf").write_text("PDF content for doc1")
(pdf_dir / "doc2.pdf").write_text("PDF content for doc2")
return pdf_dir
@pytest.fixture
def mock_metadata_file(tmp_path):
# Create a temporary metadata file in JSONL format
metadata_file = tmp_path / "MMQA_train.jsonl"
data = [
{"supporting_context": [{"doc_id": "doc1"}]},
{"supporting_context": [{"doc_id": "doc2"}]}
]
with jsonlines.open(metadata_file, mode='w') as writer:
writer.write_all(data)
return metadata_file
@pytest.fixture
def mock_target_directory(tmp_path):
return tmp_path / "target"
def test_create_split_dirs(mock_pdf_directory, mock_metadata_file, mock_target_directory):
"""Test the create_split_dirs function."""
# Prepare the split directory
split = "train"
# Call the function to create split directories
create_split_dirs(
all_pdf_dir=mock_pdf_directory,
target_dir_base=mock_target_directory,
split_metadata_file=mock_metadata_file,
split=split
)
# Assert that the target directory exists and contains the expected PDF files
target_dir = mock_target_directory / f"pdfs_{split}"
assert target_dir.exists(), f"Directory {target_dir} was not created"
assert (target_dir / "doc1.pdf").exists(), "doc1.pdf was not copied"
assert (target_dir / "doc2.pdf").exists(), "doc2.pdf was not copied"
def test_create_split_dirs_missing_pdf(mock_metadata_file, mock_target_directory):
"""Test create_split_dirs when PDF files are missing."""
# Prepare the split directory
split = "train"
all_pdf_dir = Path("non_existing_pdf_dir")
# Call the function and verify that the missing PDFs are handled correctly
create_split_dirs(
all_pdf_dir=all_pdf_dir,
target_dir_base=mock_target_directory,
split_metadata_file=mock_metadata_file,
split=split
)
target_dir = mock_target_directory / f"pdfs_{split}"
assert target_dir.exists(), f"Directory {target_dir} was not created"
assert not (target_dir / "doc1.pdf").exists(), "doc1.pdf should not exist"
assert not (target_dir / "doc2.pdf").exists(), "doc2.pdf should not exist"
@pytest.mark.parametrize("split, expected_error", [
("test_split", ValueError), # Invalid split type
(None, ValueError), # Missing split
])
def test_create_split_dirs_invalid_split_type(mock_pdf_directory, mock_metadata_file, mock_target_directory, split, expected_error):
"""Test invalid split types in create_split_dirs."""
with pytest.raises(expected_error):
create_split_dirs(
all_pdf_dir=mock_pdf_directory,
target_dir_base=mock_target_directory,
split_metadata_file=mock_metadata_file,
split=split
)