Update M3DocVQA download (#7)
This commit is contained in:
@ -33,8 +33,8 @@ The scripts allows users to:
|
|||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
```
|
```
|
||||||
git clone <url-tbd>
|
git clone https://github.com/bloomberg/m3docrag
|
||||||
cd <repo-name-tbd>/m3docvqa
|
cd m3docrag/m3docvqa
|
||||||
```
|
```
|
||||||
|
|
||||||
### Install Python Package
|
### Install Python Package
|
||||||
@ -111,90 +111,134 @@ Output:
|
|||||||
|
|
||||||
A JSONL file `id_url_mapping.jsonl` containing the ID and corresponding URL mappings.
|
A JSONL file `id_url_mapping.jsonl` containing the ID and corresponding URL mappings.
|
||||||
|
|
||||||
### Step 3: Download Wikipedia Articles as PDFs
|
### Step 3: Create Split Files
|
||||||
|
Use the `create_splits` action to create the per-split doc ids.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py create_splits --split_metadata_file=./multimodalqa/MMQA_dev.jsonl --split=dev
|
||||||
|
python main.py create_splits --split_metadata_file=./multimodalqa/MMQA_train.jsonl --split=train
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note** - In the [M3DocRAG](https://arxiv.org/abs/2411.04952) paper, we only use the `dev` split for our experiments.
|
||||||
|
|
||||||
|
Output:
|
||||||
|
|
||||||
|
- Files that store document IDs of each split: `./dev_doc_ids.json` and `./train_doc_ids.json`.
|
||||||
|
|
||||||
|
|
||||||
|
### Step 4: Download Wikipedia Articles as PDFs
|
||||||
Use the `download_pdfs` action to download Wikipedia articles in a PDF format based on the generated mapping.
|
Use the `download_pdfs` action to download Wikipedia articles in a PDF format based on the generated mapping.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python main.py download_pdfs --metadata_path=./id_url_mapping.jsonl --pdf_dir=./pdfs --result_log_path=./download_results.jsonl --first_n=10 --supporting_doc_ids_per_split=./supporting_doc_ids_per_split.json --split=dev
|
python main.py download_pdfs --metadata_path=./id_url_mapping.jsonl --pdf_dir=./pdfs_dev --result_log_dir=./download_logs/ --first_n=10 --per_split_doc_ids=./dev_doc_ids.json
|
||||||
```
|
```
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
- `--metadata_path`: Path to the id_url_mapping.jsonl file.
|
- `--metadata_path`: Path to the id_url_mapping.jsonl file.
|
||||||
- `--pdf_dir`: Directory to save the downloaded PDFs.
|
- `--pdf_dir`: Directory to save the downloaded PDFs.
|
||||||
- `--result_log_path`: Path to log the download results.
|
- `--result_log_dir`: Directory to log the download results.
|
||||||
- `--first_n`: Downloads the first N PDFs for testing. **Do not use this option for downloading all the PDFs.**
|
- `--first_n`: Downloads the first N PDFs for testing (default is -1, which means all the PDFs).
|
||||||
- `--supporting_doc_ids_per_split`: Path to JSON file containing document IDs for each split. `dev` is the default split, as all of the experimental results in the `M3DocRAG` paper were reported on the `dev` split. Anyone interested in downloading the PDFs in the `train` split can provide `--supporting_doc_ids_per_split=train` as the option. In case anyone is interested in downloading all the PDFs, one can also provide `--supporting_doc_ids_per_split=all` as an option.
|
- `--per_split_doc_ids`: Path to JSON file containing document IDs for each split. `dev_doc_ids.json` is the default file, as all of the experimental results in the `M3DocRAG` paper were reported on the `dev` split. Anyone interested in downloading the PDFs in the `train` split can provide `--per_split_doc_ids=./train_doc_ids.json` as the option.
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
|
|
||||||
- PDF files for Wikipedia articles, saved in the `./pdfs/` directory.
|
- PDF files for Wikipedia articles, saved in the `./pdfs_dev/` directory.
|
||||||
- A `download_results.jsonl` file logging the status of each download.
|
- A `download_results.jsonl` file logging the status of each download.
|
||||||
|
|
||||||
### Step 4: Check PDF Integrity
|
If you want to download PDFs in parallel, you can try following commands with arguments `proc_id` and `n_proc`. `proc_id` is the process ID (default is 0), and `n_proc` is the total number of processes (default is 1).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# e.g., distributed in 4 parallel jobs on the first 20 PDFs
|
||||||
|
N_total_processes=4
|
||||||
|
|
||||||
|
for i in $(seq 0 $((N_total_processes - 1)));
|
||||||
|
do
|
||||||
|
echo $i
|
||||||
|
python main.py \
|
||||||
|
download_pdfs \
|
||||||
|
--metadata_path './id_url_mapping.jsonl' \
|
||||||
|
--pdf_dir './pdfs_dev' \
|
||||||
|
--result_log_dir './download_logs/' \
|
||||||
|
--per_split_doc_ids './dev_doc_ids.json' \
|
||||||
|
--first_n=20 \
|
||||||
|
--proc_id=$i \
|
||||||
|
--n_proc=$N_total_processes &
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
# e.g., distributed in 16 parallel jobs on all dev PDFs
|
||||||
|
N_total_processes=16
|
||||||
|
|
||||||
|
for i in $(seq 0 $((N_total_processes - 1)));
|
||||||
|
do
|
||||||
|
echo $i
|
||||||
|
python main.py \
|
||||||
|
download_pdfs \
|
||||||
|
--metadata_path './id_url_mapping.jsonl' \
|
||||||
|
--pdf_dir './pdfs_dev' \
|
||||||
|
--result_log_dir './download_logs/' \
|
||||||
|
--per_split_doc_ids './dev_doc_ids.json' \
|
||||||
|
--first_n=-1 \
|
||||||
|
--proc_id=$i \
|
||||||
|
--n_proc=$N_total_processes &
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Step 5: Check PDF Integrity
|
||||||
Use the `check_pdfs` action to verify the integrity of the downloaded PDFs.
|
Use the `check_pdfs` action to verify the integrity of the downloaded PDFs.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python main.py check_pdfs --pdf_dir=./pdfs
|
python main.py check_pdfs --pdf_dir=./pdfs_dev
|
||||||
```
|
```
|
||||||
Output:
|
Output:
|
||||||
|
|
||||||
Identifies and logs corrupted or unreadable PDFs.
|
Identifies and logs corrupted or unreadable PDFs.
|
||||||
|
|
||||||
### Step 5: Organize Files into Splits
|
|
||||||
Use the `organize_files` action to organize the downloaded PDFs into specific splits (e.g., `train`, `dev`) based on a split information file.
|
### (Optional) Step 6: Extract Images from PDFs
|
||||||
|
When created embeddings in the [M3DocRAG](https://arxiv.org/abs/2411.04952) experiment, we extract images from the downloaded PDFs on the fly. But if the users want to extract images from the downloaded PDFs and save them for future use, they can use the `extract_images` action.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python main.py organize_files --all_pdf_dir=./pdfs --target_dir_base=./splits --split=dev --split_metadata_file=./multimodalqa/MMQA_dev.jsonl
|
python main.py extract_images --pdf_dir=./pdfs_dev/ --image_dir=./images_dev
|
||||||
```
|
|
||||||
|
|
||||||
If train split is needed:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python main.py organize_files --all_pdf_dir=./pdfs --target_dir_base=./splits --split=train --split_metadata_file=./multimodalqa/MMQA_train.jsonl
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
|
|
||||||
- Organized PDFs into directories in `./splits/pdfs_train/` and `./splits/pdfs_dev/`.
|
Extracted images from the PDFs in the dev split are saved in the `./images_dev` directory.
|
||||||
- Files that store document IDs of each split `./train_doc_ids.json` and `./dev_doc_ids.json`.
|
|
||||||
|
|
||||||
**Note** - In the [M3DocRAG](https://arxiv.org/abs/2411.04952) paper, we only use the `dev` split for our experiments.
|
|
||||||
|
|
||||||
### Step 6: Extract Images from PDFs
|
|
||||||
Use the `extract_images` action to extract images from the downloaded PDFs. A PNG image of each page of the PDFs is extracted. These images are used for both `retrieval` using `ColPali/ColQwen`, as well as `question answering` using the LLMs mentioned in the [M3DocRAG](https://arxiv.org/abs/2411.04952) paper.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python main.py extract_images --pdf_dir=./splits/pdfs_dev/ --image_dir=./images/images_dev
|
|
||||||
```
|
|
||||||
|
|
||||||
Output:
|
|
||||||
|
|
||||||
Extracted images from the PDFs in the dev split are saved in the `./images/images_dev` directory.
|
|
||||||
|
|
||||||
After following these steps, your dataset directory structure will look like this:
|
After following these steps, your dataset directory structure will look like this:
|
||||||
|
|
||||||
```
|
```bash
|
||||||
./
|
./
|
||||||
|
# original MMQA files
|
||||||
|-- multimodalqa/
|
|-- multimodalqa/
|
||||||
| |-- MMQA_train.jsonl
|
| |-- MMQA_train.jsonl
|
||||||
| |-- MMQA_dev.jsonl
|
| |-- MMQA_dev.jsonl
|
||||||
| |-- MMQA_texts.jsonl
|
| |-- MMQA_texts.jsonl
|
||||||
| |-- MMQA_images.jsonl
|
| |-- MMQA_images.jsonl
|
||||||
| |-- MMQA_tables.jsonl
|
| |-- MMQA_tables.jsonl
|
||||||
|
# generated files
|
||||||
|-- id_url_mapping.jsonl
|
|-- id_url_mapping.jsonl
|
||||||
|-- dev_doc_ids.json
|
|-- dev_doc_ids.json
|
||||||
|-- train_doc_ids.json
|
|-- train_doc_ids.json
|
||||||
|-- supporting_doc_ids_per_split.json
|
|-- supporting_doc_ids_per_split.json
|
||||||
|-- download_results.jsonl
|
# download logs
|
||||||
|-- pdfs/
|
|-- download_logs/
|
||||||
| |-- <article_1>.pdf
|
| |-- <process_id>_<first_n>.jsonl
|
||||||
| |-- <article_2>.pdf
|
# downloaded PDFs
|
||||||
|-- images/
|
|-- pdfs_dev/
|
||||||
|-- |--images_dev/
|
| |-- <article_dev_1>.pdf
|
||||||
| | |-- <doc_id_1_page_1>.png
|
| |-- <article_dev_2>.pdf
|
||||||
| | |-- <doc_id_2_page_2>.png
|
# (Below are optional outputs)
|
||||||
|-- splits/
|
# |-- pdfs_train/
|
||||||
| |-- pdfs_dev/
|
# | |-- <article_train_1>.pdf
|
||||||
| | |-- <doc_id_1>.pdf
|
# | |-- <article_train_2>.pdf
|
||||||
| | |-- <doc_id_2>.pdf
|
# |-- images_dev/
|
||||||
|
# | |-- <doc_id_dev_1_page_1>.png
|
||||||
|
# | |-- <doc_id_dev_2_page_2>.png
|
||||||
|
# |-- images_train/
|
||||||
|
# | |-- <doc_id_train_1_page_1>.png
|
||||||
|
# | |-- <doc_id_train_2_page_2>.png
|
||||||
```
|
```
|
||||||
|
@ -40,7 +40,7 @@ import jsonlines
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from m3docvqa.downloader import download_wiki_page
|
from m3docvqa.downloader import download_wiki_page
|
||||||
from m3docvqa.pdf_utils import is_pdf_downloaded, is_pdf_clean, get_images_from_pdf
|
from m3docvqa.pdf_utils import is_pdf_downloaded, is_pdf_clean, get_images_from_pdf
|
||||||
from m3docvqa.split_utils import create_split_dirs
|
from m3docvqa.split_utils import create_split_files
|
||||||
from m3docvqa.mmqa_downloader import download_and_decompress_mmqa
|
from m3docvqa.mmqa_downloader import download_and_decompress_mmqa
|
||||||
from m3docvqa.wiki_mapper import generate_wiki_links_mapping
|
from m3docvqa.wiki_mapper import generate_wiki_links_mapping
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -52,6 +52,7 @@ def _prepare_download(
|
|||||||
output_dir: Path | str,
|
output_dir: Path | str,
|
||||||
first_n: int,
|
first_n: int,
|
||||||
doc_ids: set,
|
doc_ids: set,
|
||||||
|
check_downloaded: bool = False,
|
||||||
) -> tuple[list[str], list[Path]]:
|
) -> tuple[list[str], list[Path]]:
|
||||||
"""Prepare URLs and save paths for downloading.
|
"""Prepare URLs and save paths for downloading.
|
||||||
|
|
||||||
@ -74,12 +75,14 @@ def _prepare_download(
|
|||||||
break
|
break
|
||||||
|
|
||||||
doc_id = line.get("id")
|
doc_id = line.get("id")
|
||||||
|
url = line.get("url")
|
||||||
if doc_ids and doc_id not in doc_ids:
|
if doc_ids and doc_id not in doc_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
url = line.get("url")
|
|
||||||
save_path = output_dir / f"{doc_id}.pdf"
|
save_path = output_dir / f"{doc_id}.pdf"
|
||||||
if not is_pdf_downloaded(save_path):
|
if check_downloaded and is_pdf_downloaded(save_path):
|
||||||
|
continue
|
||||||
|
|
||||||
urls.append(url)
|
urls.append(url)
|
||||||
save_paths.append(save_path)
|
save_paths.append(save_path)
|
||||||
|
|
||||||
@ -89,32 +92,31 @@ def _prepare_download(
|
|||||||
def download_pdfs(
|
def download_pdfs(
|
||||||
metadata_path: Path | str,
|
metadata_path: Path | str,
|
||||||
pdf_dir: Path | str,
|
pdf_dir: Path | str,
|
||||||
result_log_path: Path | str,
|
result_log_dir: Path | str,
|
||||||
supporting_doc_ids_per_split: Path | str,
|
per_split_doc_ids: Path | str,
|
||||||
first_n: int = -1,
|
first_n: int = -1,
|
||||||
proc_id: int = 0,
|
proc_id: int = 0,
|
||||||
n_proc: int = 1,
|
n_proc: int = 1,
|
||||||
split: str = 'dev',
|
check_downloaded: bool = False,
|
||||||
):
|
):
|
||||||
"""Download Wikipedia pages as PDFs."""
|
"""Download Wikipedia pages as PDFs."""
|
||||||
# Load document ids for the specified split
|
# Load document ids for the specified split
|
||||||
if supporting_doc_ids_per_split:
|
if per_split_doc_ids:
|
||||||
with open(supporting_doc_ids_per_split, "r") as f:
|
with open(per_split_doc_ids, "r") as f:
|
||||||
doc_ids_per_split = json.load(f)
|
doc_ids = json.load(f)
|
||||||
split_doc_ids = {
|
logger.info(f"Downloading documents with {len(doc_ids)} document IDs from {metadata_path}.")
|
||||||
"train": set(doc_ids_per_split.get("train", [])),
|
|
||||||
"dev": set(doc_ids_per_split.get("dev", [])),
|
|
||||||
"all": set(doc_ids_per_split.get("train", []) + doc_ids_per_split.get("dev", []))
|
|
||||||
}
|
|
||||||
if split not in split_doc_ids:
|
|
||||||
raise ValueError(f"Invalid or missing split. Expected one of {split_doc_ids.keys()}")
|
|
||||||
doc_ids = split_doc_ids.get(split, split_doc_ids.get("all"))
|
|
||||||
logger.info(f"Downloading documents for split: {split} with {len(doc_ids)} document IDs.")
|
|
||||||
|
|
||||||
urls, save_paths = _prepare_download(metadata_path, pdf_dir, first_n, doc_ids)
|
urls, save_paths = _prepare_download(metadata_path, pdf_dir, first_n, doc_ids, check_downloaded)
|
||||||
logger.info(f"Starting download of {len(urls)} PDFs to {pdf_dir}")
|
|
||||||
download_results = download_wiki_page(urls, save_paths, "pdf", result_log_path, proc_id, n_proc)
|
# split urls and save_paths (both are lists) into n_proc chunks
|
||||||
logger.info(f"Download completed with {sum(download_results)} successful downloads out of {len(urls)}")
|
if n_proc > 1:
|
||||||
|
logger.info(f"[{proc_id}/{n_proc}] Splitting {len(urls)} URLs into {n_proc} chunks")
|
||||||
|
urls = urls[proc_id::n_proc]
|
||||||
|
save_paths = save_paths[proc_id::n_proc]
|
||||||
|
|
||||||
|
logger.info(f"[{proc_id}/{n_proc}] Starting download of {len(urls)} PDFs to {pdf_dir}")
|
||||||
|
download_results = download_wiki_page(urls, save_paths, "pdf", result_log_dir, proc_id, n_proc)
|
||||||
|
logger.info(f"[{proc_id}/{n_proc}] Download completed with {sum(download_results)} successful downloads out of {len(urls)}")
|
||||||
|
|
||||||
|
|
||||||
def check_pdfs(pdf_dir: str, proc_id: int = 0, n_proc: int = 1):
|
def check_pdfs(pdf_dir: str, proc_id: int = 0, n_proc: int = 1):
|
||||||
@ -147,18 +149,16 @@ def extract_images(pdf_dir: str, image_dir: str, save_type='png'):
|
|||||||
|
|
||||||
for pdf_path in tqdm(pdf_files, desc="Extracting images", unit="PDF"):
|
for pdf_path in tqdm(pdf_files, desc="Extracting images", unit="PDF"):
|
||||||
get_images_from_pdf(pdf_path, save_dir=image_dir, save_type=save_type)
|
get_images_from_pdf(pdf_path, save_dir=image_dir, save_type=save_type)
|
||||||
logger.info(f"Images extracted from PDFs in {pdf_dir}")
|
logger.info(f"Images extracted from {pdf_dir} and saved to {image_dir}")
|
||||||
|
|
||||||
|
|
||||||
def organize_files(all_pdf_dir: Path | str, target_dir_base: Path | str, split_metadata_file: str | Path, split: str):
|
def create_splits(split_metadata_file: str | Path, split: str):
|
||||||
"""Organizes PDFs into directory splits based on split information file."""
|
"""Create the per-split doc ids."""
|
||||||
create_split_dirs(
|
create_split_files(
|
||||||
all_pdf_dir=all_pdf_dir,
|
|
||||||
target_dir_base=target_dir_base,
|
|
||||||
split_metadata_file=split_metadata_file,
|
split_metadata_file=split_metadata_file,
|
||||||
split=split,
|
split=split,
|
||||||
)
|
)
|
||||||
logger.info(f"Files organized for {split} split: in {target_dir_base}")
|
logger.info(f"Doc Ids Files created for {split} split")
|
||||||
|
|
||||||
|
|
||||||
def download_mmqa(output_dir: str):
|
def download_mmqa(output_dir: str):
|
||||||
@ -193,7 +193,7 @@ def main():
|
|||||||
"download_pdfs": download_pdfs,
|
"download_pdfs": download_pdfs,
|
||||||
"check_pdfs": check_pdfs,
|
"check_pdfs": check_pdfs,
|
||||||
"extract_images": extract_images,
|
"extract_images": extract_images,
|
||||||
"organize_files": organize_files,
|
"create_splits": create_splits,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,11 +47,6 @@ def _download_wiki_page(args: tuple[int, int, str, str, str, int]) -> tuple[bool
|
|||||||
"""
|
"""
|
||||||
order_i, total, url, save_path, save_type, proc_id = args
|
order_i, total, url, save_path, save_type, proc_id = args
|
||||||
|
|
||||||
if is_pdf_downloaded(save_path):
|
|
||||||
if proc_id == 0:
|
|
||||||
logger.info(f"{order_i} / {total} - {save_path} already downloaded")
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with sync_playwright() as p:
|
with sync_playwright() as p:
|
||||||
browser = p.chromium.launch(headless=True)
|
browser = p.chromium.launch(headless=True)
|
||||||
@ -78,7 +73,7 @@ def download_wiki_page(
|
|||||||
urls: list[str],
|
urls: list[str],
|
||||||
save_paths: list[str],
|
save_paths: list[str],
|
||||||
save_type: str,
|
save_type: str,
|
||||||
result_jsonl_path: str,
|
result_log_dir: str,
|
||||||
proc_id: int = 0,
|
proc_id: int = 0,
|
||||||
n_proc: int = 1
|
n_proc: int = 1
|
||||||
) -> list[bool]:
|
) -> list[bool]:
|
||||||
@ -88,7 +83,7 @@ def download_wiki_page(
|
|||||||
urls (List[str]): List of Wikipedia URLs to download.
|
urls (List[str]): List of Wikipedia URLs to download.
|
||||||
save_paths (List[str]): List of paths where each downloaded file will be saved.
|
save_paths (List[str]): List of paths where each downloaded file will be saved.
|
||||||
save_type (str): File type to save each page as ('pdf' or 'png').
|
save_type (str): File type to save each page as ('pdf' or 'png').
|
||||||
result_jsonl_path (str): Path to the JSONL file where download results will be logged.
|
result_log_dir (str): Path to the directory where the download results will be logged.
|
||||||
proc_id (int, optional): Process ID for parallel processing. Defaults to 0.
|
proc_id (int, optional): Process ID for parallel processing. Defaults to 0.
|
||||||
n_proc (int, optional): Total number of processes running in parallel. Defaults to 1.
|
n_proc (int, optional): Total number of processes running in parallel. Defaults to 1.
|
||||||
|
|
||||||
@ -99,13 +94,15 @@ def download_wiki_page(
|
|||||||
all_args = [(i, total, url, str(save_path), save_type, proc_id)
|
all_args = [(i, total, url, str(save_path), save_type, proc_id)
|
||||||
for i, (url, save_path) in enumerate(zip(urls, save_paths))]
|
for i, (url, save_path) in enumerate(zip(urls, save_paths))]
|
||||||
|
|
||||||
|
# create log directory if it doesn't exist
|
||||||
|
log_dir = Path(result_log_dir)
|
||||||
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
pbar = tqdm(total=len(all_args), ncols=100, disable=not (proc_id == 0))
|
pbar = tqdm(total=len(all_args), ncols=100, disable=not (proc_id == 0))
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
n_downloaded = 0
|
n_downloaded = 0
|
||||||
|
|
||||||
# Log results to a JSONL file
|
|
||||||
with jsonlines.open(result_jsonl_path, 'w') as writer:
|
|
||||||
for args in all_args:
|
for args in all_args:
|
||||||
downloaded, error = _download_wiki_page(args)
|
downloaded, error = _download_wiki_page(args)
|
||||||
|
|
||||||
@ -116,6 +113,10 @@ def download_wiki_page(
|
|||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
results.append(downloaded)
|
results.append(downloaded)
|
||||||
|
|
||||||
|
# Write to process-specific log file
|
||||||
|
proc_result_path = log_dir / f'process_{proc_id}_{n_proc}.jsonl'
|
||||||
|
with jsonlines.open(proc_result_path, mode='a') as writer:
|
||||||
writer.write({
|
writer.write({
|
||||||
'downloaded': downloaded,
|
'downloaded': downloaded,
|
||||||
'args': [arg if not isinstance(arg, Path) else str(arg) for arg in args],
|
'args': [arg if not isinstance(arg, Path) else str(arg) for arg in args],
|
||||||
|
@ -30,19 +30,15 @@ import json
|
|||||||
import jsonlines
|
import jsonlines
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
def create_split_dirs(
|
|
||||||
all_pdf_dir: str | Path,
|
def create_split_files(
|
||||||
target_dir_base: str | Path,
|
|
||||||
split_metadata_file: str | Path,
|
split_metadata_file: str | Path,
|
||||||
split: str
|
split: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Copies specified PDF files into a target directory based on a given split.
|
"""Create the per-split doc ids.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
all_pdf_dir (Union[str, Path]): Path to the directory containing all downloaded PDF files.
|
|
||||||
target_dir_base (Union[str, Path]): Base directory where the split-specific directory will be created.
|
|
||||||
split_metadata_file (Union[str, Path]): Path to the metadata JSONL file for the split.
|
split_metadata_file (Union[str, Path]): Path to the metadata JSONL file for the split.
|
||||||
split (str): Split type ('train' or 'dev').
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If the JSONL metadata file does not exist.
|
FileNotFoundError: If the JSONL metadata file does not exist.
|
||||||
@ -52,10 +48,6 @@ def create_split_dirs(
|
|||||||
if split not in {"train", "dev"}:
|
if split not in {"train", "dev"}:
|
||||||
raise ValueError(f"Invalid split: {split}. Expected 'train' or 'dev'.")
|
raise ValueError(f"Invalid split: {split}. Expected 'train' or 'dev'.")
|
||||||
|
|
||||||
all_pdf_dir = Path(all_pdf_dir)
|
|
||||||
target_dir = Path(target_dir_base) / f'pdfs_{split}'
|
|
||||||
target_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Validate metadata file
|
# Validate metadata file
|
||||||
split_metadata_file = Path(split_metadata_file)
|
split_metadata_file = Path(split_metadata_file)
|
||||||
if not split_metadata_file.exists():
|
if not split_metadata_file.exists():
|
||||||
@ -76,18 +68,3 @@ def create_split_dirs(
|
|||||||
with open(split_doc_ids_output_path, 'w') as f:
|
with open(split_doc_ids_output_path, 'w') as f:
|
||||||
json.dump(split_doc_ids, f, indent=4)
|
json.dump(split_doc_ids, f, indent=4)
|
||||||
logger.info(f"Split {split} -> saved doc IDs at {split_doc_ids_output_path}")
|
logger.info(f"Split {split} -> saved doc IDs at {split_doc_ids_output_path}")
|
||||||
|
|
||||||
# Copy PDF files to the target directory
|
|
||||||
missing_files = []
|
|
||||||
for doc_id in split_doc_ids:
|
|
||||||
pdf_file = all_pdf_dir / f"{doc_id}.pdf"
|
|
||||||
if pdf_file.exists():
|
|
||||||
shutil.copy(pdf_file, target_dir / pdf_file.name)
|
|
||||||
else:
|
|
||||||
missing_files.append(pdf_file)
|
|
||||||
|
|
||||||
if missing_files:
|
|
||||||
logger.warning(f"Warning: {len(missing_files)} files are missing and will be skipped.")
|
|
||||||
for missing_file in missing_files:
|
|
||||||
logger.warning(f" Missing: {missing_file}")
|
|
||||||
|
|
||||||
|
@ -1,107 +0,0 @@
|
|||||||
# Copyright 2024 Bloomberg Finance L.P.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
|
||||||
import json
|
|
||||||
import jsonlines
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from m3docvqa.split_utils import create_split_dirs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_pdf_directory(tmp_path):
|
|
||||||
# Create a temporary directory for PDFs
|
|
||||||
pdf_dir = tmp_path / "pdfs"
|
|
||||||
pdf_dir.mkdir()
|
|
||||||
# Add some mock PDF files
|
|
||||||
(pdf_dir / "doc1.pdf").write_text("PDF content for doc1")
|
|
||||||
(pdf_dir / "doc2.pdf").write_text("PDF content for doc2")
|
|
||||||
return pdf_dir
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_metadata_file(tmp_path):
|
|
||||||
# Create a temporary metadata file in JSONL format
|
|
||||||
metadata_file = tmp_path / "MMQA_train.jsonl"
|
|
||||||
data = [
|
|
||||||
{"supporting_context": [{"doc_id": "doc1"}]},
|
|
||||||
{"supporting_context": [{"doc_id": "doc2"}]}
|
|
||||||
]
|
|
||||||
with jsonlines.open(metadata_file, mode='w') as writer:
|
|
||||||
writer.write_all(data)
|
|
||||||
return metadata_file
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_target_directory(tmp_path):
|
|
||||||
return tmp_path / "target"
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_split_dirs(mock_pdf_directory, mock_metadata_file, mock_target_directory):
|
|
||||||
"""Test the create_split_dirs function."""
|
|
||||||
# Prepare the split directory
|
|
||||||
split = "train"
|
|
||||||
|
|
||||||
# Call the function to create split directories
|
|
||||||
create_split_dirs(
|
|
||||||
all_pdf_dir=mock_pdf_directory,
|
|
||||||
target_dir_base=mock_target_directory,
|
|
||||||
split_metadata_file=mock_metadata_file,
|
|
||||||
split=split
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert that the target directory exists and contains the expected PDF files
|
|
||||||
target_dir = mock_target_directory / f"pdfs_{split}"
|
|
||||||
assert target_dir.exists(), f"Directory {target_dir} was not created"
|
|
||||||
assert (target_dir / "doc1.pdf").exists(), "doc1.pdf was not copied"
|
|
||||||
assert (target_dir / "doc2.pdf").exists(), "doc2.pdf was not copied"
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_split_dirs_missing_pdf(mock_metadata_file, mock_target_directory):
|
|
||||||
"""Test create_split_dirs when PDF files are missing."""
|
|
||||||
# Prepare the split directory
|
|
||||||
split = "train"
|
|
||||||
all_pdf_dir = Path("non_existing_pdf_dir")
|
|
||||||
|
|
||||||
# Call the function and verify that the missing PDFs are handled correctly
|
|
||||||
create_split_dirs(
|
|
||||||
all_pdf_dir=all_pdf_dir,
|
|
||||||
target_dir_base=mock_target_directory,
|
|
||||||
split_metadata_file=mock_metadata_file,
|
|
||||||
split=split
|
|
||||||
)
|
|
||||||
|
|
||||||
target_dir = mock_target_directory / f"pdfs_{split}"
|
|
||||||
assert target_dir.exists(), f"Directory {target_dir} was not created"
|
|
||||||
assert not (target_dir / "doc1.pdf").exists(), "doc1.pdf should not exist"
|
|
||||||
assert not (target_dir / "doc2.pdf").exists(), "doc2.pdf should not exist"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("split, expected_error", [
|
|
||||||
("test_split", ValueError), # Invalid split type
|
|
||||||
(None, ValueError), # Missing split
|
|
||||||
])
|
|
||||||
def test_create_split_dirs_invalid_split_type(mock_pdf_directory, mock_metadata_file, mock_target_directory, split, expected_error):
|
|
||||||
"""Test invalid split types in create_split_dirs."""
|
|
||||||
with pytest.raises(expected_error):
|
|
||||||
create_split_dirs(
|
|
||||||
all_pdf_dir=mock_pdf_directory,
|
|
||||||
target_dir_base=mock_target_directory,
|
|
||||||
split_metadata_file=mock_metadata_file,
|
|
||||||
split=split
|
|
||||||
)
|
|
Reference in New Issue
Block a user