1
0
mw-lifecycle-analysis/p2/quest/python_scripts/olmo_parallel_cat.py

193 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
#from utils import MyTrainDataset
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, OlmoForCausalLM
import csv
import pandas as pd
import re
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
# ----------------- prompts for LLM
priming = "For the **GIVEN SENTENCE**, please categorize it into one of the defined [[CATEGORIES]]. Each [[CATEGORY]] is described in the TYPOLOGY for reference. Your task is to match the**GIVEN SENTENCE** to the **[[CATEGORY]]** that most accurately describes the content of the comment. Only provide the category as your output. Do not provide any text beyond the category name."
#the typology descriptions are taken straight from https://arxiv.org/pdf/1902.07093
typology = """
TYPOLOGY:
[[EXPECTED BEHAVIOR]], in which stakeholders discuss, from the users perspective, the expected or ideal situation affected by the issue. For example, a participant commented: “My suggestion/request in the near term would be to have an option to make the vocabulary read only so that users who want to be able to leave spacy alone to do streaming data processing dont need to worry about changing memory requirements.”
[[MOTIVATION]], in which stakeholders elaborate on why the issue needs to be fixed or a feature needs to be added. For example, in support of redesigning TensorFlow's input pipeline one participant wrote: “Right now, this method starves my GPU all the time, which is a shame because most other [deep learning] frameworks manage to make this much more performantly.”
[[OBSERVED BUG BEHAVIOR]], which only appears in bug reports and focuses on describing the observed behaviour of the bug. For example, one participant commented: “I found strange behavior using the pipe() method”, then started to describe this behavior.
[[BUG REPRODUCTION]], which also only appears in bug reports and focuses on any report, request, and/or question regarding the reproduction of the bug. For example, one participant commented that a bug was reproducible: “Same problem here, working on Windows 10 with German text.”
[[INVESTIGATION AND EXPLORATION]], in which OSS stakeholders discuss their exploration of ideas about the problem that was thought to have caused the issue. For example, “This result confirms my hypothesis but also shows that the memory increase really isnt all that significant... But it still points to a potential flaw in the design of the library.”
[[SOLUTION DISCUSSION]] is framed around the solution space from the developers point of view, in which participants discuss design ideas and implementation details, as well as suggestions, constraints, challenges, and useful references around such topics. For example, “I know there are multiple ways of approaching this however I strongly recommend node-gyp for performance.”
[[CONTRIBUTION AND COMMITMENT]], in which participants call for contributors and/or voice willingness or unwillingness to contribute to resolving the issue. For example, one potential collaborator said: “I will gladly contribute in any way I can, however, this is something I will not be able to do alone. Would be best if a few other people is interested as well...”
[[TASK PROGRESS]], in which stakeholders request or report progress of tasks and sub-tasks towards the solution of the issue. For example, “I made an initial stab at it... - this is just a proof of concept that gets the version string into nodejs. Ill start working on adding the swig interfaces...”
[[TESTING]], in which participants discuss the testing procedure and results, as well as the system environment, code, data, and feedback involved in testing. For example, “Tested on 0.101 and master - the issue seems to be fixed on master not just for the example document, but for the entire corpus...”
[[FUTURE PLAN]], in which participants discuss the long-term plan related to the issue; such plans usually involve work/ideas that are not required to close the current issue. For example, “For the futures, stay tuned, as were prototyping something in this direction.”
[[POTENTIAL NEW ISSUES AND REQUESTS]], in which participants identify and discuss new bugs or needed features while investigating and addressing the current issue. For example, when discussing a bug in scikit-learn about parallel execution that causes process hanging, one participant said: “As a side point, I note there seems to be a lot more joblib parallelisation overhead in master... that wasnt there in 0.14.”
[[SOLUTION USAGE]] was usually discussed once a full or partial solution of the issue was released and stakeholders asked questions or provided suggestions about how to use the library with the new solution update. For example, “Please help me how to continue training the model [with the new release].”
[[WORKAROUNDS]] focus on discussions about temporary or alternative solutions that can help overcome the issue until the official fix or enhancement is released. For example, in a discussion regarding memory growth for streamed data, one participant expressed his temporary solution: “For now workaround with reloading / collecting nlp object works quite ok in production.”
[[ISSUE CONTENT MANAGEMENT]] focuses on redirecting the discussions and controlling the quality of the comments with respect to the issue. For example, “We might want to move this discussion to here: [link to another issue]”
[[ACTION ON ISSUE]], in which participants comment on the proper actions to perform on the issue itself. For example, “Im going to close this issue because its old and most of the information here is now out of date.”
[[SOCIAL CONVERSATION]], in which participants express emotions such as appreciation, disappointment, annoyance, regret, etc. or engage in small talk. For example, “Im so glad that this has received so much thought and attention!”
"""
instructions="The sentence's category is: "
# ----------------- distributed setup
def setup_ddp():
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(local_rank)
return rank, world_size, local_rank
#cleanup is dist.destroy_process_group()
# ----------------- distributed data set
class SentenceDataset(Dataset):
def __init__(self, comments, comment_types, priming, typology, instructions):
self.samples = []
for idx, comment in enumerate(comments):
cleaned_comment = preprocess_comment(comment)
sentences = split_to_sentences(cleaned_comment)
for sentence in sentences:
given_data = f"**GIVEN SENTENCE: \n ' Type -{comment_types[idx]} \n Text -{sentence}**'\n"
prompt = f"{priming}\n{typology}\n\n{given_data}\n{instructions}"
self.samples.append((idx, sentence, prompt))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
# ----------------- data handling functions
def preprocess_comment(raw_text):
# 1. replace code with CODE
comment_text = re.sub(r'`[^`]+`', 'CODE', raw_text) # Inline code
comment_text = re.sub(r'```[\s\S]+?```', 'CODE', comment_text) # Block code
# 2. replace quotes with QUOTE
lines = comment_text.split('\n')
lines = ['QUOTE' if line.strip().startswith('>') else line for line in lines]
comment_text = '\n'.join(lines)
# 3. replace Gerrit URLs with GERRIT URL
gerrit_url_pattern = r'https://gerrit\.wikimedia\.org/r/\d+'
comment_text = re.sub(gerrit_url_pattern, 'GERRIT_URL', comment_text)
# replace URL with URL
url_pattern = r'https?://[^\s]+'
comment_text = re.sub(url_pattern, 'URL', comment_text)
# 4. if possible, replace @ with SCREEN_NAME
cleaned_text = re.sub(r'(^|\s)@\w+', 'SCREEN_NAME', comment_text)
return cleaned_text
def split_to_sentences(text):
return nltk.sent_tokenize(text)
# ----------------- distributed inference
def main():
# https://github.com/nuitrcs/examplejobs/blob/master/python/pytorch_ddp/multinode_torchrun.py
#prep ddp setting
rank, world_size, local_rank = setup_ddp()
device = torch.device(f"cuda:{local_rank}")
#load in data
df = pd.read_csv("/home/nws8519/git/mw-lifecycle-analysis/p2/quest/072525_pp_biberplus_labels.csv")
# TODO comment out below
df = df.iloc[:50].copy()
comment_texts = df['comment_text'].tolist()
comment_types = df['comment_type'].tolist()
dataset = SentenceDataset(comment_texts, comment_types, priming, typology, instructions)
#split data up across processes
batch_size = 4
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
#load model and wrap in DDP
cache_directory="/projects/p32852/cache/"
if dist.get_rank() == 0:
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-1124-13B", cache_dir=cache_directory)
olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-1124-13B", cache_dir=cache_directory).to(device)
dist.barrier()
if dist.get_rank() != 0:
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-1124-13B", cache_dir=cache_directory, local_files_only=True)
olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-1124-13B", cache_dir=cache_directory, local_files_only=True).to(device)
#olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-1124-13B", cache_dir=cache_directory).to(device)
#tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-1124-13B", cache_dir=cache_directory)
ddp_olmo = DDP(olmo, device_ids=[local_rank])
#prepare to collect results as dictionary
results = dict()
with torch.no_grad():
for batch in dataloader:
comment_idxs, sentences, prompts = batch
# categorize the batch
inputs = tokenizer(prompts, return_tensors='pt', return_token_type_ids=False).to(device)
outputs = ddp_olmo.module.generate(**inputs, max_new_tokens=256, do_sample=False)
decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for idx, response in enumerate(decoded):
match = re.search(r"The sentence's category is: \s*(.*)", response)
if match:
category = match.group(1).strip("[]*")
else:
category = "NO CATEGORY"
comment_idx = int(comment_idxs[idx])
sentence = sentences[idx]
results.setdefault(comment_idx, []).append((sentence, category))
#bring all together
gathered = [None for _ in range(world_size)]
dist.all_gather_object(gathered, results)
if rank == 0:
merged = dict()
for partial in gathered:
for k,v in partial.items():
merged.setdefault(k, []).extend(v)
out_rows = []
for comment_idx, sentence_labels in merged.items():
out_rows.append({
'id': df['id'].iloc[comment_idx],
'task_title': df['task_title'].iloc[comment_idx],
'comment_text': df['comment_text'].iloc[comment_idx],
'AuthorPHID': df['AuthorPHID'].iloc[comment_idx],
'sentence_labels': sentence_labels
})
out_df = pd.DataFrame(out_rows)
print(out_df.head())
#TODO out_df.to_csv("090325_olmo_sentence_categorized.csv")
dist.destroy_process_group()
if __name__ == "__main__":
main()
print('all pau; internal to the script')