1
0
mw-lifecycle-analysis/text_analysis/case1/coref_resolution.ipynb

220 lines
7.2 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "b270bd36-529e-4595-a780-ef6c8151c31f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/gscratch/scrubbed/mjilg/envs/coref-notebook/lib/python3.10/site-packages/torch/cuda/__init__.py:734: UserWarning: Can't initialize NVML\n",
" warnings.warn(\"Can't initialize NVML\")\n"
]
}
],
"source": [
"import pandas as pd \n",
"import spacy"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f6448c6f-2b5d-45f5-a32e-b3b47c16ef85",
"metadata": {},
"outputs": [],
"source": [
"phab_path = \"/mmfs1/gscratch/comdata/users/mjilg/mw-repo-lifecycles/case1/0228_ve_phab_comments.csv\"\n",
"phab_df = pd.read_csv(phab_path)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f32f6eed-3aeb-4b05-8d40-7ed85e7235c5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/gscratch/scrubbed/mjilg/envs/coref-notebook/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"/gscratch/scrubbed/mjilg/envs/coref-notebook/lib/python3.10/site-packages/spacy/util.py:910: UserWarning: [W095] Model 'en_coreference_web_trf' (3.4.0a2) was trained with spaCy v3.3.0 and may not be 100% compatible with the current version (3.7.5). If you see errors or degraded performance, download a newer compatible model or retrain your custom model with the current spaCy version. For more details and available updates, run: python -m spacy validate\n",
" warnings.warn(warn_msg)\n"
]
},
{
"data": {
"text/plain": [
"<spacy_experimental.coref.span_resolver_component.SpanResolver at 0x1495edce13c0>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nlp = spacy.load(\"en_core_web_trf\")\n",
"nlp_coref = spacy.load(\"en_coreference_web_trf\")\n",
"\n",
"# use replace_listeners for the coref components\n",
"nlp_coref.replace_listeners(\"transformer\", \"coref\", [\"model.tok2vec\"])\n",
"nlp_coref.replace_listeners(\"transformer\", \"span_resolver\", [\"model.tok2vec\"])\n",
"\n",
"# we won't copy over the span cleaner - this keeps the head cluster information, which we want\n",
"nlp.add_pipe(\"merge_entities\")\n",
"nlp.add_pipe(\"coref\", source=nlp_coref)\n",
"nlp.add_pipe(\"span_resolver\", source=nlp_coref)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "a5b062d8-2d26-4a3e-a84c-ba0eaf6eb436",
"metadata": {},
"outputs": [],
"source": [
"# https://github.com/explosion/spaCy/discussions/13572\n",
"# https://github.com/explosion/spaCy/issues/13111 \n",
"# https://explosion.ai/blog/coref\n",
"# https://gist.github.com/thomashacker/b5dd6042c092e0a22c2b9243a64a2466\n",
"doc = nlp(\"John is frustrated with the VisualEditor project, he thinks it doesn't work.\")\n"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "999e1656-0036-4ba2-bedf-f54493f67790",
"metadata": {},
"outputs": [],
"source": [
"# https://gist.github.com/thomashacker/b5dd6042c092e0a22c2b9243a64a2466\n",
"from spacy.tokens import Doc\n",
"# Define lightweight function for resolving references in text\n",
"def resolve_references(doc: Doc) -> str:\n",
" \"\"\"Function for resolving references with the coref ouput\n",
" doc (Doc): The Doc object processed by the coref pipeline\n",
" RETURNS (str): The Doc string with resolved references\n",
" \"\"\"\n",
" # token.idx : token.text\n",
" token_mention_mapper = {}\n",
" output_string = \"\"\n",
" clusters = [\n",
" val for key, val in doc.spans.items() if key.startswith(\"coref_cluster\")\n",
" ]\n",
"\n",
" # Iterate through every found cluster\n",
" for cluster in clusters:\n",
" first_mention = cluster[0]\n",
" # Iterate through every other span in the cluster\n",
" for mention_span in list(cluster)[1:]:\n",
" # Set first_mention as value for the first token in mention_span in the token_mention_mapper\n",
" token_mention_mapper[mention_span[0].idx] = first_mention.text + mention_span[0].whitespace_\n",
" \n",
" for token in mention_span[1:]:\n",
" # Set empty string for all the other tokens in mention_span\n",
" token_mention_mapper[token.idx] = \"\"\n",
"\n",
" # Iterate through every token in the Doc\n",
" for token in doc:\n",
" # Check if token exists in token_mention_mapper\n",
" if token.idx in token_mention_mapper:\n",
" output_string += token_mention_mapper[token.idx]\n",
" # Else add original token text\n",
" else:\n",
" output_string += token.text + token.whitespace_\n",
"\n",
" return output_string\n"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "be476647-624b-4e95-ab62-9c6b08f85368",
"metadata": {},
"outputs": [],
"source": [
"def resolving_comment(text):\n",
" doc = nlp(text)\n",
" resolved_text = resolve_references(doc)\n",
" return resolved_text"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "a9628b54-a1df-49cd-a365-9cba59de3421",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'i hate ve.interface, ve.interface always messes up i browser'"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"resolving_comment(\"i hate ve.interface, it always messes up my browser\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46873641-8e88-4829-9e24-4dd5e6749bd1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/gscratch/scrubbed/mjilg/envs/coref-notebook/lib/python3.10/site-packages/thinc/shims/pytorch.py:114: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
" with torch.cuda.amp.autocast(self._mixed_precision):\n"
]
}
],
"source": [
"phab_df['text'] = phab_df['comment_text'].apply(str)\n",
"phab_df['resolved_text'] = phab_df['text'].apply(resolving_comment)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b583feb-1c62-4c96-9ba0-2996d72e70d3",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.21"
}
},
"nbformat": 4,
"nbformat_minor": 5
}