1
0
mw-lifecycle-analysis/p2/quest/python_scripts/olmo_labeling/batched_olmo_cat.py
2025-12-07 10:10:15 -06:00

113 lines
5.0 KiB
Python

from transformers import AutoModelForCausalLM, AutoTokenizer, OlmoForCausalLM
import torch
import csv
import pandas as pd
import re
import sys
from pathlib import Path
#import os
#os.environ['BNB_CUDA_VERSION'] = ''
#import bitsandbytes
import nltk
nltk.download('punkt_tab')
cache_directory = "/projects/p32852/cache/"
#load in the different models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.cuda.get_device_name(0))
print(torch.cuda.get_device_properties(0))
#olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0325-32B", torch_dtype=torch.float16, load_in_8bit=True, cache_dir=cache_directory).to(device)
#olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0325-32B-Instruct-GGUF", cache_dir=cache_directory).to(device)
#tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-0325-32B-Instruct-GGUF", cache_dir=cache_directory)
olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-1124-13B-Instruct", cache_dir=cache_directory).to(device)
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-1124-13B-Instruct", padding_side='left')
information_types = Path('/home/nws8519/git/mw-lifecycle-analysis/p2/quest/python_scripts/olmo_labeling/info_definitions.txt').read_text(encoding="utf-8")
prompt_template = Path('/home/nws8519/git/mw-lifecycle-analysis/p2/quest/python_scripts/olmo_labeling/prompt_template_nofs.txt').read_text(encoding="utf-8")
csv.field_size_limit(sys.maxsize)
with open("/home/nws8519/git/mw-lifecycle-analysis/analysis_data/102725_unified.csv", mode='r', newline='') as file:
reader = csv.reader(file)
array_of_categorizations = []
index = -1
for row in reader:
index += 1
if index <= 0:
continue
text_dict = {}
#organizing the data from each citation
text_dict['id'] = row[0]
text_dict['task_title'] = row[1]
task_title = text_dict['task_title']
text_dict['comment_text'] = row[2]
text_dict['date_created'] = row[3]
text_dict['comment_type'] = row[6]
text_dict['TaskPHID'] = row[5]
text_dict['AuthorPHID'] = row[4]
if text_dict['comment_type'] == "task_description":
raw_text = text_dict['task_title'] + ". \n\n" + text_dict['comment_text']
else:
raw_text = text_dict['comment_text']
# comment_text preprocessing per https://arxiv.org/pdf/1902.07093
# 1. replace code with CODE
comment_text = re.sub(r'`[^`]+`', 'CODE', raw_text) # Inline code
comment_text = re.sub(r'```[\s\S]+?```', 'CODE', comment_text) # Block code
# 2. replace quotes with QUOTE
lines = comment_text.split('\n')
lines = ['QUOTE' if line.strip().startswith('>') else line for line in lines]
comment_text = '\n'.join(lines)
# 3. replace Gerrit URLs with GERRIT URL
gerrit_url_pattern = r'https://gerrit\.wikimedia\.org/r/\d+'
comment_text = re.sub(gerrit_url_pattern, 'GERRIT_URL', comment_text)
# replace URL with URL
url_pattern = r'https?://[^\s]+'
comment_text = re.sub(url_pattern, 'URL', comment_text)
# 4. if possible, replace @ with SCREEN_NAME
comment_text = re.sub(r'(^|\s)@\w+', 'SCREEN_NAME', comment_text)
# 5. split into an array of sentences
comment_sentences = nltk.sent_tokenize(comment_text)
text_dict['cleaned_sentences'] = comment_sentences
results = []
batch_size = 2
for i in range(0, len(comment_sentences), batch_size):
batch = comment_sentences[i:i+batch_size]
prompts = []
for sent in batch:
prompt = prompt_template.format_map({"info_definitions": information_types, "sent": sent, "task_title": task_title})
prompts.append(prompt)
inputs = tokenizer(prompts, return_tensors='pt', return_token_type_ids=False, padding=True, truncation=True).to(device)
with torch.no_grad():
outputs = olmo.generate(**inputs, max_new_tokens=256, do_sample=False, temperature=0)
decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for response_txt in decoded:
match = re.search(r"Response: \s*(.*)", response_txt)
#print(match)
if match:
category = re.sub(r'[()",\d]', "", match.group(1)).strip()
category = category.replace("_", " ")
else:
category = "NO CATEGORY"
results.append(category)
torch.cuda.empty_cache()
#print(comment_sentences)
text_dict['sentence_categories']=results
#print(results)
array_of_categorizations.append(text_dict)
#if index == 200:
# break
df = pd.DataFrame(array_of_categorizations)
#print(df.head())
df.to_csv('all_120525_olmo_batched_categorized.csv', index=False)