Fix tests
Surprisingly replacing list<str> with str doesn't break anything, even baselines. Signed-off-by: Will Beason <willbeason@gmail.com>
This commit is contained in:
		
							parent
							
								
									032fec3198
								
							
						
					
					
						commit
						f9383440a0
					
				| @ -364,7 +364,7 @@ class WikiqTestCase(unittest.TestCase): | |||||||
|         baseline['date_time'] = pd.to_datetime(baseline['date_time']) |         baseline['date_time'] = pd.to_datetime(baseline['date_time']) | ||||||
|         # Split strings to the arrays of reverted IDs so they can be compared. |         # Split strings to the arrays of reverted IDs so they can be compared. | ||||||
|         baseline['revert'] = baseline['revert'].replace(np.nan, None) |         baseline['revert'] = baseline['revert'].replace(np.nan, None) | ||||||
|         baseline['reverteds'] = [None if i is np.nan else [int(j) for j in str(i).split(",")] for i in baseline['reverteds']] |         # baseline['reverteds'] = [None if i is np.nan else [int(j) for j in str(i).split(",")] for i in baseline['reverteds']] | ||||||
|         baseline['sha1'] = baseline['sha1'].replace(np.nan, None) |         baseline['sha1'] = baseline['sha1'].replace(np.nan, None) | ||||||
|         baseline['editor'] = baseline['editor'].replace(np.nan, None) |         baseline['editor'] = baseline['editor'].replace(np.nan, None) | ||||||
|         baseline['anon'] = baseline['anon'].replace(np.nan, None) |         baseline['anon'] = baseline['anon'].replace(np.nan, None) | ||||||
|  | |||||||
							
								
								
									
										67
									
								
								wikiq
									
									
									
									
									
								
							
							
						
						
									
										67
									
								
								wikiq
									
									
									
									
									
								
							| @ -145,10 +145,10 @@ class RegexPair(object): | |||||||
| 
 | 
 | ||||||
|     def get_pyarrow_fields(self): |     def get_pyarrow_fields(self): | ||||||
|         if self.has_groups: |         if self.has_groups: | ||||||
|             fields = [pa.field(self._make_key(cap_group), pa.list_(pa.string())) |             fields = [pa.field(self._make_key(cap_group), pa.string()) | ||||||
|                       for cap_group in self.capture_groups] |                       for cap_group in self.capture_groups] | ||||||
|         else: |         else: | ||||||
|             fields = [pa.field(self.label, pa.list_(pa.string()))] |             fields = [pa.field(self.label, pa.string())] | ||||||
| 
 | 
 | ||||||
|         return fields |         return fields | ||||||
| 
 | 
 | ||||||
| @ -461,6 +461,8 @@ class WikiqParser: | |||||||
|         if self.output_parquet: |         if self.output_parquet: | ||||||
|             writer = pq.ParquetWriter(self.output_file, self.schema, flavor='spark') |             writer = pq.ParquetWriter(self.output_file, self.schema, flavor='spark') | ||||||
|         else: |         else: | ||||||
|  |             print(self.output_file, file=sys.stderr) | ||||||
|  |             print(self.schema, file=sys.stderr) | ||||||
|             writer = pc.CSVWriter(self.output_file, self.schema, write_options=pc.WriteOptions(delimiter='\t')) |             writer = pc.CSVWriter(self.output_file, self.schema, write_options=pc.WriteOptions(delimiter='\t')) | ||||||
| 
 | 
 | ||||||
