import json import sys from collections import namedtuple from itertools import chain from typing import Dict, Generator, List, Optional, Tuple from sortedcontainers import SortedDict import requests from deltas import (Delete, DiffEngine, Equal, Insert, Operation, RegexTokenizer, Token, tokenizers) TOKENIZER = tokenizers.wikitext_split # def find_greatest_le_key(target_key, data_dict): # found_key = None # for key in data_dict: # Iterates over keys in insertion order (which is sorted) # if key <= target_key: # found_key = ( # key # This is the largest key found so far that satisfies the condition # ) # else: # # Since the dictionary is sorted, if key > target_key, # # all subsequent keys will also be > target_key. # return found_key or key # def find_smallest_gt_key(target_key, data_dict): # found_key = None # for key in reversed(data_dict): # Iterates over keys in insertion order (which is sorted) # if key >= target_key: # found_key = ( # key # This is the largest key found so far that satisfies the condition # ) # else: # # Since the dictionary is sorted, if key > target_key, # # all subsequent keys will also be > target_key. # return found_key or key def compute_diffs(url: str, texts: list[str]) -> list: response = None try: response = requests.post(url, json=texts) response.raise_for_status() incremental_diffs = response.json() except requests.exceptions.ConnectionError as e: print( f"Connection Error: Could not connect to the server at {url}. Make sure your local server is running." ) print(e) raise e except requests.exceptions.HTTPError as e: print(f"HTTP Error: {e}") if response is not None: print(f"Response Body: {response.text}") raise e except requests.exceptions.JSONDecodeError as e: # Must come before RequestException as JSONDecodeError is # a subclass. print(f"JSON Decode Error: {e}", file=sys.stderr) if response is not None: print(f"Response Body: {response.text}", file=sys.stderr) raise e except requests.exceptions.RequestException as e: print(f"An unexpected error occurred: {e}") raise e return incremental_diffs class DiffToOperationMap: def __init__(self, diff, tokenizer): self.tokenizer = tokenizer self.diff = json.loads(diff) # the code below is designed to work in bytes because that's how wikidiff2 indexes # self.from_last_end_bytes = 0 # self.from_last_to_bytes = 0 # self.n_from_start_tokens = 0 # self.n_from_end_tokens = 0 # self.n_from_start_tokens = 0 # self.n_to_start_tokens = 0 # self.from_last_end_bytes = 0 # self.to_last_end_bytes = 0 # keeps track of the number of tokens seen so far # to avoid repeated tokenization # self.from_byte_token_index_map: SortedDict[int, int] = SortedDict() # self.to_byte_token_index_map: SortedDict[int, int] = SortedDict() self.par_move_dict = {} # we need to keep track of the bytes of line numbers to recover when wikidiff2 loses offsets. self.to_linenumber_bytes_map: SortedDict[int, int] = SortedDict() self.from_linenumber_bytes_map: SortedDict[int, int] = SortedDict() # def get_token_offset(self, byte_offset): # from_token_start = None # to_token_start = None # from_last_end_bytes = self.from_byte_token_index_map.keys()[-1] # to_last_end_bytes = self.to_byte_token_index_map.keys()[-1] # if byte_offset['from'] is not None: # if byte_offset['from'] < self.from_byte_token_index_map.values()[0]: # from_token_start = 0 # else: # key = self.from_byte_token_index_map.bisect_key_right(byte_offset['from']) # # this could be an issue; we assume that the next tokens are inserted at the end, but maybe they could go even further below? # if key > from_last_end_bytes: # from_token_start = self.from_byte_token_index_map[from_last_end_bytes] # else: # from_token_ # if byte_offset['to'] is not None: # if byte_offset['to'] < self.to_byte_token_index_map.values()[0]: # to_token_start = 0 # else: # key = self.from_byte_token_index_map.bisect_key_right(byte_offset['to']) # if key >= from # if len(self.from_byte_token_index_map) > 0 and byte_offset['from'] != 0: # if ( # byte_offset['from'] >= self.from_last_end_bytes # ): # if the from paragraph is at the end # from_token_start = next( # reversed(self.from_byte_token_index_map.values()) # ) # else: # key = find_greatest_le_key( # byte_offset['from'], self.from_byte_token_index_map # ) # from_token_start = self.from_byte_token_index_map[key] # else: # from_token_start = 0 # to_offset = None # if byte_offset['to'] is not None: # if len(self.to_byte_token_index_map) > 0: # if to_byte_start >= self.to_last_end_bytes: # to_token_start = next(reversed(self.to_byte_token_index_map.values())) # else: # key = find_smallest_gt_key(to_byte_start, self.to_byte_token_index_map) # to_token_start = self.to_byte_token_index_map[key] # else: # to_token_start = 0 # return {'from': from_token_start, # 'to': to_token_start} def tokenize(self, bytes): return self.tokenizer.tokenize(bytes.decode("utf-8")) def to_operations(self): for entry in self.diff["diff"]: entry["text"] += "\n" text = entry["text"] offset = entry["offset"] if offset["from"] and entry.get("lineNumber") is not None : if entry['type'] in [0, 2, 3, 4]: self.from_linenumber_bytes_map[entry["lineNumber"]] = offset["from"] + len(text.encode()) if offset["to"]: if entry['type'] in [0, 1, 3, 5]: self.to_linenumber_bytes_map[entry["lineNumber"]] = offset["to"] + len(text.encode()) # add back the newline # this is the first byte of the line in the 'from' revision. from_start_line = entry["offset"]["from"] # this is the first byte of the line in the 'to' revision. to_start_line = entry["offset"]["to"] if entry["type"] == 0: yield from self.doEqual(text, offset) # a line included in the 'to' revision, but not in the 'from' revision elif entry["type"] == 1: yield from self.doInsert(text, offset) # a line included in the 'from' revision, but not in the 'to' revision elif entry["type"] == 2: yield from self.doDelete(text, offset) elif entry["type"] == 3: yield from self.doHighlightRange( text, entry["highlightRanges"], offset, entry["lineNumber"] ) elif entry["type"] == 4: self.par_move_dict[entry["moveInfo"]["id"]] = entry linkId = entry["moveInfo"]["linkId"] if linkId in self.par_move_dict: yield from self.doParMove(entry, self.par_move_dict[linkId]) # we need to count the tokens in the from revision so token index is correct # self.n_from_end_tokens += len(self.tokenize(entry["text"].encode())) # self.n_from_start_tokens += len( # self.tokenize(entry["text"].encode()) # ) elif entry["type"] == 5: linkId = entry["moveInfo"]["linkId"] if linkId in self.par_move_dict: yield from self.doParMove(self.par_move_dict[linkId], entry) else: self.par_move_dict[entry["moveInfo"]["id"]] = entry # call doHighlightRange just to update the token indices # offset = { # "from": self.n_from_end_tokens, # "to": entry["offset"]["to"], # } # res = self.doHighlightRange( # entry["text"], # entry["highlightRanges"], # offset, # entry["lineNumber"], # update_idx="to", # ) # list(res) # self.n_to_end_tokens += len(self.tokenize(entry["text"].encode())) # self.n_to_start_tokens += len( # self.tokenize(entry["text"].encode()) # ) else: # The 'type' isn't one of the known raise ValueError(d) # mwpersistence expects differences to be represented in order from the # result's perspective ("to"), not the previous text. Thus, if a line # is moved earlier then its insertion should appear before its deletion. # As a rule of thumb, the "to" segments should be non-overlapping and # strictly increasing, while the "from" segments should merely be # non-overlapping. def doEqual(self, equal_segment, offset, update_idx="all"): # if from_token_start is None: # from_token_start = self.n_from_start_tokens # if to_token_start is None: # to_token_start = self.n_to_start_tokens if isinstance(equal_segment, str): equal_bytes = equal_segment.encode() elif isinstance(equal_segment, bytes): equal_bytes = equal_segment else: raise ValueError(equal_segment) tokens = self.tokenize(equal_bytes) n_tokens = len(tokens) # token_offset = self.get_token_offset(offset) # n_from_end_tokens = token_offset['from'] + n_tokens # n_to_end_tokens = token_offset['to'] + n_tokens yield ( Equal( offset['from'], None, offset['to'], None, ), tokens, tokens, ) # if update_idx in ["from", "all"]: # self.n_from_end_tokens = self.n_from_start_tokens = n_from_end_tokens # if update_idx in ["to", "all"]: # self.n_to_end_tokens = self.n_to_start_tokens = n_to_end_tokens # self.from_byte_token_index_map[offset["from"]] = n_from_end_tokens # self.to_byte_token_index_map[offset["to"]] = n_to_end_tokens def doInsert(self, insert_segment, offset, update_idx="all"): if isinstance(insert_segment, str): insert_bytes = insert_segment.encode() elif isinstance(insert_segment, bytes): insert_bytes = insert_segment else: raise ValueError(insert_segment) tokens = self.tokenize(insert_bytes) # n_tokens = len(tokens) # token_offset = self.get_token_offset(offset) # n_to_end_tokens = token_offset['to'] + n_tokens yield ( Insert( None, None, offset['to'], None, ), [], tokens, ) # We have now used more of the "to" tokens. #self.to_byte_token_index_map[offset["to"]] = n_to_end_tokens def doDelete(self, delete_segment, offset, update_idx="all", type=str): if isinstance(delete_segment, str): delete_bytes = delete_segment.encode() elif isinstance(delete_segment, bytes): delete_bytes = delete_segment else: raise ValueError(delete_segment) tokens = self.tokenize(delete_bytes) # n_tokens = len(tokens) # token_offset = self.get_token_offset(offset) # n_from_end_tokens = token_offset['from'] + n_tokens yield ( Delete( offset['from'], None, None, None ), tokens, [], ) #self.from_byte_token_index_map[offset["from"]] = n_from_end_tokens def doHighlightRange( self, highlight_text, highlightRanges, offset, lineNumber, update_idx="all"): # The text field is an overlapping mix of both the from and to, # so we need to handle it highlight-by-highlight. # there can be gaps between highlight segments. # for instance, if a word is deleted from the middle of a line. # we need to track that. highlight_bytes = highlight_text.encode() highlight_end = 0 # it's possible for offset['to'] to be null. # we can get it from the line number? update_linenumber_map = True if offset["to"] is None: keyidx = self.to_linenumber_bytes_map.bisect_right(lineNumber) - 1 if keyidx > 0: print(self.to_linenumber_bytes_map) key = self.to_linenumber_bytes_map.keys()[keyidx] offset["to"] = self.to_linenumber_bytes_map[key] else: offset["to"] = 0 highlight_offset = offset # note that diffs are token-level, but the indexes are byte-level for highlightRange in highlightRanges: highlight_start = highlightRange["start"] # equal bytes in between highlights if highlight_start > highlight_end: equal_bytes = highlight_bytes[highlight_end:highlight_start] n_equal_bytes = len(equal_bytes) yield from self.doEqual( equal_bytes, highlight_offset, update_idx=update_idx ) highlight_offset["from"] += n_equal_bytes highlight_offset["to"] += n_equal_bytes if update_linenumber_map: self.to_linenumber_bytes_map[lineNumber] = highlight_offset['to'] # handle highlighted insert / delete highlight_end = highlight_start + highlightRange["length"] range_bytes = highlight_bytes[highlight_start:highlight_end] n_range_bytes = len(range_bytes) if highlightRange["type"] == 0: yield from self.doInsert( range_bytes, highlight_offset, update_idx=update_idx ) highlight_offset["to"] += n_range_bytes if update_linenumber_map: self.to_linenumber_bytes_map[lineNumber] = highlight_offset['to'] elif highlightRange["type"] == 1: yield from self.doDelete( range_bytes, highlight_offset, update_idx=update_idx ) highlight_offset["from"] += n_range_bytes else: raise Exception(entry) # handle the rest of the line which is equal if highlight_end < len(highlight_bytes): range_bytes = highlight_bytes[highlight_end:] yield from self.doEqual(range_bytes, highlight_offset) def doParMove(self, from_diff, to_diff): # the tricky part here is to put the tokens in the right spots. from_byte_start = from_diff["offset"]["from"] to_byte_start = to_diff["offset"]["to"] offset = {"from": from_byte_start, "to": to_byte_start} # we need to cache the indexes; replace them; then restore yield from self.doHighlightRange( to_diff["text"], to_diff["highlightRanges"], offset, to_diff["lineNumber"] ) class WikiDiffMatcher: def __init__( self, texts: list[str] = None, tokenizer: Optional[RegexTokenizer] = None, url: Optional[str] = "http://127.0.0.1:8000", ): # Pre-compute diffs to reduce traffic overhead. self.diffs = compute_diffs(url, texts) self.tokenizer = tokenizer or TOKENIZER class Processor(DiffEngine.Processor): def __init__(self, texts, tokenizer=None): self.diffs = iter(texts) self.tokenizer = tokenizer or TOKENIZER self.last_tokens = [] self.previous_text = "" def update(self, last_tokens): self.last_tokens = last_tokens def process(self, text, token_class=None): # The diff has already been computed, but we need to incrementally # retrieve it to recreate the behavior DiffState expects. diff = next(self.diffs) diffToOperationsMapper = DiffToOperationMap(diff, self.tokenizer) diffops = list(diffToOperationsMapper.to_operations()) # this happens when revisions are actually equal. if len(diffops) == 0: self.last_tokens = self.tokenizer.tokenize(text) ops = [Equal(0, len(self.last_tokens), 0, len(self.last_tokens))] return ops, self.last_tokens, self.last_tokens # we get back the byte indices; now we transform to token indices diffops.sort(key = lambda t: (t[0].a1 if t[0].a1 is not None else 1e32, t[0].b1)) aorder_ops = [] token_offset = 0 _, aseq, _ = list(zip( * diffops)) for op, tokens, _ in diffops: a1 = token_offset if isinstance(op, Equal) or isinstance(op, Delete): token_offset += len(tokens) a2 = token_offset aorder_ops.append(type(op)(a1, a2, op.b1, op.b1)) else: aorder_ops.append(Insert(a1, a1, op.b1, op.b1)) _, aseq, bseq = zip(* diffops) diffops = list(zip(aorder_ops, aseq, bseq)) diffops.sort(key = lambda t: (t[0].b1 if t[0].b1 is not None else 1e32, t[0].a1)) _, _, bseq = list(zip(* diffops)) border_ops = [] token_offset = 0 for op, _, tokens in diffops: b1 = token_offset if isinstance(op, Equal) or isinstance(op, Insert): token_offset += len(tokens) b2 = token_offset border_ops.append(type(op)(op.a1, op.a2, b1, b2)) else: border_ops.append(type(op)(op.a1, op.a2, b1, b1)) self.previous_text = text self.last_tokens = list(chain.from_iterable(aseq)) tokens = list(chain.from_iterable(bseq)) return border_ops, self.last_tokens, tokens def processor(self, *args, **kwargs): return self.Processor(self.diffs, self.tokenizer) def process(self): # DiffState checks for this method even though it is not called. raise Exception("Unnecessary implementation")