Get parquet libraries writing files

Tests broken due to url encoding, which can likely now be removed.

Signed-off-by: Will Beason <willbeason@gmail.com>
This commit is contained in:
Will Beason 2025-05-30 13:06:26 -05:00
parent 4dde25c508
commit 0d56267ae0

148
wikiq
View File

@ -15,14 +15,13 @@ from itertools import groupby
from subprocess import Popen, PIPE from subprocess import Popen, PIPE
from collections import deque from collections import deque
from hashlib import sha1 from hashlib import sha1
from typing import Any, IO, TextIO from typing import Any, IO, TextIO, Final
from mwxml import Dump from mwxml import Dump
from deltas.tokenizers import wikitext_split from deltas.tokenizers import wikitext_split
import mwpersistence import mwpersistence
import mwreverts import mwreverts
from urllib.parse import quote
from pyarrow import Array, Table, Schema, DataType from pyarrow import Array, Table, Schema, DataType
from pyarrow.parquet import ParquetWriter from pyarrow.parquet import ParquetWriter
@ -36,6 +35,7 @@ import dataclasses as dc
from dataclasses import dataclass from dataclasses import dataclass
import pyarrow as pa import pyarrow as pa
import pyarrow.parquet as pq import pyarrow.parquet as pq
import pyarrow.csv as pc
class PersistMethod: class PersistMethod:
@ -233,7 +233,7 @@ class Revision:
deleted: bool deleted: bool
text_chars: int | None = None text_chars: int | None = None
revert: bool | None = None revert: bool | None = None
reverteds: list[int] = None reverteds: str = None
sha1: str | None = None sha1: str | None = None
minor: bool | None = None minor: bool | None = None
editor: str | None = None editor: str | None = None
@ -248,7 +248,7 @@ class Revision:
# this isn't a dataclass field since it doesn't have a type annotation # this isn't a dataclass field since it doesn't have a type annotation
pa_schema_fields = [ pa_schema_fields = [
pa.field("revid", pa.int64()), pa.field("revid", pa.int64()),
pa.field("date_time", pa.timestamp('ms')), pa.field("date_time", pa.timestamp('s')),
pa.field("articleid", pa.int64()), pa.field("articleid", pa.int64()),
pa.field("editorid", pa.int64(), nullable=True), pa.field("editorid", pa.int64(), nullable=True),
pa.field("title", pa.string()), pa.field("title", pa.string()),
@ -256,7 +256,8 @@ class Revision:
pa.field("deleted", pa.bool_()), pa.field("deleted", pa.bool_()),
pa.field("text_chars", pa.int32()), pa.field("text_chars", pa.int32()),
pa.field("revert", pa.bool_(), nullable=True), pa.field("revert", pa.bool_(), nullable=True),
pa.field("reverteds", pa.list_(pa.int64()), nullable=True), # reverteds is a string which contains a comma-separated list of reverted revision ids.
pa.field("reverteds", pa.string(), nullable=True),
pa.field("sha1", pa.string()), pa.field("sha1", pa.string()),
pa.field("minor", pa.bool_()), pa.field("minor", pa.bool_()),
pa.field("editor", pa.string()), pa.field("editor", pa.string()),
@ -264,45 +265,10 @@ class Revision:
] ]
# pyarrow is a columnar format, so most of the work happens in the flush_parquet_buffer function # pyarrow is a columnar format, so most of the work happens in the flush_parquet_buffer function
def to_pyarrow(self): def to_pyarrow(self) -> pa.RecordBatch:
return dc.astuple(self) d = dc.asdict(self)
lists = [[d[field.name]] for field in self.pa_schema_fields]
# logic to convert each field into the wikiq tsv format goes here. return pa.record_batch(lists, schema=pa.schema(self.pa_schema_fields))
def to_tsv_row(self):
row = []
for f in dc.fields(self):
val = getattr(self, f.name)
if getattr(self, f.name) is None:
row.append("")
elif f.type == bool:
row.append("TRUE" if val else "FALSE")
elif f.type == datetime:
row.append(val.strftime('%Y-%m-%d %H:%M:%S'))
elif f.name in {'editor', 'title'}:
s = '"' + val + '"'
if self.urlencode and f.name in TO_ENCODE:
row.append(quote(str(s)))
else:
row.append(s)
elif f.type == list[int]:
row.append('"' + ",".join([str(x) for x in val]) + '"')
elif f.type == str:
if self.urlencode and f.name in TO_ENCODE:
row.append(quote(str(val)))
else:
row.append(val)
else:
row.append(val)
return '\t'.join(map(str, row))
def header_row(self):
return '\t'.join(map(lambda f: f.name, dc.fields(self)))
""" """
@ -434,7 +400,7 @@ class WikiqParser:
self.output_file = output_file self.output_file = output_file
else: else:
self.output_file = open(output_file, 'w') self.output_file = open(output_file, 'wb')
self.output_parquet = False self.output_parquet = False
def make_matchmake_pairs(self, patterns, labels): def make_matchmake_pairs(self, patterns, labels):
@ -498,6 +464,12 @@ class WikiqParser:
page_count = 0 page_count = 0
rev_count = 0 rev_count = 0
writer: pq.ParquetWriter | pc.CSVWriter
if self.output_parquet:
writer = pq.ParquetWriter(self.output_file, self.schema, flavor='spark')
else:
writer = pc.CSVWriter(self.output_file, self.schema, write_options=pc.WriteOptions(delimiter='\t'))
# Iterate through pages # Iterate through pages
for page in dump: for page in dump:
namespace = page.namespace if page.namespace is not None else self.__get_namespace_from_title(page.title) namespace = page.namespace if page.namespace is not None else self.__get_namespace_from_title(page.title)
@ -534,7 +506,7 @@ class WikiqParser:
editorid = None if rev.deleted.user or rev.user.id is None else rev.user.id editorid = None if rev.deleted.user or rev.user.id is None else rev.user.id
# create a new data object instead of a dictionary. # create a new data object instead of a dictionary.
rev_data = self.revdata_type(revid=rev.id, rev_data: Revision = self.revdata_type(revid=rev.id,
date_time=datetime.fromtimestamp(rev.timestamp.unix(), tz=timezone.utc), date_time=datetime.fromtimestamp(rev.timestamp.unix(), tz=timezone.utc),
articleid=page.id, articleid=page.id,
editorid=editorid, editorid=editorid,
@ -567,7 +539,7 @@ class WikiqParser:
if revert: if revert:
rev_data.revert = True rev_data.revert = True
rev_data.reverteds = revert.reverteds rev_data.reverteds = ",".join([str(s) for s in revert.reverteds])
else: else:
rev_data.revert = False rev_data.revert = False
@ -612,10 +584,10 @@ class WikiqParser:
old_rev_data.tokens_removed = len(old_tokens_removed) old_rev_data.tokens_removed = len(old_tokens_removed)
old_rev_data.tokens_window = PERSISTENCE_RADIUS - 1 old_rev_data.tokens_window = PERSISTENCE_RADIUS - 1
self.print_rev_data(old_rev_data) writer.write(rev_data.to_pyarrow())
else: else:
self.print_rev_data(rev_data) writer.write(rev_data.to_pyarrow())
rev_count += 1 rev_count += 1
@ -633,7 +605,7 @@ class WikiqParser:
rev_data.tokens_added = num_tokens rev_data.tokens_added = num_tokens
rev_data.tokens_removed = len(tokens_removed) rev_data.tokens_removed = len(tokens_removed)
rev_data.tokens_window = len(window) - (i + 1) rev_data.tokens_window = len(window) - (i + 1)
self.print_rev_data(rev_data) writer.write(rev_data.to_pyarrow())
page_count += 1 page_count += 1
@ -648,54 +620,43 @@ class WikiqParser:
else: else:
self.output_file.close() self.output_file.close()
""" @staticmethod
For performance reasons it's better to write parquet in batches instead of one row at a time. def rows_to_table(rg, schema: Schema) -> Table:
So this function just puts the data on a buffer. If the buffer is full, then it gets flushed (written). cols = []
""" first = rg[0]
for col in first:
cols.append([col])
def write_parquet_row(self, rev_data): for row in rg[1:]:
padata = rev_data.to_pyarrow() for j in range(len(cols)):
self.parquet_buffer.append(padata) cols[j].append(row[j])
arrays: list[Array] = []
typ: DataType
for i, (col, typ) in enumerate(zip(cols, schema.types)):
try:
arrays.append(pa.array(col, typ))
except pa.ArrowInvalid as exc:
print("column index:", i, "type:", typ, file=sys.stderr)
print("data:", col, file=sys.stderr)
print("schema:", file=sys.stderr)
print(schema, file=sys.stderr)
raise exc
return pa.Table.from_arrays(arrays, schema=schema)
if len(self.parquet_buffer) >= self.parquet_buffer_size:
self.flush_parquet_buffer()
""" """
Function that actually writes data to the parquet file. Function that actually writes data to the parquet file.
It needs to transpose the data from row-by-row to column-by-column It needs to transpose the data from row-by-row to column-by-column
""" """
def flush_parquet_buffer(self): def flush_parquet_buffer(self):
""" """
Returns the pyarrow table that we'll write Returns the pyarrow table that we'll write
""" """
def rows_to_table(rg, schema: Schema) -> Table: outtable = self.rows_to_table(self.parquet_buffer, self.schema)
cols = []
first = rg[0]
for col in first:
cols.append([col])
for row in rg[1:]:
for j in range(len(cols)):
cols[j].append(row[j])
arrays: list[Array] = []
typ: DataType
for i, (col, typ) in enumerate(zip(cols, schema.types)):
try:
arrays.append(pa.array(col, typ))
except pa.ArrowInvalid as exc:
print("column index:", i, "type:", typ, file=sys.stderr)
print("data:", col, file=sys.stderr)
print("schema:", file=sys.stderr)
print(schema, file=sys.stderr)
raise exc
return pa.Table.from_arrays(arrays, schema=schema)
outtable = rows_to_table(self.parquet_buffer, self.schema)
if self.pq_writer is None: if self.pq_writer is None:
self.pq_writer: ParquetWriter = ( self.pq_writer: ParquetWriter = (
pq.ParquetWriter(self.output_file, self.schema, flavor='spark')) pq.ParquetWriter(self.output_file, self.schema, flavor='spark'))
@ -703,23 +664,6 @@ class WikiqParser:
self.pq_writer.write_table(outtable) self.pq_writer.write_table(outtable)
self.parquet_buffer = [] self.parquet_buffer = []
# depending on if we are configured to write tsv or parquet, we'll call a different function.
def print_rev_data(self, rev_data):
if self.output_parquet is False:
printfunc = self.write_tsv_row
else:
printfunc = self.write_parquet_row
printfunc(rev_data)
def write_tsv_row(self, rev_data):
if self.print_header:
print(rev_data.header_row(), file=self.output_file)
self.print_header = False
line = rev_data.to_tsv_row()
print(line, file=self.output_file)
def match_archive_suffix(input_filename): def match_archive_suffix(input_filename):
if re.match(r'.*\.7z$', input_filename): if re.match(r'.*\.7z$', input_filename):