1
0
cdsc_reddit/ngrams/top_comment_phrases.py
2024-12-06 08:09:02 -08:00

61 lines
2.7 KiB
Python
Executable File

#!/usr/bin/env python3
from pyspark.sql import functions as f
from pyspark.sql import Window
from pyspark.sql import SparkSession
import numpy as np
spark = SparkSession.builder.config(map={'spark.executor.memory':'900g','spark.executor.cores':128,'spark.sql.execution.arrow.pyspark.enabled':False}).getOrCreate()
df = spark.read.text("/gscratch/comdata/output/reddit_ngrams/reddit_comment_ngrams_10p_sample/")
df2 = spark.read.text("/gscratch/comdata/output/reddit_ngrams/reddit_post_ngrams_10p_sample/")
df = df.union(df2)
df = df.withColumnRenamed("value","phrase")
# count phrase occurrances
phrases = df.groupby('phrase').count()
phrases = phrases.withColumnRenamed('count','phraseCount')
phrases = phrases.filter(phrases.phraseCount > 10)
# count overall
N = phrases.select(f.sum(phrases.phraseCount).alias("phraseCount")).collect()[0].phraseCount
print(f'analyzing PMI on a sample of {N} phrases')
logN = np.log(N)
phrases = phrases.withColumn("phraseLogProb", f.log(f.col("phraseCount")) - logN)
# count term occurrances
phrases = phrases.withColumn('terms',f.split(f.col('phrase'),' '))
terms = phrases.select(['phrase','phraseCount','phraseLogProb',f.explode(phrases.terms).alias('term')])
win = Window.partitionBy('term')
terms = terms.withColumn('termCount',f.sum('phraseCount').over(win))
terms = terms.withColumnRenamed('count','termCount')
terms = terms.withColumn('termLogProb',f.log(f.col('termCount')) - logN)
terms = terms.groupBy(terms.phrase, terms.phraseLogProb, terms.phraseCount).sum('termLogProb')
terms = terms.withColumnRenamed('sum(termLogProb)','termsLogProb')
terms = terms.withColumn("phrasePWMI", f.col('phraseLogProb') - f.col('termsLogProb'))
# join phrases to term counts
df = terms.select(['phrase','phraseCount','phraseLogProb','phrasePWMI'])
df = df.sort(['phrasePWMI'],descending=True)
df = df.sortWithinPartitions(['phrasePWMI'],descending=True)
df.write.parquet("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet/",mode='overwrite',compression='snappy')
df = spark.read.parquet("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet/")
df.write.csv("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.csv/",mode='overwrite',compression='none')
import pyarrow.parquet as pq
import pyarrow.feather as feather
from pyarrow import csv
table = pq.read_table("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet", filters = [[('phraseCount','>', 3500),('phrasePWMI','>',3)]], columns=['phrase','phraseCount','phraseLogProb','phrasePWMI'])
feather.write_feather(table,"/gscratch/comdata/output/reddit_ngrams/reddit_multiword_expressions.feather")
csv.write_csv(table,"/gscratch/comdata/output/reddit_ngrams/reddit_multiword_expressions.csv")