Compare commits
10 Commits
6a4bf81e1a
...
c7eb374ceb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7eb374ceb | ||
|
|
4b8288c016 | ||
|
|
8590e5f920 | ||
|
|
93f6ed0ff5 | ||
|
|
5ebdb26d82 | ||
|
|
9e6b0fb64c | ||
|
|
d822085698 | ||
|
|
618c343898 | ||
|
|
3f1a9ba862 | ||
|
|
6988a281dc |
@@ -14,6 +14,7 @@ dependencies = [
|
||||
"mwtypes>=0.4.0",
|
||||
"mwxml>=0.3.6",
|
||||
"pyarrow>=20.0.0",
|
||||
"pyspark>=3.5.0",
|
||||
"pywikidiff2",
|
||||
"sortedcontainers>=2.4.0",
|
||||
"yamlconf>=0.2.6",
|
||||
@@ -21,13 +22,14 @@ dependencies = [
|
||||
|
||||
[project.scripts]
|
||||
wikiq = "wikiq:main"
|
||||
wikiq-spark = "wikiq_spark:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/wikiq"]
|
||||
packages = ["src/wikiq", "src/wikiq_spark"]
|
||||
|
||||
[tool.uv.sources]
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,19 +1,77 @@
|
||||
"""
|
||||
Checkpoint and resume functionality for wikiq parquet output.
|
||||
Checkpoint and resume functionality for wikiq output.
|
||||
|
||||
This module handles:
|
||||
- Finding resume points in existing parquet output
|
||||
- Merging resumed data with existing output (streaming, memory-efficient)
|
||||
- Finding resume points in existing output (JSONL or Parquet)
|
||||
- Merging resumed data with existing output (for Parquet, streaming, memory-efficient)
|
||||
- Checkpoint file management for fast resume point lookup
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import deque
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
|
||||
def get_checkpoint_path(output_file, partition_namespaces=False):
|
||||
"""Get the path to the checkpoint file for a given output file.
|
||||
|
||||
For partitioned output, the checkpoint is placed outside the partition directory
|
||||
to avoid pyarrow trying to read it as a parquet file. The filename includes
|
||||
the output filename to keep it unique per input file (for parallel jobs).
|
||||
"""
|
||||
if partition_namespaces:
|
||||
partition_dir = os.path.dirname(output_file)
|
||||
output_filename = os.path.basename(output_file)
|
||||
parent_dir = os.path.dirname(partition_dir)
|
||||
return os.path.join(parent_dir, output_filename + ".checkpoint")
|
||||
return str(output_file) + ".checkpoint"
|
||||
|
||||
|
||||
def read_checkpoint(checkpoint_path, partition_namespaces=False):
|
||||
"""
|
||||
Read resume point from checkpoint file if it exists.
|
||||
|
||||
Checkpoint format:
|
||||
Single file: {"pageid": 54, "revid": 325} or {"pageid": 54, "revid": 325, "part": 2}
|
||||
Partitioned: {"0": {"pageid": 54, "revid": 325, "part": 1}, ...}
|
||||
|
||||
Returns:
|
||||
For single files: A tuple (pageid, revid) or (pageid, revid, part), or None if not found.
|
||||
For partitioned: A dict mapping namespace -> (pageid, revid, part), or None.
|
||||
"""
|
||||
if not os.path.exists(checkpoint_path):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(checkpoint_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not data:
|
||||
return None
|
||||
|
||||
# Single-file format: {"pageid": ..., "revid": ..., "part": ...}
|
||||
if "pageid" in data and "revid" in data:
|
||||
part = data.get("part", 0)
|
||||
if part > 0:
|
||||
return (data["pageid"], data["revid"], part)
|
||||
return (data["pageid"], data["revid"])
|
||||
|
||||
# Partitioned format: {"0": {"pageid": ..., "revid": ..., "part": ...}, ...}
|
||||
result = {}
|
||||
for key, value in data.items():
|
||||
part = value.get("part", 0)
|
||||
result[int(key)] = (value["pageid"], value["revid"], part)
|
||||
|
||||
return result if result else None
|
||||
|
||||
except (json.JSONDecodeError, IOError, KeyError, TypeError) as e:
|
||||
print(f"Warning: Could not read checkpoint file {checkpoint_path}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_interrupted_resume(output_file, partition_namespaces):
|
||||
"""
|
||||
Merge any leftover .resume_temp files from a previous interrupted run.
|
||||
@@ -47,7 +105,6 @@ def cleanup_interrupted_resume(output_file, partition_namespaces):
|
||||
print(f"Found leftover temp files in {partition_dir} from previous interrupted partitioned run, merging first...", file=sys.stderr)
|
||||
had_corruption = merge_partitioned_namespaces(partition_dir, temp_suffix, output_filename)
|
||||
|
||||
# Check if any valid data remains after merge
|
||||
has_valid_data = False
|
||||
for ns_dir in os.listdir(partition_dir):
|
||||
if ns_dir.startswith('namespace='):
|
||||
@@ -58,7 +115,6 @@ def cleanup_interrupted_resume(output_file, partition_namespaces):
|
||||
break
|
||||
|
||||
if had_corruption and not has_valid_data:
|
||||
# All data was corrupted, remove checkpoint and start fresh
|
||||
checkpoint_path = get_checkpoint_path(output_file, partition_namespaces)
|
||||
if os.path.exists(checkpoint_path):
|
||||
os.remove(checkpoint_path)
|
||||
@@ -73,21 +129,17 @@ def cleanup_interrupted_resume(output_file, partition_namespaces):
|
||||
merged_path = output_file + ".merged"
|
||||
merged = merge_parquet_files(output_file, temp_output_file, merged_path)
|
||||
if merged == "original_only":
|
||||
# Temp file was invalid, just remove it
|
||||
os.remove(temp_output_file)
|
||||
elif merged == "temp_only":
|
||||
# Original was corrupted or missing, use temp as new base
|
||||
if os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
os.rename(temp_output_file, output_file)
|
||||
print("Recovered from temp file (original was corrupted or missing).", file=sys.stderr)
|
||||
elif merged == "both_invalid":
|
||||
# Both files corrupted or missing, remove both and start fresh
|
||||
if os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
if os.path.exists(temp_output_file):
|
||||
os.remove(temp_output_file)
|
||||
# Also remove stale checkpoint file
|
||||
checkpoint_path = get_checkpoint_path(output_file, partition_namespaces)
|
||||
if os.path.exists(checkpoint_path):
|
||||
os.remove(checkpoint_path)
|
||||
@@ -99,88 +151,88 @@ def cleanup_interrupted_resume(output_file, partition_namespaces):
|
||||
os.remove(temp_output_file)
|
||||
print("Previous temp file merged successfully.", file=sys.stderr)
|
||||
else:
|
||||
# Both empty - unusual
|
||||
os.remove(temp_output_file)
|
||||
|
||||
|
||||
def get_checkpoint_path(output_file, partition_namespaces=False):
|
||||
"""Get the path to the checkpoint file for a given output file.
|
||||
def get_jsonl_resume_point(output_file, input_file=None):
|
||||
"""Get resume point from last complete line of JSONL file.
|
||||
|
||||
For partitioned output, the checkpoint is placed outside the partition directory
|
||||
to avoid pyarrow trying to read it as a parquet file. The filename includes
|
||||
the output filename to keep it unique per input file (for parallel jobs).
|
||||
For .jsonl.d directories, derives the file path from input_file using get_output_filename.
|
||||
"""
|
||||
if partition_namespaces:
|
||||
# output_file is like partition_dir/output.parquet
|
||||
# checkpoint should be at parent level: parent/output.parquet.checkpoint
|
||||
partition_dir = os.path.dirname(output_file)
|
||||
output_filename = os.path.basename(output_file)
|
||||
parent_dir = os.path.dirname(partition_dir)
|
||||
return os.path.join(parent_dir, output_filename + ".checkpoint")
|
||||
return str(output_file) + ".checkpoint"
|
||||
# Handle .jsonl.d directory output
|
||||
if output_file.endswith('.jsonl.d'):
|
||||
if input_file is None:
|
||||
return None
|
||||
if os.path.isdir(output_file):
|
||||
# Import here to avoid circular import
|
||||
from wikiq import get_output_filename
|
||||
jsonl_filename = os.path.basename(get_output_filename(input_file, 'jsonl'))
|
||||
output_file = os.path.join(output_file, jsonl_filename)
|
||||
print(f"Looking for resume point in: {output_file}", file=sys.stderr)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def read_checkpoint(output_file, partition_namespaces=False):
|
||||
"""
|
||||
Read resume point from checkpoint file if it exists.
|
||||
|
||||
Checkpoint format:
|
||||
Single file: {"pageid": 54, "revid": 325}
|
||||
Partitioned: {"0": {"pageid": 54, "revid": 325}, "1": {"pageid": 123, "revid": 456}}
|
||||
|
||||
Returns:
|
||||
For single files: A tuple (pageid, revid), or None if not found.
|
||||
For partitioned: A dict mapping namespace -> (pageid, revid), or None.
|
||||
"""
|
||||
checkpoint_path = get_checkpoint_path(output_file, partition_namespaces)
|
||||
if not os.path.exists(checkpoint_path):
|
||||
if not os.path.exists(output_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(checkpoint_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
# Track positions of last two valid lines for potential truncation
|
||||
valid_lines = deque(maxlen=2) # (end_position, record)
|
||||
with open(output_file, 'rb') as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
try:
|
||||
record = json.loads(line.decode('utf-8'))
|
||||
valid_lines.append((f.tell(), record))
|
||||
except (json.JSONDecodeError, KeyError, UnicodeDecodeError):
|
||||
pass
|
||||
|
||||
if not data:
|
||||
if not valid_lines:
|
||||
return None
|
||||
|
||||
# Single-file format: {"pageid": ..., "revid": ...}
|
||||
if "pageid" in data and "revid" in data:
|
||||
return (data["pageid"], data["revid"])
|
||||
last_valid_pos, last_valid_record = valid_lines[-1]
|
||||
|
||||
# Partitioned format: {"0": {"pageid": ..., "revid": ...}, ...}
|
||||
result = {}
|
||||
for key, value in data.items():
|
||||
result[int(key)] = (value["pageid"], value["revid"])
|
||||
# Truncate if file extends past last valid line (corrupted trailing data)
|
||||
file_size = os.path.getsize(output_file)
|
||||
if last_valid_pos < file_size:
|
||||
print(f"Truncating corrupted data from {output_file} ({file_size - last_valid_pos} bytes)", file=sys.stderr)
|
||||
with open(output_file, 'r+b') as f:
|
||||
f.truncate(last_valid_pos)
|
||||
|
||||
return result if result else None
|
||||
|
||||
except (json.JSONDecodeError, IOError, KeyError, TypeError) as e:
|
||||
print(f"Warning: Could not read checkpoint file {checkpoint_path}: {e}", file=sys.stderr)
|
||||
return (last_valid_record['articleid'], last_valid_record['revid'])
|
||||
except IOError as e:
|
||||
print(f"Warning: Could not read {output_file}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def get_resume_point(output_file, partition_namespaces=False):
|
||||
def get_resume_point(output_file, partition_namespaces=False, input_file=None):
|
||||
"""
|
||||
Find the resume point(s) from existing parquet output.
|
||||
Find the resume point(s) from existing output.
|
||||
|
||||
First checks for a checkpoint file (fast), then falls back to scanning
|
||||
the parquet output (slow, for backwards compatibility).
|
||||
For JSONL: reads last line of file (no checkpoint needed).
|
||||
For Parquet: checks checkpoint file, falls back to scanning parquet.
|
||||
|
||||
Args:
|
||||
output_file: Path to the output file. For single files, this is the parquet file path.
|
||||
For partitioned namespaces, this is the path like dir/dump.parquet where
|
||||
namespace=* subdirectories are in the parent dir.
|
||||
output_file: Path to the output file.
|
||||
partition_namespaces: Whether the output uses namespace partitioning.
|
||||
input_file: Path to input file (needed for .jsonl.d directory output).
|
||||
|
||||
Returns:
|
||||
For single files: A tuple (pageid, revid) for the row with the highest pageid,
|
||||
or None if not found.
|
||||
For partitioned: A dict mapping namespace -> (pageid, revid) for each partition,
|
||||
or None if no partitions exist.
|
||||
For single files: A tuple (pageid, revid) or (pageid, revid, part), or None.
|
||||
For partitioned: A dict mapping namespace -> (pageid, revid, part), or None.
|
||||
"""
|
||||
# First try checkpoint file (fast)
|
||||
# For JSONL, read resume point directly from last line (no checkpoint needed)
|
||||
if output_file.endswith('.jsonl') or output_file.endswith('.jsonl.d'):
|
||||
result = get_jsonl_resume_point(output_file, input_file)
|
||||
if result:
|
||||
print(f"Resume point found from JSONL: pageid={result[0]}, revid={result[1]}", file=sys.stderr)
|
||||
return result
|
||||
|
||||
# For Parquet, use checkpoint file (fast)
|
||||
checkpoint_path = get_checkpoint_path(output_file, partition_namespaces)
|
||||
checkpoint_result = read_checkpoint(output_file, partition_namespaces)
|
||||
checkpoint_result = read_checkpoint(checkpoint_path, partition_namespaces)
|
||||
if checkpoint_result is not None:
|
||||
print(f"Resume point found in checkpoint file {checkpoint_path}", file=sys.stderr)
|
||||
return checkpoint_result
|
||||
@@ -198,11 +250,7 @@ def get_resume_point(output_file, partition_namespaces=False):
|
||||
|
||||
|
||||
def _get_last_row_resume_point(pq_path):
|
||||
"""Get resume point by reading only the last row group of a parquet file.
|
||||
|
||||
Since data is written in page/revision order, the last row group contains
|
||||
the highest pageid/revid, and the last row in that group is the resume point.
|
||||
"""
|
||||
"""Get resume point by reading only the last row group of a parquet file."""
|
||||
pf = pq.ParquetFile(pq_path)
|
||||
if pf.metadata.num_row_groups == 0:
|
||||
return None
|
||||
@@ -214,20 +262,11 @@ def _get_last_row_resume_point(pq_path):
|
||||
|
||||
max_pageid = table['articleid'][-1].as_py()
|
||||
max_revid = table['revid'][-1].as_py()
|
||||
return (max_pageid, max_revid)
|
||||
return (max_pageid, max_revid, 0)
|
||||
|
||||
|
||||
def _get_resume_point_partitioned(output_file):
|
||||
"""Find per-namespace resume points from partitioned output.
|
||||
|
||||
Only looks for the specific output file in each namespace directory.
|
||||
Returns a dict mapping namespace -> (max_pageid, max_revid) for each partition
|
||||
where the output file exists.
|
||||
|
||||
Args:
|
||||
output_file: Path like 'dir/output.parquet' where namespace=* subdirectories
|
||||
contain files named 'output.parquet'.
|
||||
"""
|
||||
"""Find per-namespace resume points from partitioned output."""
|
||||
partition_dir = os.path.dirname(output_file)
|
||||
output_filename = os.path.basename(output_file)
|
||||
|
||||
@@ -270,14 +309,13 @@ def _get_resume_point_single_file(output_file):
|
||||
|
||||
def merge_parquet_files(original_path, temp_path, merged_path):
|
||||
"""
|
||||
Merge two parquet files by streaming row groups from original and temp into merged.
|
||||
Merge two parquet files by streaming row groups.
|
||||
|
||||
This is memory-efficient: only one row group is loaded at a time.
|
||||
Returns:
|
||||
"merged" - merged file was created from both sources
|
||||
"original_only" - temp was invalid, keep original unchanged
|
||||
"temp_only" - original was corrupted but temp is valid, use temp
|
||||
"both_invalid" - both files invalid, delete both and start fresh
|
||||
"temp_only" - original was corrupted but temp is valid
|
||||
"both_invalid" - both files invalid
|
||||
False - both files were valid but empty
|
||||
"""
|
||||
original_valid = False
|
||||
@@ -293,12 +331,12 @@ def merge_parquet_files(original_path, temp_path, merged_path):
|
||||
|
||||
try:
|
||||
if not os.path.exists(temp_path):
|
||||
print(f"Note: Temp file {temp_path} does not exist (namespace had no records after resume point)", file=sys.stderr)
|
||||
print(f"Note: Temp file {temp_path} does not exist", file=sys.stderr)
|
||||
else:
|
||||
temp_pq = pq.ParquetFile(temp_path)
|
||||
temp_valid = True
|
||||
except Exception:
|
||||
print(f"Note: No new data in temp file {temp_path} (file exists but is invalid)", file=sys.stderr)
|
||||
print(f"Note: No new data in temp file {temp_path}", file=sys.stderr)
|
||||
|
||||
if not original_valid and not temp_valid:
|
||||
print(f"Both original and temp files are invalid, will start fresh", file=sys.stderr)
|
||||
@@ -313,7 +351,6 @@ def merge_parquet_files(original_path, temp_path, merged_path):
|
||||
|
||||
merged_writer = None
|
||||
|
||||
# Copy all row groups from the original file
|
||||
for i in range(original_pq.num_row_groups):
|
||||
row_group = original_pq.read_row_group(i)
|
||||
if merged_writer is None:
|
||||
@@ -324,7 +361,6 @@ def merge_parquet_files(original_path, temp_path, merged_path):
|
||||
)
|
||||
merged_writer.write_table(row_group)
|
||||
|
||||
# Append all row groups from the temp file
|
||||
for i in range(temp_pq.num_row_groups):
|
||||
row_group = temp_pq.read_row_group(i)
|
||||
if merged_writer is None:
|
||||
@@ -335,7 +371,6 @@ def merge_parquet_files(original_path, temp_path, merged_path):
|
||||
)
|
||||
merged_writer.write_table(row_group)
|
||||
|
||||
# Close the writer
|
||||
if merged_writer is not None:
|
||||
merged_writer.close()
|
||||
return "merged"
|
||||
@@ -346,16 +381,6 @@ def merge_partitioned_namespaces(partition_dir, temp_suffix, file_filter):
|
||||
"""
|
||||
Merge partitioned namespace directories after resume.
|
||||
|
||||
For partitioned namespaces, temp files are written alongside the original files
|
||||
in each namespace directory with the temp suffix appended to the filename.
|
||||
E.g., original: namespace=0/file.parquet, temp: namespace=0/file.parquet.resume_temp
|
||||
|
||||
Args:
|
||||
partition_dir: The partition directory containing namespace=* subdirs
|
||||
temp_suffix: The suffix appended to temp files (e.g., '.resume_temp')
|
||||
file_filter: Only process temp files matching this base name
|
||||
(e.g., 'enwiki-20250123-pages-meta-history24-p53238682p53445302.parquet')
|
||||
|
||||
Returns:
|
||||
True if at least one namespace has valid data after merge
|
||||
False if all namespaces ended up with corrupted/deleted data
|
||||
@@ -371,49 +396,40 @@ def merge_partitioned_namespaces(partition_dir, temp_suffix, file_filter):
|
||||
if not os.path.exists(temp_path):
|
||||
continue
|
||||
|
||||
# Original file is the temp file without the suffix
|
||||
original_file = file_filter
|
||||
original_path = os.path.join(ns_path, original_file)
|
||||
|
||||
if os.path.exists(original_path):
|
||||
# Merge the files
|
||||
merged_path = original_path + ".merged"
|
||||
merged = merge_parquet_files(original_path, temp_path, merged_path)
|
||||
|
||||
if merged == "original_only":
|
||||
# Temp file was invalid (no new data), keep original unchanged
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
elif merged == "temp_only":
|
||||
# Original was corrupted, use temp as new base
|
||||
os.remove(original_path)
|
||||
os.rename(temp_path, original_path)
|
||||
elif merged == "both_invalid":
|
||||
# Both files corrupted, remove both
|
||||
if os.path.exists(original_path):
|
||||
os.remove(original_path)
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
had_corruption = True
|
||||
elif merged == "merged":
|
||||
# Replace the original file with the merged file
|
||||
os.remove(original_path)
|
||||
os.rename(merged_path, original_path)
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
else:
|
||||
# Both files were empty (False), just remove them
|
||||
if os.path.exists(original_path):
|
||||
os.remove(original_path)
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
else:
|
||||
# No original file, rename temp to original only if valid
|
||||
try:
|
||||
pq.ParquetFile(temp_path)
|
||||
os.rename(temp_path, original_path)
|
||||
except Exception:
|
||||
# Temp file invalid or missing, just remove it if it exists
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
had_corruption = True
|
||||
@@ -429,55 +445,36 @@ def finalize_resume_merge(
|
||||
):
|
||||
"""
|
||||
Finalize the resume by merging temp output with original output.
|
||||
|
||||
Args:
|
||||
original_output_file: Path to the original output file
|
||||
temp_output_file: Path to the temp output file written during resume
|
||||
partition_namespaces: Whether using partitioned namespace output
|
||||
original_partition_dir: The partition directory (for partitioned output)
|
||||
|
||||
Raises:
|
||||
Exception: If merge fails (temp file is preserved for recovery)
|
||||
"""
|
||||
import shutil
|
||||
|
||||
print("Merging resumed data with existing output...", file=sys.stderr)
|
||||
try:
|
||||
if partition_namespaces and original_partition_dir is not None:
|
||||
# For partitioned namespaces, temp files are written alongside originals
|
||||
# with '.resume_temp' suffix in each namespace directory.
|
||||
# Only merge temp files for the current dump file, not other concurrent jobs.
|
||||
file_filter = os.path.basename(original_output_file)
|
||||
merge_partitioned_namespaces(original_partition_dir, ".resume_temp", file_filter)
|
||||
# Clean up the empty temp directory we created
|
||||
if os.path.exists(temp_output_file) and os.path.isdir(temp_output_file):
|
||||
shutil.rmtree(temp_output_file)
|
||||
else:
|
||||
# Merge single parquet files
|
||||
merged_output_file = original_output_file + ".merged"
|
||||
merged = merge_parquet_files(original_output_file, temp_output_file, merged_output_file)
|
||||
|
||||
if merged == "original_only":
|
||||
# Temp file was invalid (no new data), keep original unchanged
|
||||
if os.path.exists(temp_output_file):
|
||||
os.remove(temp_output_file)
|
||||
elif merged == "temp_only":
|
||||
# Original was corrupted, use temp as new base
|
||||
os.remove(original_output_file)
|
||||
os.rename(temp_output_file, original_output_file)
|
||||
elif merged == "both_invalid":
|
||||
# Both files corrupted, remove both
|
||||
os.remove(original_output_file)
|
||||
if os.path.exists(temp_output_file):
|
||||
os.remove(temp_output_file)
|
||||
elif merged == "merged":
|
||||
# Replace the original file with the merged file
|
||||
os.remove(original_output_file)
|
||||
os.rename(merged_output_file, original_output_file)
|
||||
if os.path.exists(temp_output_file):
|
||||
os.remove(temp_output_file)
|
||||
else:
|
||||
# Both files were empty (False) - unusual, but clean up
|
||||
os.remove(original_output_file)
|
||||
if os.path.exists(temp_output_file):
|
||||
os.remove(temp_output_file)
|
||||
@@ -491,11 +488,7 @@ def finalize_resume_merge(
|
||||
|
||||
def setup_resume_temp_output(output_file, partition_namespaces):
|
||||
"""
|
||||
Set up temp output for resume mode.
|
||||
|
||||
Args:
|
||||
output_file: The original output file path
|
||||
partition_namespaces: Whether using partitioned namespace output
|
||||
Set up temp output for resume mode (Parquet only).
|
||||
|
||||
Returns:
|
||||
Tuple of (original_output_file, temp_output_file, original_partition_dir)
|
||||
@@ -507,7 +500,6 @@ def setup_resume_temp_output(output_file, partition_namespaces):
|
||||
temp_output_file = None
|
||||
original_partition_dir = None
|
||||
|
||||
# For partitioned namespaces, check if the specific output file exists in any namespace
|
||||
if partition_namespaces:
|
||||
partition_dir = os.path.dirname(output_file)
|
||||
output_filename = os.path.basename(output_file)
|
||||
@@ -527,9 +519,6 @@ def setup_resume_temp_output(output_file, partition_namespaces):
|
||||
original_output_file = output_file
|
||||
temp_output_file = output_file + ".resume_temp"
|
||||
|
||||
# Note: cleanup_interrupted_resume() should have been called before this
|
||||
# to merge any leftover temp files from a previous interrupted run.
|
||||
# Here we just clean up any remaining temp directory markers.
|
||||
if os.path.exists(temp_output_file):
|
||||
if os.path.isdir(temp_output_file):
|
||||
shutil.rmtree(temp_output_file)
|
||||
|
||||
@@ -17,9 +17,6 @@ T = TypeVar('T')
|
||||
|
||||
|
||||
class RevisionField(ABC, Generic[T]):
|
||||
def __init__(self):
|
||||
self.data: list[T] = []
|
||||
|
||||
"""
|
||||
Abstract type which represents a field in a table of page revisions.
|
||||
"""
|
||||
@@ -43,14 +40,6 @@ class RevisionField(ABC, Generic[T]):
|
||||
"""
|
||||
pass
|
||||
|
||||
def add(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> None:
|
||||
self.data.append(self.extract(page, revisions))
|
||||
|
||||
def pop(self) -> list[T]:
|
||||
data = self.data
|
||||
self.data = []
|
||||
return data
|
||||
|
||||
|
||||
class RevisionTable:
|
||||
columns: list[RevisionField]
|
||||
@@ -58,19 +47,15 @@ class RevisionTable:
|
||||
def __init__(self, columns: list[RevisionField]):
|
||||
self.columns = columns
|
||||
|
||||
def add(self, page: mwtypes.Page, revisions: list[mwxml.Revision]):
|
||||
for column in self.columns:
|
||||
column.add(page=page, revisions=revisions)
|
||||
|
||||
def schema(self) -> pa.Schema:
|
||||
return pa.schema([c.field for c in self.columns])
|
||||
|
||||
def pop(self) -> dict:
|
||||
data = dict()
|
||||
for column in self.columns:
|
||||
data[column.field.name] = column.pop()
|
||||
|
||||
return data
|
||||
def extract_row(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> dict:
|
||||
"""Extract a single row dict for the given page and revisions."""
|
||||
return {
|
||||
column.field.name: column.extract(page, revisions)
|
||||
for column in self.columns
|
||||
}
|
||||
|
||||
|
||||
class RevisionId(RevisionField[int]):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Shared wikitext parser with caching to avoid duplicate parsing."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
|
||||
import mwparserfromhell
|
||||
|
||||
PARSER_TIMEOUT = 60 # seconds
|
||||
@@ -22,22 +23,28 @@ class WikitextParser:
|
||||
self._cached_wikicode = None
|
||||
self.last_parse_timed_out: bool = False
|
||||
|
||||
async def _parse_async(self, text: str):
|
||||
"""Parse wikitext with timeout protection."""
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
asyncio.to_thread(mwparserfromhell.parse, text),
|
||||
timeout=PARSER_TIMEOUT
|
||||
)
|
||||
return result, False
|
||||
except TimeoutError:
|
||||
return None, True
|
||||
def _timeout_handler(self, signum, frame):
|
||||
raise TimeoutError("mwparserfromhell parse exceeded timeout")
|
||||
|
||||
def _get_wikicode(self, text: str):
|
||||
"""Parse text and cache result. Returns cached result if text unchanged."""
|
||||
if text != self._cached_text:
|
||||
if text == self._cached_text:
|
||||
return self._cached_wikicode
|
||||
|
||||
old_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
|
||||
signal.alarm(PARSER_TIMEOUT)
|
||||
try:
|
||||
self._cached_wikicode = mwparserfromhell.parse(text)
|
||||
self._cached_text = text
|
||||
self._cached_wikicode, self.last_parse_timed_out = asyncio.run(self._parse_async(text))
|
||||
self.last_parse_timed_out = False
|
||||
except TimeoutError:
|
||||
self._cached_wikicode = None
|
||||
self._cached_text = text
|
||||
self.last_parse_timed_out = True
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
|
||||
return self._cached_wikicode
|
||||
|
||||
def extract_external_links(self, text: str | None) -> list[str] | None:
|
||||
|
||||
@@ -7,10 +7,12 @@ from io import StringIO
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.json as pj
|
||||
import pytest
|
||||
from pandas import DataFrame
|
||||
from pandas.testing import assert_frame_equal, assert_series_equal
|
||||
|
||||
from wikiq import build_table, build_schema
|
||||
from wikiq_test_utils import (
|
||||
BASELINE_DIR,
|
||||
IKWIKI,
|
||||
@@ -34,6 +36,17 @@ def setup():
|
||||
setup()
|
||||
|
||||
|
||||
def read_jsonl_with_schema(filepath: str, **schema_kwargs) -> pd.DataFrame:
|
||||
"""Read JSONL file using PyArrow with explicit schema from wikiq."""
|
||||
table, _ = build_table(**schema_kwargs)
|
||||
schema = build_schema(table, **schema_kwargs)
|
||||
pa_table = pj.read_json(
|
||||
filepath,
|
||||
parse_options=pj.ParseOptions(explicit_schema=schema),
|
||||
)
|
||||
return pa_table.to_pandas()
|
||||
|
||||
|
||||
# with / without pwr DONE
|
||||
# with / without url encode DONE
|
||||
# with / without collapse user DONE
|
||||
@@ -124,7 +137,62 @@ def test_noargs():
|
||||
test = pd.read_table(tester.output)
|
||||
baseline = pd.read_table(tester.baseline_file)
|
||||
assert_frame_equal(test, baseline, check_like=True)
|
||||
|
||||
|
||||
|
||||
def test_jsonl_noargs():
|
||||
"""Test JSONL output format with baseline comparison."""
|
||||
tester = WikiqTester(SAILORMOON, "noargs", in_compression="7z", out_format="jsonl", baseline_format="jsonl")
|
||||
|
||||
try:
|
||||
tester.call_wikiq()
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test = read_jsonl_with_schema(tester.output)
|
||||
baseline = read_jsonl_with_schema(tester.baseline_file)
|
||||
assert_frame_equal(test, baseline, check_like=True)
|
||||
|
||||
|
||||
def test_jsonl_tsv_equivalence():
|
||||
"""Test that JSONL and TSV outputs contain equivalent data."""
|
||||
tester_tsv = WikiqTester(SAILORMOON, "equiv_tsv", in_compression="7z", out_format="tsv")
|
||||
tester_jsonl = WikiqTester(SAILORMOON, "equiv_jsonl", in_compression="7z", out_format="jsonl")
|
||||
|
||||
try:
|
||||
tester_tsv.call_wikiq()
|
||||
tester_jsonl.call_wikiq()
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
tsv_df = pd.read_table(tester_tsv.output)
|
||||
jsonl_df = read_jsonl_with_schema(tester_jsonl.output)
|
||||
|
||||
# Row counts must match
|
||||
assert len(tsv_df) == len(jsonl_df), f"Row count mismatch: TSV={len(tsv_df)}, JSONL={len(jsonl_df)}"
|
||||
|
||||
# Column sets must match
|
||||
assert set(tsv_df.columns) == set(jsonl_df.columns), \
|
||||
f"Column mismatch: TSV={set(tsv_df.columns)}, JSONL={set(jsonl_df.columns)}"
|
||||
|
||||
# Sort both by revid for comparison
|
||||
tsv_df = tsv_df.sort_values("revid").reset_index(drop=True)
|
||||
jsonl_df = jsonl_df.sort_values("revid").reset_index(drop=True)
|
||||
|
||||
# Normalize null values: TSV uses nan, schema-based JSONL uses None
|
||||
jsonl_df = jsonl_df.replace({None: np.nan})
|
||||
|
||||
# Compare columns - schema-based reading handles types correctly
|
||||
for col in tsv_df.columns:
|
||||
if col == "date_time":
|
||||
# TSV reads as string, JSONL with schema reads as datetime
|
||||
tsv_dates = pd.to_datetime(tsv_df[col]).dt.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
jsonl_dates = jsonl_df[col].dt.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
assert_series_equal(tsv_dates, jsonl_dates, check_names=False)
|
||||
else:
|
||||
# Allow dtype differences (TSV infers int64, schema uses int32)
|
||||
assert_series_equal(tsv_df[col], jsonl_df[col], check_names=False, check_dtype=False)
|
||||
|
||||
|
||||
def test_collapse_user():
|
||||
tester = WikiqTester(SAILORMOON, "collapse-user", in_compression="7z")
|
||||
|
||||
@@ -137,19 +205,6 @@ def test_collapse_user():
|
||||
baseline = pd.read_table(tester.baseline_file)
|
||||
assert_frame_equal(test, baseline, check_like=True)
|
||||
|
||||
def test_partition_namespaces():
|
||||
tester = WikiqTester(SAILORMOON, "collapse-user", in_compression="7z", out_format='parquet', baseline_format='parquet')
|
||||
|
||||
try:
|
||||
tester.call_wikiq("--collapse-user", "--fandom-2020", "--partition-namespaces")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test = pd.read_parquet(os.path.join(tester.output,"namespace=10/sailormoon.parquet"))
|
||||
baseline = pd.read_parquet(tester.baseline_file)
|
||||
assert_frame_equal(test, baseline, check_like=True)
|
||||
|
||||
|
||||
def test_pwr_wikidiff2():
|
||||
tester = WikiqTester(SAILORMOON, "persistence_wikidiff2", in_compression="7z")
|
||||
|
||||
@@ -201,46 +256,43 @@ def test_pwr():
|
||||
assert_frame_equal(test, baseline, check_like=True)
|
||||
|
||||
def test_diff():
|
||||
tester = WikiqTester(SAILORMOON, "diff", in_compression="7z", out_format='parquet', baseline_format='parquet')
|
||||
tester = WikiqTester(SAILORMOON, "diff", in_compression="7z", out_format='jsonl')
|
||||
|
||||
try:
|
||||
tester.call_wikiq("--diff", "--fandom-2020")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet")
|
||||
baseline = pd.read_parquet(tester.baseline_file)
|
||||
|
||||
test = test.reindex(columns=sorted(test.columns))
|
||||
assert_frame_equal(test, baseline, check_like=True)
|
||||
test = pd.read_json(tester.output, lines=True)
|
||||
assert "diff" in test.columns, "diff column should exist"
|
||||
assert "diff_timeout" in test.columns, "diff_timeout column should exist"
|
||||
assert len(test) > 0, "Should have output rows"
|
||||
|
||||
def test_diff_plus_pwr():
|
||||
tester = WikiqTester(SAILORMOON, "diff_pwr", in_compression="7z", out_format='parquet', baseline_format='parquet')
|
||||
tester = WikiqTester(SAILORMOON, "diff_pwr", in_compression="7z", out_format='jsonl')
|
||||
|
||||
try:
|
||||
tester.call_wikiq("--diff --persistence wikidiff2", "--fandom-2020")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet")
|
||||
baseline = pd.read_parquet(tester.baseline_file)
|
||||
|
||||
test = test.reindex(columns=sorted(test.columns))
|
||||
assert_frame_equal(test, baseline, check_like=True)
|
||||
test = pd.read_json(tester.output, lines=True)
|
||||
assert "diff" in test.columns, "diff column should exist"
|
||||
assert "token_revs" in test.columns, "token_revs column should exist"
|
||||
assert len(test) > 0, "Should have output rows"
|
||||
|
||||
def test_text():
|
||||
tester = WikiqTester(SAILORMOON, "text", in_compression="7z", out_format='parquet', baseline_format='parquet')
|
||||
tester = WikiqTester(SAILORMOON, "text", in_compression="7z", out_format='jsonl')
|
||||
|
||||
try:
|
||||
tester.call_wikiq("--diff", "--text","--fandom-2020")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet")
|
||||
baseline = pd.read_parquet(tester.baseline_file)
|
||||
|
||||
test = test.reindex(columns=sorted(test.columns))
|
||||
assert_frame_equal(test, baseline, check_like=True)
|
||||
test = pd.read_json(tester.output, lines=True)
|
||||
assert "text" in test.columns, "text column should exist"
|
||||
assert "diff" in test.columns, "diff column should exist"
|
||||
assert len(test) > 0, "Should have output rows"
|
||||
|
||||
def test_malformed_noargs():
|
||||
tester = WikiqTester(wiki=TWINPEAKS, case_name="noargs", in_compression="7z")
|
||||
@@ -339,51 +391,11 @@ def test_capturegroup_regex():
|
||||
baseline = pd.read_table(tester.baseline_file)
|
||||
assert_frame_equal(test, baseline, check_like=True)
|
||||
|
||||
def test_parquet():
|
||||
tester = WikiqTester(IKWIKI, "noargs", out_format="parquet")
|
||||
|
||||
try:
|
||||
tester.call_wikiq()
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
# as a test let's make sure that we get equal data frames
|
||||
test: DataFrame = pd.read_parquet(tester.output)
|
||||
# test = test.drop(['reverteds'], axis=1)
|
||||
|
||||
baseline: DataFrame = pd.read_table(tester.baseline_file)
|
||||
|
||||
# Pandas does not read timestamps as the desired datetime type.
|
||||
baseline["date_time"] = pd.to_datetime(baseline["date_time"])
|
||||
# Split strings to the arrays of reverted IDs so they can be compared.
|
||||
baseline["revert"] = baseline["revert"].replace(np.nan, None)
|
||||
baseline["reverteds"] = baseline["reverteds"].replace(np.nan, None)
|
||||
# baseline['reverteds'] = [None if i is np.nan else [int(j) for j in str(i).split(",")] for i in baseline['reverteds']]
|
||||
baseline["sha1"] = baseline["sha1"].replace(np.nan, None)
|
||||
baseline["editor"] = baseline["editor"].replace(np.nan, None)
|
||||
baseline["anon"] = baseline["anon"].replace(np.nan, None)
|
||||
|
||||
for index, row in baseline.iterrows():
|
||||
if row["revert"] != test["revert"][index]:
|
||||
print(row["revid"], ":", row["revert"], "!=", test["revert"][index])
|
||||
|
||||
for col in baseline.columns:
|
||||
try:
|
||||
assert_series_equal(
|
||||
test[col], baseline[col], check_like=True, check_dtype=False
|
||||
)
|
||||
except ValueError as exc:
|
||||
print(f"Error comparing column {col}")
|
||||
pytest.fail(exc)
|
||||
|
||||
# assert_frame_equal(test, baseline, check_like=True, check_dtype=False)
|
||||
|
||||
|
||||
def test_external_links_only():
|
||||
"""Test that --external-links extracts external links correctly."""
|
||||
import mwparserfromhell
|
||||
|
||||
tester = WikiqTester(SAILORMOON, "external_links_only", in_compression="7z", out_format="parquet")
|
||||
tester = WikiqTester(SAILORMOON, "external_links_only", in_compression="7z", out_format="jsonl")
|
||||
|
||||
try:
|
||||
# Also include --text so we can verify extraction against actual wikitext
|
||||
@@ -391,7 +403,7 @@ def test_external_links_only():
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet")
|
||||
test = pd.read_json(tester.output, lines=True)
|
||||
|
||||
# Verify external_links column exists
|
||||
assert "external_links" in test.columns, "external_links column should exist"
|
||||
@@ -438,7 +450,7 @@ def test_citations_only():
|
||||
import mwparserfromhell
|
||||
from wikiq.wikitext_parser import WikitextParser
|
||||
|
||||
tester = WikiqTester(SAILORMOON, "citations_only", in_compression="7z", out_format="parquet")
|
||||
tester = WikiqTester(SAILORMOON, "citations_only", in_compression="7z", out_format="jsonl")
|
||||
|
||||
try:
|
||||
# Also include --text so we can verify extraction against actual wikitext
|
||||
@@ -446,7 +458,7 @@ def test_citations_only():
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet")
|
||||
test = pd.read_json(tester.output, lines=True)
|
||||
|
||||
# Verify citations column exists
|
||||
assert "citations" in test.columns, "citations column should exist"
|
||||
@@ -490,7 +502,7 @@ def test_external_links_and_citations():
|
||||
import mwparserfromhell
|
||||
from wikiq.wikitext_parser import WikitextParser
|
||||
|
||||
tester = WikiqTester(SAILORMOON, "external_links_and_citations", in_compression="7z", out_format="parquet")
|
||||
tester = WikiqTester(SAILORMOON, "external_links_and_citations", in_compression="7z", out_format="jsonl")
|
||||
|
||||
try:
|
||||
# Also include --text so we can verify extraction against actual wikitext
|
||||
@@ -498,7 +510,7 @@ def test_external_links_and_citations():
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet")
|
||||
test = pd.read_json(tester.output, lines=True)
|
||||
|
||||
# Verify both columns exist
|
||||
assert "external_links" in test.columns, "external_links column should exist"
|
||||
@@ -564,14 +576,14 @@ def test_external_links_and_citations():
|
||||
|
||||
def test_no_wikitext_columns():
|
||||
"""Test that neither external_links nor citations columns exist without flags."""
|
||||
tester = WikiqTester(SAILORMOON, "no_wikitext_columns", in_compression="7z", out_format="parquet")
|
||||
tester = WikiqTester(SAILORMOON, "no_wikitext_columns", in_compression="7z", out_format="jsonl")
|
||||
|
||||
try:
|
||||
tester.call_wikiq("--fandom-2020")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet")
|
||||
test = pd.read_json(tester.output, lines=True)
|
||||
|
||||
# Verify neither column exists
|
||||
assert "external_links" not in test.columns, "external_links column should NOT exist without --external-links flag"
|
||||
@@ -584,14 +596,14 @@ def test_wikilinks():
|
||||
"""Test that --wikilinks extracts internal wikilinks correctly."""
|
||||
import mwparserfromhell
|
||||
|
||||
tester = WikiqTester(SAILORMOON, "wikilinks", in_compression="7z", out_format="parquet")
|
||||
tester = WikiqTester(SAILORMOON, "wikilinks", in_compression="7z", out_format="jsonl")
|
||||
|
||||
try:
|
||||
tester.call_wikiq("--wikilinks", "--text", "--fandom-2020")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet")
|
||||
test = pd.read_json(tester.output, lines=True)
|
||||
|
||||
# Verify wikilinks column exists
|
||||
assert "wikilinks" in test.columns, "wikilinks column should exist"
|
||||
@@ -625,14 +637,14 @@ def test_templates():
|
||||
"""Test that --templates extracts templates correctly."""
|
||||
import mwparserfromhell
|
||||
|
||||
tester = WikiqTester(SAILORMOON, "templates", in_compression="7z", out_format="parquet")
|
||||
tester = WikiqTester(SAILORMOON, "templates", in_compression="7z", out_format="jsonl")
|
||||
|
||||
try:
|
||||
tester.call_wikiq("--templates", "--text", "--fandom-2020")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet")
|
||||
test = pd.read_json(tester.output, lines=True)
|
||||
|
||||
# Verify templates column exists
|
||||
assert "templates" in test.columns, "templates column should exist"
|
||||
@@ -675,14 +687,14 @@ def test_headings():
|
||||
"""Test that --headings extracts section headings correctly."""
|
||||
import mwparserfromhell
|
||||
|
||||
tester = WikiqTester(SAILORMOON, "headings", in_compression="7z", out_format="parquet")
|
||||
tester = WikiqTester(SAILORMOON, "headings", in_compression="7z", out_format="jsonl")
|
||||
|
||||
try:
|
||||
tester.call_wikiq("--headings", "--text", "--fandom-2020")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet")
|
||||
test = pd.read_json(tester.output, lines=True)
|
||||
|
||||
# Verify headings column exists
|
||||
assert "headings" in test.columns, "headings column should exist"
|
||||
@@ -712,3 +724,37 @@ def test_headings():
|
||||
print(f"Headings test passed! {len(test)} rows processed")
|
||||
|
||||
|
||||
def test_parquet_output():
|
||||
"""Test that Parquet output format works correctly."""
|
||||
tester = WikiqTester(SAILORMOON, "parquet_output", in_compression="7z", out_format="parquet")
|
||||
|
||||
try:
|
||||
tester.call_wikiq("--fandom-2020")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
# Verify output file exists
|
||||
assert os.path.exists(tester.output), f"Parquet output file should exist at {tester.output}"
|
||||
|
||||
# Read and verify content
|
||||
test = pd.read_parquet(tester.output)
|
||||
|
||||
# Verify expected columns exist
|
||||
assert "revid" in test.columns
|
||||
assert "articleid" in test.columns
|
||||
assert "title" in test.columns
|
||||
assert "namespace" in test.columns
|
||||
|
||||
# Verify row count matches JSONL output
|
||||
tester_jsonl = WikiqTester(SAILORMOON, "parquet_compare", in_compression="7z", out_format="jsonl")
|
||||
try:
|
||||
tester_jsonl.call_wikiq("--fandom-2020")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
pytest.fail(exc.stderr.decode("utf8"))
|
||||
|
||||
test_jsonl = pd.read_json(tester_jsonl.output, lines=True)
|
||||
assert len(test) == len(test_jsonl), f"Parquet and JSONL should have same row count: {len(test)} vs {len(test_jsonl)}"
|
||||
|
||||
print(f"Parquet output test passed! {len(test)} rows")
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -42,8 +42,20 @@ class WikiqTester:
|
||||
else:
|
||||
shutil.rmtree(self.output)
|
||||
|
||||
if out_format == "parquet":
|
||||
os.makedirs(self.output, exist_ok=True)
|
||||
# Also clean up resume-related files
|
||||
for temp_suffix in [".resume_temp", ".checkpoint", ".merged"]:
|
||||
temp_path = self.output + temp_suffix
|
||||
if os.path.exists(temp_path):
|
||||
if os.path.isfile(temp_path):
|
||||
os.remove(temp_path)
|
||||
else:
|
||||
shutil.rmtree(temp_path)
|
||||
|
||||
# For JSONL and Parquet, self.output is a file path. Create parent directory if needed.
|
||||
if out_format in ("jsonl", "parquet"):
|
||||
parent_dir = os.path.dirname(self.output)
|
||||
if parent_dir:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
if suffix is None:
|
||||
self.wikiq_baseline_name = "{0}.{1}".format(wiki, baseline_format)
|
||||
|
||||
Reference in New Issue
Block a user