From f916af98369156b2a0f0decc58ea6a47e8b4b52e Mon Sep 17 00:00:00 2001 From: Will Beason Date: Mon, 2 Jun 2025 14:13:13 -0500 Subject: [PATCH] Allow specifying output file basename instead of just directory This is optional, and doesn't impact existing users as preexisting behavior when users specify an output directory is unchanged. This makes tests not need to copy large files as part of their execution, as they can ask files to be written to explicit locations. Signed-off-by: Will Beason --- .gitignore | 1 - test/Wikiq_Unit_Test.py | 90 ++++++++++++++++------------------------- wikiq | 46 ++++++++++++++++----- 3 files changed, 69 insertions(+), 68 deletions(-) diff --git a/.gitignore b/.gitignore index 1ae46ba..88a3586 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,3 @@ uv.lock # Python build and test output __pycache__/ /test/test_output/ -/test/test_output.parquet/ diff --git a/test/Wikiq_Unit_Test.py b/test/Wikiq_Unit_Test.py index f68ef0c..edf8439 100644 --- a/test/Wikiq_Unit_Test.py +++ b/test/Wikiq_Unit_Test.py @@ -1,3 +1,4 @@ +import shutil import unittest import os import subprocess @@ -15,7 +16,6 @@ from typing import Final TEST_DIR: Final[str] = os.path.dirname(os.path.realpath(__file__)) WIKIQ: Final[str] = os.path.join(os.path.dirname(TEST_DIR), "wikiq") TEST_OUTPUT_DIR: Final[str] = os.path.join(TEST_DIR, "test_output") -PARQUET_OUTPUT_DIR: Final[str] = os.path.join(TEST_DIR, "test_output.parquet") BASELINE_DIR: Final[str] = os.path.join(TEST_DIR, "baseline_output") IKWIKI: Final[str] = "ikwiki-20180301-pages-meta-history" @@ -30,8 +30,6 @@ def setup(): # Perform directory check and reset here as this is a one-time setup step as opposed to per-test setup. if not os.path.exists(TEST_OUTPUT_DIR): os.mkdir(TEST_OUTPUT_DIR) - if not os.path.exists(PARQUET_OUTPUT_DIR): - os.mkdir(PARQUET_OUTPUT_DIR) # Always run setup, even if this is executed via "python -m unittest" rather @@ -42,7 +40,7 @@ setup() class WikiqTester: def __init__(self, wiki: str, - case_name: str | None = None, + case_name: str, suffix: str | None = None, in_compression: str = "bz2", baseline_format: str = "tsv", @@ -50,10 +48,20 @@ class WikiqTester: ): self.input_file = os.path.join(TEST_DIR, "dumps", "{0}.xml.{1}".format(wiki, in_compression)) - if out_format == "tsv": - self.output_dir = TEST_OUTPUT_DIR - else: - self.output_dir = "{0}.parquet".format(TEST_OUTPUT_DIR) + basename = "{0}_{1}".format(case_name, wiki) + if suffix: + basename = "{0}_{1}".format(basename, suffix) + + self.output = os.path.join(TEST_OUTPUT_DIR, "{0}.{1}".format(basename, out_format)) + + if os.path.exists(self.output): + if os.path.isfile(self.output): + os.remove(self.output) + else: + shutil.rmtree(self.output) + + if out_format == "parquet": + os.makedirs(self.output, exist_ok=True) if suffix is None: self.wikiq_baseline_name = "{0}.{1}".format(wiki, baseline_format) @@ -61,14 +69,10 @@ class WikiqTester: else: self.wikiq_baseline_name = "{0}_{1}.{2}".format(wiki, suffix, baseline_format) self.wikiq_out_name = "{0}_{1}.{2}".format(wiki, suffix, out_format) - self.call_output = os.path.join(self.output_dir, "{0}.{1}".format(wiki, out_format)) # If case_name is unset, there are no relevant baseline or test files. if case_name is not None: self.baseline_file = os.path.join(BASELINE_DIR, "{0}_{1}".format(case_name, self.wikiq_baseline_name)) - self.test_file = os.path.join(self.output_dir, "{0}_{1}".format(case_name, self.wikiq_out_name)) - if os.path.exists(self.test_file): - os.remove(self.test_file) def call_wikiq(self, *args: str, out: bool = True): """ @@ -78,7 +82,7 @@ class WikiqTester: :return: The output of the wikiq call. """ if out: - call = ' '.join([WIKIQ, self.input_file, "-o", self.output_dir, *args]) + call = ' '.join([WIKIQ, self.input_file, "-o", self.output, *args]) else: call = ' '.join([WIKIQ, self.input_file, *args]) @@ -105,9 +109,7 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - - test = pd.read_table(tester.test_file) + test = pd.read_table(tester.output) baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) @@ -119,10 +121,8 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - # as a test let's make sure that we get equal data frames - test = pd.read_table(tester.test_file) + test = pd.read_table(tester.output) num_wrong_ns = sum(~ test.namespace.isin({0, 1})) self.assertEqual(num_wrong_ns, 0) baseline = pd.read_table(tester.baseline_file) @@ -136,10 +136,8 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - # as a test let's make sure that we get equal data frames - test = pd.read_table(tester.test_file) + test = pd.read_table(tester.output) num_wrong_ns = sum(~ test.namespace.isin({0, 1})) self.assertEqual(num_wrong_ns, 0) baseline = pd.read_table(tester.baseline_file) @@ -153,10 +151,8 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - # as a test let's make sure that we get equal data frames - test = pd.read_table(tester.test_file) + test = pd.read_table(tester.output) num_reverted = sum(i is None for i in test.revert) self.assertEqual(num_reverted, 0) baseline = pd.read_table(tester.baseline_file) @@ -170,8 +166,7 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - test = pd.read_table(tester.test_file) + test = pd.read_table(tester.output) baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) @@ -183,9 +178,7 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - - test = pd.read_table(tester.test_file) + test = pd.read_table(tester.output) baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) @@ -197,9 +190,7 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - - test = pd.read_table(tester.test_file) + test = pd.read_table(tester.output) baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) @@ -211,9 +202,7 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - - test = pd.read_table(tester.test_file) + test = pd.read_table(tester.output) baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) @@ -225,9 +214,7 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - - test = pd.read_table(tester.test_file) + test = pd.read_table(tester.output) baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) @@ -239,16 +226,14 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - - test = pd.read_table(tester.test_file) + test = pd.read_table(tester.output) baseline = pd.read_table(tester.baseline_file) test = test.reindex(columns=sorted(test.columns)) assert_frame_equal(test, baseline, check_like=True) def test_malformed_noargs(self): - tester = WikiqTester(wiki=TWINPEAKS, in_compression="7z") + tester = WikiqTester(wiki=TWINPEAKS, case_name="noargs", in_compression="7z") want_exception = 'xml.etree.ElementTree.ParseError: no element found: line 1369, column 0' try: @@ -263,18 +248,16 @@ class WikiqTestCase(unittest.TestCase): tester = WikiqTester(wiki=SAILORMOON, case_name="noargs", in_compression="7z") try: - outs = tester.call_wikiq( "--stdout", "--fandom-2020", out=False).decode("utf8") + outs = tester.call_wikiq("--stdout", "--fandom-2020", out=False).decode("utf8") except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - test = pd.read_table(StringIO(outs)) baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) def test_bad_regex(self): - tester = WikiqTester(wiki=REGEXTEST) + tester = WikiqTester(wiki=REGEXTEST, case_name="bad_regex") # sample arguments for checking that bad arguments get terminated / test_regex_arguments bad_arguments_list = [ @@ -315,9 +298,7 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - - test = pd.read_table(tester.test_file) + test = pd.read_table(tester.output) baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) @@ -337,9 +318,7 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - - test = pd.read_table(tester.test_file) + test = pd.read_table(tester.output) baseline = pd.read_table(tester.baseline_file) assert_frame_equal(test, baseline, check_like=True) @@ -352,10 +331,8 @@ class WikiqTestCase(unittest.TestCase): except subprocess.CalledProcessError as exc: self.fail(exc.stderr.decode("utf8")) - copyfile(tester.call_output, tester.test_file) - # as a test let's make sure that we get equal data frames - test: DataFrame = pd.read_parquet(tester.test_file) + test: DataFrame = pd.read_parquet(tester.output) # test = test.drop(['reverteds'], axis=1) baseline: DataFrame = pd.read_table(tester.baseline_file) @@ -364,6 +341,7 @@ class WikiqTestCase(unittest.TestCase): baseline['date_time'] = pd.to_datetime(baseline['date_time']) # Split strings to the arrays of reverted IDs so they can be compared. baseline['revert'] = baseline['revert'].replace(np.nan, None) + baseline['reverteds'] = baseline['reverteds'].replace(np.nan, None) # baseline['reverteds'] = [None if i is np.nan else [int(j) for j in str(i).split(",")] for i in baseline['reverteds']] baseline['sha1'] = baseline['sha1'].replace(np.nan, None) baseline['editor'] = baseline['editor'].replace(np.nan, None) diff --git a/wikiq b/wikiq index ba9f6c3..f849ae6 100755 --- a/wikiq +++ b/wikiq @@ -201,6 +201,27 @@ class RegexPair(object): return rev_data +def pa_schema() -> pa.Schema: + fields = [ + pa.field("revid", pa.int64()), + pa.field("date_time", pa.timestamp('s')), + pa.field("articleid", pa.int64()), + pa.field("editorid", pa.int64(), nullable=True), + pa.field("title", pa.string()), + pa.field("namespace", pa.int32()), + pa.field("deleted", pa.bool_()), + pa.field("text_chars", pa.int32()), + pa.field("comment_chars", pa.int32()), + pa.field("revert", pa.bool_(), nullable=True), + # reverteds is a string which contains a comma-separated list of reverted revision ids. + pa.field("reverteds", pa.string(), nullable=True), + pa.field("sha1", pa.string()), + pa.field("minor", pa.bool_()), + pa.field("editor", pa.string()), + pa.field("anon", pa.bool_()) + ] + return pa.schema(fields) + """ We used to use a dictionary to collect fields for the output. @@ -229,6 +250,7 @@ class Revision: namespace: int deleted: bool text_chars: int | None = None + comment_chars: int | None = None revert: bool | None = None reverteds: str = None sha1: str | None = None @@ -249,6 +271,7 @@ class Revision: pa.field("namespace", pa.int32()), pa.field("deleted", pa.bool_()), pa.field("text_chars", pa.int32()), + # pa.field("comment_chars", pa.int32()), pa.field("revert", pa.bool_(), nullable=True), # reverteds is a string which contains a comma-separated list of reverted revision ids. pa.field("reverteds", pa.string(), nullable=True), @@ -492,6 +515,7 @@ class WikiqParser: state = persistence.State() # Iterate through a page's revisions + prev_text_chars = 0 for revs in page: revs = list(revs) rev = revs[-1] @@ -525,6 +549,7 @@ class WikiqParser: # TODO rev.bytes doesn't work.. looks like a bug rev_data.text_chars = len(rev.text) + rev_data.comment_chars = sum(0 if r.comment is None else len(r.comment) for r in revs) # generate revert data if rev_detector is not None: @@ -550,8 +575,7 @@ class WikiqParser: # TODO missing: additions_size deletions_size # if collapse user was on, let's run that - if self.collapse_user: - rev_data.collapsed_revs = len(revs) + rev_data.collapsed_revs = len(revs) # get the if self.persist != PersistMethod.none: @@ -652,7 +676,7 @@ def main(): parser.add_argument('dumpfiles', metavar="DUMPFILE", nargs="*", type=str, help="Filename of the compressed or uncompressed XML database dump. If absent, we'll look for content on stdin and output on stdout.") - parser.add_argument('-o', '--output-dir', metavar='DIR', dest='output_dir', type=str, nargs=1, + parser.add_argument('-o', '--output', metavar='OUTPUT', dest='output', type=str, nargs=1, help="Directory for output files. If it ends with .parquet output will be in parquet format.") parser.add_argument('-s', '--stdout', dest="stdout", action="store_true", @@ -714,27 +738,27 @@ def main(): namespaces = None if len(args.dumpfiles) > 0: - output_parquet = False for filename in args.dumpfiles: input_file = open_input_file(filename, args.fandom_2020) # open directory for output - if args.output_dir: - output_dir = args.output_dir[0] + if args.output: + output = args.output[0] else: - output_dir = "." + output = "." - if output_dir.endswith(".parquet"): - output_parquet = True + output_parquet = output.endswith(".parquet") print("Processing file: %s" % filename, file=sys.stderr) if args.stdout: # Parquet libraries need a binary output, so just sys.stdout doesn't work. output_file = sys.stdout.buffer - else: - filename = os.path.join(output_dir, os.path.basename(filename)) + elif os.path.isdir(output) or output_parquet: + filename = os.path.join(output, os.path.basename(filename)) output_file = get_output_filename(filename, parquet=output_parquet) + else: + output_file = output wikiq = WikiqParser(input_file, output_file,