add templates and headings to wikiq.

This commit is contained in:
Nathan TeBlunthuis 2025-12-02 17:51:08 -08:00
parent d3517ed5ca
commit 5ce9808b50
5 changed files with 239 additions and 6 deletions

View File

@ -49,14 +49,16 @@ class PersistMethod:
async def diff_async(differ, last_text, text): async def diff_async(differ, last_text, text):
"""Returns (result, timed_out) tuple."""
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
return await asyncio.wait_for( result = await asyncio.wait_for(
asyncio.to_thread(differ.inline_json_diff, last_text, text), asyncio.to_thread(differ.inline_json_diff, last_text, text),
timeout=DIFF_TIMEOUT timeout=DIFF_TIMEOUT
) )
return result, False
except TimeoutError as e: except TimeoutError as e:
return None return None, True
def calculate_persistence(tokens_added): def calculate_persistence(tokens_added):
@ -246,6 +248,8 @@ class WikiqParser:
external_links: bool = False, external_links: bool = False,
citations: bool = False, citations: bool = False,
wikilinks: bool = False, wikilinks: bool = False,
templates: bool = False,
headings: bool = False,
): ):
""" """
Parameters: Parameters:
@ -265,6 +269,8 @@ class WikiqParser:
self.external_links = external_links self.external_links = external_links
self.citations = citations self.citations = citations
self.wikilinks = wikilinks self.wikilinks = wikilinks
self.templates = templates
self.headings = headings
if namespaces is not None: if namespaces is not None:
self.namespace_filter = set(namespaces) self.namespace_filter = set(namespaces)
else: else:
@ -405,7 +411,7 @@ class WikiqParser:
table.columns.append(tables.RevisionCollapsed()) table.columns.append(tables.RevisionCollapsed())
# Create shared parser if any wikitext feature is enabled # Create shared parser if any wikitext feature is enabled
if self.external_links or self.citations or self.wikilinks: if self.external_links or self.citations or self.wikilinks or self.templates or self.headings:
wikitext_parser = WikitextParser() wikitext_parser = WikitextParser()
if self.external_links: if self.external_links:
@ -417,6 +423,16 @@ class WikiqParser:
if self.wikilinks: if self.wikilinks:
table.columns.append(tables.RevisionWikilinks(wikitext_parser)) table.columns.append(tables.RevisionWikilinks(wikitext_parser))
if self.templates:
table.columns.append(tables.RevisionTemplates(wikitext_parser))
if self.headings:
table.columns.append(tables.RevisionHeadings(wikitext_parser))
# Add parser timeout tracking if any wikitext feature is enabled
if self.external_links or self.citations or self.wikilinks or self.templates or self.headings:
table.columns.append(tables.RevisionParserTimeout(wikitext_parser))
# extract list of namespaces # extract list of namespaces
self.namespaces = { self.namespaces = {
ns.name: ns.id for ns in dump.mwiterator.site_info.namespaces ns.name: ns.id for ns in dump.mwiterator.site_info.namespaces
@ -434,6 +450,7 @@ class WikiqParser:
from wikiq.diff_pyarrow_schema import diff_field from wikiq.diff_pyarrow_schema import diff_field
schema = schema.append(diff_field) schema = schema.append(diff_field)
schema = schema.append(pa.field("diff_timeout", pa.bool_()))
if self.diff and self.persist == PersistMethod.none: if self.diff and self.persist == PersistMethod.none:
table.columns.append(tables.RevisionText()) table.columns.append(tables.RevisionText())
@ -746,12 +763,14 @@ class WikiqParser:
if self.diff: if self.diff:
last_text = last_rev_text last_text = last_rev_text
new_diffs = [] new_diffs = []
diff_timeouts = []
for i, text in enumerate(row_buffer["text"]): for i, text in enumerate(row_buffer["text"]):
diff = asyncio.run(diff_async(differ, last_text, text)) diff, timed_out = asyncio.run(diff_async(differ, last_text, text))
if diff is None: if timed_out:
print(f"WARNING! wikidiff2 timeout for rev: {row_buffer['revid'][i]}. Falling back to default limits.", file=sys.stderr) print(f"WARNING! wikidiff2 timeout for rev: {row_buffer['revid'][i]}. Falling back to default limits.", file=sys.stderr)
diff = fast_differ.inline_json_diff(last_text, text) diff = fast_differ.inline_json_diff(last_text, text)
new_diffs.append(diff) new_diffs.append(diff)
diff_timeouts.append(timed_out)
last_text = text last_text = text
row_buffer["diff"] = [ row_buffer["diff"] = [
[ [
@ -761,6 +780,7 @@ class WikiqParser:
] ]
for diff in new_diffs for diff in new_diffs
] ]
row_buffer["diff_timeout"] = diff_timeouts
# end persistence logic # end persistence logic
if self.diff or self.persist != PersistMethod.none: if self.diff or self.persist != PersistMethod.none:
@ -1179,6 +1199,22 @@ def main():
help="Extract internal wikilinks from each revision.", help="Extract internal wikilinks from each revision.",
) )
parser.add_argument(
"--templates",
dest="templates",
action="store_true",
default=False,
help="Extract templates with their parameters from each revision.",
)
parser.add_argument(
"--headings",
dest="headings",
action="store_true",
default=False,
help="Extract section headings from each revision.",
)
parser.add_argument( parser.add_argument(
"-PNS", "-PNS",
"--partition-namespaces", "--partition-namespaces",
@ -1286,6 +1322,8 @@ def main():
external_links=args.external_links, external_links=args.external_links,
citations=args.citations, citations=args.citations,
wikilinks=args.wikilinks, wikilinks=args.wikilinks,
templates=args.templates,
headings=args.headings,
) )
wikiq.process() wikiq.process()
@ -1316,6 +1354,8 @@ def main():
external_links=args.external_links, external_links=args.external_links,
citations=args.citations, citations=args.citations,
wikilinks=args.wikilinks, wikilinks=args.wikilinks,
templates=args.templates,
headings=args.headings,
) )
wikiq.process() wikiq.process()

