mediawiki_dump_tools/wiki_diff_matcher.py
2025-07-02 13:31:32 -07:00

436 lines
17 KiB
Python

import json
import sys
from collections import namedtuple
from itertools import chain
from typing import Dict, Generator, List, Optional, Tuple
import requests
from deltas import (Delete, DiffEngine, Equal, Insert, Operation, Token,
RegexTokenizer, tokenizers)
TOKENIZER = tokenizers.wikitext_split
def find_greatest_le_key(target_key, data_dict):
found_key = None
for key in data_dict: # Iterates over keys in insertion order (which is sorted)
if key <= target_key:
found_key = (
key # This is the largest key found so far that satisfies the condition
)
else:
# Since the dictionary is sorted, if key > target_key,
# all subsequent keys will also be > target_key.
return found_key or key
def compute_diffs(url: str, texts: list[str]) -> list:
response = None
try:
response = requests.post(url, json=texts)
response.raise_for_status()
incremental_diffs = response.json()
except requests.exceptions.ConnectionError as e:
print(
f"Connection Error: Could not connect to the server at {url}. Make sure your local server is running."
)
print(e)
raise e
except requests.exceptions.HTTPError as e:
print(f"HTTP Error: {e}")
if response is not None:
print(f"Response Body: {response.text}")
raise e
except requests.exceptions.JSONDecodeError as e:
# Must come before RequestException as JSONDecodeError is
# a subclass.
print(f"JSON Decode Error: {e}", file=sys.stderr)
if response is not None:
print(f"Response Body: {response.text}", file=sys.stderr)
raise e
except requests.exceptions.RequestException as e:
print(f"An unexpected error occurred: {e}")
raise e
return incremental_diffs
class DiffToOperationMap:
def __init__(self, diff, tokenizer):
self.tokenizer = tokenizer
self.diff = json.loads(diff)
# the code below is designed to work in bytes because that's how wikidiff2 indexes
self.from_last_end_bytes = 0
self.from_last_to_bytes = 0
self.n_from_start_tokens = 0
self.n_from_end_tokens = 0
self.n_from_start_tokens = 0
self.n_to_start_tokens = 0
self.last_to_start_line = 0
self.last_from_start_line = 0
self.from_last_end_bytes = 0
self.to_last_end_bytes = 0
# keeps track of the number of tokens seen so far
# to avoid repeated tokenization
self.from_byte_token_index_map: Dict[int, int] = {}
self.to_byte_token_index_map: Dict[int, int] = {}
self.par_move_dict = {}
# we need to keep track of the bytes of line numbers to recover when wikidiff2 loses offsets.
self.to_linenumber_bytes_map = {}
def tokenize(self, bytes):
return self.tokenizer.tokenize(bytes.decode("utf-8"))
def newline_result(self):
self.n_from_end_tokens += 1
self.n_from_start_tokens += 1
self.n_to_end_tokens += 1
self.n_to_start_tokens +=1
return (Equal(self.n_from_start_tokens - 1,
self.n_from_end_tokens,
self.n_to_start_tokens - 1,
self.n_from_start_tokens),
[Token('\n')],
[Token('\n')])
def to_operations(self):
parmoves = []
[print(diff) for diff in self.diff["diff"][0:5]]
for entry in self.diff["diff"]:
offset = entry["offset"]
if offset["to"]:
self.to_linenumber_bytes_map[entry["lineNumber"]] = offset["to"]
text = entry["text"]
# ignore empty diffs. They don't have any tokens
if len(text) == 0:
continue
# this is the first byte of the line in the 'from' revision.
from_start_line = entry["offset"]["from"]
# this is the first byte of the line in the 'to' revision.
to_start_line = entry["offset"]["to"]
if entry["type"] == 0:
yield from self.doEqual(text, offset)
yield self.newline_result()
# a line included in the 'to' revision, but not in the 'from' revision
elif entry["type"] == 1:
yield from self.doInsert(text, offset)
yield self.newline_result()
# a line included in the 'from' revision, but not in the 'to' revision
elif entry["type"] == 2:
yield from self.doDelete(text, offset)
yield self.newline_result()
elif entry["type"] == 3:
yield from self.doHighlightRange(
text, entry["highlightRanges"], offset, entry["lineNumber"]
)
yield self.newline_result()
elif entry["type"] == 4:
self.par_move_dict[entry["moveInfo"]["id"]] = entry
linkId = entry["moveInfo"]["linkId"]
if linkId in self.par_move_dict:
yield from self.doParMove(entry, self.par_move_dict[linkId])
yield self.newline_result()
else:
# we need to count the tokens in the from revision so token index is correct
self.n_from_end_tokens += len(self.tokenize(entry["text"].encode()))
self.n_from_start_tokens += len(
self.tokenize(entry["text"].encode())
)
elif entry["type"] == 5:
linkId = entry["moveInfo"]["linkId"]
if linkId in self.par_move_dict:
yield from self.doParMove(self.par_move_dict[linkId], entry)
yield self.newline_result()
else:
self.par_move_dict[entry["moveInfo"]["id"]] = entry
# call doHighlightRange just to update the token indices
offset = {
"from": self.n_from_end_tokens,
"to": entry["offset"]["to"],
}
res = self.doHighlightRange(
entry["text"],
entry["highlightRanges"],
offset,
entry["lineNumber"],
update_idx="to",
)
list(res)
else:
# The 'type' isn't one of the known
raise ValueError(d)
# mwpersistence expects differences to be represented in order from the
# result's perspective ("to"), not the previous text. Thus, if a line
# is moved earlier then its insertion should appear before its deletion.
# As a rule of thumb, the "to" segments should be non-overlapping and
# strictly increasing, while the "from" segments should merely be
# non-overlapping.
def doEqual(self, equal_segment, offset, update_idx="all", type=str):
if type is str:
equal_bytes = equal_segment.encode()
elif type is bytes:
equal_bytes = equal_segment
else:
raise ValueError(equal_segment)
tokens = self.tokenize(equal_bytes)
n_tokens = len(tokens)
n_from_end_tokens = self.n_from_start_tokens + n_tokens
n_to_end_tokens = self.n_to_start_tokens + n_tokens
# we need to keep track of the to and from last end bytes
self.from_last_end_bytes = offset["from"] + len(equal_bytes)
self.to_last_end_bytes = offset["to"] + len(equal_bytes)
yield (
Equal(
self.n_from_start_tokens,
n_from_end_tokens,
self.n_to_start_tokens,
n_to_end_tokens,
),
tokens,
tokens,
)
if update_idx in ["from", "all"]:
self.n_from_end_tokens = self.n_from_start_tokens = n_from_end_tokens
if update_idx in ["to", "all"]:
self.n_to_end_tokens = self.n_to_start_tokens = n_to_end_tokens
self.from_byte_token_index_map[offset["from"]] = self.n_from_end_tokens
self.to_byte_token_index_map[offset["to"]] = self.n_to_end_tokens
def doInsert(self, insert_segment, offset, update_idx="all", type=str):
if type is str:
insert_bytes = insert_segment.encode()
elif type is bytes:
insert_bytes = insert_segment
else:
raise ValueError(insert_segment)
tokens = self.tokenize(insert_bytes)
n_tokens = len(tokens)
n_to_end_tokens = self.n_to_start_tokens + n_tokens
self.to_last_end_bytes = offset["to"] + len(insert_bytes)
yield (
Insert(
self.n_from_start_tokens,
self.n_from_start_tokens,
self.n_to_start_tokens,
n_to_end_tokens,
),
[],
tokens,
)
# We have now used more of the "to" tokens.
if update_idx in ["to", "all"]:
self.n_to_end_tokens = self.n_to_start_tokens = n_to_end_tokens
self.to_byte_token_index_map[offset["to"]] = self.n_to_end_tokens
def doDelete(self, delete_segment, offset, update_idx="all", type=str):
if type is str:
delete_bytes = delete_segment.encode()
elif type is bytes:
delete_bytes = delete_segment
else:
raise ValueError(delete_segment)
tokens = self.tokenize(delete_bytes)
n_tokens = len(tokens)
n_from_end_tokens = self.n_from_start_tokens + n_tokens
self.from_last_end_bytes = offset["from"] + len(delete_bytes)
yield (
Delete(
self.n_from_start_tokens,
n_from_end_tokens,
self.n_to_start_tokens,
self.n_to_start_tokens,
),
tokens,
[],
)
# We have now used more of the "from" tokens.
if update_idx in ["from", "all"]:
self.n_from_end_tokens = self.n_from_start_tokens = n_from_end_tokens
self.from_byte_token_index_map[offset["from"]] = self.n_from_end_tokens
def doHighlightRange(
self, highlight_text, highlightRanges, offset, lineNumber, update_idx="all"
):
# The text field is an overlapping mix of both the from and to,
# so we need to handle it highlight-by-highlight.
# there can be gaps between highlight segments.
# for instance, if a word is deleted from the middle of a line.
# we need to track that.
highlight_bytes = highlight_text.encode()
highlight_end = 0
# it's possible for offset['to'] to be null.
# we can get it from the line number?
if offset["to"] is None:
offset["to"] = self.from_byte_token_index_map[
find_greatest_le_key(lineNumber, self.from_byte_token_index_map)
]
highlight_offset = offset
# note that diffs are token-level, but the indexes are byte-level
for highlightRange in highlightRanges:
highlight_start = highlightRange["start"]
# equal bytes in between highlights
if highlight_start > highlight_end:
equal_bytes = highlight_bytes[highlight_end:highlight_start]
n_equal_bytes = len(equal_bytes)
yield from self.doEqual(
equal_bytes, highlight_offset, update_idx=update_idx, type=bytes
)
highlight_offset["from"] += n_equal_bytes
highlight_offset["to"] += n_equal_bytes
# handle highlighted insert / delete
highlight_end = highlight_start + highlightRange["length"]
range_bytes = highlight_bytes[highlight_start:highlight_end]
n_range_bytes = len(range_bytes)
if highlightRange["type"] == 0:
yield from self.doInsert(
range_bytes, highlight_offset, update_idx=update_idx, type=bytes
)
highlight_offset["to"] += n_range_bytes
elif highlightRange["type"] == 1:
yield from self.doDelete(
range_bytes, highlight_offset, update_idx=update_idx, type=bytes
)
highlight_offset["from"] += n_range_bytes
else:
raise Exception(entry)
# handle the rest of the line which is equal
if highlight_end < len(highlight_bytes):
range_bytes = highlight_bytes[highlight_end:]
yield from self.doEqual(range_bytes, highlight_offset, type=bytes)
def doParMove(self, from_diff, to_diff):
# the tricky part here is to put the tokens in the right spots.
from_byte_start = from_diff["offset"]["from"]
# as of python 3.7 dictionaries are in insertion order. So
# we can just find the first key that's greater
# since the paragraph is removed in the "from" version, the index it is removed from
# will be *after* the
if len(self.from_byte_token_index_map) > 0:
if (
from_byte_start >= self.from_last_end_bytes
): # if the from paragraph is at the end
from_token_start = next(
reversed(self.from_byte_token_index_map.values())
)
else:
key = find_greatest_le_key(
from_byte_start, self.from_byte_token_index_map
)
from_token_start = self.from_byte_token_index_map[key]
else:
from_token_start = 0
if len(self.to_byte_token_index_map) > 0:
# get the to token index
to_byte_start = to_diff["offset"]["to"]
if to_byte_start >= self.to_last_end_bytes:
to_token_start = next(reversed(self.to_byte_token_index_map.values()))
else:
key = find_greatest_le_key(to_byte_start, self.to_byte_token_index_map)
to_token_start = self.to_byte_token_index_map[key]
else:
to_token_start = 0
# now we set the state and apply the highlights
self.n_from_start_tokens = self.n_from_end_tokens = from_token_start
self.n_to_start_tokens = self.n_to_end_tokens = to_token_start
offset = {"from": from_byte_start, "to": to_byte_start}
yield from self.doHighlightRange(
to_diff["text"], to_diff["highlightRanges"], offset, to_diff["lineNumber"]
)
class WikiDiffMatcher:
def __init__(
self,
texts: list[str] = None,
tokenizer: Optional[RegexTokenizer] = None,
url: Optional[str] = "http://127.0.0.1:8000",
):
# Pre-compute diffs to reduce traffic overhead.
self.diffs = compute_diffs(url, texts)
self.tokenizer = tokenizer or TOKENIZER
class Processor(DiffEngine.Processor):
def __init__(self, texts, tokenizer=None):
self.diffs = iter(texts)
self.tokenizer = tokenizer or TOKENIZER
self.last_tokens = []
self.previous_text = ""
def update(self, last_tokens):
self.last_tokens = last_tokens
def process(self, text, token_class=None):
# The diff has already been computed, but we need to incrementally
# retrieve it to recreate the behavior DiffState expects.
diff = next(self.diffs)
diffToOperationsMapper = DiffToOperationMap(diff, self.tokenizer)
diffops = list(zip(*diffToOperationsMapper.to_operations()))
if not diffops:
self.last_tokens = []
return [], [], []
diffops = (
operations,
aseq,
bseq,
) = diffops
aseq = list(aseq)
# aseq/bseq can be out of order, we need to sort it by a1/b1 index.
indices = list(range(len(aseq)))
indices.sort(key=lambda i: operations[i].a1)
aseq = [aseq[i] for i in indices]
bseq = list(bseq)
indices = list(range(len(bseq)))
indices.sort(key=lambda i: operations[i].b1)
bseq = [bseq[i] for i in indices]
self.last_tokens = list(chain.from_iterable(aseq))
tokens = list(chain.from_iterable(bseq))
self.previous_text = text
return operations, self.last_tokens, tokens
def processor(self, *args, **kwargs):
return self.Processor(self.diffs, self.tokenizer)
def process(self):
# DiffState checks for this method even though it is not called.
raise Exception("Unnecessary implementation")