Get parquet libraries writing files
Tests broken due to url encoding, which can likely now be removed. Signed-off-by: Will Beason <willbeason@gmail.com>
This commit is contained in:
parent
4dde25c508
commit
0d56267ae0
148
wikiq
148
wikiq
@ -15,14 +15,13 @@ from itertools import groupby
|
|||||||
from subprocess import Popen, PIPE
|
from subprocess import Popen, PIPE
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from typing import Any, IO, TextIO
|
from typing import Any, IO, TextIO, Final
|
||||||
|
|
||||||
from mwxml import Dump
|
from mwxml import Dump
|
||||||
|
|
||||||
from deltas.tokenizers import wikitext_split
|
from deltas.tokenizers import wikitext_split
|
||||||
import mwpersistence
|
import mwpersistence
|
||||||
import mwreverts
|
import mwreverts
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
from pyarrow import Array, Table, Schema, DataType
|
from pyarrow import Array, Table, Schema, DataType
|
||||||
from pyarrow.parquet import ParquetWriter
|
from pyarrow.parquet import ParquetWriter
|
||||||
@ -36,6 +35,7 @@ import dataclasses as dc
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
|
import pyarrow.csv as pc
|
||||||
|
|
||||||
|
|
||||||
class PersistMethod:
|
class PersistMethod:
|
||||||
@ -233,7 +233,7 @@ class Revision:
|
|||||||
deleted: bool
|
deleted: bool
|
||||||
text_chars: int | None = None
|
text_chars: int | None = None
|
||||||
revert: bool | None = None
|
revert: bool | None = None
|
||||||
reverteds: list[int] = None
|
reverteds: str = None
|
||||||
sha1: str | None = None
|
sha1: str | None = None
|
||||||
minor: bool | None = None
|
minor: bool | None = None
|
||||||
editor: str | None = None
|
editor: str | None = None
|
||||||
@ -248,7 +248,7 @@ class Revision:
|
|||||||
# this isn't a dataclass field since it doesn't have a type annotation
|
# this isn't a dataclass field since it doesn't have a type annotation
|
||||||
pa_schema_fields = [
|
pa_schema_fields = [
|
||||||
pa.field("revid", pa.int64()),
|
pa.field("revid", pa.int64()),
|
||||||
pa.field("date_time", pa.timestamp('ms')),
|
pa.field("date_time", pa.timestamp('s')),
|
||||||
pa.field("articleid", pa.int64()),
|
pa.field("articleid", pa.int64()),
|
||||||
pa.field("editorid", pa.int64(), nullable=True),
|
pa.field("editorid", pa.int64(), nullable=True),
|
||||||
pa.field("title", pa.string()),
|
pa.field("title", pa.string()),
|
||||||
@ -256,7 +256,8 @@ class Revision:
|
|||||||
pa.field("deleted", pa.bool_()),
|
pa.field("deleted", pa.bool_()),
|
||||||
pa.field("text_chars", pa.int32()),
|
pa.field("text_chars", pa.int32()),
|
||||||
pa.field("revert", pa.bool_(), nullable=True),
|
pa.field("revert", pa.bool_(), nullable=True),
|
||||||
pa.field("reverteds", pa.list_(pa.int64()), nullable=True),
|
# reverteds is a string which contains a comma-separated list of reverted revision ids.
|
||||||
|
pa.field("reverteds", pa.string(), nullable=True),
|
||||||
pa.field("sha1", pa.string()),
|
pa.field("sha1", pa.string()),
|
||||||
pa.field("minor", pa.bool_()),
|
pa.field("minor", pa.bool_()),
|
||||||
pa.field("editor", pa.string()),
|
pa.field("editor", pa.string()),
|
||||||
@ -264,45 +265,10 @@ class Revision:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# pyarrow is a columnar format, so most of the work happens in the flush_parquet_buffer function
|
# pyarrow is a columnar format, so most of the work happens in the flush_parquet_buffer function
|
||||||
def to_pyarrow(self):
|
def to_pyarrow(self) -> pa.RecordBatch:
|
||||||
return dc.astuple(self)
|
d = dc.asdict(self)
|
||||||
|
lists = [[d[field.name]] for field in self.pa_schema_fields]
|
||||||
# logic to convert each field into the wikiq tsv format goes here.
|
return pa.record_batch(lists, schema=pa.schema(self.pa_schema_fields))
|
||||||
def to_tsv_row(self):
|
|
||||||
|
|
||||||
row = []
|
|
||||||
for f in dc.fields(self):
|
|
||||||
val = getattr(self, f.name)
|
|
||||||
if getattr(self, f.name) is None:
|
|
||||||
row.append("")
|
|
||||||
elif f.type == bool:
|
|
||||||
row.append("TRUE" if val else "FALSE")
|
|
||||||
|
|
||||||
elif f.type == datetime:
|
|
||||||
row.append(val.strftime('%Y-%m-%d %H:%M:%S'))
|
|
||||||
|
|
||||||
elif f.name in {'editor', 'title'}:
|
|
||||||
s = '"' + val + '"'
|
|
||||||
if self.urlencode and f.name in TO_ENCODE:
|
|
||||||
row.append(quote(str(s)))
|
|
||||||
else:
|
|
||||||
row.append(s)
|
|
||||||
|
|
||||||
elif f.type == list[int]:
|
|
||||||
row.append('"' + ",".join([str(x) for x in val]) + '"')
|
|
||||||
|
|
||||||
elif f.type == str:
|
|
||||||
if self.urlencode and f.name in TO_ENCODE:
|
|
||||||
row.append(quote(str(val)))
|
|
||||||
else:
|
|
||||||
row.append(val)
|
|
||||||
else:
|
|
||||||
row.append(val)
|
|
||||||
|
|
||||||
return '\t'.join(map(str, row))
|
|
||||||
|
|
||||||
def header_row(self):
|
|
||||||
return '\t'.join(map(lambda f: f.name, dc.fields(self)))
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -434,7 +400,7 @@ class WikiqParser:
|
|||||||
|
|
||||||
self.output_file = output_file
|
self.output_file = output_file
|
||||||
else:
|
else:
|
||||||
self.output_file = open(output_file, 'w')
|
self.output_file = open(output_file, 'wb')
|
||||||
self.output_parquet = False
|
self.output_parquet = False
|
||||||
|
|
||||||
def make_matchmake_pairs(self, patterns, labels):
|
def make_matchmake_pairs(self, patterns, labels):
|
||||||
@ -498,6 +464,12 @@ class WikiqParser:
|
|||||||
page_count = 0
|
page_count = 0
|
||||||
rev_count = 0
|
rev_count = 0
|
||||||
|
|
||||||
|
writer: pq.ParquetWriter | pc.CSVWriter
|
||||||
|
if self.output_parquet:
|
||||||
|
writer = pq.ParquetWriter(self.output_file, self.schema, flavor='spark')
|
||||||
|
else:
|
||||||
|
writer = pc.CSVWriter(self.output_file, self.schema, write_options=pc.WriteOptions(delimiter='\t'))
|
||||||
|
|
||||||
# Iterate through pages
|
# Iterate through pages
|
||||||
for page in dump:
|
for page in dump:
|
||||||
namespace = page.namespace if page.namespace is not None else self.__get_namespace_from_title(page.title)
|
namespace = page.namespace if page.namespace is not None else self.__get_namespace_from_title(page.title)
|
||||||
@ -534,7 +506,7 @@ class WikiqParser:
|
|||||||
|
|
||||||
editorid = None if rev.deleted.user or rev.user.id is None else rev.user.id
|
editorid = None if rev.deleted.user or rev.user.id is None else rev.user.id
|
||||||
# create a new data object instead of a dictionary.
|
# create a new data object instead of a dictionary.
|
||||||
rev_data = self.revdata_type(revid=rev.id,
|
rev_data: Revision = self.revdata_type(revid=rev.id,
|
||||||
date_time=datetime.fromtimestamp(rev.timestamp.unix(), tz=timezone.utc),
|
date_time=datetime.fromtimestamp(rev.timestamp.unix(), tz=timezone.utc),
|
||||||
articleid=page.id,
|
articleid=page.id,
|
||||||
editorid=editorid,
|
editorid=editorid,
|
||||||
@ -567,7 +539,7 @@ class WikiqParser:
|
|||||||
|
|
||||||
if revert:
|
if revert:
|
||||||
rev_data.revert = True
|
rev_data.revert = True
|
||||||
rev_data.reverteds = revert.reverteds
|
rev_data.reverteds = ",".join([str(s) for s in revert.reverteds])
|
||||||
else:
|
else:
|
||||||
rev_data.revert = False
|
rev_data.revert = False
|
||||||
|
|
||||||
@ -612,10 +584,10 @@ class WikiqParser:
|
|||||||
old_rev_data.tokens_removed = len(old_tokens_removed)
|
old_rev_data.tokens_removed = len(old_tokens_removed)
|
||||||
old_rev_data.tokens_window = PERSISTENCE_RADIUS - 1
|
old_rev_data.tokens_window = PERSISTENCE_RADIUS - 1
|
||||||
|
|
||||||
self.print_rev_data(old_rev_data)
|
writer.write(rev_data.to_pyarrow())
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.print_rev_data(rev_data)
|
writer.write(rev_data.to_pyarrow())
|
||||||
|
|
||||||
rev_count += 1
|
rev_count += 1
|
||||||
|
|
||||||
@ -633,7 +605,7 @@ class WikiqParser:
|
|||||||
rev_data.tokens_added = num_tokens
|
rev_data.tokens_added = num_tokens
|
||||||
rev_data.tokens_removed = len(tokens_removed)
|
rev_data.tokens_removed = len(tokens_removed)
|
||||||
rev_data.tokens_window = len(window) - (i + 1)
|
rev_data.tokens_window = len(window) - (i + 1)
|
||||||
self.print_rev_data(rev_data)
|
writer.write(rev_data.to_pyarrow())
|
||||||
|
|
||||||
page_count += 1
|
page_count += 1
|
||||||
|
|
||||||
@ -648,54 +620,43 @@ class WikiqParser:
|
|||||||
else:
|
else:
|
||||||
self.output_file.close()
|
self.output_file.close()
|
||||||
|
|
||||||
"""
|
@staticmethod
|
||||||
For performance reasons it's better to write parquet in batches instead of one row at a time.
|
def rows_to_table(rg, schema: Schema) -> Table:
|
||||||
So this function just puts the data on a buffer. If the buffer is full, then it gets flushed (written).
|
cols = []
|
||||||
"""
|
first = rg[0]
|
||||||
|
for col in first:
|
||||||
|
cols.append([col])
|
||||||
|
|
||||||
def write_parquet_row(self, rev_data):
|
for row in rg[1:]:
|
||||||
padata = rev_data.to_pyarrow()
|
for j in range(len(cols)):
|
||||||
self.parquet_buffer.append(padata)
|
cols[j].append(row[j])
|
||||||
|
|
||||||
|
arrays: list[Array] = []
|
||||||
|
|
||||||
|
typ: DataType
|
||||||
|
for i, (col, typ) in enumerate(zip(cols, schema.types)):
|
||||||
|
try:
|
||||||
|
arrays.append(pa.array(col, typ))
|
||||||
|
except pa.ArrowInvalid as exc:
|
||||||
|
print("column index:", i, "type:", typ, file=sys.stderr)
|
||||||
|
print("data:", col, file=sys.stderr)
|
||||||
|
print("schema:", file=sys.stderr)
|
||||||
|
print(schema, file=sys.stderr)
|
||||||
|
raise exc
|
||||||
|
return pa.Table.from_arrays(arrays, schema=schema)
|
||||||
|
|
||||||
if len(self.parquet_buffer) >= self.parquet_buffer_size:
|
|
||||||
self.flush_parquet_buffer()
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Function that actually writes data to the parquet file.
|
Function that actually writes data to the parquet file.
|
||||||
It needs to transpose the data from row-by-row to column-by-column
|
It needs to transpose the data from row-by-row to column-by-column
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def flush_parquet_buffer(self):
|
def flush_parquet_buffer(self):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Returns the pyarrow table that we'll write
|
Returns the pyarrow table that we'll write
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def rows_to_table(rg, schema: Schema) -> Table:
|
outtable = self.rows_to_table(self.parquet_buffer, self.schema)
|
||||||
cols = []
|
|
||||||
first = rg[0]
|
|
||||||
for col in first:
|
|
||||||
cols.append([col])
|
|
||||||
|
|
||||||
for row in rg[1:]:
|
|
||||||
for j in range(len(cols)):
|
|
||||||
cols[j].append(row[j])
|
|
||||||
|
|
||||||
arrays: list[Array] = []
|
|
||||||
|
|
||||||
typ: DataType
|
|
||||||
for i, (col, typ) in enumerate(zip(cols, schema.types)):
|
|
||||||
try:
|
|
||||||
arrays.append(pa.array(col, typ))
|
|
||||||
except pa.ArrowInvalid as exc:
|
|
||||||
print("column index:", i, "type:", typ, file=sys.stderr)
|
|
||||||
print("data:", col, file=sys.stderr)
|
|
||||||
print("schema:", file=sys.stderr)
|
|
||||||
print(schema, file=sys.stderr)
|
|
||||||
raise exc
|
|
||||||
return pa.Table.from_arrays(arrays, schema=schema)
|
|
||||||
|
|
||||||
outtable = rows_to_table(self.parquet_buffer, self.schema)
|
|
||||||
if self.pq_writer is None:
|
if self.pq_writer is None:
|
||||||
self.pq_writer: ParquetWriter = (
|
self.pq_writer: ParquetWriter = (
|
||||||
pq.ParquetWriter(self.output_file, self.schema, flavor='spark'))
|
pq.ParquetWriter(self.output_file, self.schema, flavor='spark'))
|
||||||
@ -703,23 +664,6 @@ class WikiqParser:
|
|||||||
self.pq_writer.write_table(outtable)
|
self.pq_writer.write_table(outtable)
|
||||||
self.parquet_buffer = []
|
self.parquet_buffer = []
|
||||||
|
|
||||||
# depending on if we are configured to write tsv or parquet, we'll call a different function.
|
|
||||||
def print_rev_data(self, rev_data):
|
|
||||||
if self.output_parquet is False:
|
|
||||||
printfunc = self.write_tsv_row
|
|
||||||
else:
|
|
||||||
printfunc = self.write_parquet_row
|
|
||||||
|
|
||||||
printfunc(rev_data)
|
|
||||||
|
|
||||||
def write_tsv_row(self, rev_data):
|
|
||||||
if self.print_header:
|
|
||||||
print(rev_data.header_row(), file=self.output_file)
|
|
||||||
self.print_header = False
|
|
||||||
|
|
||||||
line = rev_data.to_tsv_row()
|
|
||||||
print(line, file=self.output_file)
|
|
||||||
|
|
||||||
|
|
||||||
def match_archive_suffix(input_filename):
|
def match_archive_suffix(input_filename):
|
||||||
if re.match(r'.*\.7z$', input_filename):
|
if re.match(r'.*\.7z$', input_filename):
|
||||||
|
Loading…
Reference in New Issue
Block a user