Get regex working

Signed-off-by: Will Beason <willbeason@gmail.com>
This commit is contained in:
Will Beason 2025-06-03 16:02:18 -05:00
parent 89465b29f4
commit b50c51a215
2 changed files with 97 additions and 104 deletions

View File

@ -206,6 +206,3 @@ class RevisionCollapsed(RevisionField[int]):
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int: def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int:
return len(revisions) return len(revisions)

198
wikiq
View File

@ -149,8 +149,7 @@ class RegexPair(object):
def _make_key(self, cap_group): def _make_key(self, cap_group):
return "{}_{}".format(self.label, cap_group) return "{}_{}".format(self.label, cap_group)
def matchmake(self, content: str, rev_data): def matchmake(self, content: str) -> dict:
temp_dict = {} temp_dict = {}
# if there are named capture groups in the regex # if there are named capture groups in the regex
if self.has_groups: if self.has_groups:
@ -191,11 +190,7 @@ class RegexPair(object):
else: else:
temp_dict[self.label] = None temp_dict[self.label] = None
# update rev_data with our new columns return temp_dict
for k, v in temp_dict.items():
setattr(rev_data, k, v)
return rev_data
def pa_schema() -> pa.Schema: def pa_schema() -> pa.Schema:
@ -285,59 +280,6 @@ class Revision:
lists = [[d[field.name]] for field in self.pa_schema_fields] lists = [[d[field.name]] for field in self.pa_schema_fields]
return pa.record_batch(lists, schema=pa.schema(self.pa_schema_fields)) return pa.record_batch(lists, schema=pa.schema(self.pa_schema_fields))
"""
If collapse=True we'll use a RevDataCollapse dataclass.
This class inherits from RevDataBase. This means that it has all the same fields and functions.
It just adds a new field and updates the pyarrow schema.
"""
@dc.dataclass()
class RevDataCollapse(Revision):
collapsed_revs: int = None
pa_collapsed_revs_schema = pa.field('collapsed_revs', pa.int64())
pa_schema_fields = Revision.pa_schema_fields + [pa_collapsed_revs_schema]
"""
If persistence data is to be computed we'll need the fields added by RevDataPersistence.
"""
@dc.dataclass()
class RevDataPersistence(Revision):
token_revs: int = None
tokens_added: int = None
tokens_removed: int = None
tokens_window: int = None
pa_persistence_schema_fields = [
pa.field("token_revs", pa.int64()),
pa.field("tokens_added", pa.int64()),
pa.field("tokens_removed", pa.int64()),
pa.field("tokens_window", pa.int64())]
pa_schema_fields = Revision.pa_schema_fields + pa_persistence_schema_fields
"""
class RevDataCollapsePersistence uses multiple inheritance to make a class that has both persistence and collapse fields.
"""
@dc.dataclass()
class RevDataCollapsePersistence(RevDataCollapse, RevDataPersistence):
pa_schema_fields = RevDataCollapse.pa_schema_fields + RevDataPersistence.pa_persistence_schema_fields
class WikiqParser: class WikiqParser:
def __init__(self, def __init__(self,
input_file: TextIOWrapper | IO[Any] | IO[bytes], input_file: TextIOWrapper | IO[Any] | IO[bytes],
@ -369,34 +311,8 @@ class WikiqParser:
self.namespace_filter = None self.namespace_filter = None
self.regex_schemas = [] self.regex_schemas = []
self.regex_revision_pairs = self.make_matchmake_pairs(regex_match_revision, regex_revision_label) self.regex_revision_pairs: list[RegexPair] = self.make_matchmake_pairs(regex_match_revision, regex_revision_label)
self.regex_comment_pairs = self.make_matchmake_pairs(regex_match_comment, regex_comment_label) self.regex_comment_pairs: list[RegexPair] = self.make_matchmake_pairs(regex_match_comment, regex_comment_label)
# This is where we set the type for revdata.
if self.collapse_user is True:
if self.persist == PersistMethod.none:
revdata_type = RevDataCollapse
else:
revdata_type = RevDataCollapsePersistence
elif self.persist != PersistMethod.none:
revdata_type = RevDataPersistence
else:
revdata_type = Revision
# if there are regex fields, we need to add them to the revdata type.
regex_fields = [(field.name, list[str], dc.field(default=None)) for field in self.regex_schemas]
# make_dataclass is a function that defines a new dataclass type.
# here we extend the type we have already chosen and add the regular expression types
self.revdata_type: type = dc.make_dataclass('RevData_Parser',
fields=regex_fields,
bases=(revdata_type,))
# we also need to make sure that we have the right pyarrow schema
self.revdata_type.pa_schema_fields = revdata_type.pa_schema_fields + self.regex_schemas
self.schema: Final[Schema] = pa.schema(self.revdata_type.pa_schema_fields)
# here we initialize the variables we need for output. # here we initialize the variables we need for output.
if output_parquet is True: if output_parquet is True:
@ -414,10 +330,10 @@ class WikiqParser:
self.output_file = open(output_file, 'wb') self.output_file = open(output_file, 'wb')
self.output_parquet = False self.output_parquet = False
def make_matchmake_pairs(self, patterns, labels): def make_matchmake_pairs(self, patterns, labels) -> list[RegexPair]:
if (patterns is not None and labels is not None) and \ if (patterns is not None and labels is not None) and \
(len(patterns) == len(labels)): (len(patterns) == len(labels)):
result = [] result: list[RegexPair] = []
for pattern, label in zip(patterns, labels): for pattern, label in zip(patterns, labels):
rp = RegexPair(pattern, label) rp = RegexPair(pattern, label)
result.append(rp) result.append(rp)
@ -428,22 +344,25 @@ class WikiqParser:
else: else:
sys.exit('Each regular expression *must* come with a corresponding label and vice versa.') sys.exit('Each regular expression *must* come with a corresponding label and vice versa.')
def matchmake_revision(self, rev: mwxml.Revision, rev_data: Revision): def matchmake_revision(self, rev: mwxml.Revision):
rev_data = self.matchmake_text(rev.text, rev_data) result = self.matchmake_text(rev.text)
rev_data = self.matchmake_comment(rev.comment, rev_data) for k, v in self.matchmake_comment(rev.comment).items():
return rev_data result[k] = v
return result
def matchmake_text(self, text: str, rev_data: Revision): def matchmake_text(self, text: str):
return self.matchmake_pairs(text, rev_data, self.regex_revision_pairs) return self.matchmake_pairs(text, self.regex_revision_pairs)
def matchmake_comment(self, comment: str, rev_data: Revision): def matchmake_comment(self, comment: str):
return self.matchmake_pairs(comment, rev_data, self.regex_comment_pairs) return self.matchmake_pairs(comment, self.regex_comment_pairs)
@staticmethod @staticmethod
def matchmake_pairs(text, rev_data, pairs): def matchmake_pairs(text, pairs):
result = {}
for pair in pairs: for pair in pairs:
rev_data = pair.matchmake(text, rev_data) for k, v in pair.matchmake(text).items():
return rev_data result[k] = v
return result
def __get_namespace_from_title(self, title): def __get_namespace_from_title(self, title):
default_ns = None default_ns = None
@ -502,11 +421,29 @@ class WikiqParser:
schema = table.schema() schema = table.schema()
schema = schema.append(pa.field('revert', pa.bool_(), nullable=True)) schema = schema.append(pa.field('revert', pa.bool_(), nullable=True))
# Add regex fields to the schema.
for pair in self.regex_revision_pairs:
for field in pair.get_pyarrow_fields():
schema = schema.append(field)
for pair in self.regex_comment_pairs:
for field in pair.get_pyarrow_fields():
schema = schema.append(field)
if self.persist != PersistMethod.none:
table.columns.append(tables.RevisionText())
schema = schema.append(pa.field('token_revs', pa.int64(), nullable=True))
schema = schema.append(pa.field('tokens_added', pa.int64(), nullable=True))
schema = schema.append(pa.field('tokens_removed', pa.int64(), nullable=True))
schema = schema.append(pa.field('tokens_window', pa.int64(), nullable=True))
if self.output_parquet: if self.output_parquet:
writer = pq.ParquetWriter(self.output_file, schema, flavor='spark') writer = pq.ParquetWriter(self.output_file, schema, flavor='spark')
else: else:
writer = pc.CSVWriter(self.output_file, schema, write_options=pc.WriteOptions(delimiter='\t')) writer = pc.CSVWriter(self.output_file, schema, write_options=pc.WriteOptions(delimiter='\t'))
regex_matches = {}
# Iterate through pages # Iterate through pages
for page in dump: for page in dump:
@ -545,6 +482,12 @@ class WikiqParser:
rev_count += 1 rev_count += 1
regex_dict = self.matchmake_revision(rev)
for k, v in regex_dict.items():
if regex_matches.get(k) is None:
regex_matches[k] = []
regex_matches[k].append(v)
buffer = table.pop() buffer = table.pop()
is_revert_column: list[bool | None] = [] is_revert_column: list[bool | None] = []
@ -556,6 +499,59 @@ class WikiqParser:
buffer['revert'] = is_revert_column buffer['revert'] = is_revert_column
for k, v in regex_matches.items():
buffer[k] = v
regex_matches = {}
if self.persist != PersistMethod.none:
window = deque(maxlen=PERSISTENCE_RADIUS)
buffer['token_revs'] = []
buffer['tokens_added'] = []
buffer['tokens_removed'] = []
buffer['tokens_window'] = []
if self.persist == PersistMethod.sequence:
state = mwpersistence.DiffState(SequenceMatcher(tokenizer=wikitext_split), revert_radius=PERSISTENCE_RADIUS)
elif self.persist == PersistMethod.segment:
state = mwpersistence.DiffState(SegmentMatcher(tokenizer=wikitext_split), revert_radius=PERSISTENCE_RADIUS)
else:
from mw.lib import persistence
state = persistence.State()
for idx, text in enumerate(buffer['text']):
rev_id = buffer['revid'][idx]
if self.persist != PersistMethod.legacy:
_, tokens_added, tokens_removed = state.update(text, rev_id)
else:
_, tokens_added, tokens_removed = state.process(text, rev_id)
window.append((rev_id, tokens_added, tokens_removed))
if len(window) == PERSISTENCE_RADIUS:
old_rev_id, old_tokens_added, old_tokens_removed = window.popleft()
num_token_revs, num_tokens = calculate_persistence(old_tokens_added)
buffer['token_revs'].append(num_token_revs)
buffer['tokens_added'].append(num_tokens)
buffer['tokens_removed'].append(len(old_tokens_removed))
buffer['tokens_window'].append(PERSISTENCE_RADIUS - 1)
del buffer['text']
# print out metadata for the last RADIUS revisions
for i, item in enumerate(window):
# if the window was full, we've already printed item 0
if len(window) == PERSISTENCE_RADIUS and i == 0:
continue
rev_id, tokens_added, tokens_removed = item
num_token_revs, num_tokens = calculate_persistence(tokens_added)
buffer['token_revs'].append(num_token_revs)
buffer['tokens_added'].append(num_tokens)
buffer['tokens_removed'].append(len(tokens_removed))
buffer['tokens_window'].append(len(window) - (i+1))
writer.write(pa.table(buffer, schema=schema)) writer.write(pa.table(buffer, schema=schema))