from transformers import AutoModelForCausalLM, AutoTokenizer, OlmoForCausalLM import torch import csv import pandas as pd import re import sys from pathlib import Path #import os #os.environ['BNB_CUDA_VERSION'] = '' #import bitsandbytes import nltk nltk.download('punkt_tab') cache_directory = "/projects/p32852/cache/" #load in the different models device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) print(torch.cuda.get_device_name(0)) print(torch.cuda.get_device_properties(0)) #olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0325-32B", torch_dtype=torch.float16, load_in_8bit=True, cache_dir=cache_directory).to(device) #olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0325-32B-Instruct-GGUF", cache_dir=cache_directory).to(device) #tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-0325-32B-Instruct-GGUF", cache_dir=cache_directory) olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-1124-13B", cache_dir=cache_directory).to(device) tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-1124-13B", padding_side='left') information_types = Path('/home/nws8519/git/mw-lifecycle-analysis/p2/quest/python_scripts/olmo_labeling/info_definitions.txt').read_text(encoding="utf-8") prompt_template = Path('/home/nws8519/git/mw-lifecycle-analysis/p2/quest/python_scripts/olmo_labeling/prompt_template.txt').read_text(encoding="utf-8") csv.field_size_limit(sys.maxsize) with open("/home/nws8519/git/mw-lifecycle-analysis/analysis_data/102725_unified.csv", mode='r', newline='') as file: reader = csv.reader(file) array_of_categorizations = [] index = -1 for row in reader: index += 1 if index <= 0: continue text_dict = {} #organizing the data from each citation text_dict['id'] = row[0] text_dict['task_title'] = row[1] task_title = text_dict['task_title'] text_dict['comment_text'] = row[2] text_dict['date_created'] = row[3] text_dict['comment_type'] = row[6] text_dict['TaskPHID'] = row[5] text_dict['AuthorPHID'] = row[4] if text_dict['comment_type'] == "task_description": raw_text = text_dict['task_title'] + ". \n\n" + text_dict['comment_text'] else: raw_text = text_dict['comment_text'] # comment_text preprocessing per https://arxiv.org/pdf/1902.07093 # 1. replace code with CODE comment_text = re.sub(r'`[^`]+`', 'CODE', raw_text) # Inline code comment_text = re.sub(r'```[\s\S]+?```', 'CODE', comment_text) # Block code # 2. replace quotes with QUOTE lines = comment_text.split('\n') lines = ['QUOTE' if line.strip().startswith('>') else line for line in lines] comment_text = '\n'.join(lines) # 3. replace Gerrit URLs with GERRIT URL gerrit_url_pattern = r'https://gerrit\.wikimedia\.org/r/\d+' comment_text = re.sub(gerrit_url_pattern, 'GERRIT_URL', comment_text) # replace URL with URL url_pattern = r'https?://[^\s]+' comment_text = re.sub(url_pattern, 'URL', comment_text) # 4. if possible, replace @ with SCREEN_NAME comment_text = re.sub(r'(^|\s)@\w+', 'SCREEN_NAME', comment_text) # 5. split into an array of sentences comment_sentences = nltk.sent_tokenize(comment_text) text_dict['cleaned_sentences'] = comment_sentences results = [] batch_size = 2 for i in range(0, len(comment_sentences), batch_size): batch = comment_sentences[i:i+batch_size] prompts = [] for sent in batch: prompt = prompt_template.format_map({"info_definitions": information_types, "sent": sent, "task_title": task_title}) prompts.append(prompt) inputs = tokenizer(prompts, return_tensors='pt', return_token_type_ids=False, padding=True, truncation=True).to(device) with torch.no_grad(): outputs = olmo.generate(**inputs, max_new_tokens=256, do_sample=False) decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True) for response_txt in decoded: match = re.search(r"Response: \s*(.*)", response_txt) #print(match) if match: category = re.sub(r"[(),\d]", "", match.group(1)).strip() else: category = "NO CATEGORY" results.append(category) torch.cuda.empty_cache() text_dict['sentence_categories']=results print(results) array_of_categorizations.append(text_dict) if index == 20: break df = pd.DataFrame(array_of_categorizations) #print(df.head()) #df.to_csv('all_110525_olmo_batched_categorized.csv', index=False)