mediawiki_dump_tools/wiki_diff_matcher.py
2025-07-03 21:32:44 -07:00

491 lines
20 KiB
Python

import json
import sys
from collections import namedtuple
from itertools import chain
from typing import Dict, Generator, List, Optional, Tuple
from sortedcontainers import SortedDict
import requests
from deltas import (Delete, DiffEngine, Equal, Insert, Operation,
RegexTokenizer, Token, tokenizers)
TOKENIZER = tokenizers.wikitext_split
# def find_greatest_le_key(target_key, data_dict):
# found_key = None
# for key in data_dict: # Iterates over keys in insertion order (which is sorted)
# if key <= target_key:
# found_key = (
# key # This is the largest key found so far that satisfies the condition
# )
# else:
# # Since the dictionary is sorted, if key > target_key,
# # all subsequent keys will also be > target_key.
# return found_key or key
# def find_smallest_gt_key(target_key, data_dict):
# found_key = None
# for key in reversed(data_dict): # Iterates over keys in insertion order (which is sorted)
# if key >= target_key:
# found_key = (
# key # This is the largest key found so far that satisfies the condition
# )
# else:
# # Since the dictionary is sorted, if key > target_key,
# # all subsequent keys will also be > target_key.
# return found_key or key
def compute_diffs(url: str, texts: list[str]) -> list:
response = None
try:
response = requests.post(url, json=texts)
response.raise_for_status()
incremental_diffs = response.json()
except requests.exceptions.ConnectionError as e:
print(
f"Connection Error: Could not connect to the server at {url}. Make sure your local server is running."
)
print(e)
raise e
except requests.exceptions.HTTPError as e:
print(f"HTTP Error: {e}")
if response is not None:
print(f"Response Body: {response.text}")
raise e
except requests.exceptions.JSONDecodeError as e:
# Must come before RequestException as JSONDecodeError is
# a subclass.
print(f"JSON Decode Error: {e}", file=sys.stderr)
if response is not None:
print(f"Response Body: {response.text}", file=sys.stderr)
raise e
except requests.exceptions.RequestException as e:
print(f"An unexpected error occurred: {e}")
raise e
return incremental_diffs
class DiffToOperationMap:
def __init__(self, diff, tokenizer):
self.tokenizer = tokenizer
self.diff = json.loads(diff)
# the code below is designed to work in bytes because that's how wikidiff2 indexes
# self.from_last_end_bytes = 0
# self.from_last_to_bytes = 0
# self.n_from_start_tokens = 0
# self.n_from_end_tokens = 0
# self.n_from_start_tokens = 0
# self.n_to_start_tokens = 0
# self.from_last_end_bytes = 0
# self.to_last_end_bytes = 0
# keeps track of the number of tokens seen so far
# to avoid repeated tokenization
# self.from_byte_token_index_map: SortedDict[int, int] = SortedDict()
# self.to_byte_token_index_map: SortedDict[int, int] = SortedDict()
self.par_move_dict = {}
# we need to keep track of the bytes of line numbers to recover when wikidiff2 loses offsets.
self.to_linenumber_bytes_map: SortedDict[int, int] = SortedDict()
self.from_linenumber_bytes_map: SortedDict[int, int] = SortedDict()
# def get_token_offset(self, byte_offset):
# from_token_start = None
# to_token_start = None
# from_last_end_bytes = self.from_byte_token_index_map.keys()[-1]
# to_last_end_bytes = self.to_byte_token_index_map.keys()[-1]
# if byte_offset['from'] is not None:
# if byte_offset['from'] < self.from_byte_token_index_map.values()[0]:
# from_token_start = 0
# else:
# key = self.from_byte_token_index_map.bisect_key_right(byte_offset['from'])
# # this could be an issue; we assume that the next tokens are inserted at the end, but maybe they could go even further below?
# if key > from_last_end_bytes:
# from_token_start = self.from_byte_token_index_map[from_last_end_bytes]
# else:
# from_token_
# if byte_offset['to'] is not None:
# if byte_offset['to'] < self.to_byte_token_index_map.values()[0]:
# to_token_start = 0
# else:
# key = self.from_byte_token_index_map.bisect_key_right(byte_offset['to'])
# if key >= from
# if len(self.from_byte_token_index_map) > 0 and byte_offset['from'] != 0:
# if (
# byte_offset['from'] >= self.from_last_end_bytes
# ): # if the from paragraph is at the end
# from_token_start = next(
# reversed(self.from_byte_token_index_map.values())
# )
# else:
# key = find_greatest_le_key(
# byte_offset['from'], self.from_byte_token_index_map
# )
# from_token_start = self.from_byte_token_index_map[key]
# else:
# from_token_start = 0
# to_offset = None
# if byte_offset['to'] is not None:
# if len(self.to_byte_token_index_map) > 0:
# if to_byte_start >= self.to_last_end_bytes:
# to_token_start = next(reversed(self.to_byte_token_index_map.values()))
# else:
# key = find_smallest_gt_key(to_byte_start, self.to_byte_token_index_map)
# to_token_start = self.to_byte_token_index_map[key]
# else:
# to_token_start = 0
# return {'from': from_token_start,
# 'to': to_token_start}
def tokenize(self, bytes):
return self.tokenizer.tokenize(bytes.decode("utf-8"))
def to_operations(self):
for entry in self.diff["diff"]:
entry["text"] += "\n"
text = entry["text"]
offset = entry["offset"]
if offset["from"] and entry.get("lineNumber") is not None :
if entry['type'] in [0, 2, 3, 4]:
self.from_linenumber_bytes_map[entry["lineNumber"]] = offset["from"] + len(text.encode())
if offset["to"]:
if entry['type'] in [0, 1, 3, 5]:
self.to_linenumber_bytes_map[entry["lineNumber"]] = offset["to"] + len(text.encode())
# add back the newline
# this is the first byte of the line in the 'from' revision.
from_start_line = entry["offset"]["from"]
# this is the first byte of the line in the 'to' revision.
to_start_line = entry["offset"]["to"]
if entry["type"] == 0:
yield from self.doEqual(text, offset)
# a line included in the 'to' revision, but not in the 'from' revision
elif entry["type"] == 1:
yield from self.doInsert(text, offset)
# a line included in the 'from' revision, but not in the 'to' revision
elif entry["type"] == 2:
yield from self.doDelete(text, offset)
elif entry["type"] == 3:
yield from self.doHighlightRange(
text, entry["highlightRanges"], offset, entry["lineNumber"]
)
elif entry["type"] == 4:
self.par_move_dict[entry["moveInfo"]["id"]] = entry
linkId = entry["moveInfo"]["linkId"]
if linkId in self.par_move_dict:
yield from self.doParMove(entry, self.par_move_dict[linkId])
# we need to count the tokens in the from revision so token index is correct
# self.n_from_end_tokens += len(self.tokenize(entry["text"].encode()))
# self.n_from_start_tokens += len(
# self.tokenize(entry["text"].encode())
# )
elif entry["type"] == 5:
linkId = entry["moveInfo"]["linkId"]
if linkId in self.par_move_dict:
yield from self.doParMove(self.par_move_dict[linkId], entry)
else:
self.par_move_dict[entry["moveInfo"]["id"]] = entry
# call doHighlightRange just to update the token indices
# offset = {
# "from": self.n_from_end_tokens,
# "to": entry["offset"]["to"],
# }
# res = self.doHighlightRange(
# entry["text"],
# entry["highlightRanges"],
# offset,
# entry["lineNumber"],
# update_idx="to",
# )
# list(res)
# self.n_to_end_tokens += len(self.tokenize(entry["text"].encode()))
# self.n_to_start_tokens += len(
# self.tokenize(entry["text"].encode())
# )
else:
# The 'type' isn't one of the known
raise ValueError(d)
# mwpersistence expects differences to be represented in order from the
# result's perspective ("to"), not the previous text. Thus, if a line
# is moved earlier then its insertion should appear before its deletion.
# As a rule of thumb, the "to" segments should be non-overlapping and
# strictly increasing, while the "from" segments should merely be
# non-overlapping.
def doEqual(self, equal_segment, offset, update_idx="all"):
# if from_token_start is None:
# from_token_start = self.n_from_start_tokens
# if to_token_start is None:
# to_token_start = self.n_to_start_tokens
if isinstance(equal_segment, str):
equal_bytes = equal_segment.encode()
elif isinstance(equal_segment, bytes):
equal_bytes = equal_segment
else:
raise ValueError(equal_segment)
tokens = self.tokenize(equal_bytes)
n_tokens = len(tokens)
# token_offset = self.get_token_offset(offset)
# n_from_end_tokens = token_offset['from'] + n_tokens
# n_to_end_tokens = token_offset['to'] + n_tokens
yield (
Equal(
offset['from'],
None,
offset['to'],
None,
),
tokens,
tokens,
)
# if update_idx in ["from", "all"]:
# self.n_from_end_tokens = self.n_from_start_tokens = n_from_end_tokens
# if update_idx in ["to", "all"]:
# self.n_to_end_tokens = self.n_to_start_tokens = n_to_end_tokens
# self.from_byte_token_index_map[offset["from"]] = n_from_end_tokens
# self.to_byte_token_index_map[offset["to"]] = n_to_end_tokens
def doInsert(self, insert_segment, offset, update_idx="all"):
if isinstance(insert_segment, str):
insert_bytes = insert_segment.encode()
elif isinstance(insert_segment, bytes):
insert_bytes = insert_segment
else:
raise ValueError(insert_segment)
tokens = self.tokenize(insert_bytes)
# n_tokens = len(tokens)
# token_offset = self.get_token_offset(offset)
# n_to_end_tokens = token_offset['to'] + n_tokens
yield (
Insert(
None,
None,
offset['to'],
None,
),
[],
tokens,
)
# We have now used more of the "to" tokens.
#self.to_byte_token_index_map[offset["to"]] = n_to_end_tokens
def doDelete(self, delete_segment, offset, update_idx="all", type=str):
if isinstance(delete_segment, str):
delete_bytes = delete_segment.encode()
elif isinstance(delete_segment, bytes):
delete_bytes = delete_segment
else:
raise ValueError(delete_segment)
tokens = self.tokenize(delete_bytes)
# n_tokens = len(tokens)
# token_offset = self.get_token_offset(offset)
# n_from_end_tokens = token_offset['from'] + n_tokens
yield (
Delete(
offset['from'],
None,
None,
None
),
tokens,
[],
)
#self.from_byte_token_index_map[offset["from"]] = n_from_end_tokens
def doHighlightRange(
self, highlight_text, highlightRanges, offset, lineNumber, update_idx="all"):
# The text field is an overlapping mix of both the from and to,
# so we need to handle it highlight-by-highlight.
# there can be gaps between highlight segments.
# for instance, if a word is deleted from the middle of a line.
# we need to track that.
highlight_bytes = highlight_text.encode()
highlight_end = 0
# it's possible for offset['to'] to be null.
# we can get it from the line number?
update_linenumber_map = True
if offset["to"] is None:
keyidx = self.to_linenumber_bytes_map.bisect_right(lineNumber) - 1
if keyidx > 0:
print(self.to_linenumber_bytes_map)
key = self.to_linenumber_bytes_map.keys()[keyidx]
offset["to"] = self.to_linenumber_bytes_map[key]
else:
offset["to"] = 0
highlight_offset = offset
# note that diffs are token-level, but the indexes are byte-level
for highlightRange in highlightRanges:
highlight_start = highlightRange["start"]
# equal bytes in between highlights
if highlight_start > highlight_end:
equal_bytes = highlight_bytes[highlight_end:highlight_start]
n_equal_bytes = len(equal_bytes)
yield from self.doEqual(
equal_bytes, highlight_offset, update_idx=update_idx
)
highlight_offset["from"] += n_equal_bytes
highlight_offset["to"] += n_equal_bytes
if update_linenumber_map:
self.to_linenumber_bytes_map[lineNumber] = highlight_offset['to']
# handle highlighted insert / delete
highlight_end = highlight_start + highlightRange["length"]
range_bytes = highlight_bytes[highlight_start:highlight_end]
n_range_bytes = len(range_bytes)
if highlightRange["type"] == 0:
yield from self.doInsert(
range_bytes, highlight_offset, update_idx=update_idx
)
highlight_offset["to"] += n_range_bytes
if update_linenumber_map:
self.to_linenumber_bytes_map[lineNumber] = highlight_offset['to']
elif highlightRange["type"] == 1:
yield from self.doDelete(
range_bytes, highlight_offset, update_idx=update_idx
)
highlight_offset["from"] += n_range_bytes
else:
raise Exception(entry)
# handle the rest of the line which is equal
if highlight_end < len(highlight_bytes):
range_bytes = highlight_bytes[highlight_end:]
yield from self.doEqual(range_bytes, highlight_offset)
def doParMove(self, from_diff, to_diff):
# the tricky part here is to put the tokens in the right spots.
from_byte_start = from_diff["offset"]["from"]
to_byte_start = to_diff["offset"]["to"]
offset = {"from": from_byte_start, "to": to_byte_start}
# we need to cache the indexes; replace them; then restore
yield from self.doHighlightRange(
to_diff["text"], to_diff["highlightRanges"], offset, to_diff["lineNumber"]
)
class WikiDiffMatcher:
def __init__(
self,
texts: list[str] = None,
tokenizer: Optional[RegexTokenizer] = None,
url: Optional[str] = "http://127.0.0.1:8000",
):
# Pre-compute diffs to reduce traffic overhead.
self.diffs = compute_diffs(url, texts)
self.tokenizer = tokenizer or TOKENIZER
class Processor(DiffEngine.Processor):
def __init__(self, texts, tokenizer=None):
self.diffs = iter(texts)
self.tokenizer = tokenizer or TOKENIZER
self.last_tokens = []
self.previous_text = ""
def update(self, last_tokens):
self.last_tokens = last_tokens
def process(self, text, token_class=None):
# The diff has already been computed, but we need to incrementally
# retrieve it to recreate the behavior DiffState expects.
diff = next(self.diffs)
diffToOperationsMapper = DiffToOperationMap(diff, self.tokenizer)
diffops = list(diffToOperationsMapper.to_operations())
# this happens when revisions are actually equal.
if len(diffops) == 0:
self.last_tokens = self.tokenizer.tokenize(text)
ops = [Equal(0, len(self.last_tokens),
0, len(self.last_tokens))]
return ops, self.last_tokens, self.last_tokens
# we get back the byte indices; now we transform to token indices
diffops.sort(key = lambda t: (t[0].a1 if t[0].a1 is not None else 1e32, t[0].b1))
aorder_ops = []
token_offset = 0
_, aseq, _ = list(zip( * diffops))
for op, tokens, _ in diffops:
a1 = token_offset
if isinstance(op, Equal) or isinstance(op, Delete):
token_offset += len(tokens)
a2 = token_offset
aorder_ops.append(type(op)(a1,
a2,
op.b1,
op.b1))
else:
aorder_ops.append(Insert(a1,
a1,
op.b1,
op.b1))
_, aseq, bseq = zip(* diffops)
diffops = list(zip(aorder_ops, aseq, bseq))
diffops.sort(key = lambda t: (t[0].b1 if t[0].b1 is not None else 1e32, t[0].a1))
_, _, bseq = list(zip(* diffops))
border_ops = []
token_offset = 0
for op, _, tokens in diffops:
b1 = token_offset
if isinstance(op, Equal) or isinstance(op, Insert):
token_offset += len(tokens)
b2 = token_offset
border_ops.append(type(op)(op.a1,
op.a2,
b1,
b2))
else:
border_ops.append(type(op)(op.a1,
op.a2,
b1,
b1))
self.previous_text = text
self.last_tokens = list(chain.from_iterable(aseq))
tokens = list(chain.from_iterable(bseq))
return border_ops, self.last_tokens, tokens
def processor(self, *args, **kwargs):
return self.Processor(self.diffs, self.tokenizer)
def process(self):
# DiffState checks for this method even though it is not called.
raise Exception("Unnecessary implementation")