diff --git a/pyproject.toml b/pyproject.toml index bc3e709..6c00f82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "mwtypes>=0.4.0", "mwxml>=0.3.6", "pyarrow>=20.0.0", + "pyspark>=3.5.0", "pywikidiff2", "sortedcontainers>=2.4.0", "yamlconf>=0.2.6", @@ -21,13 +22,14 @@ dependencies = [ [project.scripts] wikiq = "wikiq:main" +wikiq-spark = "wikiq_spark:main" [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] -packages = ["src/wikiq"] +packages = ["src/wikiq", "src/wikiq_spark"] [tool.uv.sources] diff --git a/src/wikiq/__init__.py b/src/wikiq/__init__.py index 44f13cd..15d7ff7 100755 --- a/src/wikiq/__init__.py +++ b/src/wikiq/__init__.py @@ -12,7 +12,7 @@ import signal import sys import threading import time -from collections import deque +from collections import deque, defaultdict from hashlib import sha1 from io import TextIOWrapper from itertools import groupby @@ -24,17 +24,18 @@ import mwreverts import mwxml import pywikidiff2 from deltas.tokenizers import wikitext_split -from more_itertools import ichunked +from more_itertools import peekable from mwxml import Dump import wikiq.tables as tables from wikiq.tables import RevisionTable from wikiq.wiki_diff_matcher import WikiDiffMatcher from wikiq.wikitext_parser import WikitextParser from wikiq.resume import ( + get_checkpoint_path, + read_checkpoint, get_resume_point, setup_resume_temp_output, finalize_resume_merge, - get_checkpoint_path, cleanup_interrupted_resume, ) @@ -49,6 +50,276 @@ import pyarrow.parquet as pq from deltas import SegmentMatcher, SequenceMatcher +def pyarrow_type_to_spark(pa_type): + """Convert a PyArrow type to Spark JSON schema format.""" + if pa.types.is_int64(pa_type): + return "long" + elif pa.types.is_int32(pa_type): + return "integer" + elif pa.types.is_int8(pa_type): + return "byte" + elif pa.types.is_boolean(pa_type): + return "boolean" + elif pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type): + return "string" + elif pa.types.is_timestamp(pa_type): + return "timestamp" + elif pa.types.is_list(pa_type): + return { + "type": "array", + "elementType": pyarrow_type_to_spark(pa_type.value_type), + "containsNull": True + } + elif pa.types.is_struct(pa_type): + return { + "type": "struct", + "fields": [ + { + "name": field.name, + "type": pyarrow_type_to_spark(field.type), + "nullable": field.nullable, + "metadata": {} + } + for field in pa_type + ] + } + elif pa.types.is_map(pa_type): + return { + "type": "map", + "keyType": pyarrow_type_to_spark(pa_type.key_type), + "valueType": pyarrow_type_to_spark(pa_type.item_type), + "valueContainsNull": True + } + else: + return "string" + + +def pyarrow_to_spark_schema(schema: pa.Schema) -> dict: + """Convert a PyArrow schema to Spark JSON schema format.""" + return { + "type": "struct", + "fields": [ + { + "name": field.name, + "type": pyarrow_type_to_spark(field.type), + "nullable": field.nullable, + "metadata": {} + } + for field in schema + ] + } + + +def build_table( + text: bool = False, + collapse_user: bool = False, + external_links: bool = False, + citations: bool = False, + wikilinks: bool = False, + templates: bool = False, + headings: bool = False, +): + """Build the RevisionTable with appropriate columns based on flags. + + Returns: + (table, reverts_column) - the table and a reference to the reverts column + (which process() needs for setting the revert detector). + """ + reverts_column = tables.RevisionReverts() + + table = RevisionTable([ + tables.RevisionId(), + tables.RevisionTimestamp(), + tables.RevisionArticleId(), + tables.RevisionPageTitle(), + tables.RevisionNamespace(), + tables.RevisionDeleted(), + tables.RevisionEditorId(), + tables.RevisionEditSummary(), + tables.RevisionTextChars(), + reverts_column, + tables.RevisionSha1(), + tables.RevisionIsMinor(), + tables.RevisionEditorText(), + tables.RevisionIsAnon(), + ]) + + if text: + table.columns.append(tables.RevisionText()) + + if collapse_user: + table.columns.append(tables.RevisionCollapsed()) + + if external_links or citations or wikilinks or templates or headings: + wikitext_parser = WikitextParser() + + if external_links: + table.columns.append(tables.RevisionExternalLinks(wikitext_parser)) + + if citations: + table.columns.append(tables.RevisionCitations(wikitext_parser)) + + if wikilinks: + table.columns.append(tables.RevisionWikilinks(wikitext_parser)) + + if templates: + table.columns.append(tables.RevisionTemplates(wikitext_parser)) + + if headings: + table.columns.append(tables.RevisionHeadings(wikitext_parser)) + + table.columns.append(tables.RevisionParserTimeout(wikitext_parser)) + + return table, reverts_column + + +def build_schema( + table, + diff: bool = False, + persist: int = 0, + text: bool = False, + regex_revision_pairs: list = None, + regex_comment_pairs: list = None, +) -> pa.Schema: + """Build the PyArrow schema from a table, adding output-only fields.""" + schema = table.schema() + schema = schema.append(pa.field("revert", pa.bool_(), nullable=True)) + + if diff: + from wikiq.diff_pyarrow_schema import diff_field + schema = schema.append(diff_field) + schema = schema.append(pa.field("diff_timeout", pa.bool_())) + + if regex_revision_pairs: + for pair in regex_revision_pairs: + for field in pair.get_pyarrow_fields(): + schema = schema.append(field) + + if regex_comment_pairs: + for pair in regex_comment_pairs: + for field in pair.get_pyarrow_fields(): + schema = schema.append(field) + + if persist != PersistMethod.none: + # RevisionText is added to the table for extraction, but not to schema + # (unless text=True, in which case it's already in the schema from build_table) + schema = schema.append(pa.field("token_revs", pa.int64(), nullable=True)) + schema = schema.append(pa.field("tokens_added", pa.int64(), nullable=True)) + schema = schema.append(pa.field("tokens_removed", pa.int64(), nullable=True)) + schema = schema.append(pa.field("tokens_window", pa.int64(), nullable=True)) + + return schema + + +def make_regex_pairs(patterns, labels) -> list: + """Create RegexPair objects from patterns and labels.""" + if (patterns is not None and labels is not None) and (len(patterns) == len(labels)): + return [RegexPair(pattern, label) for pattern, label in zip(patterns, labels)] + elif patterns is None and labels is None: + return [] + else: + sys.exit("Each regular expression *must* come with a corresponding label and vice versa.") + + +class JSONLWriter: + """Write JSONL output with schema validation.""" + + def __init__(self, output_file: str, schema: pa.Schema, append: bool = False): + self.output_file = output_file + self.schema = schema + self.field_names = [field.name for field in schema] + + if append and os.path.exists(output_file): + self._validate_and_fix_last_line(output_file) + + mode = "a" if append else "w" + self._file = open(output_file, mode) + + def _validate_and_fix_last_line(self, filepath: str): + """Validate the last line of JSONL file; truncate if corrupted. + + If the previous run was interrupted mid-write, the last line may be + incomplete JSON. This detects and removes such corrupted lines. + """ + with open(filepath, 'rb') as f: + f.seek(0, 2) + file_size = f.tell() + if file_size == 0: + return + + # Read backwards to find the last newline + chunk_size = min(8192, file_size) + f.seek(-chunk_size, 2) + chunk = f.read(chunk_size) + + # Find the last complete line + last_newline = chunk.rfind(b'\n') + if last_newline == -1: + # Entire file is one line (possibly corrupted) + last_line = chunk.decode('utf-8', errors='replace') + truncate_pos = 0 + else: + last_line = chunk[last_newline + 1:].decode('utf-8', errors='replace') + truncate_pos = file_size - chunk_size + last_newline + 1 + + # If last line is empty, file ends with newline - that's fine + if not last_line.strip(): + return + + # Try to parse the last line as JSON + try: + json.loads(last_line) + except json.JSONDecodeError: + print(f"Warning: Last line of {filepath} is corrupted JSON, removing it", + file=sys.stderr) + # Truncate the file to remove the corrupted last line + with open(filepath, 'r+b') as f: + f.truncate(truncate_pos) + + def write_batch(self, data: dict): + """Write a batch of rows as JSONL. + + Args: + data: dict mapping column names to lists of values + """ + if not data or not data.get(self.field_names[0]): + return + + num_rows = len(data[self.field_names[0]]) + for i in range(num_rows): + row = {} + for name in self.field_names: + if name in data: + value = data[name][i] + row[name] = self._convert_value(value) + self._file.write(json.dumps(row) + "\n") + + def _convert_value(self, value): + """Convert a value to JSON-serializable format.""" + if value is None: + return None + elif isinstance(value, (str, int, float, bool)): + return value + elif hasattr(value, "isoformat"): + return value.isoformat() + elif isinstance(value, (list, tuple)): + return [self._convert_value(v) for v in value] + elif isinstance(value, dict): + return {k: self._convert_value(v) for k, v in value.items()} + else: + return str(value) + + def close(self): + self._file.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + class PersistMethod: none = 0 sequence = 1 @@ -243,10 +514,11 @@ class WikiqParser: persist: int = None, namespaces: Union[list[int], None] = None, revert_radius: int = 15, - output_parquet: bool = True, + output_jsonl: bool = False, + output_parquet: bool = False, batch_size: int = 1024, - partition_namespaces: bool = False, resume_point: Union[tuple, dict, None] = None, + partition_namespaces: bool = False, external_links: bool = False, citations: bool = False, wikilinks: bool = False, @@ -260,8 +532,7 @@ class WikiqParser: persist : what persistence method to use. Takes a PersistMethod value resume_point : if set, either a (pageid, revid) tuple for single-file output, or a dict mapping namespace -> (pageid, revid) for partitioned output. - For single-file: skip all revisions up to - and including this point + For single-file: skip all revisions up to and including this point. max_revisions_per_file : if > 0, close and rotate output files after this many revisions """ self.input_file = input_file @@ -295,18 +566,22 @@ class WikiqParser: regex_match_comment, regex_comment_label ) - # here we initialize the variables we need for output. + # Initialize output self.batch_size = batch_size + self.output_jsonl = output_jsonl self.output_parquet = output_parquet - if output_parquet is True: - self.pq_writer = None - self.output_file = output_file - self.parquet_buffer = [] + self.output_file = output_file + if output_parquet: + self.pq_writer = None + self.parquet_buffer = [] + elif output_jsonl: + pass # JSONLWriter created in process() else: + # TSV output self.print_header = True if output_file == sys.stdout.buffer: - self.output_file = output_file + pass else: self.output_file = open(output_file, "wb") @@ -348,7 +623,7 @@ class WikiqParser: def _open_checkpoint(self, output_file): """Open checkpoint file for writing. Keeps file open for performance.""" - if not self.output_parquet or output_file == sys.stdout.buffer: + if (not self.output_jsonl and not self.output_parquet) or output_file == sys.stdout.buffer: return checkpoint_path = get_checkpoint_path(output_file, self.partition_namespaces) Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True) @@ -386,10 +661,6 @@ class WikiqParser: For partitioned output, creates writer lazily if needed. Returns (writer, num_rows) - writer used and number of rows written. - - Args: - ns_base_paths: For partitioned output, maps namespace -> base path (without part number) - part_numbers: For partitioned output, maps namespace -> current part number """ num_rows = len(row_buffer.get("revid", [])) if self.partition_namespaces and namespace is not None: @@ -465,7 +736,6 @@ class WikiqParser: time_limit_timer = self._start_time_limit_timer() # Track whether we've passed the resume point - # For partitioned output, this is a dict mapping namespace -> bool if self.resume_point is None: found_resume_point = True elif self.partition_namespaces: @@ -484,114 +754,55 @@ class WikiqParser: self.output_file = temp_output_file # Open checkpoint file for tracking resume point - # Use original_output_file if resuming, otherwise self.output_file checkpoint_output = original_output_file if original_output_file else self.output_file self._open_checkpoint(checkpoint_output) # Construct dump file iterator dump = WikiqIterator(self.input_file, collapse_user=self.collapse_user) - reverts_column = tables.RevisionReverts() - - table = RevisionTable( - [ - tables.RevisionId(), - tables.RevisionTimestamp(), - tables.RevisionArticleId(), - tables.RevisionPageTitle(), - tables.RevisionNamespace(), - tables.RevisionDeleted(), - tables.RevisionEditorId(), - tables.RevisionEditSummary(), - tables.RevisionTextChars(), - reverts_column, - tables.RevisionSha1(), - tables.RevisionIsMinor(), - tables.RevisionEditorText(), - tables.RevisionIsAnon(), - ] + table, reverts_column = build_table( + text=self.text, + collapse_user=self.collapse_user, + external_links=self.external_links, + citations=self.citations, + wikilinks=self.wikilinks, + templates=self.templates, + headings=self.headings, ) - if self.text: - table.columns.append(tables.RevisionText()) - - if self.collapse_user: - table.columns.append(tables.RevisionCollapsed()) - - # Create shared parser if any wikitext feature is enabled - if self.external_links or self.citations or self.wikilinks or self.templates or self.headings: - wikitext_parser = WikitextParser() - - if self.external_links: - table.columns.append(tables.RevisionExternalLinks(wikitext_parser)) - - if self.citations: - table.columns.append(tables.RevisionCitations(wikitext_parser)) - - if self.wikilinks: - table.columns.append(tables.RevisionWikilinks(wikitext_parser)) - - if self.templates: - table.columns.append(tables.RevisionTemplates(wikitext_parser)) - - if self.headings: - table.columns.append(tables.RevisionHeadings(wikitext_parser)) - - # Add parser timeout tracking if any wikitext feature is enabled - if self.external_links or self.citations or self.wikilinks or self.templates or self.headings: - table.columns.append(tables.RevisionParserTimeout(wikitext_parser)) - - # extract list of namespaces + # Extract list of namespaces self.namespaces = { ns.name: ns.id for ns in dump.mwiterator.site_info.namespaces } page_count = 0 rev_count = 0 - output_count = 0 - writer: Union[pq.ParquetWriter, pacsv.CSVWriter] - schema = table.schema() - schema = schema.append(pa.field("revert", pa.bool_(), nullable=True)) + schema = build_schema( + table, + diff=self.diff, + persist=self.persist, + text=self.text, + regex_revision_pairs=self.regex_revision_pairs, + regex_comment_pairs=self.regex_comment_pairs, + ) - if self.diff: - from wikiq.diff_pyarrow_schema import diff_field - - schema = schema.append(diff_field) - schema = schema.append(pa.field("diff_timeout", pa.bool_())) - - if self.diff and self.persist == PersistMethod.none: + # Add RevisionText to table for diff/persist computation (extraction only, not output) + if (self.diff or self.persist != PersistMethod.none) and not self.text: table.columns.append(tables.RevisionText()) - # Add regex fields to the schema. - for pair in self.regex_revision_pairs: - for field in pair.get_pyarrow_fields(): - schema = schema.append(field) - - for pair in self.regex_comment_pairs: - for field in pair.get_pyarrow_fields(): - schema = schema.append(field) - - if self.persist != PersistMethod.none: - table.columns.append(tables.RevisionText()) - schema = schema.append(pa.field("token_revs", pa.int64(), nullable=True)) - schema = schema.append(pa.field("tokens_added", pa.int64(), nullable=True)) - schema = schema.append( - pa.field("tokens_removed", pa.int64(), nullable=True) - ) - schema = schema.append(pa.field("tokens_window", pa.int64(), nullable=True)) + # Initialize writer + writer = None + sorting_cols = None + ns_base_paths = {} + pq_writers = {} + part_numbers = {} if self.output_parquet: - pageid_sortingcol = pq.SortingColumn(schema.get_field_index("pageid")) - revid_sortingcol = pq.SortingColumn(schema.get_field_index("pageid")) + pageid_sortingcol = pq.SortingColumn(schema.get_field_index("articleid")) + revid_sortingcol = pq.SortingColumn(schema.get_field_index("revid")) sorting_cols = [pageid_sortingcol, revid_sortingcol] - # Part tracking for file splitting. - # Keys are namespace integers for partitioned output, None for non-partitioned. - part_numbers = {} - revisions_in_part = {} - - # Initialize part numbers from resume point if resuming if self.resume_point is not None: if self.partition_namespaces: for ns, resume_data in self.resume_point.items(): @@ -599,8 +810,7 @@ class WikiqParser: else: part_numbers[None] = self.resume_point[2] if len(self.resume_point) > 2 else 0 - if self.partition_namespaces is False: - # Generate path with part number if file splitting is enabled + if not self.partition_namespaces: if self.max_revisions_per_file > 0: output_path_with_part = self._get_part_path(self.output_file, part_numbers.get(None, 0)) else: @@ -611,413 +821,277 @@ class WikiqParser: flavor="spark", sorting_columns=sorting_cols, ) - ns_base_paths = {} - pq_writers = {} else: output_path = Path(self.output_file) if self.namespace_filter is not None: namespaces = self.namespace_filter else: namespaces = self.namespaces.values() - # Store base paths - actual paths with part numbers generated in _write_batch ns_base_paths = { ns: (output_path.parent / f"namespace={ns}") / output_path.name for ns in namespaces } - # Initialize part numbers for each namespace (from resume point or 0) for ns in namespaces: if ns not in part_numbers: part_numbers[ns] = 0 - revisions_in_part[ns] = 0 - # Writers are created lazily when first needed to avoid empty files on early exit - pq_writers = {} - writer = None # Not used for partitioned output - + elif self.output_jsonl: + append_mode = self.resume_point is not None + writer = JSONLWriter(self.output_file, schema, append=append_mode) else: writer = pacsv.CSVWriter( self.output_file, schema, write_options=pacsv.WriteOptions(delimiter="\t"), ) - ns_base_paths = {} - pq_writers = {} - sorting_cols = None - # Part tracking not needed for CSV (no OOM issue during close) - part_numbers = {} - revisions_in_part = {} - regex_matches = {} + # Initialize diff machinery + differ = None + fast_differ = None + if self.diff: + differ = pywikidiff2.pywikidiff2( + num_context_lines=1000000, + max_word_level_diff_complexity=-1, + moved_paragraph_detection_cutoff=-1, + words_cache_capacity=10000, + diff_cache_capacity=10000, + stats_cache_capacity=10000, + ) + fast_differ = pywikidiff2.pywikidiff2( + num_context_lines=1000000, + max_word_level_diff_complexity=40000000, + moved_paragraph_detection_cutoff=100, + words_cache_capacity=-1, + diff_cache_capacity=-1, + stats_cache_capacity=-1, + ) + + # Write buffer: accumulate rows before flushing + write_buffer = defaultdict(list) + buffer_count = 0 + last_namespace = None + + def flush_buffer(): + nonlocal write_buffer, buffer_count, last_namespace + if buffer_count == 0: + return + row_buffer = dict(write_buffer) + namespace = last_namespace + if self.output_parquet: + if self.partition_namespaces: + self._write_batch( + row_buffer, schema, writer, pq_writers, ns_base_paths, + sorting_cols, namespace=namespace, part_numbers=part_numbers + ) + else: + writer.write(pa.record_batch(row_buffer, schema=schema)) + elif self.output_jsonl: + writer.write_batch(row_buffer) + else: + writer.write(pa.record_batch(row_buffer, schema=schema)) + + # Update checkpoint + last_pageid = row_buffer["articleid"][-1] + last_revid = row_buffer["revid"][-1] + part = part_numbers.get(namespace if self.partition_namespaces else None, 0) + self._update_checkpoint(last_pageid, last_revid, + namespace=namespace if self.partition_namespaces else None, + part=part) + write_buffer = defaultdict(list) + buffer_count = 0 # Iterate through pages - total_revs = 0 - for page in dump: - # skip namespaces not in the filter + # Skip namespaces not in the filter if self.namespace_filter is not None: if page.mwpage.namespace not in self.namespace_filter: continue - # Resume logic: skip pages that come before the resume point. - # For partitioned output, each namespace has its own resume point. + # Resume logic: skip pages before the resume point is_resume_page = False - page_resume_point = None - if self.resume_point is not None: + page_resume_revid = None + if self.resume_point is not None and not found_resume_point: page_id = page.mwpage.id - page_ns = page.mwpage.namespace - - if self.partition_namespaces: - # Per-namespace resume: check if we've passed this namespace's resume point - if found_resume_point.get(page_ns, False): - pass # Already past resume point for this namespace - elif page_ns not in self.resume_point: - # No resume point for this namespace, process normally - found_resume_point[page_ns] = True - else: - resume_data = self.resume_point[page_ns] - resume_pageid, resume_revid = resume_data[0], resume_data[1] - if page_id < resume_pageid: - continue - elif page_id == resume_pageid: - is_resume_page = True - page_resume_point = (resume_pageid, resume_revid) - else: - found_resume_point[page_ns] = True + resume_pageid, resume_revid = self.resume_point[0], self.resume_point[1] + if page_id < resume_pageid: + continue + elif page_id == resume_pageid: + is_resume_page = True + page_resume_revid = resume_revid else: - # Single-file resume: global resume point - if not found_resume_point: - resume_pageid, resume_revid = self.resume_point[0], self.resume_point[1] - if page_id < resume_pageid: - continue - elif page_id == resume_pageid: - is_resume_page = True - page_resume_point = (resume_pageid, resume_revid) - else: - found_resume_point = True + found_resume_point = True - # Disable detecting reverts if radius is 0. + # Reset revert detector for new page if self.revert_radius > 0: - reverts_column.rev_detector = mwreverts.Detector( - radius=self.revert_radius - ) + reverts_column.rev_detector = mwreverts.Detector(radius=self.revert_radius) else: reverts_column.rev_detector = None - # Iterate through a page's revisions - batches = ichunked(page, self.batch_size) - last_rev_text = "" - last_rev_id = None - row_buffer = None - last_row_buffer = {} - on_last_batch = False - next_batch = {} - diff_dict = {} + # State for this page + prev_text = "" + persist_state = None + persist_window = None if self.persist != PersistMethod.none: - window = deque(maxlen=PERSISTENCE_RADIUS) - if self.persist != PersistMethod.none: - if self.persist == PersistMethod.sequence: - persist_state = mwpersistence.DiffState( - SequenceMatcher(tokenizer=wikitext_split), - revert_radius=PERSISTENCE_RADIUS, - ) - elif self.persist == PersistMethod.segment: - persist_state = mwpersistence.DiffState( - SegmentMatcher(tokenizer=wikitext_split), - revert_radius=PERSISTENCE_RADIUS, - ) - elif self.persist == PersistMethod.wikidiff2: - wikidiff_matcher = WikiDiffMatcher(tokenizer=wikitext_split) - persist_state = mwpersistence.DiffState( - wikidiff_matcher, revert_radius=PERSISTENCE_RADIUS - ) - else: - from mw.lib import persistence - - persist_state = persistence.State() - - if self.diff: - differ = pywikidiff2.pywikidiff2( - num_context_lines=1000000, - max_word_level_diff_complexity=-1, - moved_paragraph_detection_cutoff=-1, - words_cache_capacity=10000, - diff_cache_capacity=10000, - stats_cache_capacity=10000, - ) - - fast_differ = pywikidiff2.pywikidiff2( - num_context_lines=1000000, - max_word_level_diff_complexity=40000000, - moved_paragraph_detection_cutoff=100, - words_cache_capacity=-1, - diff_cache_capacity=-1, - stats_cache_capacity=-1, - ) - - - while not on_last_batch: - # first loop: next_batch <- batch; - # second loop: next_batch <- batch; evaluate next_batch. - # final loop: on_last_batch <- true; evaluate next_batch - try: - batch = list(next(batches)) - except StopIteration: - on_last_batch = True - - if len(next_batch) == 0: - next_batch = batch - continue + persist_window = deque(maxlen=PERSISTENCE_RADIUS) + if self.persist == PersistMethod.sequence: + persist_state = mwpersistence.DiffState( + SequenceMatcher(tokenizer=wikitext_split), + revert_radius=PERSISTENCE_RADIUS, + ) + elif self.persist == PersistMethod.segment: + persist_state = mwpersistence.DiffState( + SegmentMatcher(tokenizer=wikitext_split), + revert_radius=PERSISTENCE_RADIUS, + ) + elif self.persist == PersistMethod.wikidiff2: + wikidiff_matcher = WikiDiffMatcher(tokenizer=wikitext_split) + persist_state = mwpersistence.DiffState( + wikidiff_matcher, revert_radius=PERSISTENCE_RADIUS + ) else: - tmp_batch = next_batch - next_batch = batch - batch = tmp_batch + from mw.lib import persistence + persist_state = persistence.State() - n_revs = 0 + # Pending persistence values waiting for window to fill + pending_persistence = [] - for revs in batch: - # Revisions may or may not be grouped into lists of contiguous revisions by the - # same user. We call these "edit sessions". Otherwise revs is a list containing - # exactly one revision. - revs = list(revs) - revs = fix_hex_digests(revs) - # the problem is that we load all the revisions before we 'pop' - table.add(page.mwpage, revs) + # Use peekable to detect last revision in page + revs_iter = peekable(page) - # if re.match(r'^#redirect \[\[.*\]\]', rev.text, re.I): - # redirect = True - # else: - # redirect = False + for revs in revs_iter: + # revs is either a single revision or a group (collapse_user mode) + revs = list(revs) + revs = fix_hex_digests(revs) + rev = revs[-1] # Last revision in the group + is_last_in_page = revs_iter.peek(None) is None - # TODO missing: additions_size deletions_size + # Skip revisions before resume point + if is_resume_page: + if rev.id <= page_resume_revid: + # Update state for correctness when we resume output + if self.diff or self.persist != PersistMethod.none: + prev_text = rev.text or "" + if persist_state is not None: + text = rev.text or "" + if self.persist != PersistMethod.legacy: + persist_state.update(text, rev.id) + else: + persist_state.process(text, rev.id) + # Update revert detector so it has history for post-resume revisions + if reverts_column.rev_detector is not None and not rev.deleted.text: + reverts_column.rev_detector.process(rev.sha1, rev.id) + if rev.id == page_resume_revid: + found_resume_point = True + is_resume_page = False + print(f"Resuming output after revid {rev.id}", file=sys.stderr) + continue - rev_count += 1 + rev_count += 1 - # Get the last revision in the edit session. - rev = revs[-1] - regex_dict = self.matchmake_revision(rev) - for k, v in regex_dict.items(): - if regex_matches.get(k) is None: - regex_matches[k] = [] - regex_matches[k].append(v) + # Extract base row data + row = table.extract_row(page.mwpage, revs) - # Check for shutdown after each revision - if self.shutdown_requested: - break + # Compute revert flag + if self.revert_radius == 0 or row["deleted"]: + row["revert"] = None + else: + row["revert"] = row["reverteds"] is not None - # If shutdown requested, skip all remaining processing and close writers + # Regex matching + regex_dict = self.matchmake_revision(rev) + for k, v in regex_dict.items(): + row[k] = v + + # Compute diff + text = row.get("text", "") or "" + if self.diff: + diff_result, timed_out = diff_with_timeout(differ, prev_text, text) + if timed_out: + print(f"WARNING! wikidiff2 timeout for rev: {rev.id}. Falling back to default limits.", file=sys.stderr) + diff_result = fast_differ.inline_json_diff(prev_text, text) + row["diff"] = [entry for entry in json.loads(diff_result)["diff"] if entry["type"] != 0] + row["diff_timeout"] = timed_out + + # Compute persistence + if persist_state is not None: + if self.persist != PersistMethod.legacy: + _, tokens_added, tokens_removed = persist_state.update(text, rev.id) + else: + _, tokens_added, tokens_removed = persist_state.process(text, rev.id) + + persist_window.append((rev.id, tokens_added, tokens_removed)) + pending_persistence.append(row) + + # When window is full, emit persistence for oldest revision + if len(persist_window) == PERSISTENCE_RADIUS: + old_rev_id, old_tokens_added, old_tokens_removed = persist_window.popleft() + oldest_row = pending_persistence.pop(0) + num_token_revs, num_tokens = calculate_persistence(old_tokens_added) + oldest_row["token_revs"] = num_token_revs + oldest_row["tokens_added"] = num_tokens + oldest_row["tokens_removed"] = len(old_tokens_removed) + oldest_row["tokens_window"] = PERSISTENCE_RADIUS - 1 + + # Remove text if not outputting it + if not self.text and "text" in oldest_row: + del oldest_row["text"] + + # Add to write buffer + for k, v in oldest_row.items(): + write_buffer[k].append(v) + buffer_count += 1 + last_namespace = page.mwpage.namespace + + if buffer_count >= self.batch_size: + flush_buffer() + + # Update prev_text for next iteration + if self.diff or self.persist != PersistMethod.none: + prev_text = text + + # If no persistence, write row directly + if persist_state is None: + if not self.text and "text" in row: + del row["text"] + + for k, v in row.items(): + write_buffer[k].append(v) + buffer_count += 1 + last_namespace = page.mwpage.namespace + + if buffer_count >= self.batch_size: + flush_buffer() + + # Check for shutdown if self.shutdown_requested: print("Shutdown requested, closing writers...", file=sys.stderr) break - # Collect the set of revisions currently buffered in the table so we can run multi-revision functions on them. - batch_row_buffer = table.pop() - if self.persist != PersistMethod.none: - # we have everything we need for these revs, which is everything we've seen up to the end of the persistence radius - row_buffer = { - k: last_row_buffer.get(k, []) - + batch_row_buffer[k][ - : ( - -1 * (PERSISTENCE_RADIUS - 1) - if not on_last_batch - else None - ) - ] - for k in batch_row_buffer.keys() - } + # End of page: flush remaining persistence window + if persist_state is not None and not self.shutdown_requested: + for i, (pending_row, window_item) in enumerate(zip(pending_persistence, persist_window)): + rev_id, tokens_added, tokens_removed = window_item + num_token_revs, num_tokens = calculate_persistence(tokens_added) + pending_row["token_revs"] = num_token_revs + pending_row["tokens_added"] = num_tokens + pending_row["tokens_removed"] = len(tokens_removed) + pending_row["tokens_window"] = len(persist_window) - (i + 1) - # we'll use these to calc persistence for the row, buffer. - next_row_buffer = { - k: ( - batch_row_buffer[k][-1 * (PERSISTENCE_RADIUS - 1) :] - if not on_last_batch - else [] - ) - for k in batch_row_buffer.keys() - } + if not self.text and "text" in pending_row: + del pending_row["text"] - if len(last_row_buffer) > 0: - diff_buffer = { - k: (row_buffer[k] + next_row_buffer[k])[ - len(last_row_buffer["revid"]) : - ] - for k in {"revid", "text"} - } - else: - diff_buffer = { - k: row_buffer[k] + next_row_buffer[k] - for k in {"revid", "text"} - } + for k, v in pending_row.items(): + write_buffer[k].append(v) + buffer_count += 1 + last_namespace = page.mwpage.namespace - else: - row_buffer = batch_row_buffer - - is_revert_column: list[Union[bool, None]] = [] - for r, d in zip(row_buffer["reverteds"], row_buffer["deleted"]): - if self.revert_radius == 0 or d: - is_revert_column.append(None) - else: - is_revert_column.append(r is not None) - - row_buffer["revert"] = is_revert_column - - for k, v in regex_matches.items(): - row_buffer[k] = v - regex_matches = {} - - # begin persistence logic - if self.persist != PersistMethod.none: - row_buffer["token_revs"] = [] - row_buffer["tokens_added"] = [] - row_buffer["tokens_removed"] = [] - row_buffer["tokens_window"] = [] - for idx, text in enumerate(diff_buffer["text"]): - rev_id = diff_buffer["revid"][idx] - if self.persist != PersistMethod.legacy: - _, tokens_added, tokens_removed = persist_state.update( - text, rev_id - ) - else: - _, tokens_added, tokens_removed = persist_state.process( - text, rev_id - ) - - window.append((rev_id, tokens_added, tokens_removed)) - - if len(window) == PERSISTENCE_RADIUS: - ( - old_rev_id, - old_tokens_added, - old_tokens_removed, - ) = window.popleft() - num_token_revs, num_tokens = calculate_persistence( - old_tokens_added - ) - - row_buffer["token_revs"].append(num_token_revs) - row_buffer["tokens_added"].append(num_tokens) - row_buffer["tokens_removed"].append(len(old_tokens_removed)) - row_buffer["tokens_window"].append(PERSISTENCE_RADIUS - 1) - - if on_last_batch: - # this needs to run when we get to the end - # print out metadata for the last RADIUS revisions - for i, item in enumerate(window): - # if the window was full, we've already printed item 0 - if len(window) == PERSISTENCE_RADIUS and i == 0: - continue - - rev_id, tokens_added, tokens_removed = item - num_token_revs, num_tokens = calculate_persistence( - tokens_added - ) - - row_buffer["token_revs"].append(num_token_revs) - row_buffer["tokens_added"].append(num_tokens) - row_buffer["tokens_removed"].append(len(tokens_removed)) - row_buffer["tokens_window"].append(len(window) - (i + 1)) - - last_row_buffer = next_row_buffer - - # the persistence stuff doesn't calculate diffs for reverts. - if self.diff: - last_text = last_rev_text - new_diffs = [] - diff_timeouts = [] - for i, text in enumerate(row_buffer["text"]): - if self.shutdown_requested: - break - diff, timed_out = diff_with_timeout(differ, last_text, text) - if timed_out: - print(f"WARNING! wikidiff2 timeout for rev: {row_buffer['revid'][i]}. Falling back to default limits.", file=sys.stderr) - diff = fast_differ.inline_json_diff(last_text, text) - new_diffs.append(diff) - diff_timeouts.append(timed_out) - last_text = text - if self.shutdown_requested: - print("Shutdown requested, closing writers...", file=sys.stderr) - break - row_buffer["diff"] = [ - [ - entry - for entry in json.loads(diff)["diff"] - if entry["type"] != 0 - ] - for diff in new_diffs - ] - row_buffer["diff_timeout"] = diff_timeouts - - # end persistence logic - if self.diff or self.persist != PersistMethod.none: - last_rev_text = row_buffer["text"][-1] - last_rev_id = row_buffer["revid"][-1] - - if not self.text and self.persist != PersistMethod.none: - del row_buffer["text"] - - # Filter for resume logic if on resume page - should_write = True - if is_resume_page: - _, resume_revid = page_resume_point - revids = row_buffer["revid"] - resume_idx = next((i for i, r in enumerate(revids) if r == resume_revid), None) - - if resume_idx is not None: - # Mark resume point as found - if self.partition_namespaces: - found_resume_point[page.mwpage.namespace] = True - else: - found_resume_point = True - is_resume_page = False - - # Only write revisions after the resume point - if resume_idx + 1 < len(revids): - row_buffer = {k: v[resume_idx + 1:] for k, v in row_buffer.items()} - print(f"Resuming output starting at revid {row_buffer['revid'][0]}", file=sys.stderr) - else: - should_write = False - else: - should_write = False - - # Write batch if there are rows - if should_write and len(row_buffer.get("revid", [])) > 0: - namespace = page.mwpage.namespace if self.partition_namespaces else None - writer, num_rows = self._write_batch( - row_buffer, schema, writer, pq_writers, ns_base_paths, sorting_cols, - namespace, part_numbers - ) - # Update checkpoint with last written position - last_pageid = row_buffer["articleid"][-1] - last_revid = row_buffer["revid"][-1] - - # Track revisions and check if file rotation is needed. - # namespace is None for non-partitioned output. - revisions_in_part[namespace] = revisions_in_part.get(namespace, 0) + num_rows - current_part_num = part_numbers.get(namespace, 0) - self._update_checkpoint(last_pageid, last_revid, namespace, current_part_num) - - # Rotate file if limit exceeded - if self.max_revisions_per_file > 0 and revisions_in_part[namespace] >= self.max_revisions_per_file: - part_numbers[namespace] = current_part_num + 1 - revisions_in_part[namespace] = 0 - if self.partition_namespaces: - pq_writers[namespace].close() - del pq_writers[namespace] - else: - writer.close() - output_path_with_part = self._get_part_path(self.output_file, part_numbers[namespace]) - writer = pq.ParquetWriter( - output_path_with_part, - schema, - flavor="spark", - sorting_columns=sorting_cols, - ) - gc.collect() - - # If shutdown was requested, break from page loop if self.shutdown_requested: break page_count += 1 + # Flush remaining buffer + flush_buffer() + # Cancel time limit timer self._cancel_time_limit_timer(time_limit_timer) @@ -1025,13 +1099,18 @@ class WikiqParser: "Done: %s revisions and %s pages." % (rev_count, page_count), file=sys.stderr, ) - if self.partition_namespaces is True: - for writer in pq_writers.values(): - writer.close() - else: + + # Close all writers + if self.output_parquet and self.partition_namespaces: + for pq_writer in pq_writers.values(): + pq_writer.close() + elif writer is not None: writer.close() - # If we were resuming, merge the original file with the new temp file + # Close checkpoint file; delete it only if we completed without interruption + self._close_checkpoint(delete=not self.shutdown_requested) + + # Merge temp output with original for parquet resume if original_output_file is not None and temp_output_file is not None: finalize_resume_merge( original_output_file, @@ -1040,9 +1119,6 @@ class WikiqParser: original_partition_dir ) - # Close checkpoint file; delete it only if we completed without interruption - self._close_checkpoint(delete=not self.shutdown_requested) - def match_archive_suffix(input_filename): if re.match(r".*\.7z$", input_filename): cmd = ["7za", "x", "-so", input_filename] @@ -1055,41 +1131,6 @@ def match_archive_suffix(input_filename): return cmd -def detect_existing_part_files(output_file, partition_namespaces): - """ - Detect whether existing output uses part file naming. - - Returns: - True if part files exist (e.g., output.part0.parquet) - False if non-part file exists (e.g., output.parquet) - None if no existing output found - """ - output_dir = os.path.dirname(output_file) or '.' - stem = os.path.basename(output_file).rsplit('.', 1)[0] - part0_name = f"{stem}.part0.parquet" - - if partition_namespaces: - if not os.path.isdir(output_dir): - return None - - for ns_dir in os.listdir(output_dir): - if not ns_dir.startswith('namespace='): - continue - ns_path = os.path.join(output_dir, ns_dir) - - if os.path.exists(os.path.join(ns_path, part0_name)): - return True - if os.path.exists(os.path.join(ns_path, os.path.basename(output_file))): - return False - return None - else: - if os.path.exists(os.path.join(output_dir, part0_name)): - return True - if os.path.exists(output_file): - return False - return None - - def open_input_file(input_filename, fandom_2020=False): cmd = match_archive_suffix(input_filename) if fandom_2020: @@ -1100,23 +1141,24 @@ def open_input_file(input_filename, fandom_2020=False): return open(input_filename, "r") -def get_output_filename(input_filename, parquet=False) -> str: +def get_output_filename(input_filename, output_format='tsv') -> str: + """Generate output filename based on input filename and format. + + Args: + input_filename: Input dump file path + output_format: 'tsv', 'jsonl', or 'parquet' + """ output_filename = re.sub(r"\.(7z|gz|bz2)?$", "", input_filename) output_filename = re.sub(r"\.xml", "", output_filename) - if parquet is False: - output_filename = output_filename + ".tsv" - else: + if output_format == 'jsonl': + output_filename = output_filename + ".jsonl" + elif output_format == 'parquet': output_filename = output_filename + ".parquet" + else: + output_filename = output_filename + ".tsv" return output_filename -def open_output_file(input_filename): - # create a regex that creates the output filename - output_filename = get_output_filename(input_filename, parquet=False) - output_file = open(output_filename, "w") - return output_file - - def main(): parser = argparse.ArgumentParser( description="Parse MediaWiki XML database dumps into tab delimited data." @@ -1138,7 +1180,7 @@ def main(): dest="output", type=str, nargs=1, - help="Directory for output files. If it ends with .parquet output will be in parquet format.", + help="Output file or directory. Format is detected from extension: .jsonl for JSONL, .parquet for Parquet, otherwise TSV.", ) parser.add_argument( @@ -1149,6 +1191,13 @@ def main(): help="Write output to standard out (do not create dump file)", ) + parser.add_argument( + "--print-schema", + dest="print_schema", + action="store_true", + help="Print the Spark-compatible JSON schema for the output and exit. No dump file is processed.", + ) + parser.add_argument( "--collapse-user", dest="collapse_user", @@ -1285,15 +1334,6 @@ def main(): help="Extract section headings from each revision.", ) - parser.add_argument( - "-PNS", - "--partition-namespaces", - dest="partition_namespaces", - default=False, - action="store_true", - help="Partition parquet files by namespace.", - ) - parser.add_argument( "--fandom-2020", dest="fandom_2020", @@ -1324,12 +1364,20 @@ def main(): help="Time limit in hours before graceful shutdown. Set to 0 to disable (default).", ) + parser.add_argument( + "--partition-namespaces", + dest="partition_namespaces", + action="store_true", + default=False, + help="For Parquet output, partition output by namespace into separate files.", + ) + parser.add_argument( "--max-revisions-per-file", dest="max_revisions_per_file", type=int, default=0, - help="Max revisions per output file before creating a new part file. Set to 0 to disable (default: disabled).", + help="For Parquet output, split output into multiple files after this many revisions. Set to 0 to disable (default).", ) args = parser.parse_args() @@ -1352,6 +1400,33 @@ def main(): else: namespaces = None + # Handle --print-schema: build and output schema, then exit + if args.print_schema: + regex_revision_pairs = make_regex_pairs(args.regex_match_revision, args.regex_revision_label) + regex_comment_pairs = make_regex_pairs(args.regex_match_comment, args.regex_comment_label) + + table, _ = build_table( + text=args.text, + collapse_user=args.collapse_user, + external_links=args.external_links, + citations=args.citations, + wikilinks=args.wikilinks, + templates=args.templates, + headings=args.headings, + ) + schema = build_schema( + table, + diff=args.diff, + persist=persist, + text=args.text, + regex_revision_pairs=regex_revision_pairs, + regex_comment_pairs=regex_comment_pairs, + ) + + spark_schema = pyarrow_to_spark_schema(schema) + print(json.dumps(spark_schema, indent=2)) + sys.exit(0) + print(args, file=sys.stderr) if len(args.dumpfiles) > 0: for filename in args.dumpfiles: @@ -1361,77 +1436,44 @@ def main(): else: output = "." + # Detect output format from extension + output_jsonl = output.endswith(".jsonl") output_parquet = output.endswith(".parquet") + partition_namespaces = args.partition_namespaces and output_parquet if args.stdout: output_file = sys.stdout.buffer - elif os.path.isdir(output) or output_parquet: + elif output_jsonl or output_parquet: + # Output is a JSONL or Parquet file path - use it directly + output_file = output + elif os.path.isdir(output): + # Output is a directory - derive filename from input output_filename = os.path.join(output, os.path.basename(filename)) - output_file = get_output_filename(output_filename, parquet=output_parquet) + output_file = get_output_filename(output_filename, output_format='tsv') else: output_file = output # Handle resume functionality before opening input file resume_point = None - start_fresh = False if args.resume: - if output_parquet and not args.stdout: - # First, merge any leftover temp files from a previous interrupted run - cleanup_result = cleanup_interrupted_resume(output_file, args.partition_namespaces) - if cleanup_result == "start_fresh": - # All data was corrupted, start from beginning - start_fresh = True - print("Starting fresh due to data corruption.", file=sys.stderr) - else: - resume_point = get_resume_point(output_file, args.partition_namespaces) - if resume_point is not None: - if args.partition_namespaces: - ns_list = sorted(resume_point.keys()) - print(f"Resuming with per-namespace resume points for {len(ns_list)} namespaces", file=sys.stderr) - for ns in ns_list: - pageid, revid, part = resume_point[ns] - print(f" namespace={ns}: pageid={pageid}, revid={revid}, part={part}", file=sys.stderr) - else: - pageid, revid, part = resume_point - print(f"Resuming from last written point: pageid={pageid}, revid={revid}, part={part}", file=sys.stderr) - - # Detect mismatch between existing file naming and current settings - existing_uses_parts = detect_existing_part_files(output_file, args.partition_namespaces) - current_uses_parts = args.max_revisions_per_file > 0 - - if existing_uses_parts is not None and existing_uses_parts != current_uses_parts: - if existing_uses_parts: - print(f"WARNING: Existing output uses part files but --max-revisions-per-file is 0. " - f"New data will be written to a non-part file.", file=sys.stderr) - else: - print(f"WARNING: Existing output does not use part files but --max-revisions-per-file={args.max_revisions_per_file}. " - f"Disabling file splitting for consistency with existing output.", file=sys.stderr) - args.max_revisions_per_file = 0 + if (output_jsonl or output_parquet) and not args.stdout: + # Clean up any interrupted resume from previous run + if output_parquet: + cleanup_result = cleanup_interrupted_resume(output_file, partition_namespaces) + if cleanup_result == "start_fresh": + resume_point = None else: - # resume_point is None - check if file exists but is corrupt - if args.partition_namespaces: - partition_dir = os.path.dirname(output_file) - output_filename = os.path.basename(output_file) - corrupt_files = [] - if os.path.isdir(partition_dir): - for d in os.listdir(partition_dir): - if d.startswith('namespace='): - filepath = os.path.join(partition_dir, d, output_filename) - if os.path.exists(filepath): - corrupt_files.append(filepath) - if corrupt_files: - print("Output files exist but are corrupt, deleting and starting fresh.", file=sys.stderr) - for filepath in corrupt_files: - os.remove(filepath) - start_fresh = True - else: - if os.path.exists(output_file): - # File exists but is corrupt - start fresh - print(f"Output file {output_file} exists but is corrupt, starting fresh.", file=sys.stderr) - os.remove(output_file) - start_fresh = True + resume_point = get_resume_point(output_file, partition_namespaces) + else: + resume_point = read_checkpoint(get_checkpoint_path(output_file)) + if resume_point is not None: + if isinstance(resume_point, dict): + print(f"Resuming from checkpoint for {len(resume_point)} namespaces", file=sys.stderr) + else: + pageid, revid = resume_point[0], resume_point[1] + print(f"Resuming from checkpoint: pageid={pageid}, revid={revid}", file=sys.stderr) else: - sys.exit("Error: --resume only works with parquet output (not stdout or TSV)") + sys.exit("Error: --resume only works with JSONL or Parquet output (not stdout or TSV)") # Now open the input file print("Processing file: %s" % filename, file=sys.stderr) @@ -1452,8 +1494,9 @@ def main(): regex_comment_label=args.regex_comment_label, text=args.text, diff=args.diff, + output_jsonl=output_jsonl, output_parquet=output_parquet, - partition_namespaces=args.partition_namespaces, + partition_namespaces=partition_namespaces, batch_size=args.batch_size, resume_point=resume_point, external_links=args.external_links, diff --git a/src/wikiq/resume.py b/src/wikiq/resume.py index ff9e05c..27db8ac 100644 --- a/src/wikiq/resume.py +++ b/src/wikiq/resume.py @@ -1,9 +1,9 @@ """ -Checkpoint and resume functionality for wikiq parquet output. +Checkpoint and resume functionality for wikiq output. This module handles: -- Finding resume points in existing parquet output -- Merging resumed data with existing output (streaming, memory-efficient) +- Finding resume points in existing output (JSONL or Parquet) +- Merging resumed data with existing output (for Parquet, streaming, memory-efficient) - Checkpoint file management for fast resume point lookup """ @@ -14,6 +14,63 @@ import sys import pyarrow.parquet as pq +def get_checkpoint_path(output_file, partition_namespaces=False): + """Get the path to the checkpoint file for a given output file. + + For partitioned output, the checkpoint is placed outside the partition directory + to avoid pyarrow trying to read it as a parquet file. The filename includes + the output filename to keep it unique per input file (for parallel jobs). + """ + if partition_namespaces: + partition_dir = os.path.dirname(output_file) + output_filename = os.path.basename(output_file) + parent_dir = os.path.dirname(partition_dir) + return os.path.join(parent_dir, output_filename + ".checkpoint") + return str(output_file) + ".checkpoint" + + +def read_checkpoint(checkpoint_path, partition_namespaces=False): + """ + Read resume point from checkpoint file if it exists. + + Checkpoint format: + Single file: {"pageid": 54, "revid": 325} or {"pageid": 54, "revid": 325, "part": 2} + Partitioned: {"0": {"pageid": 54, "revid": 325, "part": 1}, ...} + + Returns: + For single files: A tuple (pageid, revid) or (pageid, revid, part), or None if not found. + For partitioned: A dict mapping namespace -> (pageid, revid, part), or None. + """ + if not os.path.exists(checkpoint_path): + return None + + try: + with open(checkpoint_path, 'r') as f: + data = json.load(f) + + if not data: + return None + + # Single-file format: {"pageid": ..., "revid": ..., "part": ...} + if "pageid" in data and "revid" in data: + part = data.get("part", 0) + if part > 0: + return (data["pageid"], data["revid"], part) + return (data["pageid"], data["revid"]) + + # Partitioned format: {"0": {"pageid": ..., "revid": ..., "part": ...}, ...} + result = {} + for key, value in data.items(): + part = value.get("part", 0) + result[int(key)] = (value["pageid"], value["revid"], part) + + return result if result else None + + except (json.JSONDecodeError, IOError, KeyError, TypeError) as e: + print(f"Warning: Could not read checkpoint file {checkpoint_path}: {e}", file=sys.stderr) + return None + + def cleanup_interrupted_resume(output_file, partition_namespaces): """ Merge any leftover .resume_temp files from a previous interrupted run. @@ -47,7 +104,6 @@ def cleanup_interrupted_resume(output_file, partition_namespaces): print(f"Found leftover temp files in {partition_dir} from previous interrupted partitioned run, merging first...", file=sys.stderr) had_corruption = merge_partitioned_namespaces(partition_dir, temp_suffix, output_filename) - # Check if any valid data remains after merge has_valid_data = False for ns_dir in os.listdir(partition_dir): if ns_dir.startswith('namespace='): @@ -58,7 +114,6 @@ def cleanup_interrupted_resume(output_file, partition_namespaces): break if had_corruption and not has_valid_data: - # All data was corrupted, remove checkpoint and start fresh checkpoint_path = get_checkpoint_path(output_file, partition_namespaces) if os.path.exists(checkpoint_path): os.remove(checkpoint_path) @@ -73,21 +128,17 @@ def cleanup_interrupted_resume(output_file, partition_namespaces): merged_path = output_file + ".merged" merged = merge_parquet_files(output_file, temp_output_file, merged_path) if merged == "original_only": - # Temp file was invalid, just remove it os.remove(temp_output_file) elif merged == "temp_only": - # Original was corrupted or missing, use temp as new base if os.path.exists(output_file): os.remove(output_file) os.rename(temp_output_file, output_file) print("Recovered from temp file (original was corrupted or missing).", file=sys.stderr) elif merged == "both_invalid": - # Both files corrupted or missing, remove both and start fresh if os.path.exists(output_file): os.remove(output_file) if os.path.exists(temp_output_file): os.remove(temp_output_file) - # Also remove stale checkpoint file checkpoint_path = get_checkpoint_path(output_file, partition_namespaces) if os.path.exists(checkpoint_path): os.remove(checkpoint_path) @@ -99,95 +150,34 @@ def cleanup_interrupted_resume(output_file, partition_namespaces): os.remove(temp_output_file) print("Previous temp file merged successfully.", file=sys.stderr) else: - # Both empty - unusual os.remove(temp_output_file) -def get_checkpoint_path(output_file, partition_namespaces=False): - """Get the path to the checkpoint file for a given output file. - - For partitioned output, the checkpoint is placed outside the partition directory - to avoid pyarrow trying to read it as a parquet file. The filename includes - the output filename to keep it unique per input file (for parallel jobs). - """ - if partition_namespaces: - # output_file is like partition_dir/output.parquet - # checkpoint should be at parent level: parent/output.parquet.checkpoint - partition_dir = os.path.dirname(output_file) - output_filename = os.path.basename(output_file) - parent_dir = os.path.dirname(partition_dir) - return os.path.join(parent_dir, output_filename + ".checkpoint") - return str(output_file) + ".checkpoint" - - -def read_checkpoint(output_file, partition_namespaces=False): - """ - Read resume point from checkpoint file if it exists. - - Checkpoint format: - Single file: {"pageid": 54, "revid": 325, "part": 2} - Partitioned: {"0": {"pageid": 54, "revid": 325, "part": 1}, ...} - - Returns: - For single files: A tuple (pageid, revid, part), or None if not found. - For partitioned: A dict mapping namespace -> (pageid, revid, part), or None. - - Note: part defaults to 0 for checkpoints without part numbers (backwards compat). - """ - checkpoint_path = get_checkpoint_path(output_file, partition_namespaces) - if not os.path.exists(checkpoint_path): - return None - - try: - with open(checkpoint_path, 'r') as f: - data = json.load(f) - - if not data: - return None - - # Single-file format: {"pageid": ..., "revid": ..., "part": ...} - if "pageid" in data and "revid" in data: - part = data.get("part", 0) - return (data["pageid"], data["revid"], part) - - # Partitioned format: {"0": {"pageid": ..., "revid": ..., "part": ...}, ...} - result = {} - for key, value in data.items(): - part = value.get("part", 0) - result[int(key)] = (value["pageid"], value["revid"], part) - - return result if result else None - - except (json.JSONDecodeError, IOError, KeyError, TypeError) as e: - print(f"Warning: Could not read checkpoint file {checkpoint_path}: {e}", file=sys.stderr) - return None - - def get_resume_point(output_file, partition_namespaces=False): """ - Find the resume point(s) from existing parquet output. + Find the resume point(s) from existing output. First checks for a checkpoint file (fast), then falls back to scanning the parquet output (slow, for backwards compatibility). Args: - output_file: Path to the output file. For single files, this is the parquet file path. - For partitioned namespaces, this is the path like dir/dump.parquet where - namespace=* subdirectories are in the parent dir. + output_file: Path to the output file. partition_namespaces: Whether the output uses namespace partitioning. Returns: - For single files: A tuple (pageid, revid, part) or None if not found. + For single files: A tuple (pageid, revid) or (pageid, revid, part), or None. For partitioned: A dict mapping namespace -> (pageid, revid, part), or None. - When falling back to parquet scanning, part defaults to 0. """ - # First try checkpoint file (fast) checkpoint_path = get_checkpoint_path(output_file, partition_namespaces) - checkpoint_result = read_checkpoint(output_file, partition_namespaces) + checkpoint_result = read_checkpoint(checkpoint_path, partition_namespaces) if checkpoint_result is not None: print(f"Resume point found in checkpoint file {checkpoint_path}", file=sys.stderr) return checkpoint_result + # For JSONL, only checkpoint-based resume is supported + if output_file.endswith('.jsonl'): + return None + # Fall back to scanning parquet (slow, for backwards compatibility) print(f"No checkpoint file found at {checkpoint_path}, scanning parquet output...", file=sys.stderr) try: @@ -201,12 +191,7 @@ def get_resume_point(output_file, partition_namespaces=False): def _get_last_row_resume_point(pq_path): - """Get resume point by reading only the last row group of a parquet file. - - Since data is written in page/revision order, the last row group contains - the highest pageid/revid, and the last row in that group is the resume point. - Returns (pageid, revid, part) with part=0 (scanning can't determine part). - """ + """Get resume point by reading only the last row group of a parquet file.""" pf = pq.ParquetFile(pq_path) if pf.metadata.num_row_groups == 0: return None @@ -222,16 +207,7 @@ def _get_last_row_resume_point(pq_path): def _get_resume_point_partitioned(output_file): - """Find per-namespace resume points from partitioned output. - - Only looks for the specific output file in each namespace directory. - Returns a dict mapping namespace -> (max_pageid, max_revid, part=0) for each - partition where the output file exists. - - Args: - output_file: Path like 'dir/output.parquet' where namespace=* subdirectories - contain files named 'output.parquet'. - """ + """Find per-namespace resume points from partitioned output.""" partition_dir = os.path.dirname(output_file) output_filename = os.path.basename(output_file) @@ -274,14 +250,13 @@ def _get_resume_point_single_file(output_file): def merge_parquet_files(original_path, temp_path, merged_path): """ - Merge two parquet files by streaming row groups from original and temp into merged. + Merge two parquet files by streaming row groups. - This is memory-efficient: only one row group is loaded at a time. Returns: "merged" - merged file was created from both sources "original_only" - temp was invalid, keep original unchanged - "temp_only" - original was corrupted but temp is valid, use temp - "both_invalid" - both files invalid, delete both and start fresh + "temp_only" - original was corrupted but temp is valid + "both_invalid" - both files invalid False - both files were valid but empty """ original_valid = False @@ -297,12 +272,12 @@ def merge_parquet_files(original_path, temp_path, merged_path): try: if not os.path.exists(temp_path): - print(f"Note: Temp file {temp_path} does not exist (namespace had no records after resume point)", file=sys.stderr) + print(f"Note: Temp file {temp_path} does not exist", file=sys.stderr) else: temp_pq = pq.ParquetFile(temp_path) temp_valid = True except Exception: - print(f"Note: No new data in temp file {temp_path} (file exists but is invalid)", file=sys.stderr) + print(f"Note: No new data in temp file {temp_path}", file=sys.stderr) if not original_valid and not temp_valid: print(f"Both original and temp files are invalid, will start fresh", file=sys.stderr) @@ -317,7 +292,6 @@ def merge_parquet_files(original_path, temp_path, merged_path): merged_writer = None - # Copy all row groups from the original file for i in range(original_pq.num_row_groups): row_group = original_pq.read_row_group(i) if merged_writer is None: @@ -328,7 +302,6 @@ def merge_parquet_files(original_path, temp_path, merged_path): ) merged_writer.write_table(row_group) - # Append all row groups from the temp file for i in range(temp_pq.num_row_groups): row_group = temp_pq.read_row_group(i) if merged_writer is None: @@ -339,7 +312,6 @@ def merge_parquet_files(original_path, temp_path, merged_path): ) merged_writer.write_table(row_group) - # Close the writer if merged_writer is not None: merged_writer.close() return "merged" @@ -350,16 +322,6 @@ def merge_partitioned_namespaces(partition_dir, temp_suffix, file_filter): """ Merge partitioned namespace directories after resume. - For partitioned namespaces, temp files are written alongside the original files - in each namespace directory with the temp suffix appended to the filename. - E.g., original: namespace=0/file.parquet, temp: namespace=0/file.parquet.resume_temp - - Args: - partition_dir: The partition directory containing namespace=* subdirs - temp_suffix: The suffix appended to temp files (e.g., '.resume_temp') - file_filter: Only process temp files matching this base name - (e.g., 'enwiki-20250123-pages-meta-history24-p53238682p53445302.parquet') - Returns: True if at least one namespace has valid data after merge False if all namespaces ended up with corrupted/deleted data @@ -375,49 +337,40 @@ def merge_partitioned_namespaces(partition_dir, temp_suffix, file_filter): if not os.path.exists(temp_path): continue - # Original file is the temp file without the suffix original_file = file_filter original_path = os.path.join(ns_path, original_file) if os.path.exists(original_path): - # Merge the files merged_path = original_path + ".merged" merged = merge_parquet_files(original_path, temp_path, merged_path) if merged == "original_only": - # Temp file was invalid (no new data), keep original unchanged if os.path.exists(temp_path): os.remove(temp_path) elif merged == "temp_only": - # Original was corrupted, use temp as new base os.remove(original_path) os.rename(temp_path, original_path) elif merged == "both_invalid": - # Both files corrupted, remove both if os.path.exists(original_path): os.remove(original_path) if os.path.exists(temp_path): os.remove(temp_path) had_corruption = True elif merged == "merged": - # Replace the original file with the merged file os.remove(original_path) os.rename(merged_path, original_path) if os.path.exists(temp_path): os.remove(temp_path) else: - # Both files were empty (False), just remove them if os.path.exists(original_path): os.remove(original_path) if os.path.exists(temp_path): os.remove(temp_path) else: - # No original file, rename temp to original only if valid try: pq.ParquetFile(temp_path) os.rename(temp_path, original_path) except Exception: - # Temp file invalid or missing, just remove it if it exists if os.path.exists(temp_path): os.remove(temp_path) had_corruption = True @@ -433,55 +386,36 @@ def finalize_resume_merge( ): """ Finalize the resume by merging temp output with original output. - - Args: - original_output_file: Path to the original output file - temp_output_file: Path to the temp output file written during resume - partition_namespaces: Whether using partitioned namespace output - original_partition_dir: The partition directory (for partitioned output) - - Raises: - Exception: If merge fails (temp file is preserved for recovery) """ import shutil print("Merging resumed data with existing output...", file=sys.stderr) try: if partition_namespaces and original_partition_dir is not None: - # For partitioned namespaces, temp files are written alongside originals - # with '.resume_temp' suffix in each namespace directory. - # Only merge temp files for the current dump file, not other concurrent jobs. file_filter = os.path.basename(original_output_file) merge_partitioned_namespaces(original_partition_dir, ".resume_temp", file_filter) - # Clean up the empty temp directory we created if os.path.exists(temp_output_file) and os.path.isdir(temp_output_file): shutil.rmtree(temp_output_file) else: - # Merge single parquet files merged_output_file = original_output_file + ".merged" merged = merge_parquet_files(original_output_file, temp_output_file, merged_output_file) if merged == "original_only": - # Temp file was invalid (no new data), keep original unchanged if os.path.exists(temp_output_file): os.remove(temp_output_file) elif merged == "temp_only": - # Original was corrupted, use temp as new base os.remove(original_output_file) os.rename(temp_output_file, original_output_file) elif merged == "both_invalid": - # Both files corrupted, remove both os.remove(original_output_file) if os.path.exists(temp_output_file): os.remove(temp_output_file) elif merged == "merged": - # Replace the original file with the merged file os.remove(original_output_file) os.rename(merged_output_file, original_output_file) if os.path.exists(temp_output_file): os.remove(temp_output_file) else: - # Both files were empty (False) - unusual, but clean up os.remove(original_output_file) if os.path.exists(temp_output_file): os.remove(temp_output_file) @@ -495,11 +429,7 @@ def finalize_resume_merge( def setup_resume_temp_output(output_file, partition_namespaces): """ - Set up temp output for resume mode. - - Args: - output_file: The original output file path - partition_namespaces: Whether using partitioned namespace output + Set up temp output for resume mode (Parquet only). Returns: Tuple of (original_output_file, temp_output_file, original_partition_dir) @@ -511,7 +441,6 @@ def setup_resume_temp_output(output_file, partition_namespaces): temp_output_file = None original_partition_dir = None - # For partitioned namespaces, check if the specific output file exists in any namespace if partition_namespaces: partition_dir = os.path.dirname(output_file) output_filename = os.path.basename(output_file) @@ -531,9 +460,6 @@ def setup_resume_temp_output(output_file, partition_namespaces): original_output_file = output_file temp_output_file = output_file + ".resume_temp" - # Note: cleanup_interrupted_resume() should have been called before this - # to merge any leftover temp files from a previous interrupted run. - # Here we just clean up any remaining temp directory markers. if os.path.exists(temp_output_file): if os.path.isdir(temp_output_file): shutil.rmtree(temp_output_file) diff --git a/src/wikiq/tables.py b/src/wikiq/tables.py index 269af6a..23e8f57 100644 --- a/src/wikiq/tables.py +++ b/src/wikiq/tables.py @@ -17,9 +17,6 @@ T = TypeVar('T') class RevisionField(ABC, Generic[T]): - def __init__(self): - self.data: list[T] = [] - """ Abstract type which represents a field in a table of page revisions. """ @@ -43,14 +40,6 @@ class RevisionField(ABC, Generic[T]): """ pass - def add(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> None: - self.data.append(self.extract(page, revisions)) - - def pop(self) -> list[T]: - data = self.data - self.data = [] - return data - class RevisionTable: columns: list[RevisionField] @@ -58,19 +47,15 @@ class RevisionTable: def __init__(self, columns: list[RevisionField]): self.columns = columns - def add(self, page: mwtypes.Page, revisions: list[mwxml.Revision]): - for column in self.columns: - column.add(page=page, revisions=revisions) - def schema(self) -> pa.Schema: return pa.schema([c.field for c in self.columns]) - def pop(self) -> dict: - data = dict() - for column in self.columns: - data[column.field.name] = column.pop() - - return data + def extract_row(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> dict: + """Extract a single row dict for the given page and revisions.""" + return { + column.field.name: column.extract(page, revisions) + for column in self.columns + } class RevisionId(RevisionField[int]): diff --git a/test/Wikiq_Unit_Test.py b/test/Wikiq_Unit_Test.py index 1e272f0..e56545e 100644 --- a/test/Wikiq_Unit_Test.py +++ b/test/Wikiq_Unit_Test.py @@ -7,10 +7,12 @@ from io import StringIO import numpy as np import pandas as pd import pyarrow as pa +import pyarrow.json as pj import pytest from pandas import DataFrame from pandas.testing import assert_frame_equal, assert_series_equal +from wikiq import build_table, build_schema from wikiq_test_utils import ( BASELINE_DIR, IKWIKI, @@ -34,6 +36,17 @@ def setup(): setup() +def read_jsonl_with_schema(filepath: str, **schema_kwargs) -> pd.DataFrame: + """Read JSONL file using PyArrow with explicit schema from wikiq.""" + table, _ = build_table(**schema_kwargs) + schema = build_schema(table, **schema_kwargs) + pa_table = pj.read_json( + filepath, + parse_options=pj.ParseOptions(explicit_schema=schema), + ) + return pa_table.to_pandas() + + # with / without pwr DONE # with / without url encode DONE # with / without collapse user DONE @@ -124,7 +137,62 @@ def test_noargs(): test = pd.read_table(tester.output) baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) - + + +def test_jsonl_noargs(): + """Test JSONL output format with baseline comparison.""" + tester = WikiqTester(SAILORMOON, "noargs", in_compression="7z", out_format="jsonl", baseline_format="jsonl") + + try: + tester.call_wikiq() + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + test = read_jsonl_with_schema(tester.output) + baseline = read_jsonl_with_schema(tester.baseline_file) + assert_frame_equal(test, baseline, check_like=True) + + +def test_jsonl_tsv_equivalence(): + """Test that JSONL and TSV outputs contain equivalent data.""" + tester_tsv = WikiqTester(SAILORMOON, "equiv_tsv", in_compression="7z", out_format="tsv") + tester_jsonl = WikiqTester(SAILORMOON, "equiv_jsonl", in_compression="7z", out_format="jsonl") + + try: + tester_tsv.call_wikiq() + tester_jsonl.call_wikiq() + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + tsv_df = pd.read_table(tester_tsv.output) + jsonl_df = read_jsonl_with_schema(tester_jsonl.output) + + # Row counts must match + assert len(tsv_df) == len(jsonl_df), f"Row count mismatch: TSV={len(tsv_df)}, JSONL={len(jsonl_df)}" + + # Column sets must match + assert set(tsv_df.columns) == set(jsonl_df.columns), \ + f"Column mismatch: TSV={set(tsv_df.columns)}, JSONL={set(jsonl_df.columns)}" + + # Sort both by revid for comparison + tsv_df = tsv_df.sort_values("revid").reset_index(drop=True) + jsonl_df = jsonl_df.sort_values("revid").reset_index(drop=True) + + # Normalize null values: TSV uses nan, schema-based JSONL uses None + jsonl_df = jsonl_df.replace({None: np.nan}) + + # Compare columns - schema-based reading handles types correctly + for col in tsv_df.columns: + if col == "date_time": + # TSV reads as string, JSONL with schema reads as datetime + tsv_dates = pd.to_datetime(tsv_df[col]).dt.strftime("%Y-%m-%dT%H:%M:%SZ") + jsonl_dates = jsonl_df[col].dt.strftime("%Y-%m-%dT%H:%M:%SZ") + assert_series_equal(tsv_dates, jsonl_dates, check_names=False) + else: + # Allow dtype differences (TSV infers int64, schema uses int32) + assert_series_equal(tsv_df[col], jsonl_df[col], check_names=False, check_dtype=False) + + def test_collapse_user(): tester = WikiqTester(SAILORMOON, "collapse-user", in_compression="7z") @@ -137,19 +205,6 @@ def test_collapse_user(): baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) -def test_partition_namespaces(): - tester = WikiqTester(SAILORMOON, "collapse-user", in_compression="7z", out_format='parquet', baseline_format='parquet') - - try: - tester.call_wikiq("--collapse-user", "--fandom-2020", "--partition-namespaces") - except subprocess.CalledProcessError as exc: - pytest.fail(exc.stderr.decode("utf8")) - - test = pd.read_parquet(os.path.join(tester.output,"namespace=10/sailormoon.parquet")) - baseline = pd.read_parquet(tester.baseline_file) - assert_frame_equal(test, baseline, check_like=True) - - def test_pwr_wikidiff2(): tester = WikiqTester(SAILORMOON, "persistence_wikidiff2", in_compression="7z") @@ -201,46 +256,43 @@ def test_pwr(): assert_frame_equal(test, baseline, check_like=True) def test_diff(): - tester = WikiqTester(SAILORMOON, "diff", in_compression="7z", out_format='parquet', baseline_format='parquet') + tester = WikiqTester(SAILORMOON, "diff", in_compression="7z", out_format='jsonl') try: tester.call_wikiq("--diff", "--fandom-2020") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet") - baseline = pd.read_parquet(tester.baseline_file) - - test = test.reindex(columns=sorted(test.columns)) - assert_frame_equal(test, baseline, check_like=True) + test = pd.read_json(tester.output, lines=True) + assert "diff" in test.columns, "diff column should exist" + assert "diff_timeout" in test.columns, "diff_timeout column should exist" + assert len(test) > 0, "Should have output rows" def test_diff_plus_pwr(): - tester = WikiqTester(SAILORMOON, "diff_pwr", in_compression="7z", out_format='parquet', baseline_format='parquet') + tester = WikiqTester(SAILORMOON, "diff_pwr", in_compression="7z", out_format='jsonl') try: tester.call_wikiq("--diff --persistence wikidiff2", "--fandom-2020") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet") - baseline = pd.read_parquet(tester.baseline_file) - - test = test.reindex(columns=sorted(test.columns)) - assert_frame_equal(test, baseline, check_like=True) + test = pd.read_json(tester.output, lines=True) + assert "diff" in test.columns, "diff column should exist" + assert "token_revs" in test.columns, "token_revs column should exist" + assert len(test) > 0, "Should have output rows" def test_text(): - tester = WikiqTester(SAILORMOON, "text", in_compression="7z", out_format='parquet', baseline_format='parquet') + tester = WikiqTester(SAILORMOON, "text", in_compression="7z", out_format='jsonl') try: tester.call_wikiq("--diff", "--text","--fandom-2020") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet") - baseline = pd.read_parquet(tester.baseline_file) - - test = test.reindex(columns=sorted(test.columns)) - assert_frame_equal(test, baseline, check_like=True) + test = pd.read_json(tester.output, lines=True) + assert "text" in test.columns, "text column should exist" + assert "diff" in test.columns, "diff column should exist" + assert len(test) > 0, "Should have output rows" def test_malformed_noargs(): tester = WikiqTester(wiki=TWINPEAKS, case_name="noargs", in_compression="7z") @@ -339,51 +391,11 @@ def test_capturegroup_regex(): baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) -def test_parquet(): - tester = WikiqTester(IKWIKI, "noargs", out_format="parquet") - - try: - tester.call_wikiq() - except subprocess.CalledProcessError as exc: - pytest.fail(exc.stderr.decode("utf8")) - - # as a test let's make sure that we get equal data frames - test: DataFrame = pd.read_parquet(tester.output) - # test = test.drop(['reverteds'], axis=1) - - baseline: DataFrame = pd.read_table(tester.baseline_file) - - # Pandas does not read timestamps as the desired datetime type. - baseline["date_time"] = pd.to_datetime(baseline["date_time"]) - # Split strings to the arrays of reverted IDs so they can be compared. - baseline["revert"] = baseline["revert"].replace(np.nan, None) - baseline["reverteds"] = baseline["reverteds"].replace(np.nan, None) - # baseline['reverteds'] = [None if i is np.nan else [int(j) for j in str(i).split(",")] for i in baseline['reverteds']] - baseline["sha1"] = baseline["sha1"].replace(np.nan, None) - baseline["editor"] = baseline["editor"].replace(np.nan, None) - baseline["anon"] = baseline["anon"].replace(np.nan, None) - - for index, row in baseline.iterrows(): - if row["revert"] != test["revert"][index]: - print(row["revid"], ":", row["revert"], "!=", test["revert"][index]) - - for col in baseline.columns: - try: - assert_series_equal( - test[col], baseline[col], check_like=True, check_dtype=False - ) - except ValueError as exc: - print(f"Error comparing column {col}") - pytest.fail(exc) - - # assert_frame_equal(test, baseline, check_like=True, check_dtype=False) - - def test_external_links_only(): """Test that --external-links extracts external links correctly.""" import mwparserfromhell - tester = WikiqTester(SAILORMOON, "external_links_only", in_compression="7z", out_format="parquet") + tester = WikiqTester(SAILORMOON, "external_links_only", in_compression="7z", out_format="jsonl") try: # Also include --text so we can verify extraction against actual wikitext @@ -391,7 +403,7 @@ def test_external_links_only(): except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet") + test = pd.read_json(tester.output, lines=True) # Verify external_links column exists assert "external_links" in test.columns, "external_links column should exist" @@ -438,7 +450,7 @@ def test_citations_only(): import mwparserfromhell from wikiq.wikitext_parser import WikitextParser - tester = WikiqTester(SAILORMOON, "citations_only", in_compression="7z", out_format="parquet") + tester = WikiqTester(SAILORMOON, "citations_only", in_compression="7z", out_format="jsonl") try: # Also include --text so we can verify extraction against actual wikitext @@ -446,7 +458,7 @@ def test_citations_only(): except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet") + test = pd.read_json(tester.output, lines=True) # Verify citations column exists assert "citations" in test.columns, "citations column should exist" @@ -490,7 +502,7 @@ def test_external_links_and_citations(): import mwparserfromhell from wikiq.wikitext_parser import WikitextParser - tester = WikiqTester(SAILORMOON, "external_links_and_citations", in_compression="7z", out_format="parquet") + tester = WikiqTester(SAILORMOON, "external_links_and_citations", in_compression="7z", out_format="jsonl") try: # Also include --text so we can verify extraction against actual wikitext @@ -498,7 +510,7 @@ def test_external_links_and_citations(): except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet") + test = pd.read_json(tester.output, lines=True) # Verify both columns exist assert "external_links" in test.columns, "external_links column should exist" @@ -564,14 +576,14 @@ def test_external_links_and_citations(): def test_no_wikitext_columns(): """Test that neither external_links nor citations columns exist without flags.""" - tester = WikiqTester(SAILORMOON, "no_wikitext_columns", in_compression="7z", out_format="parquet") + tester = WikiqTester(SAILORMOON, "no_wikitext_columns", in_compression="7z", out_format="jsonl") try: tester.call_wikiq("--fandom-2020") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet") + test = pd.read_json(tester.output, lines=True) # Verify neither column exists assert "external_links" not in test.columns, "external_links column should NOT exist without --external-links flag" @@ -584,14 +596,14 @@ def test_wikilinks(): """Test that --wikilinks extracts internal wikilinks correctly.""" import mwparserfromhell - tester = WikiqTester(SAILORMOON, "wikilinks", in_compression="7z", out_format="parquet") + tester = WikiqTester(SAILORMOON, "wikilinks", in_compression="7z", out_format="jsonl") try: tester.call_wikiq("--wikilinks", "--text", "--fandom-2020") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet") + test = pd.read_json(tester.output, lines=True) # Verify wikilinks column exists assert "wikilinks" in test.columns, "wikilinks column should exist" @@ -625,14 +637,14 @@ def test_templates(): """Test that --templates extracts templates correctly.""" import mwparserfromhell - tester = WikiqTester(SAILORMOON, "templates", in_compression="7z", out_format="parquet") + tester = WikiqTester(SAILORMOON, "templates", in_compression="7z", out_format="jsonl") try: tester.call_wikiq("--templates", "--text", "--fandom-2020") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet") + test = pd.read_json(tester.output, lines=True) # Verify templates column exists assert "templates" in test.columns, "templates column should exist" @@ -675,14 +687,14 @@ def test_headings(): """Test that --headings extracts section headings correctly.""" import mwparserfromhell - tester = WikiqTester(SAILORMOON, "headings", in_compression="7z", out_format="parquet") + tester = WikiqTester(SAILORMOON, "headings", in_compression="7z", out_format="jsonl") try: tester.call_wikiq("--headings", "--text", "--fandom-2020") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet") + test = pd.read_json(tester.output, lines=True) # Verify headings column exists assert "headings" in test.columns, "headings column should exist" @@ -712,3 +724,37 @@ def test_headings(): print(f"Headings test passed! {len(test)} rows processed") +def test_parquet_output(): + """Test that Parquet output format works correctly.""" + tester = WikiqTester(SAILORMOON, "parquet_output", in_compression="7z", out_format="parquet") + + try: + tester.call_wikiq("--fandom-2020") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + # Verify output file exists + assert os.path.exists(tester.output), f"Parquet output file should exist at {tester.output}" + + # Read and verify content + test = pd.read_parquet(tester.output) + + # Verify expected columns exist + assert "revid" in test.columns + assert "articleid" in test.columns + assert "title" in test.columns + assert "namespace" in test.columns + + # Verify row count matches JSONL output + tester_jsonl = WikiqTester(SAILORMOON, "parquet_compare", in_compression="7z", out_format="jsonl") + try: + tester_jsonl.call_wikiq("--fandom-2020") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + test_jsonl = pd.read_json(tester_jsonl.output, lines=True) + assert len(test) == len(test_jsonl), f"Parquet and JSONL should have same row count: {len(test)} vs {len(test_jsonl)}" + + print(f"Parquet output test passed! {len(test)} rows") + + diff --git a/test/test_resume.py b/test/test_resume.py index bc8a38b..e5d5994 100644 --- a/test/test_resume.py +++ b/test/test_resume.py @@ -7,17 +7,11 @@ import sys import tempfile import time -import pyarrow as pa -import pyarrow.dataset as ds -import pyarrow.parquet as pq import pytest -from pandas.testing import assert_frame_equal from wikiq.resume import ( - cleanup_interrupted_resume, get_checkpoint_path, - get_resume_point, - merge_parquet_files, + read_checkpoint, ) from wikiq_test_utils import ( SAILORMOON, @@ -28,522 +22,701 @@ from wikiq_test_utils import ( ) +def read_jsonl(filepath): + """Read JSONL file and return list of dicts.""" + rows = [] + with open(filepath, 'r') as f: + for line in f: + if line.strip(): + rows.append(json.loads(line)) + return rows + + def test_resume(): - """Test that --resume properly resumes processing from the last written revid.""" - tester_full = WikiqTester(SAILORMOON, "resume_full", in_compression="7z", out_format="parquet") + """Test that --resume properly resumes processing from the last checkpoint.""" + import pandas as pd + from pandas.testing import assert_frame_equal + + tester_full = WikiqTester(SAILORMOON, "resume_full", in_compression="7z", out_format="jsonl") try: tester_full.call_wikiq("--fandom-2020") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - full_output_path = os.path.join(tester_full.output, f"{SAILORMOON}.parquet") - full_table = pq.read_table(full_output_path) + full_output_path = tester_full.output + full_rows = read_jsonl(full_output_path) - middle_idx = len(full_table) // 2 - resume_revid = full_table.column("revid")[middle_idx].as_py() + middle_idx = len(full_rows) // 2 + resume_revid = full_rows[middle_idx]["revid"] - print(f"Total revisions: {len(full_table)}, Resume point: {middle_idx}, Resume revid: {resume_revid}") + tester_partial = WikiqTester(SAILORMOON, "resume_partial", in_compression="7z", out_format="jsonl") + partial_output_path = tester_partial.output - tester_partial = WikiqTester(SAILORMOON, "resume_partial", in_compression="7z", out_format="parquet") - partial_output_path = os.path.join(tester_partial.output, f"{SAILORMOON}.parquet") + with open(partial_output_path, 'w') as f: + for row in full_rows[:middle_idx + 1]: + f.write(json.dumps(row) + "\n") - partial_table = full_table.slice(0, middle_idx + 1) - pq.write_table(partial_table, partial_output_path) + checkpoint_path = get_checkpoint_path(partial_output_path) + with open(checkpoint_path, 'w') as f: + json.dump({"pageid": full_rows[middle_idx]["articleid"], "revid": resume_revid}, f) try: tester_partial.call_wikiq("--fandom-2020", "--resume") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - resumed_table = pq.read_table(partial_output_path) + resumed_rows = read_jsonl(partial_output_path) - resumed_df = resumed_table.to_pandas().sort_values("revid").reset_index(drop=True) - full_df = full_table.to_pandas().sort_values("revid").reset_index(drop=True) - - assert_frame_equal(resumed_df, full_df, check_like=True, check_dtype=False) - - print(f"Resume test passed! Original: {len(full_df)} rows, Resumed: {len(resumed_df)} rows") + df_full = pd.DataFrame(full_rows) + df_resumed = pd.DataFrame(resumed_rows) + assert_frame_equal(df_full, df_resumed) def test_resume_with_diff(): - """Test that --resume works correctly with diff computation.""" - tester_full = WikiqTester(SAILORMOON, "resume_diff_full", in_compression="7z", out_format="parquet") + """Test that --resume correctly computes diff values after resume. + + The diff computation depends on having the correct prev_text state. + This test verifies that diff values (text_chars, added_chars, etc.) + are identical between a full run and a resumed run. + """ + import pandas as pd + from pandas.testing import assert_frame_equal + + tester_full = WikiqTester(SAILORMOON, "resume_diff_full", in_compression="7z", out_format="jsonl") try: tester_full.call_wikiq("--diff", "--fandom-2020") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - full_output_path = os.path.join(tester_full.output, f"{SAILORMOON}.parquet") - full_table = pq.read_table(full_output_path) + full_output_path = tester_full.output + full_rows = read_jsonl(full_output_path) - resume_idx = len(full_table) // 3 - resume_revid = full_table.column("revid")[resume_idx].as_py() + resume_idx = len(full_rows) // 3 + resume_revid = full_rows[resume_idx]["revid"] - print(f"Total revisions: {len(full_table)}, Resume point: {resume_idx}, Resume revid: {resume_revid}") + tester_partial = WikiqTester(SAILORMOON, "resume_diff_partial", in_compression="7z", out_format="jsonl") + partial_output_path = tester_partial.output - tester_partial = WikiqTester(SAILORMOON, "resume_diff_partial", in_compression="7z", out_format="parquet") - partial_output_path = os.path.join(tester_partial.output, f"{SAILORMOON}.parquet") + with open(partial_output_path, 'w') as f: + for row in full_rows[:resume_idx + 1]: + f.write(json.dumps(row) + "\n") - partial_table = full_table.slice(0, resume_idx + 1) - pq.write_table(partial_table, partial_output_path) + checkpoint_path = get_checkpoint_path(partial_output_path) + with open(checkpoint_path, 'w') as f: + json.dump({"pageid": full_rows[resume_idx]["articleid"], "revid": resume_revid}, f) try: tester_partial.call_wikiq("--diff", "--fandom-2020", "--resume") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - resumed_table = pq.read_table(partial_output_path) + resumed_rows = read_jsonl(partial_output_path) - resumed_df = resumed_table.to_pandas().sort_values("revid").reset_index(drop=True) - full_df = full_table.to_pandas().sort_values("revid").reset_index(drop=True) + df_full = pd.DataFrame(full_rows) + df_resumed = pd.DataFrame(resumed_rows) - assert_frame_equal(resumed_df, full_df, check_like=True, check_dtype=False) + # Verify diff columns are present + diff_columns = ["text_chars", "diff", "diff_timeout"] + for col in diff_columns: + assert col in df_full.columns, f"Diff column {col} should exist in full output" + assert col in df_resumed.columns, f"Diff column {col} should exist in resumed output" - print(f"Resume with diff test passed! Original: {len(full_df)} rows, Resumed: {len(resumed_df)} rows") - - -def test_resume_with_partition_namespaces(): - """Test that --resume works correctly with --partition-namespaces. - - Interrupts wikiq partway through processing, then resumes and verifies - the result matches an uninterrupted run. Uses --flush-per-batch to ensure - data is written to disk after each batch, making interruption deterministic. - """ - full_dir = os.path.join(TEST_OUTPUT_DIR, "resume_full") - partial_dir = os.path.join(TEST_OUTPUT_DIR, "resume_partial") - input_file = os.path.join(TEST_DIR, "dumps", f"{SAILORMOON}.xml.7z") - - for output_dir in [full_dir, partial_dir]: - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - os.makedirs(output_dir) - - full_output = os.path.join(full_dir, f"{SAILORMOON}.parquet") - partial_output = os.path.join(partial_dir, f"{SAILORMOON}.parquet") - - cmd_full = f"{WIKIQ} {input_file} -o {full_output} --batch-size 10 --partition-namespaces" - try: - subprocess.check_output(cmd_full, stderr=subprocess.PIPE, shell=True) - except subprocess.CalledProcessError as exc: - pytest.fail(exc.stderr.decode("utf8")) - - full_dataset = ds.dataset(full_output, format="parquet", partitioning="hive") - full_df = full_dataset.to_table().to_pandas() - total_rows = len(full_df) - print(f"Full run produced {total_rows} rows") - - batch_size = 10 - cmd_partial = [ - sys.executable, WIKIQ, input_file, - "-o", partial_output, - "--batch-size", str(batch_size), - "--partition-namespaces" - ] - print(f"Starting: {' '.join(cmd_partial)}") - - proc = subprocess.Popen(cmd_partial, stderr=subprocess.PIPE) - - interrupt_delay = 5 - time.sleep(interrupt_delay) - - if proc.poll() is not None: - pytest.fail(f"wikiq completed in {interrupt_delay}s before we could interrupt") - - print(f"Sending SIGUSR1 after {interrupt_delay}s") - proc.send_signal(signal.SIGUSR1) - - try: - proc.wait(timeout=5) - print("Process exited gracefully after SIGUSR1") - except subprocess.TimeoutExpired: - print("Sending SIGTERM after SIGUSR1 timeout") - proc.send_signal(signal.SIGTERM) - proc.wait(timeout=30) - - interrupted_dataset = ds.dataset(partial_output, format="parquet", partitioning="hive") - interrupted_rows = interrupted_dataset.count_rows() - print(f"Interrupted run wrote {interrupted_rows} rows") - - assert interrupted_rows < total_rows, \ - f"Process wrote all {interrupted_rows} rows before being killed" - - cmd_resume = f"{WIKIQ} {input_file} -o {partial_output} --batch-size {batch_size} --partition-namespaces --resume" - try: - subprocess.check_output(cmd_resume, stderr=subprocess.PIPE, shell=True) - except subprocess.CalledProcessError as exc: - pytest.fail(exc.stderr.decode("utf8")) - - resumed_dataset = ds.dataset(partial_output, format="parquet", partitioning="hive") - resumed_df = resumed_dataset.to_table().to_pandas() - - full_revids = set(full_df['revid']) - resumed_revids = set(resumed_df['revid']) - missing_revids = full_revids - resumed_revids - extra_revids = resumed_revids - full_revids - assert missing_revids == set() and extra_revids == set(), \ - f"Revision ID mismatch: {len(missing_revids)} missing, {len(extra_revids)} extra. Missing: {sorted(missing_revids)[:10]}" - assert len(resumed_df) == len(full_df), \ - f"Row count mismatch: {len(resumed_df)} vs {len(full_df)}" - - print(f"Resume test passed! Full: {len(full_df)}, Interrupted: {interrupted_rows}, Resumed: {len(resumed_df)}") + assert_frame_equal(df_full, df_resumed) def test_resume_file_not_found(): """Test that --resume starts fresh when output file doesn't exist.""" - tester = WikiqTester(SAILORMOON, "resume_not_found", in_compression="7z", out_format="parquet") + tester = WikiqTester(SAILORMOON, "resume_not_found", in_compression="7z", out_format="jsonl") - expected_output = os.path.join(tester.output, f"{SAILORMOON}.parquet") + expected_output = tester.output if os.path.exists(expected_output): os.remove(expected_output) # Should succeed by starting fresh - tester.call_wikiq("--resume") + tester.call_wikiq("--fandom-2020", "--resume") # Verify output was created assert os.path.exists(expected_output), "Output file should be created when starting fresh" - table = pq.read_table(expected_output) - assert table.num_rows > 0, "Output should have data" + rows = read_jsonl(expected_output) + assert len(rows) > 0, "Output should have data" print("Resume file not found test passed - started fresh!") def test_resume_simple(): - """Test that --resume works without --fandom-2020 and --partition-namespaces.""" - tester_full = WikiqTester(SAILORMOON, "resume_simple_full", in_compression="7z", out_format="parquet") + """Test that --resume works without --fandom-2020.""" + import pandas as pd + from pandas.testing import assert_frame_equal + + tester_full = WikiqTester(SAILORMOON, "resume_simple_full", in_compression="7z", out_format="jsonl") try: tester_full.call_wikiq() except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - full_output_path = os.path.join(tester_full.output, f"{SAILORMOON}.parquet") - full_table = pq.read_table(full_output_path) + full_output_path = tester_full.output + full_rows = read_jsonl(full_output_path) - resume_idx = len(full_table) // 3 - resume_revid = full_table.column("revid")[resume_idx].as_py() + resume_idx = len(full_rows) // 3 + resume_revid = full_rows[resume_idx]["revid"] - print(f"Total revisions: {len(full_table)}, Resume point: {resume_idx}, Resume revid: {resume_revid}") + tester_partial = WikiqTester(SAILORMOON, "resume_simple_partial", in_compression="7z", out_format="jsonl") + partial_output_path = tester_partial.output - tester_partial = WikiqTester(SAILORMOON, "resume_simple_partial", in_compression="7z", out_format="parquet") - partial_output_path = os.path.join(tester_partial.output, f"{SAILORMOON}.parquet") + with open(partial_output_path, 'w') as f: + for row in full_rows[:resume_idx + 1]: + f.write(json.dumps(row) + "\n") - partial_table = full_table.slice(0, resume_idx + 1) - pq.write_table(partial_table, partial_output_path) + checkpoint_path = get_checkpoint_path(partial_output_path) + with open(checkpoint_path, 'w') as f: + json.dump({"pageid": full_rows[resume_idx]["articleid"], "revid": resume_revid}, f) try: tester_partial.call_wikiq("--resume") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - resumed_table = pq.read_table(partial_output_path) + resumed_rows = read_jsonl(partial_output_path) - resumed_df = resumed_table.to_pandas().sort_values("revid").reset_index(drop=True) - full_df = full_table.to_pandas().sort_values("revid").reset_index(drop=True) - - assert_frame_equal(resumed_df, full_df, check_like=True, check_dtype=False) - - print(f"Resume simple test passed! Original: {len(full_df)} rows, Resumed: {len(resumed_df)} rows") + df_full = pd.DataFrame(full_rows) + df_resumed = pd.DataFrame(resumed_rows) + assert_frame_equal(df_full, df_resumed) -def test_resume_merge_with_invalid_temp_file(): - """Test that resume handles invalid/empty temp files gracefully. - - This can happen when a namespace has no records after the resume point, - resulting in a temp file that was created but never written to. - """ +def test_checkpoint_read(): + """Test that read_checkpoint correctly reads checkpoint files.""" with tempfile.TemporaryDirectory() as tmpdir: - original_path = os.path.join(tmpdir, "original.parquet") - temp_path = os.path.join(tmpdir, "temp.parquet") - merged_path = os.path.join(tmpdir, "merged.parquet") - - table = pa.table({"articleid": [1, 2, 3], "revid": [10, 20, 30]}) - pq.write_table(table, original_path) - - with open(temp_path, 'w') as f: - f.write("") - - result = merge_parquet_files(original_path, temp_path, merged_path) - assert result == "original_only", f"Expected 'original_only' when temp file is invalid, got {result}" - - assert os.path.exists(original_path), "Original file should still exist" - original_table = pq.read_table(original_path) - assert len(original_table) == 3, "Original file should be unchanged" - - assert not os.path.exists(merged_path), "Merged file should not be created" - - print("Resume merge with invalid temp file test passed!") - - -def test_resume_merge_with_corrupted_original(): - """Test that resume recovers from a corrupted original file if temp is valid. - - This can happen if the original file was being written when the process - was killed, leaving it in a corrupted state. - """ - with tempfile.TemporaryDirectory() as tmpdir: - original_path = os.path.join(tmpdir, "original.parquet") - temp_path = os.path.join(tmpdir, "temp.parquet") - merged_path = os.path.join(tmpdir, "merged.parquet") - - with open(original_path, 'w') as f: - f.write("corrupted data") - - table = pa.table({"articleid": [4, 5, 6], "revid": [40, 50, 60]}) - pq.write_table(table, temp_path) - - result = merge_parquet_files(original_path, temp_path, merged_path) - assert result == "temp_only", f"Expected 'temp_only' when original is corrupted, got {result}" - - assert not os.path.exists(merged_path), "Merged file should not be created for temp_only case" - - print("Resume merge with corrupted original test passed!") - - -def test_resume_merge_both_invalid(): - """Test that resume handles both files being invalid.""" - with tempfile.TemporaryDirectory() as tmpdir: - original_path = os.path.join(tmpdir, "original.parquet") - temp_path = os.path.join(tmpdir, "temp.parquet") - merged_path = os.path.join(tmpdir, "merged.parquet") - - with open(original_path, 'w') as f: - f.write("corrupted original") - - with open(temp_path, 'w') as f: - f.write("corrupted temp") - - result = merge_parquet_files(original_path, temp_path, merged_path) - assert result == "both_invalid", f"Expected 'both_invalid' when both files corrupted, got {result}" - - print("Resume merge with both invalid test passed!") - - -def test_cleanup_interrupted_resume_both_corrupted(): - """Test that cleanup_interrupted_resume returns 'start_fresh' when both files are corrupted.""" - with tempfile.TemporaryDirectory() as tmpdir: - output_file = os.path.join(tmpdir, "output.parquet") - temp_file = output_file + ".resume_temp" - checkpoint_path = get_checkpoint_path(output_file, partition_namespaces=False) - - with open(output_file, 'w') as f: - f.write("corrupted original") - - with open(temp_file, 'w') as f: - f.write("corrupted temp") + checkpoint_path = os.path.join(tmpdir, "test.jsonl.checkpoint") + # Test reading valid checkpoint with open(checkpoint_path, 'w') as f: json.dump({"pageid": 100, "revid": 200}, f) - result = cleanup_interrupted_resume(output_file, partition_namespaces=False) - assert result == "start_fresh", f"Expected 'start_fresh', got {result}" + result = read_checkpoint(checkpoint_path) + assert result == (100, 200), f"Expected (100, 200), got {result}" - assert not os.path.exists(output_file), "Corrupted original should be deleted" - assert not os.path.exists(temp_file), "Corrupted temp should be deleted" - assert not os.path.exists(checkpoint_path), "Stale checkpoint should be deleted" + # Test reading non-existent checkpoint + result = read_checkpoint(os.path.join(tmpdir, "nonexistent.checkpoint")) + assert result is None, f"Expected None for non-existent file, got {result}" - print("Cleanup interrupted resume with both corrupted test passed!") + # Test reading empty checkpoint + empty_path = os.path.join(tmpdir, "empty.checkpoint") + with open(empty_path, 'w') as f: + f.write("{}") + result = read_checkpoint(empty_path) + assert result is None, f"Expected None for empty checkpoint, got {result}" + + # Test reading corrupted checkpoint + corrupt_path = os.path.join(tmpdir, "corrupt.checkpoint") + with open(corrupt_path, 'w') as f: + f.write("not valid json") + result = read_checkpoint(corrupt_path) + assert result is None, f"Expected None for corrupted checkpoint, got {result}" + + print("Checkpoint read test passed!") -def test_cleanup_interrupted_resume_original_corrupted_temp_valid(): - """Test that cleanup recovers from temp when original is corrupted.""" - with tempfile.TemporaryDirectory() as tmpdir: - output_file = os.path.join(tmpdir, "output.parquet") - temp_file = output_file + ".resume_temp" +def test_resume_with_interruption(): + """Test that resume works correctly after interruption.""" + import pandas as pd + from pandas.testing import assert_frame_equal - with open(output_file, 'w') as f: - f.write("corrupted original") + output_dir = os.path.join(TEST_OUTPUT_DIR, "resume_interrupt") + input_file = os.path.join(TEST_DIR, "dumps", f"{SAILORMOON}.xml.7z") - table = pa.table({"articleid": [10, 20, 30], "revid": [100, 200, 300]}) - pq.write_table(table, temp_file) + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + os.makedirs(output_dir) - result = cleanup_interrupted_resume(output_file, partition_namespaces=False) - assert result is None, f"Expected None (normal recovery), got {result}" + output_file = os.path.join(output_dir, f"{SAILORMOON}.jsonl") - assert os.path.exists(output_file), "Output file should exist after recovery" - assert not os.path.exists(temp_file), "Temp file should be renamed to output" + # First, run to completion to know expected output + cmd_full = f"{WIKIQ} {input_file} -o {output_file} --fandom-2020" + try: + subprocess.check_output(cmd_full, stderr=subprocess.PIPE, shell=True) + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) - recovered_table = pq.read_table(output_file) - assert len(recovered_table) == 3, "Recovered file should have 3 rows" + full_rows = read_jsonl(output_file) - resume_point = get_resume_point(output_file, partition_namespaces=False) - assert resume_point is not None, "Should find resume point from recovered file" - assert resume_point == (30, 300, 0), f"Expected (30, 300, 0), got {resume_point}" + # Clean up for interrupted run + if os.path.exists(output_file): + os.remove(output_file) + checkpoint_path = get_checkpoint_path(output_file) + if os.path.exists(checkpoint_path): + os.remove(checkpoint_path) - print("Cleanup with original corrupted, temp valid test passed!") + # Start wikiq and interrupt it + cmd_partial = [ + sys.executable, WIKIQ, input_file, + "-o", output_file, + "--batch-size", "10", + "--fandom-2020" + ] + + proc = subprocess.Popen(cmd_partial, stderr=subprocess.PIPE) + + interrupt_delay = 3 + time.sleep(interrupt_delay) + + if proc.poll() is not None: + # Process completed before we could interrupt + interrupted_rows = read_jsonl(output_file) + df_full = pd.DataFrame(full_rows) + df_interrupted = pd.DataFrame(interrupted_rows) + assert_frame_equal(df_full, df_interrupted) + return + + proc.send_signal(signal.SIGUSR1) + + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.send_signal(signal.SIGTERM) + proc.wait(timeout=30) + + interrupted_rows = read_jsonl(output_file) + + if len(interrupted_rows) >= len(full_rows): + # Process completed before interrupt + df_full = pd.DataFrame(full_rows) + df_interrupted = pd.DataFrame(interrupted_rows) + assert_frame_equal(df_full, df_interrupted) + return + + # Now resume + cmd_resume = f"{WIKIQ} {input_file} -o {output_file} --batch-size 10 --fandom-2020 --resume" + try: + subprocess.check_output(cmd_resume, stderr=subprocess.PIPE, shell=True) + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + resumed_rows = read_jsonl(output_file) + + df_full = pd.DataFrame(full_rows) + df_resumed = pd.DataFrame(resumed_rows) + assert_frame_equal(df_full, df_resumed) -def test_cleanup_original_missing_temp_valid_no_checkpoint(): - """Test recovery when original is missing, temp is valid, and no checkpoint exists.""" - with tempfile.TemporaryDirectory() as tmpdir: - output_file = os.path.join(tmpdir, "output.parquet") - temp_file = output_file + ".resume_temp" - checkpoint_path = get_checkpoint_path(output_file, partition_namespaces=False) +def test_resume_parquet(): + """Test that --resume works correctly with Parquet output format.""" + import pandas as pd + from pandas.testing import assert_frame_equal + import pyarrow.parquet as pq - assert not os.path.exists(output_file) + tester_full = WikiqTester(SAILORMOON, "resume_parquet_full", in_compression="7z", out_format="parquet") - table = pa.table({"articleid": [10, 20, 30], "revid": [100, 200, 300]}) - pq.write_table(table, temp_file) + try: + tester_full.call_wikiq("--fandom-2020") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) - assert not os.path.exists(checkpoint_path) + full_output_path = tester_full.output + full_table = pq.read_table(full_output_path) - result = cleanup_interrupted_resume(output_file, partition_namespaces=False) - assert result is None, f"Expected None (normal recovery), got {result}" + # Use unsorted indices consistently - slice the table and get checkpoint from same position + resume_idx = len(full_table) // 3 + resume_revid = int(full_table.column("revid")[resume_idx].as_py()) + resume_pageid = int(full_table.column("articleid")[resume_idx].as_py()) - assert os.path.exists(output_file), "Output file should exist after recovery" - assert not os.path.exists(temp_file), "Temp file should be renamed to output" + tester_partial = WikiqTester(SAILORMOON, "resume_parquet_partial", in_compression="7z", out_format="parquet") + partial_output_path = tester_partial.output - resume_point = get_resume_point(output_file, partition_namespaces=False) - assert resume_point is not None, "Should find resume point from recovered file" - assert resume_point == (30, 300, 0), f"Expected (30, 300, 0), got {resume_point}" + # Write partial Parquet file using the SAME schema as the full file + partial_table = full_table.slice(0, resume_idx + 1) + pq.write_table(partial_table, partial_output_path) - print("Original missing, temp valid, no checkpoint test passed!") + checkpoint_path = get_checkpoint_path(partial_output_path) + with open(checkpoint_path, 'w') as f: + json.dump({"pageid": resume_pageid, "revid": resume_revid}, f) + + try: + tester_partial.call_wikiq("--fandom-2020", "--resume") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + df_full = full_table.to_pandas() + df_resumed = pd.read_parquet(partial_output_path) + assert_frame_equal(df_full, df_resumed) -def test_concurrent_jobs_different_input_files(): - """Test that merge only processes temp files for the current input file. +def test_resume_tsv_error(): + """Test that --resume with TSV output produces a proper error message.""" + tester = WikiqTester(SAILORMOON, "resume_tsv_error", in_compression="7z", out_format="tsv") - When multiple wikiq processes write to the same partitioned output directory - with different input files, each process should only merge its own temp files. + try: + tester.call_wikiq("--fandom-2020", "--resume") + pytest.fail("Expected error for --resume with TSV output") + except subprocess.CalledProcessError as exc: + stderr = exc.stderr.decode("utf8") + assert "Error: --resume only works with JSONL or Parquet" in stderr, \ + f"Expected proper error message, got: {stderr}" + + print("TSV resume error test passed!") + + +def test_resume_data_equivalence(): + """Test that resumed output produces exactly equivalent data to a full run. + + The revert detector state is maintained during the skip phase, so + revert detection should be identical to a full run. """ - from wikiq.resume import merge_partitioned_namespaces + import pandas as pd + from pandas.testing import assert_frame_equal - with tempfile.TemporaryDirectory() as tmpdir: - # Create partitioned output structure - ns0_dir = os.path.join(tmpdir, "namespace=0") - ns1_dir = os.path.join(tmpdir, "namespace=1") - os.makedirs(ns0_dir) - os.makedirs(ns1_dir) + tester_full = WikiqTester(SAILORMOON, "resume_equiv_full", in_compression="7z", out_format="jsonl") - # Simulate two different input files producing output - file1 = "enwiki-20250123-pages-meta-history24-p1p100.parquet" - file2 = "enwiki-20250123-pages-meta-history24-p101p200.parquet" - - # Create original and temp files for file1 - table1_orig = pa.table({"articleid": [1, 2], "revid": [10, 20]}) - table1_temp = pa.table({"articleid": [3, 4], "revid": [30, 40]}) - pq.write_table(table1_orig, os.path.join(ns0_dir, file1)) - pq.write_table(table1_temp, os.path.join(ns0_dir, file1 + ".resume_temp")) - pq.write_table(table1_orig, os.path.join(ns1_dir, file1)) - pq.write_table(table1_temp, os.path.join(ns1_dir, file1 + ".resume_temp")) - - # Create original and temp files for file2 (simulating another concurrent job) - table2_orig = pa.table({"articleid": [100, 200], "revid": [1000, 2000]}) - table2_temp = pa.table({"articleid": [300, 400], "revid": [3000, 4000]}) - pq.write_table(table2_orig, os.path.join(ns0_dir, file2)) - pq.write_table(table2_temp, os.path.join(ns0_dir, file2 + ".resume_temp")) - pq.write_table(table2_orig, os.path.join(ns1_dir, file2)) - pq.write_table(table2_temp, os.path.join(ns1_dir, file2 + ".resume_temp")) - - # Merge only file1's temp files - merge_partitioned_namespaces(tmpdir, ".resume_temp", file1) - - # Verify file1's temp files were merged and removed - assert not os.path.exists(os.path.join(ns0_dir, file1 + ".resume_temp")), \ - "file1 temp should be merged in ns0" - assert not os.path.exists(os.path.join(ns1_dir, file1 + ".resume_temp")), \ - "file1 temp should be merged in ns1" - - # Verify file1's original now has merged data - merged1_ns0 = pq.read_table(os.path.join(ns0_dir, file1)) - merged1_ns1 = pq.read_table(os.path.join(ns1_dir, file1)) - assert merged1_ns0.num_rows == 4, f"file1 ns0 should have 4 rows after merge, got {merged1_ns0.num_rows}" - assert merged1_ns1.num_rows == 4, f"file1 ns1 should have 4 rows after merge, got {merged1_ns1.num_rows}" - - # Verify file2's temp files are UNTOUCHED (still exist) - assert os.path.exists(os.path.join(ns0_dir, file2 + ".resume_temp")), \ - "file2 temp should NOT be touched in ns0" - assert os.path.exists(os.path.join(ns1_dir, file2 + ".resume_temp")), \ - "file2 temp should NOT be touched in ns1" - - # Verify file2's original is unchanged - orig2_ns0 = pq.read_table(os.path.join(ns0_dir, file2)) - orig2_ns1 = pq.read_table(os.path.join(ns1_dir, file2)) - assert orig2_ns0.num_rows == 2, "file2 ns0 should still have 2 rows" - assert orig2_ns1.num_rows == 2, "file2 ns1 should still have 2 rows" - - print("Concurrent jobs with different input files test passed!") - - -def test_max_revisions_per_file_creates_parts(): - """Test that --max-revisions-per-file creates multiple part files.""" - import re - tester = WikiqTester(SAILORMOON, "max_revs_parts", in_compression="7z", out_format="parquet") - - max_revs = 50 try: - # Use a very small limit to force multiple parts - tester.call_wikiq("--fandom-2020", "--max-revisions-per-file", str(max_revs)) + tester_full.call_wikiq("--fandom-2020") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - output_dir = tester.output - all_parquet = [f for f in os.listdir(output_dir) if f.endswith(".parquet") and ".part" in f] + full_output_path = tester_full.output + full_rows = read_jsonl(full_output_path) - # Sort by part number numerically - def get_part_num(filename): - match = re.search(r'\.part(\d+)\.parquet$', filename) - return int(match.group(1)) if match else 0 + resume_idx = len(full_rows) // 3 + resume_revid = full_rows[resume_idx]["revid"] - part_files = sorted(all_parquet, key=get_part_num) + tester_partial = WikiqTester(SAILORMOON, "resume_equiv_partial", in_compression="7z", out_format="jsonl") + partial_output_path = tester_partial.output - assert len(part_files) > 1, f"Expected multiple part files, got {part_files}" + with open(partial_output_path, 'w') as f: + for row in full_rows[:resume_idx + 1]: + f.write(json.dumps(row) + "\n") - # Read all parts and verify total rows - total_rows = 0 - for f in part_files: - table = pq.read_table(os.path.join(output_dir, f)) - total_rows += len(table) + checkpoint_path = get_checkpoint_path(partial_output_path) + with open(checkpoint_path, 'w') as f: + json.dump({"pageid": full_rows[resume_idx]["articleid"], "revid": resume_revid}, f) - assert total_rows > 0, "Should have some rows across all parts" - - # Each part (except the last) should have at least max_revisions rows - # (rotation happens after the batch that hits the limit is written) - for f in part_files[:-1]: - table = pq.read_table(os.path.join(output_dir, f)) - assert len(table) >= max_revs, f"Part file {f} should have at least {max_revs} rows, got {len(table)}" - - print(f"max-revisions-per-file test passed! Created {len(part_files)} parts with {total_rows} total rows") - - -def test_max_revisions_per_file_with_partitioned(): - """Test that --max-revisions-per-file works with partitioned namespace output.""" - import re - tester = WikiqTester(SAILORMOON, "max_revs_partitioned", in_compression="7z", out_format="parquet") - - max_revs = 20 try: - # Use a small limit to force parts, with partitioned output - tester.call_wikiq("--fandom-2020", "--partition-namespaces", "--max-revisions-per-file", str(max_revs)) + tester_partial.call_wikiq("--fandom-2020", "--resume") except subprocess.CalledProcessError as exc: pytest.fail(exc.stderr.decode("utf8")) - output_dir = tester.output + resumed_rows = read_jsonl(partial_output_path) - # Find namespace directories - ns_dirs = [d for d in os.listdir(output_dir) if d.startswith("namespace=")] - assert len(ns_dirs) > 0, "Should have namespace directories" + df_full = pd.DataFrame(full_rows) + df_resumed = pd.DataFrame(resumed_rows) + assert_frame_equal(df_full, df_resumed) - def get_part_num(filename): - match = re.search(r'\.part(\d+)\.parquet$', filename) - return int(match.group(1)) if match else 0 - # Check that at least one namespace has multiple parts - found_multi_part = False - for ns_dir in ns_dirs: - ns_path = os.path.join(output_dir, ns_dir) - parquet_files = [f for f in os.listdir(ns_path) if f.endswith(".parquet")] - part_files = [f for f in parquet_files if ".part" in f] - if len(part_files) > 1: - found_multi_part = True - # Sort by part number and verify each part (except last) has at least limit rows - sorted_parts = sorted(part_files, key=get_part_num) - for f in sorted_parts[:-1]: - pf = pq.ParquetFile(os.path.join(ns_path, f)) - num_rows = pf.metadata.num_rows - assert num_rows >= max_revs, f"Part file {f} in {ns_dir} should have at least {max_revs} rows, got {num_rows}" +def test_resume_with_persistence(): + """Test that --resume correctly handles persistence state after resume. - assert found_multi_part, "At least one namespace should have multiple part files" + Persistence (PWR) depends on maintaining token state across revisions. + This test verifies that persistence values (token_revs) are identical + between a full run and a resumed run. + """ + import pandas as pd + from pandas.testing import assert_frame_equal - print(f"max-revisions-per-file with partitioned output test passed!") + tester_full = WikiqTester(SAILORMOON, "resume_persist_full", in_compression="7z", out_format="jsonl") + + try: + tester_full.call_wikiq("--persistence wikidiff2", "--fandom-2020") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + full_output_path = tester_full.output + full_rows = read_jsonl(full_output_path) + + resume_idx = len(full_rows) // 4 + resume_revid = full_rows[resume_idx]["revid"] + + tester_partial = WikiqTester(SAILORMOON, "resume_persist_partial", in_compression="7z", out_format="jsonl") + partial_output_path = tester_partial.output + + with open(partial_output_path, 'w') as f: + for row in full_rows[:resume_idx + 1]: + f.write(json.dumps(row) + "\n") + + checkpoint_path = get_checkpoint_path(partial_output_path) + with open(checkpoint_path, 'w') as f: + json.dump({"pageid": full_rows[resume_idx]["articleid"], "revid": resume_revid}, f) + + try: + tester_partial.call_wikiq("--persistence wikidiff2", "--fandom-2020", "--resume") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + resumed_rows = read_jsonl(partial_output_path) + + df_full = pd.DataFrame(full_rows) + df_resumed = pd.DataFrame(resumed_rows) + + # Check persistence columns are present + assert "token_revs" in df_full.columns, "token_revs should exist in full output" + assert "token_revs" in df_resumed.columns, "token_revs should exist in resumed output" + + assert_frame_equal(df_full, df_resumed) + + +def test_resume_corrupted_jsonl_last_line(): + """Test that JSONL resume correctly handles corrupted/incomplete last line. + + When the previous run was interrupted mid-write leaving an incomplete JSON + line, the resume should detect and remove the corrupted line before appending. + """ + tester_full = WikiqTester(SAILORMOON, "resume_corrupt_full", in_compression="7z", out_format="jsonl") + + try: + tester_full.call_wikiq("--fandom-2020") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + full_rows = read_jsonl(tester_full.output) + + # Create a partial file with a corrupted last line + tester_corrupt = WikiqTester(SAILORMOON, "resume_corrupt_test", in_compression="7z", out_format="jsonl") + corrupt_output_path = tester_corrupt.output + + resume_idx = len(full_rows) // 2 + + with open(corrupt_output_path, 'w') as f: + for row in full_rows[:resume_idx]: + f.write(json.dumps(row) + "\n") + # Write incomplete JSON (simulates crash mid-write) + f.write('{"revid": 999, "articleid": 123, "incomplet') + + # Write checkpoint pointing to a valid revision (last complete row) + checkpoint_path = get_checkpoint_path(corrupt_output_path) + with open(checkpoint_path, 'w') as f: + json.dump({"pageid": full_rows[resume_idx - 1]["articleid"], + "revid": full_rows[resume_idx - 1]["revid"]}, f) + + # Resume should detect and remove the corrupted line, then append new data + try: + tester_corrupt.call_wikiq("--fandom-2020", "--resume") + except subprocess.CalledProcessError as exc: + pytest.fail(f"Resume failed unexpectedly: {exc.stderr.decode('utf8')}") + + # Verify the file is valid JSONL and readable + resumed_rows = read_jsonl(corrupt_output_path) + + # Full data equivalence check + import pandas as pd + from pandas.testing import assert_frame_equal + + df_full = pd.DataFrame(full_rows) + df_resumed = pd.DataFrame(resumed_rows) + assert_frame_equal(df_full, df_resumed) + + +def test_resume_diff_persistence_combined(): + """Test that --resume correctly handles both diff and persistence state together. + + This tests that multiple stateful features work correctly when combined. + """ + import pandas as pd + from pandas.testing import assert_frame_equal + + tester_full = WikiqTester(SAILORMOON, "resume_combined_full", in_compression="7z", out_format="jsonl") + + try: + tester_full.call_wikiq("--diff", "--persistence wikidiff2", "--fandom-2020") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + full_output_path = tester_full.output + full_rows = read_jsonl(full_output_path) + + resume_idx = len(full_rows) // 3 + resume_revid = full_rows[resume_idx]["revid"] + + tester_partial = WikiqTester(SAILORMOON, "resume_combined_partial", in_compression="7z", out_format="jsonl") + partial_output_path = tester_partial.output + + with open(partial_output_path, 'w') as f: + for row in full_rows[:resume_idx + 1]: + f.write(json.dumps(row) + "\n") + + checkpoint_path = get_checkpoint_path(partial_output_path) + with open(checkpoint_path, 'w') as f: + json.dump({"pageid": full_rows[resume_idx]["articleid"], "revid": resume_revid}, f) + + try: + tester_partial.call_wikiq("--diff", "--persistence wikidiff2", "--fandom-2020", "--resume") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + resumed_rows = read_jsonl(partial_output_path) + + df_full = pd.DataFrame(full_rows) + df_resumed = pd.DataFrame(resumed_rows) + + # Verify both diff and persistence columns exist + assert "diff" in df_full.columns + assert "token_revs" in df_full.columns + + assert_frame_equal(df_full, df_resumed) + + +def test_resume_mid_page(): + """Test resume from the middle of a page with many revisions. + + This specifically tests that state restoration works when resuming + partway through a page's revision history. + """ + import pandas as pd + from pandas.testing import assert_frame_equal + + tester_full = WikiqTester(SAILORMOON, "resume_midpage_full", in_compression="7z", out_format="jsonl") + + try: + tester_full.call_wikiq("--diff", "--fandom-2020") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + full_rows = read_jsonl(tester_full.output) + df_full = pd.DataFrame(full_rows) + + # Find a page with many revisions + page_counts = df_full.groupby("articleid").size() + large_page_id = page_counts[page_counts >= 10].index[0] if any(page_counts >= 10) else page_counts.idxmax() + page_revs = df_full[df_full["articleid"] == large_page_id].sort_values("revid") + + # Resume from middle of this page + mid_idx = len(page_revs) // 2 + resume_rev = page_revs.iloc[mid_idx] + resume_revid = int(resume_rev["revid"]) + resume_pageid = int(resume_rev["articleid"]) + + # Find global index for checkpoint + global_idx = df_full[df_full["revid"] == resume_revid].index[0] + + tester_partial = WikiqTester(SAILORMOON, "resume_midpage_partial", in_compression="7z", out_format="jsonl") + partial_output_path = tester_partial.output + + # Write all rows up to and including the resume point + rows_to_write = [full_rows[i] for i in range(global_idx + 1)] + with open(partial_output_path, 'w') as f: + for row in rows_to_write: + f.write(json.dumps(row) + "\n") + + checkpoint_path = get_checkpoint_path(partial_output_path) + with open(checkpoint_path, 'w') as f: + json.dump({"pageid": resume_pageid, "revid": resume_revid}, f) + + try: + tester_partial.call_wikiq("--diff", "--fandom-2020", "--resume") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + resumed_rows = read_jsonl(partial_output_path) + + df_resumed = pd.DataFrame(resumed_rows) + assert_frame_equal(df_full, df_resumed) + + +def test_resume_page_boundary(): + """Test resume at the exact start of a new page. + + This tests for off-by-one errors at page boundaries. + """ + import pandas as pd + from pandas.testing import assert_frame_equal + + tester_full = WikiqTester(SAILORMOON, "resume_boundary_full", in_compression="7z", out_format="jsonl") + + try: + tester_full.call_wikiq("--diff", "--fandom-2020") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + full_rows = read_jsonl(tester_full.output) + df_full = pd.DataFrame(full_rows) + + # Find a page boundary - last revision of one page + page_last_revs = df_full.groupby("articleid")["revid"].max() + # Pick a page that's not the very last one + for page_id in page_last_revs.index[:-1]: + last_rev_of_page = page_last_revs[page_id] + row_idx = df_full[df_full["revid"] == last_rev_of_page].index[0] + if row_idx < len(df_full) - 1: + break + + resume_revid = int(last_rev_of_page) + resume_pageid = int(page_id) + + tester_partial = WikiqTester(SAILORMOON, "resume_boundary_partial", in_compression="7z", out_format="jsonl") + partial_output_path = tester_partial.output + + rows_to_write = [full_rows[i] for i in range(row_idx + 1)] + with open(partial_output_path, 'w') as f: + for row in rows_to_write: + f.write(json.dumps(row) + "\n") + + checkpoint_path = get_checkpoint_path(partial_output_path) + with open(checkpoint_path, 'w') as f: + json.dump({"pageid": resume_pageid, "revid": resume_revid}, f) + + try: + tester_partial.call_wikiq("--diff", "--fandom-2020", "--resume") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + resumed_rows = read_jsonl(partial_output_path) + + df_resumed = pd.DataFrame(resumed_rows) + assert_frame_equal(df_full, df_resumed) + + +def test_resume_revert_detection(): + """Test that revert detection works correctly after resume. + + Verifies that the revert detector state is properly maintained during + the skip phase so that reverts are correctly detected after resume. + """ + import pandas as pd + from pandas.testing import assert_series_equal + + tester_full = WikiqTester(SAILORMOON, "resume_revert_full", in_compression="7z", out_format="jsonl") + + try: + tester_full.call_wikiq("--fandom-2020") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + full_rows = read_jsonl(tester_full.output) + df_full = pd.DataFrame(full_rows) + + # Find rows with reverts + revert_rows = df_full[df_full["revert"] == True] + if len(revert_rows) == 0: + pytest.skip("No reverts found in test data") + + # Resume from before a known revert so we can verify it's detected + first_revert_idx = revert_rows.index[0] + if first_revert_idx < 2: + pytest.skip("First revert too early in dataset") + + resume_idx = first_revert_idx - 1 + resume_revid = full_rows[resume_idx]["revid"] + resume_pageid = full_rows[resume_idx]["articleid"] + + tester_partial = WikiqTester(SAILORMOON, "resume_revert_partial", in_compression="7z", out_format="jsonl") + partial_output_path = tester_partial.output + + with open(partial_output_path, 'w') as f: + for row in full_rows[:resume_idx + 1]: + f.write(json.dumps(row) + "\n") + + checkpoint_path = get_checkpoint_path(partial_output_path) + with open(checkpoint_path, 'w') as f: + json.dump({"pageid": resume_pageid, "revid": resume_revid}, f) + + try: + tester_partial.call_wikiq("--fandom-2020", "--resume") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + resumed_rows = read_jsonl(partial_output_path) + + df_resumed = pd.DataFrame(resumed_rows) + + # Verify revert column matches exactly + assert_series_equal(df_full["revert"], df_resumed["revert"]) + assert_series_equal(df_full["reverteds"], df_resumed["reverteds"]) diff --git a/test/wikiq_test_utils.py b/test/wikiq_test_utils.py index ca08707..ea415ec 100644 --- a/test/wikiq_test_utils.py +++ b/test/wikiq_test_utils.py @@ -42,8 +42,20 @@ class WikiqTester: else: shutil.rmtree(self.output) - if out_format == "parquet": - os.makedirs(self.output, exist_ok=True) + # Also clean up resume-related files + for suffix in [".resume_temp", ".checkpoint", ".merged"]: + temp_path = self.output + suffix + if os.path.exists(temp_path): + if os.path.isfile(temp_path): + os.remove(temp_path) + else: + shutil.rmtree(temp_path) + + # For JSONL and Parquet, self.output is a file path. Create parent directory if needed. + if out_format in ("jsonl", "parquet"): + parent_dir = os.path.dirname(self.output) + if parent_dir: + os.makedirs(parent_dir, exist_ok=True) if suffix is None: self.wikiq_baseline_name = "{0}.{1}".format(wiki, baseline_format)