output parquet files in chunks to avoid memory issues with parquet.

This commit is contained in:
Nathan TeBlunthuis
2025-12-20 21:45:39 -08:00
parent 6a4bf81e1a
commit 6988a281dc
3 changed files with 254 additions and 36 deletions

View File

@@ -253,6 +253,7 @@ class WikiqParser:
templates: bool = False,
headings: bool = False,
time_limit_seconds: Union[float, None] = None,
max_revisions_per_file: int = 0,
):
"""
Parameters:
@@ -261,6 +262,7 @@ class WikiqParser:
or a dict mapping namespace -> (pageid, revid) for partitioned output.
For single-file: skip all revisions up to
and including this point
max_revisions_per_file : if > 0, close and rotate output files after this many revisions
"""
self.input_file = input_file
@@ -279,6 +281,7 @@ class WikiqParser:
self.headings = headings
self.shutdown_requested = False
self.time_limit_seconds = time_limit_seconds
self.max_revisions_per_file = max_revisions_per_file
if namespaces is not None:
self.namespace_filter = set(namespaces)
else:
@@ -335,6 +338,14 @@ class WikiqParser:
if timer is not None:
timer.cancel()
def _get_part_path(self, base_path, part_num):
"""Generate path with part number inserted before extension.
Example: output.parquet -> output.part0.parquet
"""
path = Path(base_path)
return path.parent / f"{path.stem}.part{part_num}{path.suffix}"
def _open_checkpoint(self, output_file):
"""Open checkpoint file for writing. Keeps file open for performance."""
if not self.output_parquet or output_file == sys.stdout.buffer:
@@ -344,14 +355,14 @@ class WikiqParser:
self.checkpoint_file = open(checkpoint_path, 'w')
print(f"Checkpoint file opened: {checkpoint_path}", file=sys.stderr)
def _update_checkpoint(self, pageid, revid, namespace=None):
def _update_checkpoint(self, pageid, revid, namespace=None, part=0):
"""Update checkpoint state and write to file."""
if self.checkpoint_file is None:
return
if self.partition_namespaces:
self.checkpoint_state[namespace] = {"pageid": pageid, "revid": revid}
self.checkpoint_state[namespace] = {"pageid": pageid, "revid": revid, "part": part}
else:
self.checkpoint_state = {"pageid": pageid, "revid": revid}
self.checkpoint_state = {"pageid": pageid, "revid": revid, "part": part}
self.checkpoint_file.seek(0)
self.checkpoint_file.truncate()
json.dump(self.checkpoint_state, self.checkpoint_file)
@@ -370,22 +381,32 @@ class WikiqParser:
else:
print(f"Checkpoint file preserved for resume: {checkpoint_path}", file=sys.stderr)
def _write_batch(self, row_buffer, schema, writer, pq_writers, ns_paths, sorting_cols, namespace=None):
def _write_batch(self, row_buffer, schema, writer, pq_writers, ns_base_paths, sorting_cols, namespace=None, part_numbers=None):
"""Write a batch of rows to the appropriate writer.
For partitioned output, creates writer lazily if needed.
Returns the writer used (for non-partitioned output, same as input).
Returns (writer, num_rows) - writer used and number of rows written.
Args:
ns_base_paths: For partitioned output, maps namespace -> base path (without part number)
part_numbers: For partitioned output, maps namespace -> current part number
"""
num_rows = len(row_buffer.get("revid", []))
if self.partition_namespaces and namespace is not None:
if namespace not in pq_writers:
ns_path = ns_paths[namespace]
base_path = ns_base_paths[namespace]
part_num = part_numbers.get(namespace, 0) if part_numbers else 0
if self.max_revisions_per_file > 0:
ns_path = self._get_part_path(base_path, part_num)
else:
ns_path = base_path
Path(ns_path).parent.mkdir(exist_ok=True, parents=True)
pq_writers[namespace] = pq.ParquetWriter(
ns_path, schema, flavor="spark", sorting_columns=sorting_cols
)
writer = pq_writers[namespace]
writer.write(pa.record_batch(row_buffer, schema=schema))
return writer
return writer, num_rows
def make_matchmake_pairs(self, patterns, labels) -> list[RegexPair]:
if (patterns is not None and labels is not None) and (
@@ -565,14 +586,32 @@ class WikiqParser:
revid_sortingcol = pq.SortingColumn(schema.get_field_index("pageid"))
sorting_cols = [pageid_sortingcol, revid_sortingcol]
# Part tracking for file splitting.
# Keys are namespace integers for partitioned output, None for non-partitioned.
part_numbers = {}
revisions_in_part = {}
# Initialize part numbers from resume point if resuming
if self.resume_point is not None:
if self.partition_namespaces:
for ns, resume_data in self.resume_point.items():
part_numbers[ns] = resume_data[2] if len(resume_data) > 2 else 0
else:
part_numbers[None] = self.resume_point[2] if len(self.resume_point) > 2 else 0
if self.partition_namespaces is False:
# Generate path with part number if file splitting is enabled
if self.max_revisions_per_file > 0:
output_path_with_part = self._get_part_path(self.output_file, part_numbers.get(None, 0))
else:
output_path_with_part = self.output_file
writer = pq.ParquetWriter(
self.output_file,
output_path_with_part,
schema,
flavor="spark",
sorting_columns=sorting_cols,
)
ns_paths = {}
ns_base_paths = {}
pq_writers = {}
else:
output_path = Path(self.output_file)
@@ -580,10 +619,16 @@ class WikiqParser:
namespaces = self.namespace_filter
else:
namespaces = self.namespaces.values()
ns_paths = {
# Store base paths - actual paths with part numbers generated in _write_batch
ns_base_paths = {
ns: (output_path.parent / f"namespace={ns}") / output_path.name
for ns in namespaces
}
# Initialize part numbers for each namespace (from resume point or 0)
for ns in namespaces:
if ns not in part_numbers:
part_numbers[ns] = 0
revisions_in_part[ns] = 0
# Writers are created lazily when first needed to avoid empty files on early exit
pq_writers = {}
writer = None # Not used for partitioned output
@@ -594,9 +639,12 @@ class WikiqParser:
schema,
write_options=pacsv.WriteOptions(delimiter="\t"),
)
ns_paths = {}
ns_base_paths = {}
pq_writers = {}
sorting_cols = None
# Part tracking not needed for CSV (no OOM issue during close)
part_numbers = {}
revisions_in_part = {}
regex_matches = {}
@@ -625,7 +673,8 @@ class WikiqParser:
# No resume point for this namespace, process normally
found_resume_point[page_ns] = True
else:
resume_pageid, resume_revid = self.resume_point[page_ns]
resume_data = self.resume_point[page_ns]
resume_pageid, resume_revid = resume_data[0], resume_data[1]
if page_id < resume_pageid:
continue
elif page_id == resume_pageid:
@@ -636,7 +685,7 @@ class WikiqParser:
else:
# Single-file resume: global resume point
if not found_resume_point:
resume_pageid, resume_revid = self.resume_point
resume_pageid, resume_revid = self.resume_point[0], self.resume_point[1]
if page_id < resume_pageid:
continue
elif page_id == resume_pageid:
@@ -932,11 +981,36 @@ class WikiqParser:
# Write batch if there are rows
if should_write and len(row_buffer.get("revid", [])) > 0:
namespace = page.mwpage.namespace if self.partition_namespaces else None
self._write_batch(row_buffer, schema, writer, pq_writers, ns_paths, sorting_cols, namespace)
writer, num_rows = self._write_batch(
row_buffer, schema, writer, pq_writers, ns_base_paths, sorting_cols,
namespace, part_numbers
)
# Update checkpoint with last written position
last_pageid = row_buffer["articleid"][-1]
last_revid = row_buffer["revid"][-1]
self._update_checkpoint(last_pageid, last_revid, namespace)
# Track revisions and check if file rotation is needed.
# namespace is None for non-partitioned output.
revisions_in_part[namespace] = revisions_in_part.get(namespace, 0) + num_rows
current_part_num = part_numbers.get(namespace, 0)
self._update_checkpoint(last_pageid, last_revid, namespace, current_part_num)
# Rotate file if limit exceeded
if self.max_revisions_per_file > 0 and revisions_in_part[namespace] >= self.max_revisions_per_file:
part_numbers[namespace] = current_part_num + 1
revisions_in_part[namespace] = 0
if self.partition_namespaces:
pq_writers[namespace].close()
del pq_writers[namespace]
else:
writer.close()
output_path_with_part = self._get_part_path(self.output_file, part_numbers[namespace])
writer = pq.ParquetWriter(
output_path_with_part,
schema,
flavor="spark",
sorting_columns=sorting_cols,
)
gc.collect()
# If shutdown was requested, break from page loop
@@ -981,6 +1055,41 @@ def match_archive_suffix(input_filename):
return cmd
def detect_existing_part_files(output_file, partition_namespaces):
"""
Detect whether existing output uses part file naming.
Returns:
True if part files exist (e.g., output.part0.parquet)
False if non-part file exists (e.g., output.parquet)
None if no existing output found
"""
output_dir = os.path.dirname(output_file) or '.'
stem = os.path.basename(output_file).rsplit('.', 1)[0]
part0_name = f"{stem}.part0.parquet"
if partition_namespaces:
if not os.path.isdir(output_dir):
return None
for ns_dir in os.listdir(output_dir):
if not ns_dir.startswith('namespace='):
continue
ns_path = os.path.join(output_dir, ns_dir)
if os.path.exists(os.path.join(ns_path, part0_name)):
return True
if os.path.exists(os.path.join(ns_path, os.path.basename(output_file))):
return False
return None
else:
if os.path.exists(os.path.join(output_dir, part0_name)):
return True
if os.path.exists(output_file):
return False
return None
def open_input_file(input_filename, fandom_2020=False):
cmd = match_archive_suffix(input_filename)
if fandom_2020:
@@ -1215,6 +1324,14 @@ def main():
help="Time limit in hours before graceful shutdown. Set to 0 to disable (default).",
)
parser.add_argument(
"--max-revisions-per-file",
dest="max_revisions_per_file",
type=int,
default=0,
help="Max revisions per output file before creating a new part file. Set to 0 to disable (default: disabled).",
)
args = parser.parse_args()
# set persistence method
@@ -1272,11 +1389,24 @@ def main():
ns_list = sorted(resume_point.keys())
print(f"Resuming with per-namespace resume points for {len(ns_list)} namespaces", file=sys.stderr)
for ns in ns_list:
pageid, revid = resume_point[ns]
print(f" namespace={ns}: pageid={pageid}, revid={revid}", file=sys.stderr)
pageid, revid, part = resume_point[ns]
print(f" namespace={ns}: pageid={pageid}, revid={revid}, part={part}", file=sys.stderr)
else:
pageid, revid = resume_point
print(f"Resuming from last written point: pageid={pageid}, revid={revid}", file=sys.stderr)
pageid, revid, part = resume_point
print(f"Resuming from last written point: pageid={pageid}, revid={revid}, part={part}", file=sys.stderr)
# Detect mismatch between existing file naming and current settings
existing_uses_parts = detect_existing_part_files(output_file, args.partition_namespaces)
current_uses_parts = args.max_revisions_per_file > 0
if existing_uses_parts is not None and existing_uses_parts != current_uses_parts:
if existing_uses_parts:
print(f"WARNING: Existing output uses part files but --max-revisions-per-file is 0. "
f"New data will be written to a non-part file.", file=sys.stderr)
else:
print(f"WARNING: Existing output does not use part files but --max-revisions-per-file={args.max_revisions_per_file}. "
f"Disabling file splitting for consistency with existing output.", file=sys.stderr)
args.max_revisions_per_file = 0
else:
# resume_point is None - check if file exists but is corrupt
if args.partition_namespaces:
@@ -1332,6 +1462,7 @@ def main():
templates=args.templates,
headings=args.headings,
time_limit_seconds=time_limit_seconds,
max_revisions_per_file=args.max_revisions_per_file,
)
# Register signal handlers for graceful shutdown (CLI only)