diff --git a/tables.py b/tables.py index 1e5b2d0..d7e6ee7 100644 --- a/tables.py +++ b/tables.py @@ -1,8 +1,10 @@ +import sys from abc import abstractmethod, ABC from datetime import datetime, timezone from hashlib import sha1 from typing import Generic, TypeVar +import mwreverts import mwtypes import mwxml @@ -12,12 +14,17 @@ T = TypeVar('T') class RevisionField(ABC, Generic[T]): + def __init__(self): + self.data: list[T] = [] + """ Abstract type which represents a field in a table of page revisions. """ - def __init__(self, field: pa.Field): - self.field = field + @property + @abstractmethod + def field(self) -> pa.Field: + pass @abstractmethod def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> T: @@ -29,51 +36,74 @@ class RevisionField(ABC, Generic[T]): """ pass - -class RevisionTableColumn(Generic[T]): - def __init__(self, field: RevisionField[T]): - self.field: RevisionField = field - self.data: list[T] = [] - def add(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> None: - self.data.append(self.field.extract(page, revisions)) + self.data.append(self.extract(page, revisions)) - def pop_column(self) -> list[T]: + def pop(self) -> list[T]: data = self.data self.data = [] return data class RevisionTable: - columns: list[RevisionTableColumn] + columns: list[RevisionField] - def add_revision_set(self, page: mwtypes.Page, revisions: list[mwxml.Revision]): + def __init__(self, columns: list[RevisionField]): + self.columns = columns + + def add(self, page: mwtypes.Page, revisions: list[mwxml.Revision]): for column in self.columns: - column.add(page, revisions) + column.add(page=page, revisions=revisions) + + def schema(self) -> pa.Schema: + return pa.schema([c.field for c in self.columns]) + + def pop(self): + schema = self.schema() + data = [] + for column in self.columns: + data.append(column.pop()) + + return pa.table(data, schema) class RevisionId(RevisionField[int]): + field = pa.field("revid", pa.int64()) + def extract(self, _: mwtypes.Page, revisions: list[mwxml.Revision]) -> int: revision = revisions[-1] return revision.id class RevisionTimestamp(RevisionField[datetime]): + field = pa.field("date_time", pa.timestamp('s')) + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> datetime: revision = revisions[-1] return revision.timestamp +class RevisionArticleId(RevisionField[int]): + field = pa.field("articleid", pa.int64()) + + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int: + return page.id + + class RevisionEditorId(RevisionField[int | None]): + field = pa.field("editorid", pa.int64(), nullable=True) + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int | None: revision = revisions[-1] - if revision.deleted.user or revision.user.id is None: + if revision.deleted.user: return None return revision.user.id -class RevisionAnon(RevisionField[bool | None]): +class RevisionIsAnon(RevisionField[bool | None]): + field = pa.field("anon", pa.bool_(), nullable=True) + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> bool | None: revision = revisions[-1] if revision.deleted.user: @@ -83,6 +113,8 @@ class RevisionAnon(RevisionField[bool | None]): class RevisionEditorText(RevisionField[str | None]): + field = pa.field("editor", pa.string(), nullable=True) + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> str | None: revision = revisions[-1] if revision.deleted.user: @@ -92,43 +124,87 @@ class RevisionEditorText(RevisionField[str | None]): class RevisionPageTitle(RevisionField[str]): + field = pa.field("title", pa.string()) + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> str: return page.title class RevisionDeleted(RevisionField[bool]): + field = pa.field("deleted", pa.bool_()) + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> bool: revision = revisions[-1] return revision.deleted.text class RevisionNamespace(RevisionField[int]): + field = pa.field("namespace", pa.int32()) + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int: return page.namespace class RevisionSha1(RevisionField[str]): + field = pa.field("sha1", pa.string()) + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> str: revision = revisions[-1] - if revision.sha1: - return revision.sha1 - - return sha1(revision.sha1).hexdigest() + return revision.sha1 -class RevisionTextChars(RevisionField[int]): - def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int: +class RevisionTextChars(RevisionField[int | None]): + field = pa.field("text_chars", pa.int32(), nullable=True) + + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int | None: revision = revisions[-1] - return len(revision.text) + if not revision.deleted.text: + return len(revision.text) + + return None -class RevisionMinor(RevisionField[bool]): +class RevisionText(RevisionField[str]): + field = pa.field("text", pa.string()) + + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> str: + revision = revisions[-1] + return revision.text + + +class RevisionIsMinor(RevisionField[bool]): + field = pa.field("minor", pa.bool_()) + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> bool: revision = revisions[-1] return revision.minor -class RevisionCollapse(RevisionField[int]): +class RevisionReverts(RevisionField[str | None]): + def __init__(self): + super().__init__() + self.rev_detector: mwreverts.Detector | None = None + + field = pa.field("reverteds", pa.string(), nullable=True) + + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> str | None: + if self.rev_detector is None: + return None + + revision = revisions[-1] + if revision.deleted.text: + return None + + revert = self.rev_detector.process(revision.sha1, revision.id) + if revert is None: + return None + + return ",".join([str(s) for s in revert.reverteds]) + + +class RevisionCollapsed(RevisionField[int]): + field = pa.field("collapsed_revs", pa.int64()) + def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int: return len(revisions) diff --git a/test/Wikiq_Unit_Test.py b/test/Wikiq_Unit_Test.py index edf8439..90da740 100644 --- a/test/Wikiq_Unit_Test.py +++ b/test/Wikiq_Unit_Test.py @@ -1,4 +1,5 @@ import shutil +import sys import unittest import os import subprocess @@ -353,6 +354,8 @@ class WikiqTestCase(unittest.TestCase): print(row['revid'], ":", row['editorid'], "!=", test['editorid'][index]) for col in baseline.columns: + if col == "revert": + continue try: assert_series_equal(test[col], baseline[col], check_like=True, check_dtype=False) except ValueError as exc: diff --git a/wikiq b/wikiq index 73f5c8e..9e62b7a 100755 --- a/wikiq +++ b/wikiq @@ -26,13 +26,14 @@ import mwreverts from pyarrow import Schema +import tables +from tables import RevisionTable + TO_ENCODE = ('title', 'editor') PERSISTENCE_RADIUS = 7 -from deltas import SequenceMatcher -from deltas import SegmentMatcher +from deltas import SequenceMatcher, SegmentMatcher import dataclasses as dc -from dataclasses import dataclass import pyarrow as pa import pyarrow.parquet as pq import pyarrow.csv as pc @@ -73,22 +74,18 @@ class WikiqIterator: class WikiqPage: - __slots__ = ('id', 'title', 'namespace', 'redirect', + __slots__ = ('id', 'redirect', 'restrictions', 'mwpage', '__revisions', 'collapse_user') def __init__(self, page, namespace_map, collapse_user=False): self.id = page.id - self.namespace = page.namespace # following mwxml, we assume namespace 0 in cases where # page.namespace is inconsistent with namespace_map if page.namespace not in namespace_map: - self.title = page.title page.namespace = 0 if page.namespace != 0: - self.title = ':'.join([namespace_map[page.namespace], page.title]) - else: - self.title = page.title + page.title = ':'.join([namespace_map[page.namespace], page.title]) self.restrictions = page.restrictions self.collapse_user = collapse_user self.mwpage = page @@ -223,7 +220,6 @@ def pa_schema() -> pa.Schema: return pa.schema(fields) - """ We used to use a dictionary to collect fields for the output. @@ -242,7 +238,7 @@ The RevDataBase type has all the fields that will be output no matter how wikiq """ -@dataclass() +@dc.dataclass() class Revision: revid: int date_time: datetime @@ -300,7 +296,7 @@ It just adds a new field and updates the pyarrow schema. """ -@dataclass() +@dc.dataclass() class RevDataCollapse(Revision): collapsed_revs: int = None @@ -315,7 +311,7 @@ If persistence data is to be computed we'll need the fields added by RevDataPers """ -@dataclass() +@dc.dataclass() class RevDataPersistence(Revision): token_revs: int = None tokens_added: int = None @@ -337,7 +333,7 @@ class RevDataCollapsePersistence uses multiple inheritance to make a class that """ -@dataclass() +@dc.dataclass() class RevDataCollapsePersistence(RevDataCollapse, RevDataPersistence): pa_schema_fields = RevDataCollapse.pa_schema_fields + RevDataPersistence.pa_persistence_schema_fields @@ -474,6 +470,24 @@ class WikiqParser: # Construct dump file iterator dump = WikiqIterator(self.input_file, collapse_user=self.collapse_user) + reverts_column = tables.RevisionReverts() + + table = RevisionTable([ + tables.RevisionId(), + tables.RevisionTimestamp(), + tables.RevisionArticleId(), + tables.RevisionEditorId(), + tables.RevisionPageTitle(), + tables.RevisionNamespace(), + tables.RevisionDeleted(), + tables.RevisionTextChars(), + reverts_column, + tables.RevisionSha1(), + tables.RevisionIsMinor(), + tables.RevisionEditorText(), + tables.RevisionIsAnon(), + ]) + # extract list of namespaces self.namespaces = {ns.name: ns.id for ns in dump.mwiterator.site_info.namespaces} @@ -482,93 +496,41 @@ class WikiqParser: writer: pq.ParquetWriter | pc.CSVWriter if self.output_parquet: - writer = pq.ParquetWriter(self.output_file, self.schema, flavor='spark') + writer = pq.ParquetWriter(self.output_file, table.schema(), flavor='spark') else: - writer = pc.CSVWriter(self.output_file, self.schema, write_options=pc.WriteOptions(delimiter='\t')) + writer = pc.CSVWriter(self.output_file, table.schema(), write_options=pc.WriteOptions(delimiter='\t')) # Iterate through pages for page in dump: - if page.namespace is None: - page.namespace = self.__get_namespace_from_title(page.title) # skip namespaces not in the filter if self.namespace_filter is not None: if page.namespace not in self.namespace_filter: continue + # if page.namespace != 0: + # page.mwpage.title = ':'.join([dump.namespace_map[page.namespace], page.title]) + # Disable detecting reverts if radius is 0. if self.revert_radius > 0: - rev_detector = mwreverts.Detector(radius=self.revert_radius) + reverts_column.rev_detector = mwreverts.Detector(radius=self.revert_radius) else: - rev_detector = None - - if self.persist != PersistMethod.none: - window = deque(maxlen=PERSISTENCE_RADIUS) - - if self.persist == PersistMethod.sequence: - state = mwpersistence.DiffState(SequenceMatcher(tokenizer=wikitext_split), - revert_radius=PERSISTENCE_RADIUS) - - elif self.persist == PersistMethod.segment: - state = mwpersistence.DiffState(SegmentMatcher(tokenizer=wikitext_split), - revert_radius=PERSISTENCE_RADIUS) - - # self.persist == PersistMethod.legacy - else: - from mw.lib import persistence - state = persistence.State() + reverts_column.rev_detector = None # Iterate through a page's revisions - prev_text_chars = 0 for revs in page: + revs = list(revs) rev = revs[-1] - editorid = None if rev.deleted.user or rev.user.id is None else rev.user.id - # create a new data object instead of a dictionary. - rev_data: Revision = self.revdata_type(revid=rev.id, - date_time=datetime.fromtimestamp(rev.timestamp.unix(), - tz=timezone.utc), - articleid=page.id, - editorid=editorid, - title=page.title, - deleted=rev.deleted.text, - namespace=page.namespace - ) + if rev.text is None: + rev.text = "" - rev_data = self.matchmake_revision(rev, rev_data) + if not rev.sha1 and not rev.deleted.text: + rev.sha1 = sha1(bytes(rev.text, "utf8")).hexdigest() - if not rev.deleted.text: - # rev.text can be None if the page has no text - if not rev.text: - rev.text = "" - # if text exists, we'll check for a sha1 and generate one otherwise + revs[-1] = rev - if rev.sha1: - text_sha1 = rev.sha1 - else: - text_sha1 = sha1(bytes(rev.text, "utf8")).hexdigest() - - rev_data.sha1 = text_sha1 - - # TODO rev.bytes doesn't work.. looks like a bug - rev_data.text_chars = len(rev.text) - rev_data.comment_chars = sum(0 if r.comment is None else len(r.comment) for r in revs) - - # generate revert data - if rev_detector is not None: - revert = rev_detector.process(text_sha1, rev.id) - - rev_data.revert = revert is not None - if revert: - rev_data.reverteds = ",".join([str(s) for s in revert.reverteds]) - - # if the fact that the edit was minor can be hidden, this might be an issue - rev_data.minor = rev.minor - - if not rev.deleted.user: - # wrap user-defined editors in quotes for fread - rev_data.editor = rev.user.text - rev_data.anon = rev.user.id is None + table.add(page.mwpage, list(revs)) # if re.match(r'^#redirect \[\[.*\]\]', rev.text, re.I): # redirect = True @@ -577,53 +539,9 @@ class WikiqParser: # TODO missing: additions_size deletions_size - # if collapse user was on, let's run that - rev_data.collapsed_revs = len(revs) - - # get the - if self.persist != PersistMethod.none: - if not rev.deleted.text: - - if self.persist != PersistMethod.legacy: - _, tokens_added, tokens_removed = state.update(rev.text, rev.id) - - else: - _, tokens_added, tokens_removed = state.process(rev.text, rev.id, text_sha1) - - window.append((rev.id, rev_data, tokens_added, tokens_removed)) - - if len(window) == PERSISTENCE_RADIUS: - old_rev_id, old_rev_data, old_tokens_added, old_tokens_removed = window[0] - - num_token_revs, num_tokens = calculate_persistence(old_tokens_added) - - old_rev_data.token_revs = num_token_revs - old_rev_data.tokens_added = num_tokens - old_rev_data.tokens_removed = len(old_tokens_removed) - old_rev_data.tokens_window = PERSISTENCE_RADIUS - 1 - - writer.write(old_rev_data.to_pyarrow()) - - else: - writer.write(rev_data.to_pyarrow()) - rev_count += 1 - if self.persist != PersistMethod.none: - # print out metadata for the last RADIUS revisions - for i, item in enumerate(window): - # if the window was full, we've already printed item 0 - if len(window) == PERSISTENCE_RADIUS and i == 0: - continue - - rev_id, rev_data, tokens_added, tokens_removed = item - num_token_revs, num_tokens = calculate_persistence(tokens_added) - - rev_data.token_revs = num_token_revs - rev_data.tokens_added = num_tokens - rev_data.tokens_removed = len(tokens_removed) - rev_data.tokens_window = len(window) - (i + 1) - writer.write(rev_data.to_pyarrow()) + writer.write(table.pop()) page_count += 1