From 6988a281dc0d640d24a9e26fc7a2c82eba621726 Mon Sep 17 00:00:00 2001 From: Nathan TeBlunthuis Date: Sat, 20 Dec 2025 21:45:39 -0800 Subject: [PATCH] output parquet files in chunks to avoid memory issues with parquet. --- src/wikiq/__init__.py | 169 +++++++++++++++++++++++++++++++++++++----- src/wikiq/resume.py | 34 +++++---- test/test_resume.py | 87 +++++++++++++++++++++- 3 files changed, 254 insertions(+), 36 deletions(-) diff --git a/src/wikiq/__init__.py b/src/wikiq/__init__.py index ae1efc0..44f13cd 100755 --- a/src/wikiq/__init__.py +++ b/src/wikiq/__init__.py @@ -253,6 +253,7 @@ class WikiqParser: templates: bool = False, headings: bool = False, time_limit_seconds: Union[float, None] = None, + max_revisions_per_file: int = 0, ): """ Parameters: @@ -261,6 +262,7 @@ class WikiqParser: or a dict mapping namespace -> (pageid, revid) for partitioned output. For single-file: skip all revisions up to and including this point + max_revisions_per_file : if > 0, close and rotate output files after this many revisions """ self.input_file = input_file @@ -279,6 +281,7 @@ class WikiqParser: self.headings = headings self.shutdown_requested = False self.time_limit_seconds = time_limit_seconds + self.max_revisions_per_file = max_revisions_per_file if namespaces is not None: self.namespace_filter = set(namespaces) else: @@ -335,6 +338,14 @@ class WikiqParser: if timer is not None: timer.cancel() + def _get_part_path(self, base_path, part_num): + """Generate path with part number inserted before extension. + + Example: output.parquet -> output.part0.parquet + """ + path = Path(base_path) + return path.parent / f"{path.stem}.part{part_num}{path.suffix}" + def _open_checkpoint(self, output_file): """Open checkpoint file for writing. Keeps file open for performance.""" if not self.output_parquet or output_file == sys.stdout.buffer: @@ -344,14 +355,14 @@ class WikiqParser: self.checkpoint_file = open(checkpoint_path, 'w') print(f"Checkpoint file opened: {checkpoint_path}", file=sys.stderr) - def _update_checkpoint(self, pageid, revid, namespace=None): + def _update_checkpoint(self, pageid, revid, namespace=None, part=0): """Update checkpoint state and write to file.""" if self.checkpoint_file is None: return if self.partition_namespaces: - self.checkpoint_state[namespace] = {"pageid": pageid, "revid": revid} + self.checkpoint_state[namespace] = {"pageid": pageid, "revid": revid, "part": part} else: - self.checkpoint_state = {"pageid": pageid, "revid": revid} + self.checkpoint_state = {"pageid": pageid, "revid": revid, "part": part} self.checkpoint_file.seek(0) self.checkpoint_file.truncate() json.dump(self.checkpoint_state, self.checkpoint_file) @@ -370,22 +381,32 @@ class WikiqParser: else: print(f"Checkpoint file preserved for resume: {checkpoint_path}", file=sys.stderr) - def _write_batch(self, row_buffer, schema, writer, pq_writers, ns_paths, sorting_cols, namespace=None): + def _write_batch(self, row_buffer, schema, writer, pq_writers, ns_base_paths, sorting_cols, namespace=None, part_numbers=None): """Write a batch of rows to the appropriate writer. For partitioned output, creates writer lazily if needed. - Returns the writer used (for non-partitioned output, same as input). + Returns (writer, num_rows) - writer used and number of rows written. + + Args: + ns_base_paths: For partitioned output, maps namespace -> base path (without part number) + part_numbers: For partitioned output, maps namespace -> current part number """ + num_rows = len(row_buffer.get("revid", [])) if self.partition_namespaces and namespace is not None: if namespace not in pq_writers: - ns_path = ns_paths[namespace] + base_path = ns_base_paths[namespace] + part_num = part_numbers.get(namespace, 0) if part_numbers else 0 + if self.max_revisions_per_file > 0: + ns_path = self._get_part_path(base_path, part_num) + else: + ns_path = base_path Path(ns_path).parent.mkdir(exist_ok=True, parents=True) pq_writers[namespace] = pq.ParquetWriter( ns_path, schema, flavor="spark", sorting_columns=sorting_cols ) writer = pq_writers[namespace] writer.write(pa.record_batch(row_buffer, schema=schema)) - return writer + return writer, num_rows def make_matchmake_pairs(self, patterns, labels) -> list[RegexPair]: if (patterns is not None and labels is not None) and ( @@ -565,14 +586,32 @@ class WikiqParser: revid_sortingcol = pq.SortingColumn(schema.get_field_index("pageid")) sorting_cols = [pageid_sortingcol, revid_sortingcol] + # Part tracking for file splitting. + # Keys are namespace integers for partitioned output, None for non-partitioned. + part_numbers = {} + revisions_in_part = {} + + # Initialize part numbers from resume point if resuming + if self.resume_point is not None: + if self.partition_namespaces: + for ns, resume_data in self.resume_point.items(): + part_numbers[ns] = resume_data[2] if len(resume_data) > 2 else 0 + else: + part_numbers[None] = self.resume_point[2] if len(self.resume_point) > 2 else 0 + if self.partition_namespaces is False: + # Generate path with part number if file splitting is enabled + if self.max_revisions_per_file > 0: + output_path_with_part = self._get_part_path(self.output_file, part_numbers.get(None, 0)) + else: + output_path_with_part = self.output_file writer = pq.ParquetWriter( - self.output_file, + output_path_with_part, schema, flavor="spark", sorting_columns=sorting_cols, ) - ns_paths = {} + ns_base_paths = {} pq_writers = {} else: output_path = Path(self.output_file) @@ -580,10 +619,16 @@ class WikiqParser: namespaces = self.namespace_filter else: namespaces = self.namespaces.values() - ns_paths = { + # Store base paths - actual paths with part numbers generated in _write_batch + ns_base_paths = { ns: (output_path.parent / f"namespace={ns}") / output_path.name for ns in namespaces } + # Initialize part numbers for each namespace (from resume point or 0) + for ns in namespaces: + if ns not in part_numbers: + part_numbers[ns] = 0 + revisions_in_part[ns] = 0 # Writers are created lazily when first needed to avoid empty files on early exit pq_writers = {} writer = None # Not used for partitioned output @@ -594,9 +639,12 @@ class WikiqParser: schema, write_options=pacsv.WriteOptions(delimiter="\t"), ) - ns_paths = {} + ns_base_paths = {} pq_writers = {} sorting_cols = None + # Part tracking not needed for CSV (no OOM issue during close) + part_numbers = {} + revisions_in_part = {} regex_matches = {} @@ -625,7 +673,8 @@ class WikiqParser: # No resume point for this namespace, process normally found_resume_point[page_ns] = True else: - resume_pageid, resume_revid = self.resume_point[page_ns] + resume_data = self.resume_point[page_ns] + resume_pageid, resume_revid = resume_data[0], resume_data[1] if page_id < resume_pageid: continue elif page_id == resume_pageid: @@ -636,7 +685,7 @@ class WikiqParser: else: # Single-file resume: global resume point if not found_resume_point: - resume_pageid, resume_revid = self.resume_point + resume_pageid, resume_revid = self.resume_point[0], self.resume_point[1] if page_id < resume_pageid: continue elif page_id == resume_pageid: @@ -932,11 +981,36 @@ class WikiqParser: # Write batch if there are rows if should_write and len(row_buffer.get("revid", [])) > 0: namespace = page.mwpage.namespace if self.partition_namespaces else None - self._write_batch(row_buffer, schema, writer, pq_writers, ns_paths, sorting_cols, namespace) + writer, num_rows = self._write_batch( + row_buffer, schema, writer, pq_writers, ns_base_paths, sorting_cols, + namespace, part_numbers + ) # Update checkpoint with last written position last_pageid = row_buffer["articleid"][-1] last_revid = row_buffer["revid"][-1] - self._update_checkpoint(last_pageid, last_revid, namespace) + + # Track revisions and check if file rotation is needed. + # namespace is None for non-partitioned output. + revisions_in_part[namespace] = revisions_in_part.get(namespace, 0) + num_rows + current_part_num = part_numbers.get(namespace, 0) + self._update_checkpoint(last_pageid, last_revid, namespace, current_part_num) + + # Rotate file if limit exceeded + if self.max_revisions_per_file > 0 and revisions_in_part[namespace] >= self.max_revisions_per_file: + part_numbers[namespace] = current_part_num + 1 + revisions_in_part[namespace] = 0 + if self.partition_namespaces: + pq_writers[namespace].close() + del pq_writers[namespace] + else: + writer.close() + output_path_with_part = self._get_part_path(self.output_file, part_numbers[namespace]) + writer = pq.ParquetWriter( + output_path_with_part, + schema, + flavor="spark", + sorting_columns=sorting_cols, + ) gc.collect() # If shutdown was requested, break from page loop @@ -981,6 +1055,41 @@ def match_archive_suffix(input_filename): return cmd +def detect_existing_part_files(output_file, partition_namespaces): + """ + Detect whether existing output uses part file naming. + + Returns: + True if part files exist (e.g., output.part0.parquet) + False if non-part file exists (e.g., output.parquet) + None if no existing output found + """ + output_dir = os.path.dirname(output_file) or '.' + stem = os.path.basename(output_file).rsplit('.', 1)[0] + part0_name = f"{stem}.part0.parquet" + + if partition_namespaces: + if not os.path.isdir(output_dir): + return None + + for ns_dir in os.listdir(output_dir): + if not ns_dir.startswith('namespace='): + continue + ns_path = os.path.join(output_dir, ns_dir) + + if os.path.exists(os.path.join(ns_path, part0_name)): + return True + if os.path.exists(os.path.join(ns_path, os.path.basename(output_file))): + return False + return None + else: + if os.path.exists(os.path.join(output_dir, part0_name)): + return True + if os.path.exists(output_file): + return False + return None + + def open_input_file(input_filename, fandom_2020=False): cmd = match_archive_suffix(input_filename) if fandom_2020: @@ -1215,6 +1324,14 @@ def main(): help="Time limit in hours before graceful shutdown. Set to 0 to disable (default).", ) + parser.add_argument( + "--max-revisions-per-file", + dest="max_revisions_per_file", + type=int, + default=0, + help="Max revisions per output file before creating a new part file. Set to 0 to disable (default: disabled).", + ) + args = parser.parse_args() # set persistence method @@ -1272,11 +1389,24 @@ def main(): ns_list = sorted(resume_point.keys()) print(f"Resuming with per-namespace resume points for {len(ns_list)} namespaces", file=sys.stderr) for ns in ns_list: - pageid, revid = resume_point[ns] - print(f" namespace={ns}: pageid={pageid}, revid={revid}", file=sys.stderr) + pageid, revid, part = resume_point[ns] + print(f" namespace={ns}: pageid={pageid}, revid={revid}, part={part}", file=sys.stderr) else: - pageid, revid = resume_point - print(f"Resuming from last written point: pageid={pageid}, revid={revid}", file=sys.stderr) + pageid, revid, part = resume_point + print(f"Resuming from last written point: pageid={pageid}, revid={revid}, part={part}", file=sys.stderr) + + # Detect mismatch between existing file naming and current settings + existing_uses_parts = detect_existing_part_files(output_file, args.partition_namespaces) + current_uses_parts = args.max_revisions_per_file > 0 + + if existing_uses_parts is not None and existing_uses_parts != current_uses_parts: + if existing_uses_parts: + print(f"WARNING: Existing output uses part files but --max-revisions-per-file is 0. " + f"New data will be written to a non-part file.", file=sys.stderr) + else: + print(f"WARNING: Existing output does not use part files but --max-revisions-per-file={args.max_revisions_per_file}. " + f"Disabling file splitting for consistency with existing output.", file=sys.stderr) + args.max_revisions_per_file = 0 else: # resume_point is None - check if file exists but is corrupt if args.partition_namespaces: @@ -1332,6 +1462,7 @@ def main(): templates=args.templates, headings=args.headings, time_limit_seconds=time_limit_seconds, + max_revisions_per_file=args.max_revisions_per_file, ) # Register signal handlers for graceful shutdown (CLI only) diff --git a/src/wikiq/resume.py b/src/wikiq/resume.py index 67adeb9..ff9e05c 100644 --- a/src/wikiq/resume.py +++ b/src/wikiq/resume.py @@ -125,12 +125,14 @@ def read_checkpoint(output_file, partition_namespaces=False): Read resume point from checkpoint file if it exists. Checkpoint format: - Single file: {"pageid": 54, "revid": 325} - Partitioned: {"0": {"pageid": 54, "revid": 325}, "1": {"pageid": 123, "revid": 456}} + Single file: {"pageid": 54, "revid": 325, "part": 2} + Partitioned: {"0": {"pageid": 54, "revid": 325, "part": 1}, ...} Returns: - For single files: A tuple (pageid, revid), or None if not found. - For partitioned: A dict mapping namespace -> (pageid, revid), or None. + For single files: A tuple (pageid, revid, part), or None if not found. + For partitioned: A dict mapping namespace -> (pageid, revid, part), or None. + + Note: part defaults to 0 for checkpoints without part numbers (backwards compat). """ checkpoint_path = get_checkpoint_path(output_file, partition_namespaces) if not os.path.exists(checkpoint_path): @@ -143,14 +145,16 @@ def read_checkpoint(output_file, partition_namespaces=False): if not data: return None - # Single-file format: {"pageid": ..., "revid": ...} + # Single-file format: {"pageid": ..., "revid": ..., "part": ...} if "pageid" in data and "revid" in data: - return (data["pageid"], data["revid"]) + part = data.get("part", 0) + return (data["pageid"], data["revid"], part) - # Partitioned format: {"0": {"pageid": ..., "revid": ...}, ...} + # Partitioned format: {"0": {"pageid": ..., "revid": ..., "part": ...}, ...} result = {} for key, value in data.items(): - result[int(key)] = (value["pageid"], value["revid"]) + part = value.get("part", 0) + result[int(key)] = (value["pageid"], value["revid"], part) return result if result else None @@ -173,10 +177,9 @@ def get_resume_point(output_file, partition_namespaces=False): partition_namespaces: Whether the output uses namespace partitioning. Returns: - For single files: A tuple (pageid, revid) for the row with the highest pageid, - or None if not found. - For partitioned: A dict mapping namespace -> (pageid, revid) for each partition, - or None if no partitions exist. + For single files: A tuple (pageid, revid, part) or None if not found. + For partitioned: A dict mapping namespace -> (pageid, revid, part), or None. + When falling back to parquet scanning, part defaults to 0. """ # First try checkpoint file (fast) checkpoint_path = get_checkpoint_path(output_file, partition_namespaces) @@ -202,6 +205,7 @@ def _get_last_row_resume_point(pq_path): Since data is written in page/revision order, the last row group contains the highest pageid/revid, and the last row in that group is the resume point. + Returns (pageid, revid, part) with part=0 (scanning can't determine part). """ pf = pq.ParquetFile(pq_path) if pf.metadata.num_row_groups == 0: @@ -214,15 +218,15 @@ def _get_last_row_resume_point(pq_path): max_pageid = table['articleid'][-1].as_py() max_revid = table['revid'][-1].as_py() - return (max_pageid, max_revid) + return (max_pageid, max_revid, 0) def _get_resume_point_partitioned(output_file): """Find per-namespace resume points from partitioned output. Only looks for the specific output file in each namespace directory. - Returns a dict mapping namespace -> (max_pageid, max_revid) for each partition - where the output file exists. + Returns a dict mapping namespace -> (max_pageid, max_revid, part=0) for each + partition where the output file exists. Args: output_file: Path like 'dir/output.parquet' where namespace=* subdirectories diff --git a/test/test_resume.py b/test/test_resume.py index b37cc30..bc8a38b 100644 --- a/test/test_resume.py +++ b/test/test_resume.py @@ -369,7 +369,7 @@ def test_cleanup_interrupted_resume_original_corrupted_temp_valid(): resume_point = get_resume_point(output_file, partition_namespaces=False) assert resume_point is not None, "Should find resume point from recovered file" - assert resume_point == (30, 300), f"Expected (30, 300), got {resume_point}" + assert resume_point == (30, 300, 0), f"Expected (30, 300, 0), got {resume_point}" print("Cleanup with original corrupted, temp valid test passed!") @@ -396,7 +396,7 @@ def test_cleanup_original_missing_temp_valid_no_checkpoint(): resume_point = get_resume_point(output_file, partition_namespaces=False) assert resume_point is not None, "Should find resume point from recovered file" - assert resume_point == (30, 300), f"Expected (30, 300), got {resume_point}" + assert resume_point == (30, 300, 0), f"Expected (30, 300, 0), got {resume_point}" print("Original missing, temp valid, no checkpoint test passed!") @@ -464,3 +464,86 @@ def test_concurrent_jobs_different_input_files(): assert orig2_ns1.num_rows == 2, "file2 ns1 should still have 2 rows" print("Concurrent jobs with different input files test passed!") + + +def test_max_revisions_per_file_creates_parts(): + """Test that --max-revisions-per-file creates multiple part files.""" + import re + tester = WikiqTester(SAILORMOON, "max_revs_parts", in_compression="7z", out_format="parquet") + + max_revs = 50 + try: + # Use a very small limit to force multiple parts + tester.call_wikiq("--fandom-2020", "--max-revisions-per-file", str(max_revs)) + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + output_dir = tester.output + all_parquet = [f for f in os.listdir(output_dir) if f.endswith(".parquet") and ".part" in f] + + # Sort by part number numerically + def get_part_num(filename): + match = re.search(r'\.part(\d+)\.parquet$', filename) + return int(match.group(1)) if match else 0 + + part_files = sorted(all_parquet, key=get_part_num) + + assert len(part_files) > 1, f"Expected multiple part files, got {part_files}" + + # Read all parts and verify total rows + total_rows = 0 + for f in part_files: + table = pq.read_table(os.path.join(output_dir, f)) + total_rows += len(table) + + assert total_rows > 0, "Should have some rows across all parts" + + # Each part (except the last) should have at least max_revisions rows + # (rotation happens after the batch that hits the limit is written) + for f in part_files[:-1]: + table = pq.read_table(os.path.join(output_dir, f)) + assert len(table) >= max_revs, f"Part file {f} should have at least {max_revs} rows, got {len(table)}" + + print(f"max-revisions-per-file test passed! Created {len(part_files)} parts with {total_rows} total rows") + + +def test_max_revisions_per_file_with_partitioned(): + """Test that --max-revisions-per-file works with partitioned namespace output.""" + import re + tester = WikiqTester(SAILORMOON, "max_revs_partitioned", in_compression="7z", out_format="parquet") + + max_revs = 20 + try: + # Use a small limit to force parts, with partitioned output + tester.call_wikiq("--fandom-2020", "--partition-namespaces", "--max-revisions-per-file", str(max_revs)) + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + output_dir = tester.output + + # Find namespace directories + ns_dirs = [d for d in os.listdir(output_dir) if d.startswith("namespace=")] + assert len(ns_dirs) > 0, "Should have namespace directories" + + def get_part_num(filename): + match = re.search(r'\.part(\d+)\.parquet$', filename) + return int(match.group(1)) if match else 0 + + # Check that at least one namespace has multiple parts + found_multi_part = False + for ns_dir in ns_dirs: + ns_path = os.path.join(output_dir, ns_dir) + parquet_files = [f for f in os.listdir(ns_path) if f.endswith(".parquet")] + part_files = [f for f in parquet_files if ".part" in f] + if len(part_files) > 1: + found_multi_part = True + # Sort by part number and verify each part (except last) has at least limit rows + sorted_parts = sorted(part_files, key=get_part_num) + for f in sorted_parts[:-1]: + pf = pq.ParquetFile(os.path.join(ns_path, f)) + num_rows = pf.metadata.num_rows + assert num_rows >= max_revs, f"Part file {f} in {ns_dir} should have at least {max_revs} rows, got {num_rows}" + + assert found_multi_part, "At least one namespace should have multiple part files" + + print(f"max-revisions-per-file with partitioned output test passed!")