updated the labels to try to store in a better format
This commit is contained in:
parent
7e8fb1982b
commit
43fb346318
203736
p2/quest/071525_neurobiber_labels.csv
Normal file
203736
p2/quest/071525_neurobiber_labels.csv
Normal file
File diff suppressed because one or more lines are too long
6
p2/quest/neurobiber-categorization.log
Normal file
6
p2/quest/neurobiber-categorization.log
Normal file
@ -0,0 +1,6 @@
|
||||
starting the job at: Tue Jul 15 14:09:10 CDT 2025
|
||||
setting up the environment
|
||||
running the neurobiber labeling script
|
||||
neurobiber labeling pau
|
||||
job finished, cleaning up
|
||||
job pau at: Tue Jul 15 14:12:26 CDT 2025
|
@ -110,9 +110,13 @@ if __name__ == "__main__":
|
||||
model, tokenizer = load_model_and_tokenizer()
|
||||
preds = predict_batch(model, tokenizer, docs)
|
||||
#new columns in the df for the predicted neurobiber items
|
||||
preds_cols = [f"neurobiber_{i+1}" for i in range(96)]
|
||||
preds_df = pd.DataFrame(preds, columns=preds_cols, index=first_discussion_df.index)
|
||||
final_discussion_df = pd.concat([first_discussion_df, preds_df], axis=1)
|
||||
#preds_cols = [f"neurobiber_{i+1}" for i in range(96)]
|
||||
#preds_df = pd.DataFrame(preds, columns=preds_cols, index=first_discussion_df.index)
|
||||
#final_discussion_df = pd.concat([first_discussion_df, preds_df], axis=1)
|
||||
|
||||
#assigning the preditions as a new column
|
||||
final_discussion_df = first_discussion_df.copy()
|
||||
final_discussion_df["neurobiber_preds"] = list(preds)
|
||||
#assert that order has been preserved
|
||||
for _ in range(10):
|
||||
random_index = random.choice(first_discussion_df.index)
|
||||
@ -120,7 +124,7 @@ if __name__ == "__main__":
|
||||
#assert that there are the same number of rows in first_discussion_df and second_discussion_df
|
||||
assert len(first_discussion_df) == len(final_discussion_df)
|
||||
# if passing the prior asserts, let's write to a csv
|
||||
final_discussion_df.to_csv("/home/nws8519/git/mw-lifecycle-analysis/p2/quest/071425_neurobiber_labels.csv", index=False)
|
||||
final_discussion_df.to_csv("/home/nws8519/git/mw-lifecycle-analysis/p2/quest/071525_neurobiber_labels.csv", index=False)
|
||||
print('neurobiber labeling pau')
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user