diff --git a/src/wikiq/__init__.py b/src/wikiq/__init__.py index 9597cf9..e1ce26c 100755 --- a/src/wikiq/__init__.py +++ b/src/wikiq/__init__.py @@ -49,14 +49,16 @@ class PersistMethod: async def diff_async(differ, last_text, text): + """Returns (result, timed_out) tuple.""" try: loop = asyncio.get_running_loop() - return await asyncio.wait_for( + result = await asyncio.wait_for( asyncio.to_thread(differ.inline_json_diff, last_text, text), timeout=DIFF_TIMEOUT ) + return result, False except TimeoutError as e: - return None + return None, True def calculate_persistence(tokens_added): @@ -246,6 +248,8 @@ class WikiqParser: external_links: bool = False, citations: bool = False, wikilinks: bool = False, + templates: bool = False, + headings: bool = False, ): """ Parameters: @@ -265,6 +269,8 @@ class WikiqParser: self.external_links = external_links self.citations = citations self.wikilinks = wikilinks + self.templates = templates + self.headings = headings if namespaces is not None: self.namespace_filter = set(namespaces) else: @@ -405,7 +411,7 @@ class WikiqParser: table.columns.append(tables.RevisionCollapsed()) # Create shared parser if any wikitext feature is enabled - if self.external_links or self.citations or self.wikilinks: + if self.external_links or self.citations or self.wikilinks or self.templates or self.headings: wikitext_parser = WikitextParser() if self.external_links: @@ -417,6 +423,16 @@ class WikiqParser: if self.wikilinks: table.columns.append(tables.RevisionWikilinks(wikitext_parser)) + if self.templates: + table.columns.append(tables.RevisionTemplates(wikitext_parser)) + + if self.headings: + table.columns.append(tables.RevisionHeadings(wikitext_parser)) + + # Add parser timeout tracking if any wikitext feature is enabled + if self.external_links or self.citations or self.wikilinks or self.templates or self.headings: + table.columns.append(tables.RevisionParserTimeout(wikitext_parser)) + # extract list of namespaces self.namespaces = { ns.name: ns.id for ns in dump.mwiterator.site_info.namespaces @@ -434,6 +450,7 @@ class WikiqParser: from wikiq.diff_pyarrow_schema import diff_field schema = schema.append(diff_field) + schema = schema.append(pa.field("diff_timeout", pa.bool_())) if self.diff and self.persist == PersistMethod.none: table.columns.append(tables.RevisionText()) @@ -746,12 +763,14 @@ class WikiqParser: if self.diff: last_text = last_rev_text new_diffs = [] + diff_timeouts = [] for i, text in enumerate(row_buffer["text"]): - diff = asyncio.run(diff_async(differ, last_text, text)) - if diff is None: + diff, timed_out = asyncio.run(diff_async(differ, last_text, text)) + if timed_out: print(f"WARNING! wikidiff2 timeout for rev: {row_buffer['revid'][i]}. Falling back to default limits.", file=sys.stderr) diff = fast_differ.inline_json_diff(last_text, text) new_diffs.append(diff) + diff_timeouts.append(timed_out) last_text = text row_buffer["diff"] = [ [ @@ -761,6 +780,7 @@ class WikiqParser: ] for diff in new_diffs ] + row_buffer["diff_timeout"] = diff_timeouts # end persistence logic if self.diff or self.persist != PersistMethod.none: @@ -1179,6 +1199,22 @@ def main(): help="Extract internal wikilinks from each revision.", ) + parser.add_argument( + "--templates", + dest="templates", + action="store_true", + default=False, + help="Extract templates with their parameters from each revision.", + ) + + parser.add_argument( + "--headings", + dest="headings", + action="store_true", + default=False, + help="Extract section headings from each revision.", + ) + parser.add_argument( "-PNS", "--partition-namespaces", @@ -1286,6 +1322,8 @@ def main(): external_links=args.external_links, citations=args.citations, wikilinks=args.wikilinks, + templates=args.templates, + headings=args.headings, ) wikiq.process() @@ -1316,6 +1354,8 @@ def main(): external_links=args.external_links, citations=args.citations, wikilinks=args.wikilinks, + templates=args.templates, + headings=args.headings, ) wikiq.process() diff --git a/src/wikiq/tables.py b/src/wikiq/tables.py index 1ea1eaf..269af6a 100644 --- a/src/wikiq/tables.py +++ b/src/wikiq/tables.py @@ -272,3 +272,56 @@ class RevisionWikilinks(RevisionField[Union[list[dict], None]]): if revision.deleted.text: return None return self.wikitext_parser.extract_wikilinks(revision.text) + + +class RevisionTemplates(RevisionField[Union[list[dict], None]]): + """Extract all templates from revision text.""" + + # Struct type with name and params map + field = pa.field("templates", pa.list_(pa.struct([ + pa.field("name", pa.string()), + pa.field("params", pa.map_(pa.string(), pa.string())), + ])), nullable=True) + + def __init__(self, wikitext_parser: "WikitextParser"): + super().__init__() + self.wikitext_parser = wikitext_parser + + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> Union[list[dict], None]: + revision = revisions[-1] + if revision.deleted.text: + return None + return self.wikitext_parser.extract_templates(revision.text) + + +class RevisionHeadings(RevisionField[Union[list[dict], None]]): + """Extract all section headings from revision text.""" + + # Struct type with level and text + field = pa.field("headings", pa.list_(pa.struct([ + pa.field("level", pa.int8()), + pa.field("text", pa.string()), + ])), nullable=True) + + def __init__(self, wikitext_parser: "WikitextParser"): + super().__init__() + self.wikitext_parser = wikitext_parser + + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> Union[list[dict], None]: + revision = revisions[-1] + if revision.deleted.text: + return None + return self.wikitext_parser.extract_headings(revision.text) + + +class RevisionParserTimeout(RevisionField[bool]): + """Track whether the wikitext parser timed out for this revision.""" + + field = pa.field("parser_timeout", pa.bool_()) + + def __init__(self, wikitext_parser: "WikitextParser"): + super().__init__() + self.wikitext_parser = wikitext_parser + + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> bool: + return self.wikitext_parser.last_parse_timed_out diff --git a/src/wikiq/wikitext_parser.py b/src/wikiq/wikitext_parser.py index 4a69533..8acc7be 100644 --- a/src/wikiq/wikitext_parser.py +++ b/src/wikiq/wikitext_parser.py @@ -1,8 +1,11 @@ """Shared wikitext parser with caching to avoid duplicate parsing.""" from __future__ import annotations +import asyncio import mwparserfromhell +PARSER_TIMEOUT = 60 # seconds + class WikitextParser: """Caches parsed wikicode to avoid re-parsing the same text.""" @@ -17,12 +20,24 @@ class WikitextParser: def __init__(self): self._cached_text: str | None = None self._cached_wikicode = None + self.last_parse_timed_out: bool = False + + async def _parse_async(self, text: str): + """Parse wikitext with timeout protection.""" + try: + result = await asyncio.wait_for( + asyncio.to_thread(mwparserfromhell.parse, text), + timeout=PARSER_TIMEOUT + ) + return result, False + except TimeoutError: + return None, True def _get_wikicode(self, text: str): """Parse text and cache result. Returns cached result if text unchanged.""" if text != self._cached_text: self._cached_text = text - self._cached_wikicode = mwparserfromhell.parse(text) + self._cached_wikicode, self.last_parse_timed_out = asyncio.run(self._parse_async(text)) return self._cached_wikicode def extract_external_links(self, text: str | None) -> list[str] | None: @@ -75,3 +90,37 @@ class WikitextParser: return result except Exception: return None + + def extract_templates(self, text: str | None) -> list[dict] | None: + """Extract all templates with their names and parameters.""" + if text is None: + return None + try: + wikicode = self._get_wikicode(text) + result = [] + for template in wikicode.filter_templates(): + name = str(template.name).strip() + params = {} + for param in template.params: + param_name = str(param.name).strip() + param_value = str(param.value).strip() + params[param_name] = param_value + result.append({"name": name, "params": params}) + return result + except Exception: + return None + + def extract_headings(self, text: str | None) -> list[dict] | None: + """Extract all section headings with their levels.""" + if text is None: + return None + try: + wikicode = self._get_wikicode(text) + result = [] + for heading in wikicode.filter_headings(): + level = heading.level + heading_text = str(heading.title).strip() + result.append({"level": level, "text": heading_text}) + return result + except Exception: + return None diff --git a/test/Wikiq_Unit_Test.py b/test/Wikiq_Unit_Test.py index d139074..f37ef6c 100644 --- a/test/Wikiq_Unit_Test.py +++ b/test/Wikiq_Unit_Test.py @@ -872,3 +872,94 @@ def test_wikilinks(): assert actual_dicts == expected, f"Row {idx}: wikilinks mismatch" print(f"Wikilinks test passed! {len(test)} rows processed") + + +def test_templates(): + """Test that --templates extracts templates correctly.""" + import mwparserfromhell + + tester = WikiqTester(SAILORMOON, "templates", in_compression="7z", out_format="parquet") + + try: + tester.call_wikiq("--templates", "--text", "--fandom-2020") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet") + + # Verify templates column exists + assert "templates" in test.columns, "templates column should exist" + + # Verify column has list/array type + assert test["templates"].apply(lambda x: x is None or hasattr(x, '__len__')).all() + + # Verify extraction matches mwparserfromhell for sample rows + rows_with_templates = test[test["templates"].apply(lambda x: x is not None and len(x) > 0)] + if len(rows_with_templates) > 0: + sample = rows_with_templates.head(5) + for idx, row in sample.iterrows(): + text = row["text"] + if text: + wikicode = mwparserfromhell.parse(text) + expected = [] + for template in wikicode.filter_templates(): + name = str(template.name).strip() + params = {} + for param in template.params: + param_name = str(param.name).strip() + param_value = str(param.value).strip() + params[param_name] = param_value + expected.append({"name": name, "params": params}) + + actual = list(row["templates"]) + # Convert to comparable format + actual_list = [] + for item in actual: + actual_list.append({ + "name": item["name"], + "params": dict(item["params"]) if item["params"] else {} + }) + assert actual_list == expected, f"Row {idx}: templates mismatch" + + print(f"Templates test passed! {len(test)} rows processed") + + +def test_headings(): + """Test that --headings extracts section headings correctly.""" + import mwparserfromhell + + tester = WikiqTester(SAILORMOON, "headings", in_compression="7z", out_format="parquet") + + try: + tester.call_wikiq("--headings", "--text", "--fandom-2020") + except subprocess.CalledProcessError as exc: + pytest.fail(exc.stderr.decode("utf8")) + + test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet") + + # Verify headings column exists + assert "headings" in test.columns, "headings column should exist" + + # Verify column has list/array type + assert test["headings"].apply(lambda x: x is None or hasattr(x, '__len__')).all() + + # Verify extraction matches mwparserfromhell for sample rows + rows_with_headings = test[test["headings"].apply(lambda x: x is not None and len(x) > 0)] + if len(rows_with_headings) > 0: + sample = rows_with_headings.head(5) + for idx, row in sample.iterrows(): + text = row["text"] + if text: + wikicode = mwparserfromhell.parse(text) + expected = [] + for heading in wikicode.filter_headings(): + level = heading.level + heading_text = str(heading.title).strip() + expected.append({"level": level, "text": heading_text}) + + actual = list(row["headings"]) + # Convert to comparable format + actual_list = [{"level": item["level"], "text": item["text"]} for item in actual] + assert actual_list == expected, f"Row {idx}: headings mismatch" + + print(f"Headings test passed! {len(test)} rows processed") diff --git a/test/baseline_output/diff_sailormoon.parquet b/test/baseline_output/diff_sailormoon.parquet index e70cc1d..1b14699 100644 Binary files a/test/baseline_output/diff_sailormoon.parquet and b/test/baseline_output/diff_sailormoon.parquet differ