enable --resuming from interrupted jobs.

This commit is contained in:
Nathan TeBlunthuis 2025-11-30 20:36:31 -08:00
parent 95b33123e3
commit 3c26185739
2 changed files with 265 additions and 4 deletions

View File

@ -30,7 +30,7 @@ from wikiq.wiki_diff_matcher import WikiDiffMatcher
TO_ENCODE = ("title", "editor")
PERSISTENCE_RADIUS = 7
DIFF_TIMEOUT = 30
DIFF_TIMEOUT = 60
from pathlib import Path
import pyarrow as pa
@ -241,10 +241,12 @@ class WikiqParser:
output_parquet: bool = True,
batch_size: int = 1024,
partition_namespaces: bool = False,
resume_from_revid: int = None,
):
"""
Parameters:
persist : what persistence method to use. Takes a PersistMethod value
resume_from_revid : if set, skip all revisions up to and including this revid
"""
self.input_file = input_file
@ -255,6 +257,7 @@ class WikiqParser:
self.diff = diff
self.text = text
self.partition_namespaces = partition_namespaces
self.resume_from_revid = resume_from_revid
if namespaces is not None:
self.namespace_filter = set(namespaces)
else:
@ -341,6 +344,18 @@ class WikiqParser:
# r'output/wikiq-\1-\2.tsv',
# input_filename)
# Track whether we've passed the resume point
found_resume_point = self.resume_from_revid is None
# When resuming with parquet, write new data to temp file and merge at the end
original_output_file = None
temp_output_file = None
if self.resume_from_revid is not None and self.output_parquet:
if isinstance(self.output_file, str) and os.path.exists(self.output_file):
original_output_file = self.output_file
temp_output_file = self.output_file + ".resume_temp"
self.output_file = temp_output_file
# Construct dump file iterator
dump = WikiqIterator(self.input_file, collapse_user=self.collapse_user)
@ -539,6 +554,28 @@ class WikiqParser:
n_revs = 0
# If we're resuming and haven't found the resume point yet, check this batch
skip_batch = False
if not found_resume_point and self.resume_from_revid is not None:
batch_has_resume_point = False
for revs in batch:
revs_list = list(revs)
for rev in revs_list:
if rev.id == self.resume_from_revid:
batch_has_resume_point = True
found_resume_point = True
print(f"Found resume point at revid {self.resume_from_revid}", file=sys.stderr)
break
if batch_has_resume_point:
break
# If this batch doesn't contain the resume point, skip it entirely
if not batch_has_resume_point:
skip_batch = True
if skip_batch:
continue
for revs in batch:
# Revisions may or may not be grouped into lists of contiguous revisions by the
# same user. We call these "edit sessions". Otherwise revs is a list containing
@ -702,9 +739,30 @@ class WikiqParser:
if not self.text and self.persist != PersistMethod.none:
del row_buffer["text"]
if self.partition_namespaces is True:
writer = pq_writers[page.mwpage.namespace]
writer.write(pa.record_batch(row_buffer, schema=schema))
# If we just found the resume point in this batch, filter to only write revisions after it
if self.resume_from_revid is not None:
revids = row_buffer["revid"]
# Find the index of the resume revid
resume_idx = None
for idx, revid in enumerate(revids):
if revid == self.resume_from_revid:
resume_idx = idx
break
if resume_idx is not None:
# Only write revisions after the resume point
if resume_idx + 1 < len(revids):
row_buffer = {k: v[resume_idx + 1:] for k, v in row_buffer.items()}
print(f"Resuming output starting at revid {row_buffer['revid'][0]}", file=sys.stderr)
else:
# The resume point was the last revision in this batch, skip writing
continue
# Only write if there are rows to write
if len(row_buffer.get("revid", [])) > 0:
if self.partition_namespaces is True:
writer = pq_writers[page.mwpage.namespace]
writer.write(pa.record_batch(row_buffer, schema=schema))
gc.collect()
page_count += 1
@ -718,6 +776,54 @@ class WikiqParser:
else:
writer.close()
# If we were resuming, merge the original file with the new temp file
if original_output_file is not None and temp_output_file is not None:
print("Merging resumed data with existing output...", file=sys.stderr)
try:
# Create a merged output file
merged_output_file = original_output_file + ".merged"
# Open the original file
original_pq = pq.ParquetFile(original_output_file)
# Open the temp file
temp_pq = pq.ParquetFile(temp_output_file)
# Create a writer for the merged file
merged_writer = None
# Copy all row groups from the original file
for i in range(original_pq.num_row_groups):
row_group = original_pq.read_row_group(i)
if merged_writer is None:
merged_writer = pq.ParquetWriter(
merged_output_file,
row_group.schema,
flavor="spark"
)
merged_writer.write_table(row_group)
# Append all row groups from the temp file
for i in range(temp_pq.num_row_groups):
row_group = temp_pq.read_row_group(i)
merged_writer.write_table(row_group)
# Close the writer
if merged_writer is not None:
merged_writer.close()
# Replace the original file with the merged file
os.remove(original_output_file)
os.rename(merged_output_file, original_output_file)
# Clean up the temp file
os.remove(temp_output_file)
print("Merge complete.", file=sys.stderr)
except Exception as e:
print(f"Error merging resume data: {e}", file=sys.stderr)
print(f"New data saved in: {temp_output_file}", file=sys.stderr)
raise
def match_archive_suffix(input_filename):
if re.match(r".*\.7z$", input_filename):
@ -758,6 +864,40 @@ def open_output_file(input_filename):
return output_file
def get_last_revid_from_parquet(output_file):
"""
Read the last revid from a parquet file by reading only the last row group.
Returns None if the file doesn't exist or is empty.
"""
try:
file_path = output_file
if not os.path.exists(file_path):
return None
# Open the parquet file
parquet_file = pq.ParquetFile(file_path)
# Get the number of row groups
num_row_groups = parquet_file.num_row_groups
if num_row_groups == 0:
return None
# Read only the last row group, and only the revid column
last_row_group = parquet_file.read_row_group(num_row_groups - 1, columns=['revid'])
if last_row_group.num_rows == 0:
return None
# Get the last revid from this row group
last_revid = last_row_group.column('revid')[-1].as_py()
return last_revid
except Exception as e:
print(f"Error reading last revid from {file_path}: {e}", file=sys.stderr)
return None
def main():
parser = argparse.ArgumentParser(
description="Parse MediaWiki XML database dumps into tab delimited data."
@ -910,6 +1050,13 @@ def main():
help="How many revisions to process in each batch. This ends up being the Parquet row group size",
)
parser.add_argument(
"--resume",
dest="resume",
action="store_true",
help="Resume processing from the last successfully written revision in the output file.",
)
args = parser.parse_args()
# set persistence method
@ -954,6 +1101,18 @@ def main():
else:
output_file = output
# Handle resume functionality
resume_from_revid = None
if args.resume:
if output_parquet and not args.stdout:
resume_from_revid = get_last_revid_from_parquet(output_file)
if resume_from_revid is not None:
print(f"Resuming from last written revid: {resume_from_revid}", file=sys.stderr)
else:
print("Resume requested but no existing output file found, starting from beginning", file=sys.stderr)
else:
print("Warning: --resume only works with parquet output (not stdout or TSV)", file=sys.stderr)
wikiq = WikiqParser(
input_file,
output_file,
@ -970,6 +1129,7 @@ def main():
output_parquet=output_parquet,
partition_namespaces=args.partition_namespaces,
batch_size=args.batch_size,
resume_from_revid=resume_from_revid,
)
wikiq.process()
@ -978,6 +1138,9 @@ def main():
input_file.close()
else:
if args.resume:
print("Warning: --resume cannot be used with stdin/stdout", file=sys.stderr)
wikiq = WikiqParser(
sys.stdin,
sys.stdout,
@ -993,6 +1156,7 @@ def main():
diff=args.diff,
text=args.text,
batch_size=args.batch_size,
resume_from_revid=None,
)
wikiq.process()

