improve resume logic.

This commit is contained in:
Nathan TeBlunthuis 2025-12-07 06:06:26 -08:00
parent 577ddc87f5
commit 783f5fd8bc
2 changed files with 165 additions and 54 deletions

View File

@ -33,6 +33,7 @@ from wikiq.resume import (
get_resume_point, get_resume_point,
setup_resume_temp_output, setup_resume_temp_output,
finalize_resume_merge, finalize_resume_merge,
get_checkpoint_path,
) )
TO_ENCODE = ("title", "editor") TO_ENCODE = ("title", "editor")
@ -309,10 +310,49 @@ class WikiqParser:
else: else:
self.output_file = open(output_file, "wb") self.output_file = open(output_file, "wb")
# Checkpoint file for tracking resume point
self.checkpoint_file = None
self.checkpoint_state = {} # namespace -> (pageid, revid) or None -> (pageid, revid)
def request_shutdown(self): def request_shutdown(self):
"""Request graceful shutdown. The process() method will exit after completing the current batch.""" """Request graceful shutdown. The process() method will exit after completing the current batch."""
self.shutdown_requested = True self.shutdown_requested = True
def _open_checkpoint(self, output_file):
"""Open checkpoint file for writing. Keeps file open for performance."""
if not self.output_parquet or output_file == sys.stdout.buffer:
return
checkpoint_path = get_checkpoint_path(output_file)
Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
self.checkpoint_file = open(checkpoint_path, 'w')
print(f"Checkpoint file opened: {checkpoint_path}", file=sys.stderr)
def _update_checkpoint(self, pageid, revid, namespace=None):
"""Update checkpoint state and write to file."""
if self.checkpoint_file is None:
return
if self.partition_namespaces:
self.checkpoint_state[namespace] = {"pageid": pageid, "revid": revid}
else:
self.checkpoint_state = {"pageid": pageid, "revid": revid}
self.checkpoint_file.seek(0)
self.checkpoint_file.truncate()
json.dump(self.checkpoint_state, self.checkpoint_file)
self.checkpoint_file.flush()
def _close_checkpoint(self, delete=False):
"""Close checkpoint file, optionally deleting it."""
if self.checkpoint_file is None:
return
checkpoint_path = self.checkpoint_file.name
self.checkpoint_file.close()
self.checkpoint_file = None
if delete and os.path.exists(checkpoint_path):
os.remove(checkpoint_path)
print(f"Checkpoint file deleted (processing complete): {checkpoint_path}", file=sys.stderr)
else:
print(f"Checkpoint file preserved for resume: {checkpoint_path}", file=sys.stderr)
def _write_batch(self, row_buffer, schema, writer, pq_writers, ns_paths, sorting_cols, namespace=None): def _write_batch(self, row_buffer, schema, writer, pq_writers, ns_paths, sorting_cols, namespace=None):
"""Write a batch of rows to the appropriate writer. """Write a batch of rows to the appropriate writer.
@ -407,6 +447,11 @@ class WikiqParser:
if temp_output_file is not None: if temp_output_file is not None:
self.output_file = temp_output_file self.output_file = temp_output_file
# Open checkpoint file for tracking resume point
# Use original_output_file if resuming, otherwise self.output_file
checkpoint_output = original_output_file if original_output_file else self.output_file
self._open_checkpoint(checkpoint_output)
# Construct dump file iterator # Construct dump file iterator
dump = WikiqIterator(self.input_file, collapse_user=self.collapse_user) dump = WikiqIterator(self.input_file, collapse_user=self.collapse_user)
@ -868,6 +913,10 @@ class WikiqParser:
if should_write and len(row_buffer.get("revid", [])) > 0: if should_write and len(row_buffer.get("revid", [])) > 0:
namespace = page.mwpage.namespace if self.partition_namespaces else None namespace = page.mwpage.namespace if self.partition_namespaces else None
self._write_batch(row_buffer, schema, writer, pq_writers, ns_paths, sorting_cols, namespace) self._write_batch(row_buffer, schema, writer, pq_writers, ns_paths, sorting_cols, namespace)
# Update checkpoint with last written position
last_pageid = row_buffer["articleid"][-1]
last_revid = row_buffer["revid"][-1]
self._update_checkpoint(last_pageid, last_revid, namespace)
gc.collect() gc.collect()
# If shutdown was requested, break from page loop # If shutdown was requested, break from page loop
@ -894,6 +943,9 @@ class WikiqParser:
original_partition_dir original_partition_dir
) )
# Close checkpoint file; delete it only if we completed without interruption
self._close_checkpoint(delete=not self.shutdown_requested)
def match_archive_suffix(input_filename): def match_archive_suffix(input_filename):
if re.match(r".*\.7z$", input_filename): if re.match(r".*\.7z$", input_filename):
cmd = ["7za", "x", "-so", input_filename] cmd = ["7za", "x", "-so", input_filename]
@ -1155,9 +1207,7 @@ def main():
print(args, file=sys.stderr) print(args, file=sys.stderr)
if len(args.dumpfiles) > 0: if len(args.dumpfiles) > 0:
for filename in args.dumpfiles: for filename in args.dumpfiles:
input_file = open_input_file(filename, args.fandom_2020) # Determine output file path before opening input (so resume errors are caught early)
# open directory for output
if args.output: if args.output:
output = args.output[0] output = args.output[0]
else: else:
@ -1165,25 +1215,21 @@ def main():
output_parquet = output.endswith(".parquet") output_parquet = output.endswith(".parquet")
print("Processing file: %s" % filename, file=sys.stderr)
if args.stdout: if args.stdout:
# Parquet libraries need a binary output, so just sys.stdout doesn't work.
output_file = sys.stdout.buffer output_file = sys.stdout.buffer
elif os.path.isdir(output) or output_parquet: elif os.path.isdir(output) or output_parquet:
filename = os.path.join(output, os.path.basename(filename)) output_filename = os.path.join(output, os.path.basename(filename))
output_file = get_output_filename(filename, parquet=output_parquet) output_file = get_output_filename(output_filename, parquet=output_parquet)
else: else:
output_file = output output_file = output
# Handle resume functionality # Handle resume functionality before opening input file
resume_point = None resume_point = None
if args.resume: if args.resume:
if output_parquet and not args.stdout: if output_parquet and not args.stdout:
resume_point = get_resume_point(output_file, args.partition_namespaces) resume_point = get_resume_point(output_file, args.partition_namespaces)
if resume_point is not None: if resume_point is not None:
if args.partition_namespaces: if args.partition_namespaces:
# Dict mapping namespace -> (pageid, revid)
ns_list = sorted(resume_point.keys()) ns_list = sorted(resume_point.keys())
print(f"Resuming with per-namespace resume points for {len(ns_list)} namespaces", file=sys.stderr) print(f"Resuming with per-namespace resume points for {len(ns_list)} namespaces", file=sys.stderr)
for ns in ns_list: for ns in ns_list:
@ -1201,6 +1247,10 @@ def main():
else: else:
sys.exit("Error: --resume only works with parquet output (not stdout or TSV)") sys.exit("Error: --resume only works with parquet output (not stdout or TSV)")
# Now open the input file
print("Processing file: %s" % filename, file=sys.stderr)
input_file = open_input_file(filename, args.fandom_2020)
wikiq = WikiqParser( wikiq = WikiqParser(
input_file, input_file,
output_file, output_file,

View File

@ -4,20 +4,67 @@ Checkpoint and resume functionality for wikiq parquet output.
This module handles: This module handles:
- Finding resume points in existing parquet output - Finding resume points in existing parquet output
- Merging resumed data with existing output (streaming, memory-efficient) - Merging resumed data with existing output (streaming, memory-efficient)
- Checkpoint file management for fast resume point lookup
""" """
import json
import os import os
import sys import sys
import pyarrow.dataset as ds
import pyarrow.parquet as pq import pyarrow.parquet as pq
import pyarrow.compute as pc
def get_checkpoint_path(output_file):
"""Get the path to the checkpoint file for a given output file."""
return str(output_file) + ".checkpoint"
def read_checkpoint(output_file):
"""
Read resume point from checkpoint file if it exists.
Checkpoint format:
Single file: {"pageid": 54, "revid": 325}
Partitioned: {"0": {"pageid": 54, "revid": 325}, "1": {"pageid": 123, "revid": 456}}
Returns:
For single files: A tuple (pageid, revid), or None if not found.
For partitioned: A dict mapping namespace -> (pageid, revid), or None.
"""
checkpoint_path = get_checkpoint_path(output_file)
if not os.path.exists(checkpoint_path):
return None
try:
with open(checkpoint_path, 'r') as f:
data = json.load(f)
if not data:
return None
# Single-file format: {"pageid": ..., "revid": ...}
if "pageid" in data and "revid" in data:
return (data["pageid"], data["revid"])
# Partitioned format: {"0": {"pageid": ..., "revid": ...}, ...}
result = {}
for key, value in data.items():
result[int(key)] = (value["pageid"], value["revid"])
return result if result else None
except (json.JSONDecodeError, IOError, KeyError, TypeError) as e:
print(f"Warning: Could not read checkpoint file {checkpoint_path}: {e}", file=sys.stderr)
return None
def get_resume_point(output_file, partition_namespaces=False): def get_resume_point(output_file, partition_namespaces=False):
""" """
Find the resume point(s) from existing parquet output. Find the resume point(s) from existing parquet output.
First checks for a checkpoint file (fast), then falls back to scanning
the parquet output (slow, for backwards compatibility).
Args: Args:
output_file: Path to the output file. For single files, this is the parquet file path. output_file: Path to the output file. For single files, this is the parquet file path.
For partitioned namespaces, this is the path like dir/dump.parquet where For partitioned namespaces, this is the path like dir/dump.parquet where
@ -30,6 +77,14 @@ def get_resume_point(output_file, partition_namespaces=False):
For partitioned: A dict mapping namespace -> (pageid, revid) for each partition, For partitioned: A dict mapping namespace -> (pageid, revid) for each partition,
or None if no partitions exist. or None if no partitions exist.
""" """
# First try checkpoint file (fast)
checkpoint_result = read_checkpoint(output_file)
if checkpoint_result is not None:
print(f"Resume point found in checkpoint file", file=sys.stderr)
return checkpoint_result
# Fall back to scanning parquet (slow, for backwards compatibility)
print(f"No checkpoint file found, scanning parquet output...", file=sys.stderr)
try: try:
if partition_namespaces: if partition_namespaces:
return _get_resume_point_partitioned(output_file) return _get_resume_point_partitioned(output_file)
@ -40,14 +95,40 @@ def get_resume_point(output_file, partition_namespaces=False):
return None return None
def _get_last_row_resume_point(pq_path):
"""Get resume point by reading only the last row group of a parquet file.
Since data is written in page/revision order, the last row group contains
the highest pageid/revid, and the last row in that group is the resume point.
"""
pf = pq.ParquetFile(pq_path)
if pf.metadata.num_row_groups == 0:
return None
last_rg_idx = pf.metadata.num_row_groups - 1
table = pf.read_row_group(last_rg_idx, columns=['articleid', 'revid'])
if table.num_rows == 0:
return None
max_pageid = table['articleid'][-1].as_py()
max_revid = table['revid'][-1].as_py()
return (max_pageid, max_revid)
def _get_resume_point_partitioned(output_file): def _get_resume_point_partitioned(output_file):
"""Find per-namespace resume points from partitioned output. """Find per-namespace resume points from partitioned output.
Returns a dict mapping namespace -> (max_pageid, max_revid) for each partition. Only looks for the specific output file in each namespace directory.
This allows resume to correctly handle cases where different namespaces have Returns a dict mapping namespace -> (max_pageid, max_revid) for each partition
different progress due to interleaved dump ordering. where the output file exists.
Args:
output_file: Path like 'dir/output.parquet' where namespace=* subdirectories
contain files named 'output.parquet'.
""" """
partition_dir = os.path.dirname(output_file) partition_dir = os.path.dirname(output_file)
output_filename = os.path.basename(output_file)
if not os.path.exists(partition_dir) or not os.path.isdir(partition_dir): if not os.path.exists(partition_dir) or not os.path.isdir(partition_dir):
return None return None
@ -58,32 +139,18 @@ def _get_resume_point_partitioned(output_file):
resume_points = {} resume_points = {}
for ns_dir in namespace_dirs: for ns_dir in namespace_dirs:
ns = int(ns_dir.split('=')[1]) ns = int(ns_dir.split('=')[1])
ns_path = os.path.join(partition_dir, ns_dir) pq_path = os.path.join(partition_dir, ns_dir, output_filename)
# Find parquet files in this namespace directory if not os.path.exists(pq_path):
parquet_files = [f for f in os.listdir(ns_path) if f.endswith('.parquet')]
if not parquet_files:
continue continue
# Read all parquet files in this namespace try:
for pq_file in parquet_files: result = _get_last_row_resume_point(pq_path)
pq_path = os.path.join(ns_path, pq_file) if result is not None:
try: resume_points[ns] = result
pf = pq.ParquetFile(pq_path) except Exception as e:
table = pf.read(columns=['articleid', 'revid']) print(f"Warning: Could not read {pq_path}: {e}", file=sys.stderr)
if table.num_rows == 0: continue
continue
max_pageid = pc.max(table['articleid']).as_py()
mask = pc.equal(table['articleid'], max_pageid)
max_revid = pc.max(pc.filter(table['revid'], mask)).as_py()
# Keep the highest pageid for this namespace
if ns not in resume_points or max_pageid > resume_points[ns][0]:
resume_points[ns] = (max_pageid, max_revid)
except Exception as e:
print(f"Warning: Could not read {pq_path}: {e}", file=sys.stderr)
continue
return resume_points if resume_points else None return resume_points if resume_points else None
@ -96,18 +163,7 @@ def _get_resume_point_single_file(output_file):
if os.path.isdir(output_file): if os.path.isdir(output_file):
return None return None
# Find the row with the highest pageid return _get_last_row_resume_point(output_file)
pf = pq.ParquetFile(output_file)
table = pf.read(columns=['articleid', 'revid'])
if table.num_rows == 0:
return None
max_pageid = pc.max(table['articleid']).as_py()
# Filter to row(s) with max pageid and get max revid
mask = pc.equal(table['articleid'], max_pageid)
max_revid = pc.max(pc.filter(table['revid'], mask)).as_py()
return (max_pageid, max_revid)
def merge_parquet_files(original_path, temp_path, merged_path): def merge_parquet_files(original_path, temp_path, merged_path):
@ -266,12 +322,17 @@ def setup_resume_temp_output(output_file, partition_namespaces):
temp_output_file = None temp_output_file = None
original_partition_dir = None original_partition_dir = None
# For partitioned namespaces, check if the partition directory exists # For partitioned namespaces, check if the specific output file exists in any namespace
if partition_namespaces: if partition_namespaces:
partition_dir = os.path.dirname(output_file) partition_dir = os.path.dirname(output_file)
output_exists = os.path.isdir(partition_dir) and any( output_filename = os.path.basename(output_file)
d.startswith('namespace=') for d in os.listdir(partition_dir) output_exists = False
) if os.path.isdir(partition_dir):
for d in os.listdir(partition_dir):
if d.startswith('namespace='):
if os.path.exists(os.path.join(partition_dir, d, output_filename)):
output_exists = True
break
if output_exists: if output_exists:
original_partition_dir = partition_dir original_partition_dir = partition_dir
else: else: