Get parquet libraries writing files

Tests broken due to url encoding, which can likely now be removed.

Signed-off-by: Will Beason <willbeason@gmail.com>
This commit is contained in:
Will Beason 2025-05-30 13:06:26 -05:00
parent 4dde25c508
commit 0d56267ae0

148
wikiq
View File

@ -15,14 +15,13 @@ from itertools import groupby
from subprocess import Popen, PIPE
from collections import deque
from hashlib import sha1
from typing import Any, IO, TextIO
from typing import Any, IO, TextIO, Final
from mwxml import Dump
from deltas.tokenizers import wikitext_split
import mwpersistence
import mwreverts
from urllib.parse import quote
from pyarrow import Array, Table, Schema, DataType
from pyarrow.parquet import ParquetWriter
@ -36,6 +35,7 @@ import dataclasses as dc
from dataclasses import dataclass
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.csv as pc
class PersistMethod:
@ -233,7 +233,7 @@ class Revision:
deleted: bool
text_chars: int | None = None
revert: bool | None = None
reverteds: list[int] = None
reverteds: str = None
sha1: str | None = None
minor: bool | None = None
editor: str | None = None
@ -248,7 +248,7 @@ class Revision:
# this isn't a dataclass field since it doesn't have a type annotation
pa_schema_fields = [
pa.field("revid", pa.int64()),
pa.field("date_time", pa.timestamp('ms')),
pa.field("date_time", pa.timestamp('s')),
pa.field("articleid", pa.int64()),
pa.field("editorid", pa.int64(), nullable=True),
pa.field("title", pa.string()),
@ -256,7 +256,8 @@ class Revision:
pa.field("deleted", pa.bool_()),
pa.field("text_chars", pa.int32()),
pa.field("revert", pa.bool_(), nullable=True),
pa.field("reverteds", pa.list_(pa.int64()), nullable=True),
# reverteds is a string which contains a comma-separated list of reverted revision ids.
pa.field("reverteds", pa.string(), nullable=True),
pa.field("sha1", pa.string()),
pa.field("minor", pa.bool_()),
pa.field("editor", pa.string()),
@ -264,45 +265,10 @@ class Revision:
]
# pyarrow is a columnar format, so most of the work happens in the flush_parquet_buffer function
def to_pyarrow(self):
return dc.astuple(self)
# logic to convert each field into the wikiq tsv format goes here.
def to_tsv_row(self):
row = []
for f in dc.fields(self):
val = getattr(self, f.name)
if getattr(self, f.name) is None:
row.append("")
elif f.type == bool:
row.append("TRUE" if val else "FALSE")
elif f.type == datetime:
row.append(val.strftime('%Y-%m-%d %H:%M:%S'))
elif f.name in {'editor', 'title'}:
s = '"' + val + '"'
if self.urlencode and f.name in TO_ENCODE:
row.append(quote(str(s)))
else:
row.append(s)
elif f.type == list[int]:
row.append('"' + ",".join([str(x) for x in val]) + '"')
elif f.type == str:
if self.urlencode and f.name in TO_ENCODE:
row.append(quote(str(val)))
else:
row.append(val)
else:
row.append(val)
return '\t'.join(map(str, row))
def header_row(self):
return '\t'.join(map(lambda f: f.name, dc.fields(self)))
def to_pyarrow(self) -> pa.RecordBatch:
d = dc.asdict(self)
lists = [[d[field.name]] for field in self.pa_schema_fields]
return pa.record_batch(lists, schema=pa.schema(self.pa_schema_fields))
"""
@ -434,7 +400,7 @@ class WikiqParser:
self.output_file = output_file
else:
self.output_file = open(output_file, 'w')
self.output_file = open(output_file, 'wb')
self.output_parquet = False
def make_matchmake_pairs(self, patterns, labels):
@ -498,6 +464,12 @@ class WikiqParser:
page_count = 0
rev_count = 0
writer: pq.ParquetWriter | pc.CSVWriter
if self.output_parquet:
writer = pq.ParquetWriter(self.output_file, self.schema, flavor='spark')
else:
writer = pc.CSVWriter(self.output_file, self.schema, write_options=pc.WriteOptions(delimiter='\t'))
# Iterate through pages
for page in dump:
namespace = page.namespace if page.namespace is not None else self.__get_namespace_from_title(page.title)
@ -534,7 +506,7 @@ class WikiqParser:
editorid = None if rev.deleted.user or rev.user.id is None else rev.user.id
# create a new data object instead of a dictionary.
rev_data = self.revdata_type(revid=rev.id,
rev_data: Revision = self.revdata_type(revid=rev.id,
date_time=datetime.fromtimestamp(rev.timestamp.unix(), tz=timezone.utc),
articleid=page.id,
editorid=editorid,
@ -567,7 +539,7 @@ class WikiqParser:
if revert:
rev_data.revert = True
rev_data.reverteds = revert.reverteds
rev_data.reverteds = ",".join([str(s) for s in revert.reverteds])
else:
rev_data.revert = False
@ -612,10 +584,10 @@ class WikiqParser:
old_rev_data.tokens_removed = len(old_tokens_removed)
old_rev_data.tokens_window = PERSISTENCE_RADIUS - 1
self.print_rev_data(old_rev_data)
writer.write(rev_data.to_pyarrow())
else:
self.print_rev_data(rev_data)
writer.write(rev_data.to_pyarrow())
rev_count += 1
@ -633,7 +605,7 @@ class WikiqParser:
rev_data.tokens_added = num_tokens
rev_data.tokens_removed = len(tokens_removed)
rev_data.tokens_window = len(window) - (i + 1)
self.print_rev_data(rev_data)
writer.write(rev_data.to_pyarrow())
page_count += 1
@ -648,54 +620,43 @@ class WikiqParser:
else:
self.output_file.close()
"""
For performance reasons it's better to write parquet in batches instead of one row at a time.
So this function just puts the data on a buffer. If the buffer is full, then it gets flushed (written).
"""
@staticmethod
def rows_to_table(rg, schema: Schema) -> Table:
cols = []
first = rg[0]
for col in first:
cols.append([col])
def write_parquet_row(self, rev_data):
padata = rev_data.to_pyarrow()
self.parquet_buffer.append(padata)
for row in rg[1:]:
for j in range(len(cols)):
cols[j].append(row[j])
arrays: list[Array] = []
typ: DataType
for i, (col, typ) in enumerate(zip(cols, schema.types)):
try:
arrays.append(pa.array(col, typ))
except pa.ArrowInvalid as exc:
print("column index:", i, "type:", typ, file=sys.stderr)
print("data:", col, file=sys.stderr)
print("schema:", file=sys.stderr)
print(schema, file=sys.stderr)
raise exc
return pa.Table.from_arrays(arrays, schema=schema)
if len(self.parquet_buffer) >= self.parquet_buffer_size:
self.flush_parquet_buffer()
"""
Function that actually writes data to the parquet file.
It needs to transpose the data from row-by-row to column-by-column
"""
def flush_parquet_buffer(self):
"""
Returns the pyarrow table that we'll write
"""
def rows_to_table(rg, schema: Schema) -> Table:
cols = []
first = rg[0]
for col in first:
cols.append([col])
for row in rg[1:]:
for j in range(len(cols)):
cols[j].append(row[j])
arrays: list[Array] = []
typ: DataType
for i, (col, typ) in enumerate(zip(cols, schema.types)):
try:
arrays.append(pa.array(col, typ))
except pa.ArrowInvalid as exc:
print("column index:", i, "type:", typ, file=sys.stderr)
print("data:", col, file=sys.stderr)
print("schema:", file=sys.stderr)
print(schema, file=sys.stderr)
raise exc
return pa.Table.from_arrays(arrays, schema=schema)
outtable = rows_to_table(self.parquet_buffer, self.schema)
outtable = self.rows_to_table(self.parquet_buffer, self.schema)
if self.pq_writer is None:
self.pq_writer: ParquetWriter = (
pq.ParquetWriter(self.output_file, self.schema, flavor='spark'))
@ -703,23 +664,6 @@ class WikiqParser:
self.pq_writer.write_table(outtable)
self.parquet_buffer = []
# depending on if we are configured to write tsv or parquet, we'll call a different function.
def print_rev_data(self, rev_data):
if self.output_parquet is False:
printfunc = self.write_tsv_row
else:
printfunc = self.write_parquet_row
printfunc(rev_data)
def write_tsv_row(self, rev_data):
if self.print_header:
print(rev_data.header_row(), file=self.output_file)
self.print_header = False
line = rev_data.to_tsv_row()
print(line, file=self.output_file)
def match_archive_suffix(input_filename):
if re.match(r'.*\.7z$', input_filename):