updated olmo categorization
This commit is contained in:
parent
a9ec0b19ef
commit
cec9d82d41
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@ -83,7 +83,7 @@ with open("/home/nws8519/git/mw-lifecycle-analysis/analysis_data/102725_unified.
|
||||
prompts.append(prompt)
|
||||
inputs = tokenizer(prompts, return_tensors='pt', return_token_type_ids=False, padding=True, truncation=True).to(device)
|
||||
with torch.no_grad():
|
||||
outputs = olmo.generate(**inputs, max_new_tokens=256, do_sample=False)
|
||||
outputs = olmo.generate(**inputs, max_new_tokens=256, do_sample=False, temperature=0)
|
||||
decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
for response_txt in decoded:
|
||||
match = re.search(r"Response: \s*(.*)", response_txt)
|
||||
@ -104,7 +104,7 @@ with open("/home/nws8519/git/mw-lifecycle-analysis/analysis_data/102725_unified.
|
||||
# break
|
||||
df = pd.DataFrame(array_of_categorizations)
|
||||
#print(df.head())
|
||||
df.to_csv('all_110525_olmo_batched_categorized.csv', index=False)
|
||||
df.to_csv('all_120525_olmo_batched_categorized.csv', index=False)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH --cpus-per-task=4
|
||||
#SBATCH --job-name=batched-MW-info-typology
|
||||
#SBATCH --output=110525-batched-mw-olmo-info-cat.log
|
||||
#SBATCH --output=120525-batched-mw-olmo-info-cat.log
|
||||
#SBATCH --mail-type=BEGIN,END,FAIL
|
||||
#SBATCH --mail-user=gaughan@u.northwestern.edu
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user