From 0613193e9dbc55b7c3fdcb8a5cacf1f732a47893 Mon Sep 17 00:00:00 2001 From: Nathan TeBlunthuis Date: Sat, 11 Jan 2025 18:57:02 -0800 Subject: [PATCH] support passing in a model object. --- similarities/similarities_helper.py | 7 +++++-- similarities/weekly_cosine_similarities.py | 21 ++++++++++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/similarities/similarities_helper.py b/similarities/similarities_helper.py index 0050c1d..62d5590 100644 --- a/similarities/similarities_helper.py +++ b/similarities/similarities_helper.py @@ -230,7 +230,7 @@ def test_lsi_sims(): # if n_components is a list we'll return a list of similarities with different latent dimensionalities # if algorithm is 'randomized' instead of 'arpack' then n_iter gives the number of iterations. # this function takes the svd and then the column similarities of it -def lsi_column_similarities(tfidfmat,n_components=300,n_iter=10,random_state=1968,algorithm='randomized',lsi_model_save=None,lsi_model_load=None): +def lsi_column_similarities(tfidfmat,n_components=300,n_iter=10,random_state=1968,algorithm='randomized',lsi_model=None,lsi_model_save=None,lsi_model_load=None): # first compute the lsi of the matrix # then take the column similarities @@ -241,7 +241,10 @@ def lsi_column_similarities(tfidfmat,n_components=300,n_iter=10,random_state=196 svd_components = n_components[0] - if lsi_model_load is not None and Path(lsi_model_load).exists(): + if lsi_model is not None: + mod = lsi_model + + elif lsi_model_load is not None and Path(lsi_model_load).exists(): print("loading LSI") mod = pickle.load(open(lsi_model_load ,'rb')) lsi_model_save = lsi_model_load diff --git a/similarities/weekly_cosine_similarities.py b/similarities/weekly_cosine_similarities.py index 620ed37..1e3dc39 100755 --- a/similarities/weekly_cosine_similarities.py +++ b/similarities/weekly_cosine_similarities.py @@ -70,7 +70,10 @@ def cosine_similarities_weekly_lsi(*args, n_components=100, lsi_model=None, **kw term_colname= kwargs.get('term_colname') # lsi_model = "/gscratch/comdata/users/nathante/competitive_exclusion_reddit/data/similarity/comment_authors_compex_LSI/1000_author_LSIMOD.pkl" #simfunc = partial(lsi_column_similarities,n_components=n_components,random_state=random_state,algorithm='randomized',lsi_model=lsi_model) - simfunc = partial(lsi_column_similarities,n_components=n_components,random_state=kwargs.get('random_state'),lsi_model_load=lsi_model) + if isinstance(lsi_model,str): + lsi_model = pickle.load(open(lsi_model,'rb')) + + simfunc = partial(lsi_column_similarities,n_components=n_components,random_state=kwargs.get('random_state'),lsi_model=lsi_model) return cosine_similarities_weekly(*args, simfunc=simfunc, **kwargs) @@ -92,20 +95,20 @@ def cosine_similarities_weekly(tfidf_path, outfile, term_colname, included_subre nterms = conn.execute(f"SELECT MAX({term_colname + '_id'}) as nterms FROM read_parquet('{tfidf_path}/*/*.parquet')").df() nterms = nterms.nterms.values nterms = int(nterms[0]) - weeks = conn.execute(f"SELECT DISTINCT week FROM read_parquet('{tfidf_path}/*/*.parquet')").df() + weeks = conn.execute(f"SELECT DISTINCT CAST(CAST(week AS DATE) AS STRING) AS week FROM read_parquet('{tfidf_path}/*/*.parquet')").df() weeks = weeks.week.values conn.close() print(f"computing weekly similarities") week_similarities_helper = partial(_week_similarities,simfunc=simfunc, tfidf_path=tfidf_path, term_colname=term_colname, outdir=outfile, min_df=min_df, max_df=max_df, included_subreddits=included_subreddits, topN=None, subreddit_names=subreddit_names,nterms=nterms) - for week in weeks: - week_similarities_helper(week) - # pool = Pool(cpu_count()) - - # list(pool.imap(week_similarities_helper, weeks)) - # pool.close() - # with Pool(cpu_count()) as pool: # maybe it can be done with 40 cores on the huge machine? + # for week in weeks: + # week_similarities_helper(week) + + with Pool(cpu_count()) as pool: # maybe it can be done with 128 cores on the huge machine? + list(pool.imap(week_similarities_helper, weeks)) + pool.close() + def author_cosine_similarities_weekly(outfile, infile='/gscratch/comdata/output/reddit_similarity/tfidf_weekly/comment_authors_test.parquet', min_df=2, max_df=None, included_subreddits=None, topN=500, static_tfidf_path=None):