View File

@ -272,3 +272,56 @@ class RevisionWikilinks(RevisionField[Union[list[dict], None]]):
if revision.deleted.text: if revision.deleted.text:
return None return None
return self.wikitext_parser.extract_wikilinks(revision.text) return self.wikitext_parser.extract_wikilinks(revision.text)
class RevisionTemplates(RevisionField[Union[list[dict], None]]):
"""Extract all templates from revision text."""
# Struct type with name and params map
field = pa.field("templates", pa.list_(pa.struct([
pa.field("name", pa.string()),
pa.field("params", pa.map_(pa.string(), pa.string())),
])), nullable=True)
def __init__(self, wikitext_parser: "WikitextParser"):
super().__init__()
self.wikitext_parser = wikitext_parser
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> Union[list[dict], None]:
revision = revisions[-1]
if revision.deleted.text:
return None
return self.wikitext_parser.extract_templates(revision.text)
class RevisionHeadings(RevisionField[Union[list[dict], None]]):
"""Extract all section headings from revision text."""
# Struct type with level and text
field = pa.field("headings", pa.list_(pa.struct([
pa.field("level", pa.int8()),
pa.field("text", pa.string()),
])), nullable=True)
def __init__(self, wikitext_parser: "WikitextParser"):
super().__init__()
self.wikitext_parser = wikitext_parser
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> Union[list[dict], None]:
revision = revisions[-1]
if revision.deleted.text:
return None
return self.wikitext_parser.extract_headings(revision.text)
class RevisionParserTimeout(RevisionField[bool]):
"""Track whether the wikitext parser timed out for this revision."""
field = pa.field("parser_timeout", pa.bool_())
def __init__(self, wikitext_parser: "WikitextParser"):
super().__init__()
self.wikitext_parser = wikitext_parser
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> bool:
return self.wikitext_parser.last_parse_timed_out

View File

