improve resume logic.

This commit is contained in:
Nathan TeBlunthuis
2025-12-07 06:06:26 -08:00
parent 577ddc87f5
commit 783f5fd8bc
2 changed files with 165 additions and 54 deletions

View File

@@ -4,20 +4,67 @@ Checkpoint and resume functionality for wikiq parquet output.
This module handles:
- Finding resume points in existing parquet output
- Merging resumed data with existing output (streaming, memory-efficient)
- Checkpoint file management for fast resume point lookup
"""
import json
import os
import sys
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import pyarrow.compute as pc
def get_checkpoint_path(output_file):
"""Get the path to the checkpoint file for a given output file."""
return str(output_file) + ".checkpoint"
def read_checkpoint(output_file):
"""
Read resume point from checkpoint file if it exists.
Checkpoint format:
Single file: {"pageid": 54, "revid": 325}
Partitioned: {"0": {"pageid": 54, "revid": 325}, "1": {"pageid": 123, "revid": 456}}
Returns:
For single files: A tuple (pageid, revid), or None if not found.
For partitioned: A dict mapping namespace -> (pageid, revid), or None.
"""
checkpoint_path = get_checkpoint_path(output_file)
if not os.path.exists(checkpoint_path):
return None
try:
with open(checkpoint_path, 'r') as f:
data = json.load(f)
if not data:
return None
# Single-file format: {"pageid": ..., "revid": ...}
if "pageid" in data and "revid" in data:
return (data["pageid"], data["revid"])
# Partitioned format: {"0": {"pageid": ..., "revid": ...}, ...}
result = {}
for key, value in data.items():
result[int(key)] = (value["pageid"], value["revid"])
return result if result else None
except (json.JSONDecodeError, IOError, KeyError, TypeError) as e:
print(f"Warning: Could not read checkpoint file {checkpoint_path}: {e}", file=sys.stderr)
return None
def get_resume_point(output_file, partition_namespaces=False):
"""
Find the resume point(s) from existing parquet output.
First checks for a checkpoint file (fast), then falls back to scanning
the parquet output (slow, for backwards compatibility).
Args:
output_file: Path to the output file. For single files, this is the parquet file path.
For partitioned namespaces, this is the path like dir/dump.parquet where
@@ -30,6 +77,14 @@ def get_resume_point(output_file, partition_namespaces=False):
For partitioned: A dict mapping namespace -> (pageid, revid) for each partition,
or None if no partitions exist.
"""
# First try checkpoint file (fast)
checkpoint_result = read_checkpoint(output_file)
if checkpoint_result is not None:
print(f"Resume point found in checkpoint file", file=sys.stderr)
return checkpoint_result
# Fall back to scanning parquet (slow, for backwards compatibility)
print(f"No checkpoint file found, scanning parquet output...", file=sys.stderr)
try:
if partition_namespaces:
return _get_resume_point_partitioned(output_file)
@@ -40,14 +95,40 @@ def get_resume_point(output_file, partition_namespaces=False):
return None
def _get_last_row_resume_point(pq_path):
"""Get resume point by reading only the last row group of a parquet file.
Since data is written in page/revision order, the last row group contains
the highest pageid/revid, and the last row in that group is the resume point.
"""
pf = pq.ParquetFile(pq_path)
if pf.metadata.num_row_groups == 0:
return None
last_rg_idx = pf.metadata.num_row_groups - 1
table = pf.read_row_group(last_rg_idx, columns=['articleid', 'revid'])
if table.num_rows == 0:
return None
max_pageid = table['articleid'][-1].as_py()
max_revid = table['revid'][-1].as_py()
return (max_pageid, max_revid)
def _get_resume_point_partitioned(output_file):
"""Find per-namespace resume points from partitioned output.
Returns a dict mapping namespace -> (max_pageid, max_revid) for each partition.
This allows resume to correctly handle cases where different namespaces have
different progress due to interleaved dump ordering.
Only looks for the specific output file in each namespace directory.
Returns a dict mapping namespace -> (max_pageid, max_revid) for each partition
where the output file exists.
Args:
output_file: Path like 'dir/output.parquet' where namespace=* subdirectories
contain files named 'output.parquet'.
"""
partition_dir = os.path.dirname(output_file)
output_filename = os.path.basename(output_file)
if not os.path.exists(partition_dir) or not os.path.isdir(partition_dir):
return None
@@ -58,32 +139,18 @@ def _get_resume_point_partitioned(output_file):
resume_points = {}
for ns_dir in namespace_dirs:
ns = int(ns_dir.split('=')[1])
ns_path = os.path.join(partition_dir, ns_dir)
pq_path = os.path.join(partition_dir, ns_dir, output_filename)
# Find parquet files in this namespace directory
parquet_files = [f for f in os.listdir(ns_path) if f.endswith('.parquet')]
if not parquet_files:
if not os.path.exists(pq_path):
continue
# Read all parquet files in this namespace
for pq_file in parquet_files:
pq_path = os.path.join(ns_path, pq_file)
try:
pf = pq.ParquetFile(pq_path)
table = pf.read(columns=['articleid', 'revid'])
if table.num_rows == 0:
continue
max_pageid = pc.max(table['articleid']).as_py()
mask = pc.equal(table['articleid'], max_pageid)
max_revid = pc.max(pc.filter(table['revid'], mask)).as_py()
# Keep the highest pageid for this namespace
if ns not in resume_points or max_pageid > resume_points[ns][0]:
resume_points[ns] = (max_pageid, max_revid)
except Exception as e:
print(f"Warning: Could not read {pq_path}: {e}", file=sys.stderr)
continue
try:
result = _get_last_row_resume_point(pq_path)
if result is not None:
resume_points[ns] = result
except Exception as e:
print(f"Warning: Could not read {pq_path}: {e}", file=sys.stderr)
continue
return resume_points if resume_points else None
@@ -96,18 +163,7 @@ def _get_resume_point_single_file(output_file):
if os.path.isdir(output_file):
return None
# Find the row with the highest pageid
pf = pq.ParquetFile(output_file)
table = pf.read(columns=['articleid', 'revid'])
if table.num_rows == 0:
return None
max_pageid = pc.max(table['articleid']).as_py()
# Filter to row(s) with max pageid and get max revid
mask = pc.equal(table['articleid'], max_pageid)
max_revid = pc.max(pc.filter(table['revid'], mask)).as_py()
return (max_pageid, max_revid)
return _get_last_row_resume_point(output_file)
def merge_parquet_files(original_path, temp_path, merged_path):
@@ -266,12 +322,17 @@ def setup_resume_temp_output(output_file, partition_namespaces):
temp_output_file = None
original_partition_dir = None
# For partitioned namespaces, check if the partition directory exists
# For partitioned namespaces, check if the specific output file exists in any namespace
if partition_namespaces:
partition_dir = os.path.dirname(output_file)
output_exists = os.path.isdir(partition_dir) and any(
d.startswith('namespace=') for d in os.listdir(partition_dir)
)
output_filename = os.path.basename(output_file)
output_exists = False
if os.path.isdir(partition_dir):
for d in os.listdir(partition_dir):
if d.startswith('namespace='):
if os.path.exists(os.path.join(partition_dir, d, output_filename)):
output_exists = True
break
if output_exists:
original_partition_dir = partition_dir
else: