Fix tests

Surprisingly replacing list<str> with str doesn't break anything,
even baselines.

Signed-off-by: Will Beason <willbeason@gmail.com>
This commit is contained in:
Will Beason 2025-05-30 13:56:31 -05:00
parent 032fec3198
commit f9383440a0
2 changed files with 16 additions and 65 deletions

View File

@ -364,7 +364,7 @@ class WikiqTestCase(unittest.TestCase):
baseline['date_time'] = pd.to_datetime(baseline['date_time'])
# Split strings to the arrays of reverted IDs so they can be compared.
baseline['revert'] = baseline['revert'].replace(np.nan, None)
baseline['reverteds'] = [None if i is np.nan else [int(j) for j in str(i).split(",")] for i in baseline['reverteds']]
# baseline['reverteds'] = [None if i is np.nan else [int(j) for j in str(i).split(",")] for i in baseline['reverteds']]
baseline['sha1'] = baseline['sha1'].replace(np.nan, None)
baseline['editor'] = baseline['editor'].replace(np.nan, None)
baseline['anon'] = baseline['anon'].replace(np.nan, None)

79
wikiq
View File

@ -145,10 +145,10 @@ class RegexPair(object):
def get_pyarrow_fields(self):
if self.has_groups:
fields = [pa.field(self._make_key(cap_group), pa.list_(pa.string()))
fields = [pa.field(self._make_key(cap_group), pa.string())
for cap_group in self.capture_groups]
else:
fields = [pa.field(self.label, pa.list_(pa.string()))]
fields = [pa.field(self.label, pa.string())]
return fields
@ -461,6 +461,8 @@ class WikiqParser:
if self.output_parquet:
writer = pq.ParquetWriter(self.output_file, self.schema, flavor='spark')
else:
print(self.output_file, file=sys.stderr)
print(self.schema, file=sys.stderr)
writer = pc.CSVWriter(self.output_file, self.schema, write_options=pc.WriteOptions(delimiter='\t'))
# Iterate through pages
@ -500,13 +502,14 @@ class WikiqParser:
editorid = None if rev.deleted.user or rev.user.id is None else rev.user.id
# create a new data object instead of a dictionary.
rev_data: Revision = self.revdata_type(revid=rev.id,
date_time=datetime.fromtimestamp(rev.timestamp.unix(), tz=timezone.utc),
articleid=page.id,
editorid=editorid,
title=page.title,
deleted=rev.deleted.text,
namespace=namespace
)
date_time=datetime.fromtimestamp(rev.timestamp.unix(),
tz=timezone.utc),
articleid=page.id,
editorid=editorid,
title=page.title,
deleted=rev.deleted.text,
namespace=namespace
)
rev_data = self.matchmake_revision(rev, rev_data)
@ -530,11 +533,9 @@ class WikiqParser:
if rev_detector is not None:
revert = rev_detector.process(text_sha1, rev.id)
rev_data.revert = revert is not None
if revert:
rev_data.revert = True
rev_data.reverteds = ",".join([str(s) for s in revert.reverteds])
else:
rev_data.revert = False
# if the fact that the edit was minor can be hidden, this might be an issue
rev_data.minor = rev.minor
@ -577,7 +578,7 @@ class WikiqParser:
old_rev_data.tokens_removed = len(old_tokens_removed)
old_rev_data.tokens_window = PERSISTENCE_RADIUS - 1
writer.write(rev_data.to_pyarrow())
writer.write(old_rev_data.to_pyarrow())
else:
writer.write(rev_data.to_pyarrow())
@ -605,57 +606,7 @@ class WikiqParser:
print("Done: %s revisions and %s pages." % (rev_count, page_count),
file=sys.stderr)
# remember to flush the parquet_buffer if we're done
if self.output_parquet is True:
self.flush_parquet_buffer()
self.pq_writer.close()
else:
self.output_file.close()
@staticmethod
def rows_to_table(rg, schema: Schema) -> Table:
cols = []
first = rg[0]
for col in first:
cols.append([col])
for row in rg[1:]:
for j in range(len(cols)):
cols[j].append(row[j])
arrays: list[Array] = []
typ: DataType
for i, (col, typ) in enumerate(zip(cols, schema.types)):
try:
arrays.append(pa.array(col, typ))
except pa.ArrowInvalid as exc:
print("column index:", i, "type:", typ, file=sys.stderr)
print("data:", col, file=sys.stderr)
print("schema:", file=sys.stderr)
print(schema, file=sys.stderr)
raise exc
return pa.Table.from_arrays(arrays, schema=schema)
"""
Function that actually writes data to the parquet file.
It needs to transpose the data from row-by-row to column-by-column
"""
def flush_parquet_buffer(self):
"""
Returns the pyarrow table that we'll write
"""
outtable = self.rows_to_table(self.parquet_buffer, self.schema)
if self.pq_writer is None:
self.pq_writer: ParquetWriter = (
pq.ParquetWriter(self.output_file, self.schema, flavor='spark'))
self.pq_writer.write_table(outtable)
self.parquet_buffer = []
writer.close()
def match_archive_suffix(input_filename):