diff --git a/.gitignore b/.gitignore index c90a397..d5257ec 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ # Python build and test output __pycache__/ -test_output/ +/test/test_output/ +/test/test_output.parquet/ diff --git a/test/Wikiq_Unit_Test.py b/test/Wikiq_Unit_Test.py index a45e9d9..c129849 100644 --- a/test/Wikiq_Unit_Test.py +++ b/test/Wikiq_Unit_Test.py @@ -42,16 +42,21 @@ class WikiqTester: ): self.input_file = os.path.join(TEST_DIR, "dumps", "{0}.xml.{1}".format(wiki, in_compression)) + if out_format == "tsv": + self.output_dir = TEST_OUTPUT_DIR + else: + self.output_dir = "{0}.parquet".format(TEST_OUTPUT_DIR) + if suffix is None: self.wikiq_out_name = "{0}.{1}".format(wiki, out_format) else: self.wikiq_out_name = "{0}_{1}.{2}".format(wiki, suffix, out_format) - self.call_output = os.path.join(TEST_OUTPUT_DIR, "{0}.{1}".format(wiki, out_format)) + self.call_output = os.path.join(self.output_dir, "{0}.{1}".format(wiki, out_format)) # If case_name is unset, there are no relevant baseline or test files. if case_name is not None: self.baseline_file = os.path.join(BASELINE_DIR, "{0}_{1}".format(case_name, self.wikiq_out_name)) - self.test_file = os.path.join(TEST_OUTPUT_DIR, "{0}_{1}".format(case_name, self.wikiq_out_name)) + self.test_file = os.path.join(self.output_dir, "{0}_{1}".format(case_name, self.wikiq_out_name)) if os.path.exists(self.test_file): os.remove(self.test_file) @@ -63,7 +68,7 @@ class WikiqTester: :return: The output of the wikiq call. """ if out: - call = ' '.join([WIKIQ, self.input_file, "-o", TEST_OUTPUT_DIR, *args]) + call = ' '.join([WIKIQ, self.input_file, "-o", self.output_dir, *args]) else: call = ' '.join([WIKIQ, self.input_file, *args]) @@ -314,6 +319,20 @@ class WikiqTestCase(unittest.TestCase): baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) + def test_parquet(self): + tester = WikiqTester(IKWIKI, "parquet", out_format="parquet") + + try: + tester.call_wikiq() + except subprocess.CalledProcessError as exc: + self.fail(exc.stderr.decode("utf8")) + + copyfile(tester.call_output, tester.test_file) + + # as a test let's make sure that we get equal data frames + test = pd.read_table(tester.test_file) + baseline = pd.read_table(tester.baseline_file) + assert_frame_equal(test, baseline, check_like=True) if __name__ == '__main__': unittest.main() diff --git a/wikiq b/wikiq index 7553d8c..605ef2a 100755 --- a/wikiq +++ b/wikiq @@ -9,10 +9,12 @@ import sys import os.path import re from datetime import datetime, timezone +from io import TextIOWrapper from subprocess import Popen, PIPE from collections import deque from hashlib import sha1 +from typing import Any, IO, TextIO, Iterable from mwxml import Dump @@ -21,6 +23,9 @@ import mwpersistence import mwreverts from urllib.parse import quote +from pyarrow import Array, Table, Schema, DataType +from pyarrow.parquet import ParquetWriter + TO_ENCODE = ('title', 'editor') PERSISTENCE_RADIUS = 7 from deltas import SequenceMatcher @@ -275,7 +280,7 @@ class RevDataBase: ] # pyarrow is a columnar format, so most of the work happens in the flush_parquet_buffer function - def to_pyarrow(self): + def to_pyarrow(self) -> tuple[Any, ...]: return dc.astuple(self) # logic to convert each field into the wikiq tsv format goes here. @@ -369,9 +374,20 @@ class RevDataCollapsePersistence(RevDataCollapse, RevDataPersistence): class WikiqParser: - def __init__(self, input_file, output_file, regex_match_revision, regex_match_comment, regex_revision_label, - regex_comment_label, collapse_user=False, persist=None, urlencode=False, namespaces=None, - revert_radius=15, output_parquet=True, parquet_buffer_size=2000): + def __init__(self, + input_file: TextIOWrapper | IO[Any] | IO[bytes], + output_file: TextIO | str, + regex_match_revision: list[str], + regex_match_comment: list[str], + regex_revision_label: list[str], + regex_comment_label: list[str], + collapse_user: bool = False, + persist: int = None, + urlencode: bool = False, + namespaces: list[int] | None = None, + revert_radius: int = 15, + output_parquet: bool = True, + parquet_buffer_size: int=2000): """ Parameters: persist : what persistence method to use. Takes a PersistMethod value @@ -379,9 +395,9 @@ class WikiqParser: self.input_file = input_file self.collapse_user = collapse_user - self.persist = persist + self.persist: int = persist self.namespaces = [] - self.urlencode = urlencode + self.urlencode: bool = urlencode self.revert_radius = revert_radius if namespaces is not None: @@ -410,7 +426,7 @@ class WikiqParser: # make_dataclass is a function that defines a new dataclass type. # here we extend the type we have already chosen and add the regular expression types - self.revdata_type = dc.make_dataclass('RevData_Parser', + self.revdata_type: type = dc.make_dataclass('RevData_Parser', fields=regex_fields, bases=(revdata_type,)) @@ -419,7 +435,7 @@ class WikiqParser: self.revdata_type.urlencode = self.urlencode - self.schema = pa.schema(self.revdata_type.pa_schema_fields) + self.schema: Schema = pa.schema(self.revdata_type.pa_schema_fields) # here we initialize the variables we need for output. if output_parquet is True: @@ -532,7 +548,7 @@ class WikiqParser: rev_data = self.revdata_type(revid=rev.id, date_time=datetime.fromtimestamp(rev.timestamp.unix(), tz=timezone.utc), articleid=page.id, - editorid="" if rev.deleted.user == True or rev.user.id is None else rev.user.id, + editorid=None if rev.deleted.user == True or rev.user.id is None else rev.user.id, title=page.title, deleted=rev.deleted.text, namespace=namespace @@ -665,7 +681,7 @@ class WikiqParser: Returns the pyarrow table that we'll write """ - def rows_to_table(rg, schema): + def rows_to_table(rg, schema: Schema) -> Table: cols = [] first = rg[0] for col in first: @@ -675,14 +691,24 @@ class WikiqParser: for j in range(len(cols)): cols[j].append(row[j]) - arrays = [] - for col, typ in zip(cols, schema.types): - arrays.append(pa.array(col, typ)) + arrays: list[Array] = [] + + typ: DataType + for i, (col, typ) in enumerate(zip(cols, schema.types)): + try: + arrays.append(pa.array(col, typ)) + except pa.ArrowInvalid as exc: + print("column index:", i, "type:", typ, file=sys.stderr) + print("data:", col, file=sys.stderr) + print("schema:", file=sys.stderr) + print(schema, file=sys.stderr) + raise exc return pa.Table.from_arrays(arrays, schema=schema) outtable = rows_to_table(self.parquet_buffer, self.schema) if self.pq_writer is None: - self.pq_writer = pq.ParquetWriter(self.output_file, self.schema, flavor='spark') + self.pq_writer: ParquetWriter = ( + pq.ParquetWriter(self.output_file, self.schema, flavor='spark')) self.pq_writer.write_table(outtable) self.parquet_buffer = [] @@ -705,7 +731,7 @@ class WikiqParser: print(line, file=self.output_file) -def open_input_file(input_filename): +def open_input_file(input_filename) -> TextIOWrapper | IO[Any] | IO[bytes]: if re.match(r'.*\.7z$', input_filename): cmd = ["7za", "x", "-so", input_filename, "*.xml"] elif re.match(r'.*\.gz$', input_filename): @@ -721,7 +747,7 @@ def open_input_file(input_filename): return open(input_filename, 'r') -def get_output_filename(input_filename, parquet=False): +def get_output_filename(input_filename, parquet=False) -> str: output_filename = re.sub(r'\.(7z|gz|bz2)?$', '', input_filename) output_filename = re.sub(r'\.xml', '', output_filename) if parquet is False: