enable --resuming from interrupted jobs.

This commit is contained in:
Nathan TeBlunthuis
2025-11-30 20:36:31 -08:00
parent 95b33123e3
commit 3c26185739
2 changed files with 265 additions and 4 deletions

View File

@@ -30,7 +30,7 @@ from wikiq.wiki_diff_matcher import WikiDiffMatcher
TO_ENCODE = ("title", "editor")
PERSISTENCE_RADIUS = 7
DIFF_TIMEOUT = 30
DIFF_TIMEOUT = 60
from pathlib import Path
import pyarrow as pa
@@ -241,10 +241,12 @@ class WikiqParser:
output_parquet: bool = True,
batch_size: int = 1024,
partition_namespaces: bool = False,
resume_from_revid: int = None,
):
"""
Parameters:
persist : what persistence method to use. Takes a PersistMethod value
resume_from_revid : if set, skip all revisions up to and including this revid
"""
self.input_file = input_file
@@ -255,6 +257,7 @@ class WikiqParser:
self.diff = diff
self.text = text
self.partition_namespaces = partition_namespaces
self.resume_from_revid = resume_from_revid
if namespaces is not None:
self.namespace_filter = set(namespaces)
else:
@@ -341,6 +344,18 @@ class WikiqParser:
# r'output/wikiq-\1-\2.tsv',
# input_filename)
# Track whether we've passed the resume point
found_resume_point = self.resume_from_revid is None
# When resuming with parquet, write new data to temp file and merge at the end
original_output_file = None
temp_output_file = None
if self.resume_from_revid is not None and self.output_parquet:
if isinstance(self.output_file, str) and os.path.exists(self.output_file):
original_output_file = self.output_file
temp_output_file = self.output_file + ".resume_temp"
self.output_file = temp_output_file
# Construct dump file iterator
dump = WikiqIterator(self.input_file, collapse_user=self.collapse_user)
@@ -539,6 +554,28 @@ class WikiqParser:
n_revs = 0
# If we're resuming and haven't found the resume point yet, check this batch
skip_batch = False
if not found_resume_point and self.resume_from_revid is not None:
batch_has_resume_point = False
for revs in batch:
revs_list = list(revs)
for rev in revs_list:
if rev.id == self.resume_from_revid:
batch_has_resume_point = True
found_resume_point = True
print(f"Found resume point at revid {self.resume_from_revid}", file=sys.stderr)
break
if batch_has_resume_point:
break
# If this batch doesn't contain the resume point, skip it entirely
if not batch_has_resume_point:
skip_batch = True
if skip_batch:
continue
for revs in batch:
# Revisions may or may not be grouped into lists of contiguous revisions by the
# same user. We call these "edit sessions". Otherwise revs is a list containing
@@ -702,9 +739,30 @@ class WikiqParser:
if not self.text and self.persist != PersistMethod.none:
del row_buffer["text"]
if self.partition_namespaces is True:
writer = pq_writers[page.mwpage.namespace]
writer.write(pa.record_batch(row_buffer, schema=schema))
# If we just found the resume point in this batch, filter to only write revisions after it
if self.resume_from_revid is not None:
revids = row_buffer["revid"]
# Find the index of the resume revid
resume_idx = None
for idx, revid in enumerate(revids):
if revid == self.resume_from_revid:
resume_idx = idx
break
if resume_idx is not None:
# Only write revisions after the resume point
if resume_idx + 1 < len(revids):
row_buffer = {k: v[resume_idx + 1:] for k, v in row_buffer.items()}
print(f"Resuming output starting at revid {row_buffer['revid'][0]}", file=sys.stderr)
else:
# The resume point was the last revision in this batch, skip writing
continue
# Only write if there are rows to write
if len(row_buffer.get("revid", [])) > 0:
if self.partition_namespaces is True:
writer = pq_writers[page.mwpage.namespace]
writer.write(pa.record_batch(row_buffer, schema=schema))
gc.collect()
page_count += 1
@@ -718,6 +776,54 @@ class WikiqParser:
else:
writer.close()
# If we were resuming, merge the original file with the new temp file
if original_output_file is not None and temp_output_file is not None:
print("Merging resumed data with existing output...", file=sys.stderr)
try:
# Create a merged output file
merged_output_file = original_output_file + ".merged"
# Open the original file
original_pq = pq.ParquetFile(original_output_file)
# Open the temp file
temp_pq = pq.ParquetFile(temp_output_file)
# Create a writer for the merged file
merged_writer = None
# Copy all row groups from the original file
for i in range(original_pq.num_row_groups):
row_group = original_pq.read_row_group(i)
if merged_writer is None:
merged_writer = pq.ParquetWriter(
merged_output_file,
row_group.schema,
flavor="spark"
)
merged_writer.write_table(row_group)
# Append all row groups from the temp file
for i in range(temp_pq.num_row_groups):
row_group = temp_pq.read_row_group(i)
merged_writer.write_table(row_group)
# Close the writer
if merged_writer is not None:
merged_writer.close()
# Replace the original file with the merged file
os.remove(original_output_file)
os.rename(merged_output_file, original_output_file)
# Clean up the temp file
os.remove(temp_output_file)
print("Merge complete.", file=sys.stderr)
except Exception as e:
print(f"Error merging resume data: {e}", file=sys.stderr)
print(f"New data saved in: {temp_output_file}", file=sys.stderr)
raise
def match_archive_suffix(input_filename):
if re.match(r".*\.7z$", input_filename):
@@ -758,6 +864,40 @@ def open_output_file(input_filename):
return output_file
def get_last_revid_from_parquet(output_file):
"""
Read the last revid from a parquet file by reading only the last row group.
Returns None if the file doesn't exist or is empty.
"""
try:
file_path = output_file
if not os.path.exists(file_path):
return None
# Open the parquet file
parquet_file = pq.ParquetFile(file_path)
# Get the number of row groups
num_row_groups = parquet_file.num_row_groups
if num_row_groups == 0:
return None
# Read only the last row group, and only the revid column
last_row_group = parquet_file.read_row_group(num_row_groups - 1, columns=['revid'])
if last_row_group.num_rows == 0:
return None
# Get the last revid from this row group
last_revid = last_row_group.column('revid')[-1].as_py()
return last_revid
except Exception as e:
print(f"Error reading last revid from {file_path}: {e}", file=sys.stderr)
return None
def main():
parser = argparse.ArgumentParser(
description="Parse MediaWiki XML database dumps into tab delimited data."
@@ -910,6 +1050,13 @@ def main():
help="How many revisions to process in each batch. This ends up being the Parquet row group size",
)
parser.add_argument(
"--resume",
dest="resume",
action="store_true",
help="Resume processing from the last successfully written revision in the output file.",
)
args = parser.parse_args()
# set persistence method
@@ -954,6 +1101,18 @@ def main():
else:
output_file = output
# Handle resume functionality
resume_from_revid = None
if args.resume:
if output_parquet and not args.stdout:
resume_from_revid = get_last_revid_from_parquet(output_file)
if resume_from_revid is not None:
print(f"Resuming from last written revid: {resume_from_revid}", file=sys.stderr)
else:
print("Resume requested but no existing output file found, starting from beginning", file=sys.stderr)
else:
print("Warning: --resume only works with parquet output (not stdout or TSV)", file=sys.stderr)
wikiq = WikiqParser(
input_file,
output_file,
@@ -970,6 +1129,7 @@ def main():
output_parquet=output_parquet,
partition_namespaces=args.partition_namespaces,
batch_size=args.batch_size,
resume_from_revid=resume_from_revid,
)
wikiq.process()
@@ -978,6 +1138,9 @@ def main():
input_file.close()
else:
if args.resume:
print("Warning: --resume cannot be used with stdin/stdout", file=sys.stderr)
wikiq = WikiqParser(
sys.stdin,
sys.stdout,
@@ -993,6 +1156,7 @@ def main():
diff=args.diff,
text=args.text,
batch_size=args.batch_size,
resume_from_revid=None,
)
wikiq.process()