diff --git a/ngrams/term_frequencies.py b/ngrams/term_frequencies.py index 4e56db9..8924ccb 100755 --- a/ngrams/term_frequencies.py +++ b/ngrams/term_frequencies.py @@ -79,7 +79,10 @@ def weekly_tf(partition, if os.path.exists(f"{output_10p_sample_path}/{ngram_output}"): os.remove(f"{output_10p_sample_path}/{ngram_output}") - batches = dataset.to_batches(columns=['CreatedAt','subreddit','body','author']) + if reddit_dataset == 'comments': + batches = dataset.to_batches(columns=['CreatedAt','subreddit','body','author']) + if reddit_dataset == 'posts': + batches = dataset.to_batches(columns=['CreatedAt','subreddit','title','author']) schema = pa.schema([pa.field('subreddit', pa.string(), nullable=False), pa.field('term', pa.string(), nullable=False),