113 lines
5.0 KiB
Python
113 lines
5.0 KiB
Python
from transformers import AutoModelForCausalLM, AutoTokenizer, OlmoForCausalLM
|
|
import torch
|
|
import csv
|
|
import pandas as pd
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
#import os
|
|
#os.environ['BNB_CUDA_VERSION'] = ''
|
|
#import bitsandbytes
|
|
|
|
import nltk
|
|
nltk.download('punkt_tab')
|
|
|
|
cache_directory = "/projects/p32852/cache/"
|
|
#load in the different models
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(device)
|
|
print(torch.cuda.get_device_name(0))
|
|
print(torch.cuda.get_device_properties(0))
|
|
|
|
#olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0325-32B", torch_dtype=torch.float16, load_in_8bit=True, cache_dir=cache_directory).to(device)
|
|
#olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0325-32B-Instruct-GGUF", cache_dir=cache_directory).to(device)
|
|
#tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-0325-32B-Instruct-GGUF", cache_dir=cache_directory)
|
|
olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-1124-13B-Instruct", cache_dir=cache_directory).to(device)
|
|
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-1124-13B-Instruct", padding_side='left')
|
|
|
|
information_types = Path('/home/nws8519/git/mw-lifecycle-analysis/p2/quest/python_scripts/olmo_labeling/info_definitions.txt').read_text(encoding="utf-8")
|
|
prompt_template = Path('/home/nws8519/git/mw-lifecycle-analysis/p2/quest/python_scripts/olmo_labeling/prompt_template_nofs.txt').read_text(encoding="utf-8")
|
|
|
|
csv.field_size_limit(sys.maxsize)
|
|
with open("/home/nws8519/git/mw-lifecycle-analysis/analysis_data/102725_unified.csv", mode='r', newline='') as file:
|
|
reader = csv.reader(file)
|
|
array_of_categorizations = []
|
|
index = -1
|
|
for row in reader:
|
|
index += 1
|
|
if index <= 0:
|
|
continue
|
|
text_dict = {}
|
|
#organizing the data from each citation
|
|
text_dict['id'] = row[0]
|
|
text_dict['task_title'] = row[1]
|
|
task_title = text_dict['task_title']
|
|
text_dict['comment_text'] = row[2]
|
|
text_dict['date_created'] = row[3]
|
|
text_dict['comment_type'] = row[6]
|
|
text_dict['TaskPHID'] = row[5]
|
|
text_dict['AuthorPHID'] = row[4]
|
|
if text_dict['comment_type'] == "task_description":
|
|
raw_text = text_dict['task_title'] + ". \n\n" + text_dict['comment_text']
|
|
else:
|
|
raw_text = text_dict['comment_text']
|
|
|
|
# comment_text preprocessing per https://arxiv.org/pdf/1902.07093
|
|
# 1. replace code with CODE
|
|
comment_text = re.sub(r'`[^`]+`', 'CODE', raw_text) # Inline code
|
|
comment_text = re.sub(r'```[\s\S]+?```', 'CODE', comment_text) # Block code
|
|
# 2. replace quotes with QUOTE
|
|
lines = comment_text.split('\n')
|
|
lines = ['QUOTE' if line.strip().startswith('>') else line for line in lines]
|
|
comment_text = '\n'.join(lines)
|
|
# 3. replace Gerrit URLs with GERRIT URL
|
|
gerrit_url_pattern = r'https://gerrit\.wikimedia\.org/r/\d+'
|
|
comment_text = re.sub(gerrit_url_pattern, 'GERRIT_URL', comment_text)
|
|
# replace URL with URL
|
|
url_pattern = r'https?://[^\s]+'
|
|
comment_text = re.sub(url_pattern, 'URL', comment_text)
|
|
# 4. if possible, replace @ with SCREEN_NAME
|
|
comment_text = re.sub(r'(^|\s)@\w+', 'SCREEN_NAME', comment_text)
|
|
# 5. split into an array of sentences
|
|
comment_sentences = nltk.sent_tokenize(comment_text)
|
|
text_dict['cleaned_sentences'] = comment_sentences
|
|
|
|
results = []
|
|
batch_size = 2
|
|
for i in range(0, len(comment_sentences), batch_size):
|
|
batch = comment_sentences[i:i+batch_size]
|
|
prompts = []
|
|
for sent in batch:
|
|
prompt = prompt_template.format_map({"info_definitions": information_types, "sent": sent, "task_title": task_title})
|
|
prompts.append(prompt)
|
|
inputs = tokenizer(prompts, return_tensors='pt', return_token_type_ids=False, padding=True, truncation=True).to(device)
|
|
with torch.no_grad():
|
|
outputs = olmo.generate(**inputs, max_new_tokens=256, do_sample=False, temperature=0)
|
|
decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
for response_txt in decoded:
|
|
match = re.search(r"Response: \s*(.*)", response_txt)
|
|
#print(match)
|
|
if match:
|
|
category = re.sub(r'[()",\d]', "", match.group(1)).strip()
|
|
category = category.replace("_", " ")
|
|
else:
|
|
category = "NO CATEGORY"
|
|
results.append(category)
|
|
torch.cuda.empty_cache()
|
|
|
|
#print(comment_sentences)
|
|
text_dict['sentence_categories']=results
|
|
#print(results)
|
|
array_of_categorizations.append(text_dict)
|
|
#if index == 200:
|
|
# break
|
|
df = pd.DataFrame(array_of_categorizations)
|
|
#print(df.head())
|
|
df.to_csv('all_120525_olmo_batched_categorized.csv', index=False)
|
|
|
|
|
|
|
|
|
|
|