Fix tests
Surprisingly replacing list<str> with str doesn't break anything, even baselines. Signed-off-by: Will Beason <willbeason@gmail.com>
This commit is contained in:
parent
032fec3198
commit
f9383440a0
@ -364,7 +364,7 @@ class WikiqTestCase(unittest.TestCase):
|
||||
baseline['date_time'] = pd.to_datetime(baseline['date_time'])
|
||||
# Split strings to the arrays of reverted IDs so they can be compared.
|
||||
baseline['revert'] = baseline['revert'].replace(np.nan, None)
|
||||
baseline['reverteds'] = [None if i is np.nan else [int(j) for j in str(i).split(",")] for i in baseline['reverteds']]
|
||||
# baseline['reverteds'] = [None if i is np.nan else [int(j) for j in str(i).split(",")] for i in baseline['reverteds']]
|
||||
baseline['sha1'] = baseline['sha1'].replace(np.nan, None)
|
||||
baseline['editor'] = baseline['editor'].replace(np.nan, None)
|
||||
baseline['anon'] = baseline['anon'].replace(np.nan, None)
|
||||
|
67
wikiq
67
wikiq
@ -145,10 +145,10 @@ class RegexPair(object):
|
||||
|
||||
def get_pyarrow_fields(self):
|
||||
if self.has_groups:
|
||||
fields = [pa.field(self._make_key(cap_group), pa.list_(pa.string()))
|
||||
fields = [pa.field(self._make_key(cap_group), pa.string())
|
||||
for cap_group in self.capture_groups]
|
||||
else:
|
||||
fields = [pa.field(self.label, pa.list_(pa.string()))]
|
||||
fields = [pa.field(self.label, pa.string())]
|
||||
|
||||
return fields
|
||||
|
||||
@ -461,6 +461,8 @@ class WikiqParser:
|
||||
if self.output_parquet:
|
||||
writer = pq.ParquetWriter(self.output_file, self.schema, flavor='spark')
|
||||
else:
|
||||
print(self.output_file, file=sys.stderr)
|
||||
print(self.schema, file=sys.stderr)
|
||||
writer = pc.CSVWriter(self.output_file, self.schema, write_options=pc.WriteOptions(delimiter='\t'))
|
||||
|
||||
# Iterate through pages
|
||||
@ -500,7 +502,8 @@ class WikiqParser:
|
||||
editorid = None if rev.deleted.user or rev.user.id is None else rev.user.id
|
||||
# create a new data object instead of a dictionary.
|
||||
rev_data: Revision = self.revdata_type(revid=rev.id,
|
||||
date_time=datetime.fromtimestamp(rev.timestamp.unix(), tz=timezone.utc),
|
||||
date_time=datetime.fromtimestamp(rev.timestamp.unix(),
|
||||
tz=timezone.utc),
|
||||
articleid=page.id,
|
||||
editorid=editorid,
|
||||
title=page.title,
|
||||
@ -530,11 +533,9 @@ class WikiqParser:
|
||||
if rev_detector is not None:
|
||||
revert = rev_detector.process(text_sha1, rev.id)
|
||||
|
||||
rev_data.revert = revert is not None
|
||||
if revert:
|
||||
rev_data.revert = True
|
||||
rev_data.reverteds = ",".join([str(s) for s in revert.reverteds])
|
||||
else:
|
||||
rev_data.revert = False
|
||||
|
||||
# if the fact that the edit was minor can be hidden, this might be an issue
|
||||
rev_data.minor = rev.minor
|
||||
@ -577,7 +578,7 @@ class WikiqParser:
|
||||
old_rev_data.tokens_removed = len(old_tokens_removed)
|
||||
old_rev_data.tokens_window = PERSISTENCE_RADIUS - 1
|
||||
|
||||
writer.write(rev_data.to_pyarrow())
|
||||
writer.write(old_rev_data.to_pyarrow())
|
||||
|
||||
else:
|
||||
writer.write(rev_data.to_pyarrow())
|
||||
@ -605,57 +606,7 @@ class WikiqParser:
|
||||
print("Done: %s revisions and %s pages." % (rev_count, page_count),
|
||||
file=sys.stderr)
|
||||
|
||||
# remember to flush the parquet_buffer if we're done
|
||||
if self.output_parquet is True:
|
||||
self.flush_parquet_buffer()
|
||||
self.pq_writer.close()
|
||||
|
||||
else:
|
||||
self.output_file.close()
|
||||
|
||||
@staticmethod
|
||||
def rows_to_table(rg, schema: Schema) -> Table:
|
||||
cols = []
|
||||
first = rg[0]
|
||||
for col in first:
|
||||
cols.append([col])
|
||||
|
||||
for row in rg[1:]:
|
||||
for j in range(len(cols)):
|
||||
cols[j].append(row[j])
|
||||
|
||||
arrays: list[Array] = []
|
||||
|
||||
typ: DataType
|
||||
for i, (col, typ) in enumerate(zip(cols, schema.types)):
|
||||
try:
|
||||
arrays.append(pa.array(col, typ))
|
||||
except pa.ArrowInvalid as exc:
|
||||
print("column index:", i, "type:", typ, file=sys.stderr)
|
||||
print("data:", col, file=sys.stderr)
|
||||
print("schema:", file=sys.stderr)
|
||||
print(schema, file=sys.stderr)
|
||||
raise exc
|
||||
return pa.Table.from_arrays(arrays, schema=schema)
|
||||
|
||||
|
||||
"""
|
||||
Function that actually writes data to the parquet file.
|
||||
It needs to transpose the data from row-by-row to column-by-column
|
||||
"""
|
||||
def flush_parquet_buffer(self):
|
||||
|
||||
"""
|
||||
Returns the pyarrow table that we'll write
|
||||
"""
|
||||
|
||||
outtable = self.rows_to_table(self.parquet_buffer, self.schema)
|
||||
if self.pq_writer is None:
|
||||
self.pq_writer: ParquetWriter = (
|
||||
pq.ParquetWriter(self.output_file, self.schema, flavor='spark'))
|
||||
|
||||
self.pq_writer.write_table(outtable)
|
||||
self.parquet_buffer = []
|
||||
writer.close()
|
||||
|
||||
|
||||
def match_archive_suffix(input_filename):
|
||||
|
Loading…
Reference in New Issue
Block a user