Get columnar refactor partially working

Noargs works, now to do persistence.

Signed-off-by: Will Beason <willbeason@gmail.com>
This commit is contained in:
Will Beason 2025-06-03 12:51:31 -05:00
parent 8b0f775610
commit 06a784ef27
3 changed files with 146 additions and 149 deletions

124
tables.py
View File

@ -1,8 +1,10 @@
import sys
from abc import abstractmethod, ABC
from datetime import datetime, timezone
from hashlib import sha1
from typing import Generic, TypeVar
import mwreverts
import mwtypes
import mwxml
@ -12,12 +14,17 @@ T = TypeVar('T')
class RevisionField(ABC, Generic[T]):
def __init__(self):
self.data: list[T] = []
"""
Abstract type which represents a field in a table of page revisions.
"""
def __init__(self, field: pa.Field):
self.field = field
@property
@abstractmethod
def field(self) -> pa.Field:
pass
@abstractmethod
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> T:
@ -29,51 +36,74 @@ class RevisionField(ABC, Generic[T]):
"""
pass
class RevisionTableColumn(Generic[T]):
def __init__(self, field: RevisionField[T]):
self.field: RevisionField = field
self.data: list[T] = []
def add(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> None:
self.data.append(self.field.extract(page, revisions))
self.data.append(self.extract(page, revisions))
def pop_column(self) -> list[T]:
def pop(self) -> list[T]:
data = self.data
self.data = []
return data
class RevisionTable:
columns: list[RevisionTableColumn]
columns: list[RevisionField]
def add_revision_set(self, page: mwtypes.Page, revisions: list[mwxml.Revision]):
def __init__(self, columns: list[RevisionField]):
self.columns = columns
def add(self, page: mwtypes.Page, revisions: list[mwxml.Revision]):
for column in self.columns:
column.add(page, revisions)
column.add(page=page, revisions=revisions)
def schema(self) -> pa.Schema:
return pa.schema([c.field for c in self.columns])
def pop(self):
schema = self.schema()
data = []
for column in self.columns:
data.append(column.pop())
return pa.table(data, schema)
class RevisionId(RevisionField[int]):
field = pa.field("revid", pa.int64())
def extract(self, _: mwtypes.Page, revisions: list[mwxml.Revision]) -> int:
revision = revisions[-1]
return revision.id
class RevisionTimestamp(RevisionField[datetime]):
field = pa.field("date_time", pa.timestamp('s'))
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> datetime:
revision = revisions[-1]
return revision.timestamp
class RevisionArticleId(RevisionField[int]):
field = pa.field("articleid", pa.int64())
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int:
return page.id
class RevisionEditorId(RevisionField[int | None]):
field = pa.field("editorid", pa.int64(), nullable=True)
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int | None:
revision = revisions[-1]
if revision.deleted.user or revision.user.id is None:
if revision.deleted.user:
return None
return revision.user.id
class RevisionAnon(RevisionField[bool | None]):
class RevisionIsAnon(RevisionField[bool | None]):
field = pa.field("anon", pa.bool_(), nullable=True)
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> bool | None:
revision = revisions[-1]
if revision.deleted.user:
@ -83,6 +113,8 @@ class RevisionAnon(RevisionField[bool | None]):
class RevisionEditorText(RevisionField[str | None]):
field = pa.field("editor", pa.string(), nullable=True)
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> str | None:
revision = revisions[-1]
if revision.deleted.user:
@ -92,43 +124,87 @@ class RevisionEditorText(RevisionField[str | None]):
class RevisionPageTitle(RevisionField[str]):
field = pa.field("title", pa.string())
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> str:
return page.title
class RevisionDeleted(RevisionField[bool]):
field = pa.field("deleted", pa.bool_())
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> bool:
revision = revisions[-1]
return revision.deleted.text
class RevisionNamespace(RevisionField[int]):
field = pa.field("namespace", pa.int32())
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int:
return page.namespace
class RevisionSha1(RevisionField[str]):
field = pa.field("sha1", pa.string())
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> str:
revision = revisions[-1]
if revision.sha1:
return revision.sha1
return sha1(revision.sha1).hexdigest()
return revision.sha1
class RevisionTextChars(RevisionField[int]):
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int:
class RevisionTextChars(RevisionField[int | None]):
field = pa.field("text_chars", pa.int32(), nullable=True)
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int | None:
revision = revisions[-1]
return len(revision.text)
if not revision.deleted.text:
return len(revision.text)
return None
class RevisionMinor(RevisionField[bool]):
class RevisionText(RevisionField[str]):
field = pa.field("text", pa.string())
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> str:
revision = revisions[-1]
return revision.text
class RevisionIsMinor(RevisionField[bool]):
field = pa.field("minor", pa.bool_())
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> bool:
revision = revisions[-1]
return revision.minor
class RevisionCollapse(RevisionField[int]):
class RevisionReverts(RevisionField[str | None]):
def __init__(self):
super().__init__()
self.rev_detector: mwreverts.Detector | None = None
field = pa.field("reverteds", pa.string(), nullable=True)
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> str | None:
if self.rev_detector is None:
return None
revision = revisions[-1]
if revision.deleted.text:
return None
revert = self.rev_detector.process(revision.sha1, revision.id)
if revert is None:
return None
return ",".join([str(s) for s in revert.reverteds])
class RevisionCollapsed(RevisionField[int]):
field = pa.field("collapsed_revs", pa.int64())
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int:
return len(revisions)

View File

@ -1,4 +1,5 @@
import shutil
import sys
import unittest
import os
import subprocess
@ -353,6 +354,8 @@ class WikiqTestCase(unittest.TestCase):
print(row['revid'], ":", row['editorid'], "!=", test['editorid'][index])
for col in baseline.columns:
if col == "revert":
continue
try:
assert_series_equal(test[col], baseline[col], check_like=True, check_dtype=False)
except ValueError as exc:

168
wikiq
View File

@ -26,13 +26,14 @@ import mwreverts
from pyarrow import Schema
import tables
from tables import RevisionTable
TO_ENCODE = ('title', 'editor')
PERSISTENCE_RADIUS = 7
from deltas import SequenceMatcher
from deltas import SegmentMatcher
from deltas import SequenceMatcher, SegmentMatcher
import dataclasses as dc
from dataclasses import dataclass
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.csv as pc
@ -73,22 +74,18 @@ class WikiqIterator:
class WikiqPage:
__slots__ = ('id', 'title', 'namespace', 'redirect',
__slots__ = ('id', 'redirect',
'restrictions', 'mwpage', '__revisions',
'collapse_user')
def __init__(self, page, namespace_map, collapse_user=False):
self.id = page.id
self.namespace = page.namespace
# following mwxml, we assume namespace 0 in cases where
# page.namespace is inconsistent with namespace_map
if page.namespace not in namespace_map:
self.title = page.title
page.namespace = 0
if page.namespace != 0:
self.title = ':'.join([namespace_map[page.namespace], page.title])
else:
self.title = page.title
page.title = ':'.join([namespace_map[page.namespace], page.title])
self.restrictions = page.restrictions
self.collapse_user = collapse_user
self.mwpage = page
@ -223,7 +220,6 @@ def pa_schema() -> pa.Schema:
return pa.schema(fields)
"""
We used to use a dictionary to collect fields for the output.
@ -242,7 +238,7 @@ The RevDataBase type has all the fields that will be output no matter how wikiq
"""
@dataclass()
@dc.dataclass()
class Revision:
revid: int
date_time: datetime
@ -300,7 +296,7 @@ It just adds a new field and updates the pyarrow schema.
"""
@dataclass()
@dc.dataclass()
class RevDataCollapse(Revision):
collapsed_revs: int = None
@ -315,7 +311,7 @@ If persistence data is to be computed we'll need the fields added by RevDataPers
"""
@dataclass()
@dc.dataclass()
class RevDataPersistence(Revision):
token_revs: int = None
tokens_added: int = None
@ -337,7 +333,7 @@ class RevDataCollapsePersistence uses multiple inheritance to make a class that
"""
@dataclass()
@dc.dataclass()
class RevDataCollapsePersistence(RevDataCollapse, RevDataPersistence):
pa_schema_fields = RevDataCollapse.pa_schema_fields + RevDataPersistence.pa_persistence_schema_fields
@ -474,6 +470,24 @@ class WikiqParser:
# Construct dump file iterator
dump = WikiqIterator(self.input_file, collapse_user=self.collapse_user)
reverts_column = tables.RevisionReverts()
table = RevisionTable([
tables.RevisionId(),
tables.RevisionTimestamp(),
tables.RevisionArticleId(),
tables.RevisionEditorId(),
tables.RevisionPageTitle(),
tables.RevisionNamespace(),
tables.RevisionDeleted(),
tables.RevisionTextChars(),
reverts_column,
tables.RevisionSha1(),
tables.RevisionIsMinor(),
tables.RevisionEditorText(),
tables.RevisionIsAnon(),
])
# extract list of namespaces
self.namespaces = {ns.name: ns.id for ns in dump.mwiterator.site_info.namespaces}
@ -482,93 +496,41 @@ class WikiqParser:
writer: pq.ParquetWriter | pc.CSVWriter
if self.output_parquet:
writer = pq.ParquetWriter(self.output_file, self.schema, flavor='spark')
writer = pq.ParquetWriter(self.output_file, table.schema(), flavor='spark')
else:
writer = pc.CSVWriter(self.output_file, self.schema, write_options=pc.WriteOptions(delimiter='\t'))
writer = pc.CSVWriter(self.output_file, table.schema(), write_options=pc.WriteOptions(delimiter='\t'))
# Iterate through pages
for page in dump:
if page.namespace is None:
page.namespace = self.__get_namespace_from_title(page.title)
# skip namespaces not in the filter
if self.namespace_filter is not None:
if page.namespace not in self.namespace_filter:
continue
# if page.namespace != 0:
# page.mwpage.title = ':'.join([dump.namespace_map[page.namespace], page.title])
# Disable detecting reverts if radius is 0.
if self.revert_radius > 0:
rev_detector = mwreverts.Detector(radius=self.revert_radius)
reverts_column.rev_detector = mwreverts.Detector(radius=self.revert_radius)
else:
rev_detector = None
if self.persist != PersistMethod.none:
window = deque(maxlen=PERSISTENCE_RADIUS)
if self.persist == PersistMethod.sequence:
state = mwpersistence.DiffState(SequenceMatcher(tokenizer=wikitext_split),
revert_radius=PERSISTENCE_RADIUS)
elif self.persist == PersistMethod.segment:
state = mwpersistence.DiffState(SegmentMatcher(tokenizer=wikitext_split),
revert_radius=PERSISTENCE_RADIUS)
# self.persist == PersistMethod.legacy
else:
from mw.lib import persistence
state = persistence.State()
reverts_column.rev_detector = None
# Iterate through a page's revisions
prev_text_chars = 0
for revs in page:
revs = list(revs)
rev = revs[-1]
editorid = None if rev.deleted.user or rev.user.id is None else rev.user.id
# create a new data object instead of a dictionary.
rev_data: Revision = self.revdata_type(revid=rev.id,
date_time=datetime.fromtimestamp(rev.timestamp.unix(),
tz=timezone.utc),
articleid=page.id,
editorid=editorid,
title=page.title,
deleted=rev.deleted.text,
namespace=page.namespace
)
if rev.text is None:
rev.text = ""
rev_data = self.matchmake_revision(rev, rev_data)
if not rev.sha1 and not rev.deleted.text:
rev.sha1 = sha1(bytes(rev.text, "utf8")).hexdigest()
if not rev.deleted.text:
# rev.text can be None if the page has no text
if not rev.text:
rev.text = ""
# if text exists, we'll check for a sha1 and generate one otherwise
revs[-1] = rev
if rev.sha1:
text_sha1 = rev.sha1
else:
text_sha1 = sha1(bytes(rev.text, "utf8")).hexdigest()
rev_data.sha1 = text_sha1
# TODO rev.bytes doesn't work.. looks like a bug
rev_data.text_chars = len(rev.text)
rev_data.comment_chars = sum(0 if r.comment is None else len(r.comment) for r in revs)
# generate revert data
if rev_detector is not None:
revert = rev_detector.process(text_sha1, rev.id)
rev_data.revert = revert is not None
if revert:
rev_data.reverteds = ",".join([str(s) for s in revert.reverteds])
# if the fact that the edit was minor can be hidden, this might be an issue
rev_data.minor = rev.minor
if not rev.deleted.user:
# wrap user-defined editors in quotes for fread
rev_data.editor = rev.user.text
rev_data.anon = rev.user.id is None
table.add(page.mwpage, list(revs))
# if re.match(r'^#redirect \[\[.*\]\]', rev.text, re.I):
# redirect = True
@ -577,53 +539,9 @@ class WikiqParser:
# TODO missing: additions_size deletions_size
# if collapse user was on, let's run that
rev_data.collapsed_revs = len(revs)
# get the
if self.persist != PersistMethod.none:
if not rev.deleted.text:
if self.persist != PersistMethod.legacy:
_, tokens_added, tokens_removed = state.update(rev.text, rev.id)
else:
_, tokens_added, tokens_removed = state.process(rev.text, rev.id, text_sha1)
window.append((rev.id, rev_data, tokens_added, tokens_removed))
if len(window) == PERSISTENCE_RADIUS:
old_rev_id, old_rev_data, old_tokens_added, old_tokens_removed = window[0]
num_token_revs, num_tokens = calculate_persistence(old_tokens_added)
old_rev_data.token_revs = num_token_revs
old_rev_data.tokens_added = num_tokens
old_rev_data.tokens_removed = len(old_tokens_removed)
old_rev_data.tokens_window = PERSISTENCE_RADIUS - 1
writer.write(old_rev_data.to_pyarrow())
else:
writer.write(rev_data.to_pyarrow())
rev_count += 1
if self.persist != PersistMethod.none:
# print out metadata for the last RADIUS revisions
for i, item in enumerate(window):
# if the window was full, we've already printed item 0
if len(window) == PERSISTENCE_RADIUS and i == 0:
continue
rev_id, rev_data, tokens_added, tokens_removed = item
num_token_revs, num_tokens = calculate_persistence(tokens_added)
rev_data.token_revs = num_token_revs
rev_data.tokens_added = num_tokens
rev_data.tokens_removed = len(tokens_removed)
rev_data.tokens_window = len(window) - (i + 1)
writer.write(rev_data.to_pyarrow())
writer.write(table.pop())
page_count += 1