@ -1,8 +1,11 @@
"""Shared wikitext parser with caching to avoid duplicate parsing.""" """Shared wikitext parser with caching to avoid duplicate parsing."""
from __future__ import annotations from __future__ import annotations
import asyncio
import mwparserfromhell import mwparserfromhell
PARSER_TIMEOUT = 60 # seconds
class WikitextParser: class WikitextParser:
"""Caches parsed wikicode to avoid re-parsing the same text.""" """Caches parsed wikicode to avoid re-parsing the same text."""
@ -17,12 +20,24 @@ class WikitextParser:
def __init__(self): def __init__(self):
self._cached_text: str | None = None self._cached_text: str | None = None
self._cached_wikicode = None self._cached_wikicode = None
self.last_parse_timed_out: bool = False
async def _parse_async(self, text: str):
"""Parse wikitext with timeout protection."""
try:
result = await asyncio.wait_for(
asyncio.to_thread(mwparserfromhell.parse, text),
timeout=PARSER_TIMEOUT
)
return result, False
except TimeoutError:
return None, True
def _get_wikicode(self, text: str): def _get_wikicode(self, text: str):
"""Parse text and cache result. Returns cached result if text unchanged.""" """Parse text and cache result. Returns cached result if text unchanged."""
if text != self._cached_text: if text != self._cached_text:
self._cached_text = text self._cached_text = text
self._cached_wikicode = mwparserfromhell.parse(text) self._cached_wikicode, self.last_parse_timed_out = asyncio.run(self._parse_async(text))
return self._cached_wikicode return self._cached_wikicode
def extract_external_links(self, text: str | None) -> list[str] | None: def extract_external_links(self, text: str | None) -> list[str] | None:
@ -75,3 +90,37 @@ class WikitextParser:
return result return result
except Exception: except Exception:
return None return None
def extract_templates(self, text: str | None) -> list[dict] | None:
"""Extract all templates with their names and parameters."""
if text is None:
return None
try:
wikicode = self._get_wikicode(text)
result = []
for template in wikicode.filter_templates():
name = str(template.name).strip()
params = {}
for param in template.params:
param_name = str(param.name).strip()
param_value = str(param.value).strip()
params[param_name] = param_value
result.append({"name": name, "params": params})
return result
except Exception:
return None
def extract_headings(self, text: str | None) -> list[dict] | None:
"""Extract all section headings with their levels."""
if text is None:
return None
try:
wikicode = self._get_wikicode(text)
result = []
for heading in wikicode.filter_headings():
level = heading.level
heading_text = str(heading.title).strip()
result.append({"level": level, "text": heading_text})
return result
except Exception:
return None

View File

@ -872,3 +872,94 @@ def test_wikilinks():
assert actual_dicts == expected, f"Row {idx}: wikilinks mismatch" assert actual_dicts == expected, f"Row {idx}: wikilinks mismatch"
print(f"Wikilinks test passed! {len(test)} rows processed") print(f"Wikilinks test passed! {len(test)} rows processed")
def test_templates():
"""Test that --templates extracts templates correctly."""
import mwparserfromhell
tester = WikiqTester(SAILORMOON, "templates", in_compression="7z", out_format="parquet")
try:
tester.call_wikiq("--templates", "--text", "--fandom-2020")
except subprocess.CalledProcessError as exc:
pytest.fail(exc.stderr.decode("utf8"))
test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet")
# Verify templates column exists
assert "templates" in test.columns, "templates column should exist"
# Verify column has list/array type
assert test["templates"].apply(lambda x: x is None or hasattr(x, '__len__')).all()
# Verify extraction matches mwparserfromhell for sample rows
rows_with_templates = test[test["templates"].apply(lambda x: x is not None and len(x) > 0)]
if len(rows_with_templates) > 0:
sample = rows_with_templates.head(5)
for idx, row in sample.iterrows():
text = row["text"]
if text:
wikicode = mwparserfromhell.parse(text)
expected = []
for template in wikicode.filter_templates():
name = str(template.name).strip()
params = {}
for param in template.params:
param_name = str(param.name).strip()
param_value = str(param.value).strip()
params[param_name] = param_value
expected.append({"name": name, "params": params})
actual = list(row["templates"])
# Convert to comparable format
actual_list = []
for item in actual:
actual_list.append({
"name": item["name"],
"params": dict(item["params"]) if item["params"] else {}
})
assert actual_list == expected, f"Row {idx}: templates mismatch"
print(f"Templates test passed! {len(test)} rows processed")
def test_headings():
"""Test that --headings extracts section headings correctly."""
import mwparserfromhell
tester = WikiqTester(SAILORMOON, "headings", in_compression="7z", out_format="parquet")
try:
tester.call_wikiq("--headings", "--text", "--fandom-2020")
except subprocess.CalledProcessError as exc:
pytest.fail(exc.stderr.decode("utf8"))
test = pd.read_parquet(tester.output + f"/{SAILORMOON}.parquet")
# Verify headings column exists
assert "headings" in test.columns, "headings column should exist"
# Verify column has list/array type
assert test["headings"].apply(lambda x: x is None or hasattr(x, '__len__')).all()
# Verify extraction matches mwparserfromhell for sample rows
rows_with_headings = test[test["headings"].apply(lambda x: x is not None and len(x) > 0)]
if len(rows_with_headings) > 0:
sample = rows_with_headings.head(5)
for idx, row in sample.iterrows():
text = row["text"]
if text:
wikicode = mwparserfromhell.parse(text)
expected = []
for heading in wikicode.filter_headings():
level = heading.level
heading_text = str(heading.title).strip()
expected.append({"level": level, "text": heading_text})
actual = list(row["headings"])
# Convert to comparable format
actual_list = [{"level": item["level"], "text": item["text"]} for item in actual]
assert actual_list == expected, f"Row {idx}: headings mismatch"
print(f"Headings test passed! {len(test)} rows processed")