add the rest of the code.

This commit is contained in:
2024-03-12 09:39:12 -07:00
parent 29abd26b97
commit 2c733a8788
15 changed files with 1909 additions and 0 deletions

348
sample_training_labels.py Executable file
View File

@@ -0,0 +1,348 @@
#!/usr/bin/env python3
'''
Take a stratified sample of article quality labels.
For now we just stratify by label type.
Later we might add date.
Later we might stratify by wikiproject too.
A key limitation of this approach is that we can sample on the level of the page.
We'd really like to be able to sample on the level of edit session.
But that isn't possible because of how article assessments work.
'''
from itertools import islice, chain
from pathlib import Path
import pandas as pd
import numpy as np
random = np.random.RandomState(1968)
import json
import pyarrow.feather as feather
import fire
from collections import Counter
from pyRemembeR import Remember
from enum import IntEnum, unique
from datetime import datetime
from dataclasses import dataclass, asdict
from multiprocessing import Pool
from urllib.parse import unquote
from pyspark.sql import functions as f
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
from numpy import dtype
import csv
def wikiq_to_parquet():
path = Path("/gscratch/comdata/users/nathante/wikiqRunning/wikiq_output/")
outpath = Path("/gscratch/comdata/output/wikiq_enwiki_20200301_nathante_parquet/")
files = list(map(Path,path.glob("*.tsv")))
dumpfile = files[0]
def wikiq_tsv_to_parquet(dumpfile, outpath = Path("/gscratch/comdata/output/wikiq_enwiki_20200301_nathante.parquet/")):
outfile = outpath / (dumpfile.name + ".parquet")
outpath.mkdir(parents=True, exist_ok=True)
_wikiq_tsv_to_parquet(dumpfile,outfile)
dumpfile = Path("/gscratch/comdata/users/nathante/wikiqRunning/wikiq_output/enwiki-20200301-pages-meta-history12-p4980874p5038451.tsv")
def _wikiq_tsv_to_parquet(dumpfile, outfile):
dtypes = {'anon': dtype('O'), 'articleid': dtype('int64'), 'deleted': dtype('bool'), 'editor': dtype('O'), 'editor_id': dtype('float64'), 'minor': dtype('bool'), 'namespace': dtype('int64'), 'revert': dtype('O'), 'reverteds': dtype('O'), 'revid': dtype('int64'), 'sha1': dtype('O'), 'text_chars': dtype('float64'), 'title': dtype('O')}
print(dumpfile)
df = pd.read_csv(dumpfile,sep='\t',quoting=csv.QUOTE_NONE,error_bad_lines=False, warn_bad_lines=True,parse_dates=['date_time'],dtype=dtypes)
df.to_parquet(outfile)
with Pool(28) as pool:
jobs = pool.imap_unordered(wikiq_tsv_to_parquet, files)
list(jobs)
spark = SparkSession.builder.getOrCreate()
@udf(StringType())
def decode_strip_udf(val):
if val is None:
return ""
else:
return unquote(val).strip('\"')
df = spark.read.parquet('/gscratch/comdata/output/wikiq_enwiki_20200301_nathante.parquet')
df = df.withColumnRenamed("anon","anonRaw")
df = df.withColumn("anon",f.when(f.col("anonRaw")=="TRUE",True).otherwise(False))
df = df.drop("anonRaw")
df = df.withColumnRenamed("text_chars","text_chars_raw")
df = df.withColumn("text_chars",f.col("text_chars_raw").cast('int'))
df = df.drop("text_chars_raw")
df = df.withColumnRenamed("editor_id",'editor_id_raw')
df = df.withColumn("editor_id",f.col("editor_id_raw").cast("int"))
df = df.drop("editor_id_raw")
df = df.withColumnRenamed("revert","revert_raw")
df = df.withColumn("revert",f.when(f.col("revert_raw")=="TRUE",True).otherwise(False))
df = df.drop("revert_raw")
df = df.withColumnRenamed("title","title_raw")
df = df.withColumn("title", decode_strip_udf(f.col("title_raw")))
df = df.drop("title_raw")
df = df.withColumnRenamed("editor","editor_raw")
df = df.withColumn("editor", decode_strip_udf(f.col("editor_raw")))
df = df.drop("editor_raw")
df = df.repartition(400,'articleid')
df.write.parquet("/gscratch/comdata/output/wikiq_enwiki_20200301_nathante_partitioned.parquet",mode='overwrite')
@unique
class WP10(IntEnum):
start = 1
stub = 2
c = 3
b = 4
a = 5
ga = 6
fa = 7
@staticmethod
def from_string(s):
return {'start':WP10.start,
'stub':WP10.stub,
'c':WP10.c,
'b':WP10.b,
'a':WP10.a,
'ga':WP10.ga,
'fa':WP10.fa}.get(s,None)
def to_string(self):
return {WP10.start:'start',
WP10.stub:'stub',
WP10.c:'c',
WP10.b:'b',
WP10.a:'a',
WP10.ga:'ga',
WP10.fa:'fa'}[self]
@dataclass
class PageLabel:
timestamp:datetime
wp10:WP10
@staticmethod
def from_json(obj):
timestamp = obj.get('timestamp',None)
if timestamp is not None:
timestamp = datetime.strptime(obj['timestamp'],'%Y%m%d%H%M%S')
else:
timestamp = None
return PageLabel(timestamp=timestamp,
wp10=WP10.from_string(obj.get('wp10')))
@staticmethod
def from_row(row):
return PageLabel(timestamp = row.timestamp,
wp10 = WP10(row.wp10))
def to_json(self):
d = asdict(self)
if self.timestamp is not None:
d['timestamp'] = self.timestamp.strftime('%Y%m%d%H%M%S')
if self.wp10 is not None:
d['wp10'] = self.wp10.to_string()
return json.dumps(d)
@dataclass
class TalkPageLabel(PageLabel):
dump_talk_page_title:str
talk_page_id:int
project:str
@staticmethod
def from_json(obj):
res = PageLabel.from_json(obj)
return TalkPageLabel(dump_talk_page_title=obj.get('dump_talk_page_title',None),
talk_page_id=obj.get('talk_page_id',None),
project=obj.get("project",None),
**asdict(res)
)
@staticmethod
def from_row(row):
res = PageLabel.from_row(row)
return TalkPageLabel(dump_talk_page_title = row.dump_talk_page_title,
talk_page_id = row.talk_page_id,
project = row.project
**asdict(res))
@dataclass
class ArticlePageLabel(PageLabel):
'''class representing labels to a page'''
title: str
articleid: int
revid:int
@staticmethod
def from_json(obj):
res = PageLabel.from_json(obj)
return ArticlePageLabel(title=obj.get('title',None),
articleid=obj.get('articleid',None),
**asdict(res)
)
@staticmethod
def from_row(row):
res = PageLabel.from_row(row)
return ArticlePageLabel(title = row.title,
articleid = row.articleid,
revid = row.revid,
**asdict(res))
infiles="enwiki-20200301-pages-meta-history*.xml-p*.7z_article_labelings.json"; samplesize=5000*7
def main(infiles="enwiki-20200301-pages-meta-history*.xml-p*.7z_article_labelings.json", samplesize=5000*7):
path = Path('data')
infiles = path.glob(infiles)
pool = Pool(28)
lines = chain(* map(lambda f: open(f,'r'), infiles))
parsed = pool.imap_unordered(json.loads, lines, chunksize=int(1e3))
formatted = pool.imap_unordered(TalkPageLabel.from_json, parsed, chunksize=int(1e3))
dicted = pool.imap_unordered(asdict,formatted, chunksize=int(1e3))
# data frame of the the latest labels.
df = pd.DataFrame(dicted)
df = df.loc[df.timestamp <= datetime(2019,1,1)]
groups = df.groupby(["talk_page_id"])
max_labels = groups.wp10.max().reset_index()
df2 = pd.merge(df,max_labels,on=['talk_page_id','wp10'],how='right')
last_timestamp = df2.groupby(['talk_page_id']).timestamp.max().reset_index()
df2 = pd.merge(df2, last_timestamp, on=['talk_page_id','timestamp'], how='right')
first_project = df2.groupby(['talk_page_id']).project.first()
df2 = pd.merge(df2, first_project,on=['talk_page_id','project'], how='right')
tpid = df2
#.wp10.max().reset_index()
tpid = tpid.loc[~tpid.dump_talk_page_title.isna()]
# pick out just the samples we want.
spark = SparkSession.builder.getOrCreate()
sparkdf = spark.read.parquet("/gscratch/comdata/output/wikiq_enwiki_20200301_nathante_partitioned.parquet")
tpid['timestamp'] = tpid['timestamp'].dt.tz_localize('utc')
labels = spark.createDataFrame(tpid)
talks = sparkdf.filter(sparkdf.namespace==1)
articles = sparkdf.filter(sparkdf.namespace==0)
# labels = labels.join(talks,on=[labels.talk_page_id == talks.articleid],how='left_outer')
talks = talks.join(labels,on=[labels.talk_page_id == talks.articleid])
#talks.filter(talks.wp10==7).select('talk_page_id').distinct().count()
talks = talks.withColumn('timediff', f.datediff(talks.timestamp, talks.date_time))
talks = talks.filter(talks.timediff <= 0)
win = Window.partitionBy("talk_page_id")
talks = talks.withColumn('best_timediff', f.max('timediff').over(win))
talks = talks.filter(talks.timediff == talks.best_timediff)
talks = talks.withColumn('article_title',f.substring_index(f.col("title"),':',-1))
talks = talks.select(['article_title','wp10',f.col('timestamp').alias('timestamp'),'talk_page_id']).distinct()
articles = articles.join(talks,on=[talks.article_title == articles.title])
articles = articles.withColumn('timediff', f.datediff(articles.timestamp, articles.date_time))
articles = articles.filter(articles.timediff <= 0)
win2 = Window.partitionBy("articleid")
articles = articles.filter(f.col("revert")==False)
articles = articles.withColumn('best_timediff', f.max('timediff').over(win2))
articles = articles.filter(articles.timediff == articles.best_timediff)
articles = articles.select(['revid','timestamp','wp10','articleid','title'])
articles = articles.groupby(['timestamp','wp10','articleid','title']).agg(f.first(f.col("revid")).alias("revid"))
articles.write.parquet("data/article_quality_data.parquet",mode='overwrite')
tpid = pd.read_parquet("data/article_quality_data.parquet")
# we want to sample /papges/ not /labels/.
# so we need to do a /full/ groupby pages.
# this is why we have a lot of RAM!
# we need the number of
label_counts = {}
sample_page_ids = {}
label_max_samplesize = int(samplesize / len(WP10))
sample_chunks = []
for lab in WP10:
print(lab)
page_ids = tpid.loc[tpid.wp10==lab].articleid
label_counts[lab] = len(page_ids)
print(lab,label_counts)
if(label_counts[lab] <= label_max_samplesize):
sample_page_ids[lab] = page_ids
else:
sample_page_ids[lab] = random.choice(page_ids,label_max_samplesize,replace=False)
# get the labels for each sampled article
sample_data_lab = tpid.loc[(tpid.articleid.isin(sample_page_ids[lab]))]
sample_chunks.append(sample_data_lab)
remember = Remember(f='remember_sample_quality_labels.RDS')
remember(label_max_samplesize, 'label_max_samplesize')
# Note that different wikiprojects can have different labels
sample = pd.concat(sample_chunks,ignore_index=True)
revisions_per_article = sparkdf.filter(sparkdf.namespace==0).select(['revid','articleid','date_time','title'])
revisions_per_article = revisions_per_article.filter(f.col("date_time") >= datetime(2019,1,1))
revisions_per_article = revisions_per_article.filter(f.col("date_time") <= datetime(2019,12,31))
revisions_per_article = revisions_per_article.groupby(["articleid",'title']).count().toPandas()
revisions_per_article['title'] = revisions_per_article.title.apply(lambda s: unquote(s).strip('\"'))
revisions_per_article = pd.merge(revisions_per_article,tpid,left_on='articleid',right_on='articleid')
revisions_per_class = revisions_per_article.groupby('wp10').agg({'count':'sum'}).reset_index()
revisions_per_class['wp10'] = revisions_per_class.wp10.apply(lambda s: WP10(s).to_string())
label_counts = pd.DataFrame({'wp10':map(lambda x: x.to_string(),label_counts.keys()),'n_articles':label_counts.values()})
label_counts = pd.merge(label_counts,revisions_per_class,left_on='wp10',right_on='wp10')
label_counts = label_counts.rename(columns={'count':'n_revisions'})
remember(label_counts, 'label_sample_counts')
sample.to_feather("data/20200301_article_labelings_sample.feather")
sample = pd.read_feather("data/20200301_article_labelings_sample.feather")
sample_counts = sample.articleid.groupby(sample.wp10).count().reset_index()
remember(sample_counts,'sample_counts')
sample_labels = sample.apply(ArticlePageLabel.from_row,axis=1)
sample_labels = map(PageLabel.to_json, sample_labels)
with open("data/20200301_article_labelings_sample.json",'w') as of:
of.writelines((l + '\n' for l in sample_labels))
pool.close()
if __name__ == "__main__":
fire.Fire(main)