diff --git a/ngrams/term_frequencies.py b/ngrams/term_frequencies.py index 30e1336..9703f3c 100755 --- a/ngrams/term_frequencies.py +++ b/ngrams/term_frequencies.py @@ -119,9 +119,10 @@ def weekly_tf(partition, if reddit_dataset == 'comments': tf_func = tf_comments + nullable_schema = False elif reddit_dataset == 'posts': tf_func = tf_posts - + nullable_schema = True dataset = ds.dataset(f"{input_parquet}/{partition}", format='parquet') if not os.path.exists(output_10p_sample_path): os.mkdir(output_10p_sample_path) @@ -140,16 +141,16 @@ def weekly_tf(partition, if reddit_dataset == 'posts': batches = dataset.to_batches(columns=['CreatedAt','subreddit','title','author']) - schema = pa.schema([pa.field('subreddit', pa.string(), nullable=False), - pa.field('term', pa.string(), nullable=False), - pa.field('week', pa.date32(), nullable=False), - pa.field('tf', pa.int64(), nullable=False)] + schema = pa.schema([pa.field('subreddit', pa.string(), nullable=nullable_schema), + pa.field('term', pa.string(), nullable=nullable_schema), + pa.field('week', pa.date32(), nullable=nullable_schema), + pa.field('tf', pa.int64(), nullable=nullable_schema)] ) - author_schema = pa.schema([pa.field('subreddit', pa.string(), nullable=False), - pa.field('author', pa.string(), nullable=False), - pa.field('week', pa.date32(), nullable=False), - pa.field('tf', pa.int64(), nullable=False)] + author_schema = pa.schema([pa.field('subreddit', pa.string(), nullable=nullable_schema), + pa.field('author', pa.string(), nullable=nullable_schema), + pa.field('week', pa.date32(), nullable=nullable_schema), + pa.field('tf', pa.int64(), nullable=nullable_schema)] ) dfs = (b.to_pandas() for b in batches)