1
0

add a 'limit' parameter for testing.

This commit is contained in:
Nathan TeBlunthuis 2024-12-01 09:51:49 -08:00
parent 4218bf864b
commit 271cbea7d9

View File

@ -118,7 +118,8 @@ def weekly_tf(partition,
temp_output_tfidf_path="/gscratch/comdata/users/nathante/reddit_tfidf_test_authors.parquet_temp/",
output_terms_path="/gscratch/comdata/output/reddit_ngrams/comment_terms.parquet",
output_authors_path="/gscratch/comdata/output/reddit_ngrams/comment_authors.parquet",
reddit_dataset = 'comments'):
reddit_dataset = 'comments',
limit = None):
if reddit_dataset == 'comments':
tf_func = tf_comments
@ -195,10 +196,18 @@ def weekly_tf(partition,
Path(output_terms_path).mkdir(parents=True, exist_ok=True)
if limit is not None:
n_lines_out = 0
with pq.ParquetWriter(f"{output_terms_path}/{partition}",schema=schema,compression='snappy',flavor='spark') as writer, pq.ParquetWriter(f"{output_authors_path}/{partition}",schema=author_schema,compression='snappy',flavor='spark') as author_writer:
while True:
if limit is not None:
n_lines_left = limit - n_lines_out
if n_lines_left < outchunksize:
outchunksize = n_lines_left
chunk = islice(outrows,outchunksize)
chunk = (c for c in chunk if c[1] is not None)
pddf = pd.DataFrame(chunk, columns=["is_token"] + schema.names)
@ -218,6 +227,12 @@ def weekly_tf(partition,
author_writer.write_table(author_table)
do_break = False
if limit is not None:
if n_lines_out < limit:
n_lines_out += outchunksize
else:
do_break = True
if do_break:
break