Get regex working
Signed-off-by: Will Beason <willbeason@gmail.com>
This commit is contained in:
parent
89465b29f4
commit
b50c51a215
@ -206,6 +206,3 @@ class RevisionCollapsed(RevisionField[int]):
|
||||
|
||||
def extract(self, page: mwtypes.Page, revisions: list[mwxml.Revision]) -> int:
|
||||
return len(revisions)
|
||||
|
||||
|
||||
|
||||
|
198
wikiq
198
wikiq
@ -149,8 +149,7 @@ class RegexPair(object):
|
||||
def _make_key(self, cap_group):
|
||||
return "{}_{}".format(self.label, cap_group)
|
||||
|
||||
def matchmake(self, content: str, rev_data):
|
||||
|
||||
def matchmake(self, content: str) -> dict:
|
||||
temp_dict = {}
|
||||
# if there are named capture groups in the regex
|
||||
if self.has_groups:
|
||||
@ -191,11 +190,7 @@ class RegexPair(object):
|
||||
else:
|
||||
temp_dict[self.label] = None
|
||||
|
||||
# update rev_data with our new columns
|
||||
for k, v in temp_dict.items():
|
||||
setattr(rev_data, k, v)
|
||||
|
||||
return rev_data
|
||||
return temp_dict
|
||||
|
||||
|
||||
def pa_schema() -> pa.Schema:
|
||||
@ -285,59 +280,6 @@ class Revision:
|
||||
lists = [[d[field.name]] for field in self.pa_schema_fields]
|
||||
return pa.record_batch(lists, schema=pa.schema(self.pa_schema_fields))
|
||||
|
||||
|
||||
"""
|
||||
|
||||
If collapse=True we'll use a RevDataCollapse dataclass.
|
||||
This class inherits from RevDataBase. This means that it has all the same fields and functions.
|
||||
|
||||
It just adds a new field and updates the pyarrow schema.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@dc.dataclass()
|
||||
class RevDataCollapse(Revision):
|
||||
collapsed_revs: int = None
|
||||
|
||||
pa_collapsed_revs_schema = pa.field('collapsed_revs', pa.int64())
|
||||
pa_schema_fields = Revision.pa_schema_fields + [pa_collapsed_revs_schema]
|
||||
|
||||
|
||||
"""
|
||||
|
||||
If persistence data is to be computed we'll need the fields added by RevDataPersistence.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@dc.dataclass()
|
||||
class RevDataPersistence(Revision):
|
||||
token_revs: int = None
|
||||
tokens_added: int = None
|
||||
tokens_removed: int = None
|
||||
tokens_window: int = None
|
||||
|
||||
pa_persistence_schema_fields = [
|
||||
pa.field("token_revs", pa.int64()),
|
||||
pa.field("tokens_added", pa.int64()),
|
||||
pa.field("tokens_removed", pa.int64()),
|
||||
pa.field("tokens_window", pa.int64())]
|
||||
|
||||
pa_schema_fields = Revision.pa_schema_fields + pa_persistence_schema_fields
|
||||
|
||||
|
||||
"""
|
||||
class RevDataCollapsePersistence uses multiple inheritance to make a class that has both persistence and collapse fields.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@dc.dataclass()
|
||||
class RevDataCollapsePersistence(RevDataCollapse, RevDataPersistence):
|
||||
pa_schema_fields = RevDataCollapse.pa_schema_fields + RevDataPersistence.pa_persistence_schema_fields
|
||||
|
||||
|
||||
class WikiqParser:
|
||||
def __init__(self,
|
||||
input_file: TextIOWrapper | IO[Any] | IO[bytes],
|
||||
@ -369,34 +311,8 @@ class WikiqParser:
|
||||
self.namespace_filter = None
|
||||
|
||||
self.regex_schemas = []
|
||||
self.regex_revision_pairs = self.make_matchmake_pairs(regex_match_revision, regex_revision_label)
|
||||
self.regex_comment_pairs = self.make_matchmake_pairs(regex_match_comment, regex_comment_label)
|
||||
|
||||
# This is where we set the type for revdata.
|
||||
|
||||
if self.collapse_user is True:
|
||||
if self.persist == PersistMethod.none:
|
||||
revdata_type = RevDataCollapse
|
||||
else:
|
||||
revdata_type = RevDataCollapsePersistence
|
||||
elif self.persist != PersistMethod.none:
|
||||
revdata_type = RevDataPersistence
|
||||
else:
|
||||
revdata_type = Revision
|
||||
|
||||
# if there are regex fields, we need to add them to the revdata type.
|
||||
regex_fields = [(field.name, list[str], dc.field(default=None)) for field in self.regex_schemas]
|
||||
|
||||
# make_dataclass is a function that defines a new dataclass type.
|
||||
# here we extend the type we have already chosen and add the regular expression types
|
||||
self.revdata_type: type = dc.make_dataclass('RevData_Parser',
|
||||
fields=regex_fields,
|
||||
bases=(revdata_type,))
|
||||
|
||||
# we also need to make sure that we have the right pyarrow schema
|
||||
self.revdata_type.pa_schema_fields = revdata_type.pa_schema_fields + self.regex_schemas
|
||||
|
||||
self.schema: Final[Schema] = pa.schema(self.revdata_type.pa_schema_fields)
|
||||
self.regex_revision_pairs: list[RegexPair] = self.make_matchmake_pairs(regex_match_revision, regex_revision_label)
|
||||
self.regex_comment_pairs: list[RegexPair] = self.make_matchmake_pairs(regex_match_comment, regex_comment_label)
|
||||
|
||||
# here we initialize the variables we need for output.
|
||||
if output_parquet is True:
|
||||
@ -414,10 +330,10 @@ class WikiqParser:
|
||||
self.output_file = open(output_file, 'wb')
|
||||
self.output_parquet = False
|
||||
|
||||
def make_matchmake_pairs(self, patterns, labels):
|
||||
def make_matchmake_pairs(self, patterns, labels) -> list[RegexPair]:
|
||||
if (patterns is not None and labels is not None) and \
|
||||
(len(patterns) == len(labels)):
|
||||
result = []
|
||||
result: list[RegexPair] = []
|
||||
for pattern, label in zip(patterns, labels):
|
||||
rp = RegexPair(pattern, label)
|
||||
result.append(rp)
|
||||
@ -428,22 +344,25 @@ class WikiqParser:
|
||||
else:
|
||||
sys.exit('Each regular expression *must* come with a corresponding label and vice versa.')
|
||||
|
||||
def matchmake_revision(self, rev: mwxml.Revision, rev_data: Revision):
|
||||
rev_data = self.matchmake_text(rev.text, rev_data)
|
||||
rev_data = self.matchmake_comment(rev.comment, rev_data)
|
||||
return rev_data
|
||||
def matchmake_revision(self, rev: mwxml.Revision):
|
||||
result = self.matchmake_text(rev.text)
|
||||
for k, v in self.matchmake_comment(rev.comment).items():
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
def matchmake_text(self, text: str, rev_data: Revision):
|
||||
return self.matchmake_pairs(text, rev_data, self.regex_revision_pairs)
|
||||
def matchmake_text(self, text: str):
|
||||
return self.matchmake_pairs(text, self.regex_revision_pairs)
|
||||
|
||||
def matchmake_comment(self, comment: str, rev_data: Revision):
|
||||
return self.matchmake_pairs(comment, rev_data, self.regex_comment_pairs)
|
||||
def matchmake_comment(self, comment: str):
|
||||
return self.matchmake_pairs(comment, self.regex_comment_pairs)
|
||||
|
||||
@staticmethod
|
||||
def matchmake_pairs(text, rev_data, pairs):
|
||||
def matchmake_pairs(text, pairs):
|
||||
result = {}
|
||||
for pair in pairs:
|
||||
rev_data = pair.matchmake(text, rev_data)
|
||||
return rev_data
|
||||
for k, v in pair.matchmake(text).items():
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
def __get_namespace_from_title(self, title):
|
||||
default_ns = None
|
||||
@ -502,11 +421,29 @@ class WikiqParser:
|
||||
schema = table.schema()
|
||||
schema = schema.append(pa.field('revert', pa.bool_(), nullable=True))
|
||||
|
||||
# Add regex fields to the schema.
|
||||
for pair in self.regex_revision_pairs:
|
||||
for field in pair.get_pyarrow_fields():
|
||||
schema = schema.append(field)
|
||||
|
||||
for pair in self.regex_comment_pairs:
|
||||
for field in pair.get_pyarrow_fields():
|
||||
schema = schema.append(field)
|
||||
|
||||
if self.persist != PersistMethod.none:
|
||||
table.columns.append(tables.RevisionText())
|
||||
schema = schema.append(pa.field('token_revs', pa.int64(), nullable=True))
|
||||
schema = schema.append(pa.field('tokens_added', pa.int64(), nullable=True))
|
||||
schema = schema.append(pa.field('tokens_removed', pa.int64(), nullable=True))
|
||||
schema = schema.append(pa.field('tokens_window', pa.int64(), nullable=True))
|
||||
|
||||
if self.output_parquet:
|
||||
writer = pq.ParquetWriter(self.output_file, schema, flavor='spark')
|
||||
else:
|
||||
writer = pc.CSVWriter(self.output_file, schema, write_options=pc.WriteOptions(delimiter='\t'))
|
||||
|
||||
regex_matches = {}
|
||||
|
||||
# Iterate through pages
|
||||
for page in dump:
|
||||
|
||||
@ -545,6 +482,12 @@ class WikiqParser:
|
||||
|
||||
rev_count += 1
|
||||
|
||||
regex_dict = self.matchmake_revision(rev)
|
||||
for k, v in regex_dict.items():
|
||||
if regex_matches.get(k) is None:
|
||||
regex_matches[k] = []
|
||||
regex_matches[k].append(v)
|
||||
|
||||
buffer = table.pop()
|
||||
|
||||
is_revert_column: list[bool | None] = []
|
||||
@ -556,6 +499,59 @@ class WikiqParser:
|
||||
|
||||
buffer['revert'] = is_revert_column
|
||||
|
||||
for k, v in regex_matches.items():
|
||||
buffer[k] = v
|
||||
regex_matches = {}
|
||||
|
||||
if self.persist != PersistMethod.none:
|
||||
window = deque(maxlen=PERSISTENCE_RADIUS)
|
||||
|
||||
buffer['token_revs'] = []
|
||||
buffer['tokens_added'] = []
|
||||
buffer['tokens_removed'] = []
|
||||
buffer['tokens_window'] = []
|
||||
|
||||
if self.persist == PersistMethod.sequence:
|
||||
state = mwpersistence.DiffState(SequenceMatcher(tokenizer=wikitext_split), revert_radius=PERSISTENCE_RADIUS)
|
||||
elif self.persist == PersistMethod.segment:
|
||||
state = mwpersistence.DiffState(SegmentMatcher(tokenizer=wikitext_split), revert_radius=PERSISTENCE_RADIUS)
|
||||
else:
|
||||
from mw.lib import persistence
|
||||
state = persistence.State()
|
||||
|
||||
for idx, text in enumerate(buffer['text']):
|
||||
rev_id = buffer['revid'][idx]
|
||||
if self.persist != PersistMethod.legacy:
|
||||
_, tokens_added, tokens_removed = state.update(text, rev_id)
|
||||
else:
|
||||
_, tokens_added, tokens_removed = state.process(text, rev_id)
|
||||
|
||||
window.append((rev_id, tokens_added, tokens_removed))
|
||||
|
||||
if len(window) == PERSISTENCE_RADIUS:
|
||||
old_rev_id, old_tokens_added, old_tokens_removed = window.popleft()
|
||||
num_token_revs, num_tokens = calculate_persistence(old_tokens_added)
|
||||
|
||||
buffer['token_revs'].append(num_token_revs)
|
||||
buffer['tokens_added'].append(num_tokens)
|
||||
buffer['tokens_removed'].append(len(old_tokens_removed))
|
||||
buffer['tokens_window'].append(PERSISTENCE_RADIUS - 1)
|
||||
|
||||
del buffer['text']
|
||||
|
||||
# print out metadata for the last RADIUS revisions
|
||||
for i, item in enumerate(window):
|
||||
# if the window was full, we've already printed item 0
|
||||
if len(window) == PERSISTENCE_RADIUS and i == 0:
|
||||
continue
|
||||
|
||||
rev_id, tokens_added, tokens_removed = item
|
||||
num_token_revs, num_tokens = calculate_persistence(tokens_added)
|
||||
|
||||
buffer['token_revs'].append(num_token_revs)
|
||||
buffer['tokens_added'].append(num_tokens)
|
||||
buffer['tokens_removed'].append(len(tokens_removed))
|
||||
buffer['tokens_window'].append(len(window) - (i+1))
|
||||
|
||||
writer.write(pa.table(buffer, schema=schema))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user