diff --git a/ngrams/term_frequencies.py b/ngrams/term_frequencies.py index 8a15296..b83f55d 100755 --- a/ngrams/term_frequencies.py +++ b/ngrams/term_frequencies.py @@ -241,15 +241,16 @@ def weekly_tf(partition, author_writer.close() def sort_tf(input_parquet="/gscratch/comdata/output/temp_reddit_comments_by_subreddit.parquet/", - output_parquet="/gscratch/comdata/output/reddit_comments_by_subreddit.parquet/"): + output_parquet="/gscratch/comdata/output/reddit_comments_by_subreddit.parquet/", + tf_name='term'): from pyspark.sql import functions as f from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() df = spark.read.parquet(input_parquet) - df = df.repartition(2000,'term') - df = df.sort(['term','week','subreddit']) - df = df.sortWithinPartitions(['term','week','subreddit']) + df = df.repartition(2000,tf_name) + df = df.sort([tf_name,'week','subreddit']) + df = df.sortWithinPartitions([tf_name,'week','subreddit']) df.write.parquet(output_parquet,mode='overwrite',compression='snappy') def gen_task_list(mwe_pass='first',