View File

@ -439,3 +439,100 @@ def test_parquet():
pytest.fail(exc)
# assert_frame_equal(test, baseline, check_like=True, check_dtype=False)
def test_resume():
"""Test that --resume properly resumes processing from the last written revid."""
import pyarrow.parquet as pq
# First, create a complete baseline output
tester_full = WikiqTester(SAILORMOON, "resume_full", in_compression="7z", out_format="parquet")
try:
tester_full.call_wikiq("--fandom-2020")
except subprocess.CalledProcessError as exc:
pytest.fail(exc.stderr.decode("utf8"))
# Read the full output
full_output_path = os.path.join(tester_full.output, f"{SAILORMOON}.parquet")
full_table = pq.read_table(full_output_path)
# Get the middle revid to use as the resume point
middle_idx = len(full_table) // 2
resume_revid = full_table.column("revid")[middle_idx].as_py()
print(f"Total revisions: {len(full_table)}, Resume point: {middle_idx}, Resume revid: {resume_revid}")
# Create a partial output by copying row groups to preserve the exact schema
tester_partial = WikiqTester(SAILORMOON, "resume_partial", in_compression="7z", out_format="parquet")
partial_output_path = os.path.join(tester_partial.output, f"{SAILORMOON}.parquet")
# Create partial output by filtering the table and writing with the same schema
partial_table = full_table.slice(0, middle_idx + 1)
pq.write_table(partial_table, partial_output_path)
# Now resume from the partial output
try:
tester_partial.call_wikiq("--fandom-2020", "--resume")
except subprocess.CalledProcessError as exc:
pytest.fail(exc.stderr.decode("utf8"))
# Read the resumed output
resumed_table = pq.read_table(partial_output_path)
# The resumed output should match the full output
# Convert to dataframes for comparison, sorting by revid
resumed_df = resumed_table.to_pandas().sort_values("revid").reset_index(drop=True)
full_df = full_table.to_pandas().sort_values("revid").reset_index(drop=True)
# Compare the dataframes
assert_frame_equal(resumed_df, full_df, check_like=True, check_dtype=False)
print(f"Resume test passed! Original: {len(full_df)} rows, Resumed: {len(resumed_df)} rows")
def test_resume_with_diff():
"""Test that --resume works correctly with diff computation."""
import pyarrow.parquet as pq
# First, create a complete baseline output with diff
tester_full = WikiqTester(SAILORMOON, "resume_diff_full", in_compression="7z", out_format="parquet")
try:
tester_full.call_wikiq("--diff", "--fandom-2020")
except subprocess.CalledProcessError as exc:
pytest.fail(exc.stderr.decode("utf8"))
# Read the full output
full_output_path = os.path.join(tester_full.output, f"{SAILORMOON}.parquet")
full_table = pq.read_table(full_output_path)
# Get a revid about 1/3 through to use as the resume point
resume_idx = len(full_table) // 3
resume_revid = full_table.column("revid")[resume_idx].as_py()
print(f"Total revisions: {len(full_table)}, Resume point: {resume_idx}, Resume revid: {resume_revid}")
# Create a partial output by filtering the table to preserve the exact schema
tester_partial = WikiqTester(SAILORMOON, "resume_diff_partial", in_compression="7z", out_format="parquet")
partial_output_path = os.path.join(tester_partial.output, f"{SAILORMOON}.parquet")
# Create partial output by slicing the table
partial_table = full_table.slice(0, resume_idx + 1)
pq.write_table(partial_table, partial_output_path)
# Now resume from the partial output
try:
tester_partial.call_wikiq("--diff", "--fandom-2020", "--resume")
except subprocess.CalledProcessError as exc:
pytest.fail(exc.stderr.decode("utf8"))
# Read the resumed output
resumed_table = pq.read_table(partial_output_path)
# Convert to dataframes for comparison, sorting by revid
resumed_df = resumed_table.to_pandas().sort_values("revid").reset_index(drop=True)
full_df = full_table.to_pandas().sort_values("revid").reset_index(drop=True)
# Compare the dataframes
assert_frame_equal(resumed_df, full_df, check_like=True, check_dtype=False)
print(f"Resume with diff test passed! Original: {len(full_df)} rows, Resumed: {len(resumed_df)} rows")