|         # Iterate through pages |         # Iterate through pages | ||||||
| @ -500,7 +502,8 @@ class WikiqParser: | |||||||
|                 editorid = None if rev.deleted.user or rev.user.id is None else rev.user.id |                 editorid = None if rev.deleted.user or rev.user.id is None else rev.user.id | ||||||
|                 # create a new data object instead of a dictionary. |                 # create a new data object instead of a dictionary. | ||||||
|                 rev_data: Revision = self.revdata_type(revid=rev.id, |                 rev_data: Revision = self.revdata_type(revid=rev.id, | ||||||
|                                              date_time=datetime.fromtimestamp(rev.timestamp.unix(), tz=timezone.utc), |                                                        date_time=datetime.fromtimestamp(rev.timestamp.unix(), | ||||||
|  |                                                                                         tz=timezone.utc), | ||||||
|                                                        articleid=page.id, |                                                        articleid=page.id, | ||||||
|                                                        editorid=editorid, |                                                        editorid=editorid, | ||||||
|                                                        title=page.title, |                                                        title=page.title, | ||||||
| @ -530,11 +533,9 @@ class WikiqParser: | |||||||
|                     if rev_detector is not None: |                     if rev_detector is not None: | ||||||
|                         revert = rev_detector.process(text_sha1, rev.id) |                         revert = rev_detector.process(text_sha1, rev.id) | ||||||
| 
 | 
 | ||||||
|  |                         rev_data.revert = revert is not None | ||||||
|                         if revert: |                         if revert: | ||||||
|                             rev_data.revert = True |  | ||||||
|                             rev_data.reverteds = ",".join([str(s) for s in revert.reverteds]) |                             rev_data.reverteds = ",".join([str(s) for s in revert.reverteds]) | ||||||
|                         else: |  | ||||||
|                             rev_data.revert = False |  | ||||||
| 
 | 
 | ||||||
|                 # if the fact that the edit was minor can be hidden, this might be an issue |                 # if the fact that the edit was minor can be hidden, this might be an issue | ||||||
|                 rev_data.minor = rev.minor |                 rev_data.minor = rev.minor | ||||||
| @ -577,7 +578,7 @@ class WikiqParser: | |||||||
|                             old_rev_data.tokens_removed = len(old_tokens_removed) |                             old_rev_data.tokens_removed = len(old_tokens_removed) | ||||||
|                             old_rev_data.tokens_window = PERSISTENCE_RADIUS - 1 |                             old_rev_data.tokens_window = PERSISTENCE_RADIUS - 1 | ||||||
| 
 | 
 | ||||||
|                             writer.write(rev_data.to_pyarrow()) |                             writer.write(old_rev_data.to_pyarrow()) | ||||||
| 
 | 
 | ||||||
|                 else: |                 else: | ||||||
|                     writer.write(rev_data.to_pyarrow()) |                     writer.write(rev_data.to_pyarrow()) | ||||||
| @ -605,57 +606,7 @@ class WikiqParser: | |||||||
|         print("Done: %s revisions and %s pages." % (rev_count, page_count), |         print("Done: %s revisions and %s pages." % (rev_count, page_count), | ||||||
|               file=sys.stderr) |               file=sys.stderr) | ||||||
| 
 | 
 | ||||||
|         # remember to flush the parquet_buffer if we're done |         writer.close() | ||||||
|         if self.output_parquet is True: |  | ||||||
|             self.flush_parquet_buffer() |  | ||||||
|             self.pq_writer.close() |  | ||||||
| 
 |  | ||||||
|         else: |  | ||||||
|             self.output_file.close() |  | ||||||
| 
 |  | ||||||
|     @staticmethod |  | ||||||
|     def rows_to_table(rg, schema: Schema) -> Table: |  | ||||||
|         cols = [] |  | ||||||
|         first = rg[0] |  | ||||||
|         for col in first: |  | ||||||
|             cols.append([col]) |  | ||||||
| 
 |  | ||||||
|         for row in rg[1:]: |  | ||||||
|             for j in range(len(cols)): |  | ||||||
|                 cols[j].append(row[j]) |  | ||||||
| 
 |  | ||||||
|         arrays: list[Array] = [] |  | ||||||
| 
 |  | ||||||
|         typ: DataType |  | ||||||
|         for i, (col, typ) in enumerate(zip(cols, schema.types)): |  | ||||||
|             try: |  | ||||||
|                 arrays.append(pa.array(col, typ)) |  | ||||||
|             except pa.ArrowInvalid as exc: |  | ||||||
|                 print("column index:", i, "type:", typ, file=sys.stderr) |  | ||||||
|                 print("data:", col, file=sys.stderr) |  | ||||||
|                 print("schema:", file=sys.stderr) |  | ||||||
|                 print(schema, file=sys.stderr) |  | ||||||
|                 raise exc |  | ||||||
|         return pa.Table.from_arrays(arrays, schema=schema) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     """ |  | ||||||
|     Function that actually writes data to the parquet file.  |  | ||||||
|     It needs to transpose the data from row-by-row to column-by-column |  | ||||||
|     """ |  | ||||||
|     def flush_parquet_buffer(self): |  | ||||||
| 
 |  | ||||||
|         """ |  | ||||||
|         Returns the pyarrow table that we'll write |  | ||||||
|         """ |  | ||||||
| 
 |  | ||||||
|         outtable = self.rows_to_table(self.parquet_buffer, self.schema) |  | ||||||
|         if self.pq_writer is None: |  | ||||||
|             self.pq_writer: ParquetWriter = ( |  | ||||||
|                 pq.ParquetWriter(self.output_file, self.schema, flavor='spark')) |  | ||||||
| 
 |  | ||||||
|         self.pq_writer.write_table(outtable) |  | ||||||
|         self.parquet_buffer = [] |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def match_archive_suffix(input_filename): | def match_archive_suffix(input_filename): | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user