add the rest of the code.
This commit is contained in:
348
sample_training_labels.py
Executable file
348
sample_training_labels.py
Executable file
@@ -0,0 +1,348 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
'''
|
||||
Take a stratified sample of article quality labels.
|
||||
|
||||
For now we just stratify by label type.
|
||||
Later we might add date.
|
||||
Later we might stratify by wikiproject too.
|
||||
|
||||
A key limitation of this approach is that we can sample on the level of the page.
|
||||
We'd really like to be able to sample on the level of edit session.
|
||||
But that isn't possible because of how article assessments work.
|
||||
'''
|
||||
from itertools import islice, chain
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
random = np.random.RandomState(1968)
|
||||
import json
|
||||
import pyarrow.feather as feather
|
||||
import fire
|
||||
from collections import Counter
|
||||
from pyRemembeR import Remember
|
||||
from enum import IntEnum, unique
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, asdict
|
||||
from multiprocessing import Pool
|
||||
from urllib.parse import unquote
|
||||
from pyspark.sql import functions as f
|
||||
from pyspark.sql import SparkSession, Window
|
||||
from pyspark.sql.functions import udf
|
||||
from pyspark.sql.types import StringType
|
||||
from numpy import dtype
|
||||
import csv
|
||||
|
||||
def wikiq_to_parquet():
|
||||
|
||||
path = Path("/gscratch/comdata/users/nathante/wikiqRunning/wikiq_output/")
|
||||
outpath = Path("/gscratch/comdata/output/wikiq_enwiki_20200301_nathante_parquet/")
|
||||
files = list(map(Path,path.glob("*.tsv")))
|
||||
dumpfile = files[0]
|
||||
|
||||
def wikiq_tsv_to_parquet(dumpfile, outpath = Path("/gscratch/comdata/output/wikiq_enwiki_20200301_nathante.parquet/")):
|
||||
outfile = outpath / (dumpfile.name + ".parquet")
|
||||
outpath.mkdir(parents=True, exist_ok=True)
|
||||
_wikiq_tsv_to_parquet(dumpfile,outfile)
|
||||
|
||||
dumpfile = Path("/gscratch/comdata/users/nathante/wikiqRunning/wikiq_output/enwiki-20200301-pages-meta-history12-p4980874p5038451.tsv")
|
||||
|
||||
def _wikiq_tsv_to_parquet(dumpfile, outfile):
|
||||
|
||||
dtypes = {'anon': dtype('O'), 'articleid': dtype('int64'), 'deleted': dtype('bool'), 'editor': dtype('O'), 'editor_id': dtype('float64'), 'minor': dtype('bool'), 'namespace': dtype('int64'), 'revert': dtype('O'), 'reverteds': dtype('O'), 'revid': dtype('int64'), 'sha1': dtype('O'), 'text_chars': dtype('float64'), 'title': dtype('O')}
|
||||
|
||||
print(dumpfile)
|
||||
df = pd.read_csv(dumpfile,sep='\t',quoting=csv.QUOTE_NONE,error_bad_lines=False, warn_bad_lines=True,parse_dates=['date_time'],dtype=dtypes)
|
||||
|
||||
df.to_parquet(outfile)
|
||||
|
||||
with Pool(28) as pool:
|
||||
jobs = pool.imap_unordered(wikiq_tsv_to_parquet, files)
|
||||
list(jobs)
|
||||
|
||||
spark = SparkSession.builder.getOrCreate()
|
||||
|
||||
@udf(StringType())
|
||||
def decode_strip_udf(val):
|
||||
if val is None:
|
||||
return ""
|
||||
else:
|
||||
return unquote(val).strip('\"')
|
||||
df = spark.read.parquet('/gscratch/comdata/output/wikiq_enwiki_20200301_nathante.parquet')
|
||||
df = df.withColumnRenamed("anon","anonRaw")
|
||||
df = df.withColumn("anon",f.when(f.col("anonRaw")=="TRUE",True).otherwise(False))
|
||||
df = df.drop("anonRaw")
|
||||
df = df.withColumnRenamed("text_chars","text_chars_raw")
|
||||
df = df.withColumn("text_chars",f.col("text_chars_raw").cast('int'))
|
||||
df = df.drop("text_chars_raw")
|
||||
df = df.withColumnRenamed("editor_id",'editor_id_raw')
|
||||
df = df.withColumn("editor_id",f.col("editor_id_raw").cast("int"))
|
||||
df = df.drop("editor_id_raw")
|
||||
df = df.withColumnRenamed("revert","revert_raw")
|
||||
df = df.withColumn("revert",f.when(f.col("revert_raw")=="TRUE",True).otherwise(False))
|
||||
df = df.drop("revert_raw")
|
||||
df = df.withColumnRenamed("title","title_raw")
|
||||
df = df.withColumn("title", decode_strip_udf(f.col("title_raw")))
|
||||
df = df.drop("title_raw")
|
||||
df = df.withColumnRenamed("editor","editor_raw")
|
||||
df = df.withColumn("editor", decode_strip_udf(f.col("editor_raw")))
|
||||
df = df.drop("editor_raw")
|
||||
df = df.repartition(400,'articleid')
|
||||
df.write.parquet("/gscratch/comdata/output/wikiq_enwiki_20200301_nathante_partitioned.parquet",mode='overwrite')
|
||||
|
||||
@unique
|
||||
class WP10(IntEnum):
|
||||
start = 1
|
||||
stub = 2
|
||||
c = 3
|
||||
b = 4
|
||||
a = 5
|
||||
ga = 6
|
||||
fa = 7
|
||||
|
||||
@staticmethod
|
||||
def from_string(s):
|
||||
return {'start':WP10.start,
|
||||
'stub':WP10.stub,
|
||||
'c':WP10.c,
|
||||
'b':WP10.b,
|
||||
'a':WP10.a,
|
||||
'ga':WP10.ga,
|
||||
'fa':WP10.fa}.get(s,None)
|
||||
|
||||
def to_string(self):
|
||||
return {WP10.start:'start',
|
||||
WP10.stub:'stub',
|
||||
WP10.c:'c',
|
||||
WP10.b:'b',
|
||||
WP10.a:'a',
|
||||
WP10.ga:'ga',
|
||||
WP10.fa:'fa'}[self]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PageLabel:
|
||||
timestamp:datetime
|
||||
wp10:WP10
|
||||
|
||||
@staticmethod
|
||||
def from_json(obj):
|
||||
timestamp = obj.get('timestamp',None)
|
||||
if timestamp is not None:
|
||||
timestamp = datetime.strptime(obj['timestamp'],'%Y%m%d%H%M%S')
|
||||
else:
|
||||
timestamp = None
|
||||
|
||||
return PageLabel(timestamp=timestamp,
|
||||
wp10=WP10.from_string(obj.get('wp10')))
|
||||
|
||||
@staticmethod
|
||||
def from_row(row):
|
||||
return PageLabel(timestamp = row.timestamp,
|
||||
wp10 = WP10(row.wp10))
|
||||
|
||||
def to_json(self):
|
||||
d = asdict(self)
|
||||
|
||||
if self.timestamp is not None:
|
||||
d['timestamp'] = self.timestamp.strftime('%Y%m%d%H%M%S')
|
||||
|
||||
if self.wp10 is not None:
|
||||
d['wp10'] = self.wp10.to_string()
|
||||
|
||||
return json.dumps(d)
|
||||
|
||||
@dataclass
|
||||
class TalkPageLabel(PageLabel):
|
||||
dump_talk_page_title:str
|
||||
talk_page_id:int
|
||||
project:str
|
||||
|
||||
@staticmethod
|
||||
def from_json(obj):
|
||||
res = PageLabel.from_json(obj)
|
||||
|
||||
return TalkPageLabel(dump_talk_page_title=obj.get('dump_talk_page_title',None),
|
||||
talk_page_id=obj.get('talk_page_id',None),
|
||||
project=obj.get("project",None),
|
||||
**asdict(res)
|
||||
)
|
||||
@staticmethod
|
||||
def from_row(row):
|
||||
res = PageLabel.from_row(row)
|
||||
return TalkPageLabel(dump_talk_page_title = row.dump_talk_page_title,
|
||||
talk_page_id = row.talk_page_id,
|
||||
project = row.project
|
||||
**asdict(res))
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArticlePageLabel(PageLabel):
|
||||
'''class representing labels to a page'''
|
||||
title: str
|
||||
articleid: int
|
||||
revid:int
|
||||
|
||||
@staticmethod
|
||||
def from_json(obj):
|
||||
res = PageLabel.from_json(obj)
|
||||
|
||||
return ArticlePageLabel(title=obj.get('title',None),
|
||||
articleid=obj.get('articleid',None),
|
||||
**asdict(res)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_row(row):
|
||||
res = PageLabel.from_row(row)
|
||||
return ArticlePageLabel(title = row.title,
|
||||
articleid = row.articleid,
|
||||
revid = row.revid,
|
||||
**asdict(res))
|
||||
|
||||
infiles="enwiki-20200301-pages-meta-history*.xml-p*.7z_article_labelings.json"; samplesize=5000*7
|
||||
|
||||
def main(infiles="enwiki-20200301-pages-meta-history*.xml-p*.7z_article_labelings.json", samplesize=5000*7):
|
||||
path = Path('data')
|
||||
infiles = path.glob(infiles)
|
||||
|
||||
pool = Pool(28)
|
||||
|
||||
lines = chain(* map(lambda f: open(f,'r'), infiles))
|
||||
|
||||
parsed = pool.imap_unordered(json.loads, lines, chunksize=int(1e3))
|
||||
formatted = pool.imap_unordered(TalkPageLabel.from_json, parsed, chunksize=int(1e3))
|
||||
dicted = pool.imap_unordered(asdict,formatted, chunksize=int(1e3))
|
||||
|
||||
# data frame of the the latest labels.
|
||||
df = pd.DataFrame(dicted)
|
||||
|
||||
df = df.loc[df.timestamp <= datetime(2019,1,1)]
|
||||
|
||||
groups = df.groupby(["talk_page_id"])
|
||||
max_labels = groups.wp10.max().reset_index()
|
||||
|
||||
df2 = pd.merge(df,max_labels,on=['talk_page_id','wp10'],how='right')
|
||||
last_timestamp = df2.groupby(['talk_page_id']).timestamp.max().reset_index()
|
||||
|
||||
df2 = pd.merge(df2, last_timestamp, on=['talk_page_id','timestamp'], how='right')
|
||||
first_project = df2.groupby(['talk_page_id']).project.first()
|
||||
df2 = pd.merge(df2, first_project,on=['talk_page_id','project'], how='right')
|
||||
|
||||
tpid = df2
|
||||
|
||||
#.wp10.max().reset_index()
|
||||
tpid = tpid.loc[~tpid.dump_talk_page_title.isna()]
|
||||
|
||||
# pick out just the samples we want.
|
||||
spark = SparkSession.builder.getOrCreate()
|
||||
|
||||
sparkdf = spark.read.parquet("/gscratch/comdata/output/wikiq_enwiki_20200301_nathante_partitioned.parquet")
|
||||
|
||||
tpid['timestamp'] = tpid['timestamp'].dt.tz_localize('utc')
|
||||
labels = spark.createDataFrame(tpid)
|
||||
talks = sparkdf.filter(sparkdf.namespace==1)
|
||||
articles = sparkdf.filter(sparkdf.namespace==0)
|
||||
|
||||
# labels = labels.join(talks,on=[labels.talk_page_id == talks.articleid],how='left_outer')
|
||||
|
||||
talks = talks.join(labels,on=[labels.talk_page_id == talks.articleid])
|
||||
|
||||
#talks.filter(talks.wp10==7).select('talk_page_id').distinct().count()
|
||||
|
||||
talks = talks.withColumn('timediff', f.datediff(talks.timestamp, talks.date_time))
|
||||
|
||||
talks = talks.filter(talks.timediff <= 0)
|
||||
|
||||
win = Window.partitionBy("talk_page_id")
|
||||
talks = talks.withColumn('best_timediff', f.max('timediff').over(win))
|
||||
talks = talks.filter(talks.timediff == talks.best_timediff)
|
||||
|
||||
talks = talks.withColumn('article_title',f.substring_index(f.col("title"),':',-1))
|
||||
talks = talks.select(['article_title','wp10',f.col('timestamp').alias('timestamp'),'talk_page_id']).distinct()
|
||||
|
||||
articles = articles.join(talks,on=[talks.article_title == articles.title])
|
||||
|
||||
articles = articles.withColumn('timediff', f.datediff(articles.timestamp, articles.date_time))
|
||||
articles = articles.filter(articles.timediff <= 0)
|
||||
|
||||
win2 = Window.partitionBy("articleid")
|
||||
articles = articles.filter(f.col("revert")==False)
|
||||
articles = articles.withColumn('best_timediff', f.max('timediff').over(win2))
|
||||
articles = articles.filter(articles.timediff == articles.best_timediff)
|
||||
articles = articles.select(['revid','timestamp','wp10','articleid','title'])
|
||||
|
||||
articles = articles.groupby(['timestamp','wp10','articleid','title']).agg(f.first(f.col("revid")).alias("revid"))
|
||||
|
||||
articles.write.parquet("data/article_quality_data.parquet",mode='overwrite')
|
||||
|
||||
tpid = pd.read_parquet("data/article_quality_data.parquet")
|
||||
|
||||
# we want to sample /papges/ not /labels/.
|
||||
# so we need to do a /full/ groupby pages.
|
||||
# this is why we have a lot of RAM!
|
||||
# we need the number of
|
||||
label_counts = {}
|
||||
sample_page_ids = {}
|
||||
label_max_samplesize = int(samplesize / len(WP10))
|
||||
sample_chunks = []
|
||||
|
||||
for lab in WP10:
|
||||
print(lab)
|
||||
page_ids = tpid.loc[tpid.wp10==lab].articleid
|
||||
label_counts[lab] = len(page_ids)
|
||||
print(lab,label_counts)
|
||||
if(label_counts[lab] <= label_max_samplesize):
|
||||
sample_page_ids[lab] = page_ids
|
||||
else:
|
||||
sample_page_ids[lab] = random.choice(page_ids,label_max_samplesize,replace=False)
|
||||
|
||||
# get the labels for each sampled article
|
||||
sample_data_lab = tpid.loc[(tpid.articleid.isin(sample_page_ids[lab]))]
|
||||
|
||||
sample_chunks.append(sample_data_lab)
|
||||
|
||||
remember = Remember(f='remember_sample_quality_labels.RDS')
|
||||
|
||||
remember(label_max_samplesize, 'label_max_samplesize')
|
||||
|
||||
|
||||
# Note that different wikiprojects can have different labels
|
||||
sample = pd.concat(sample_chunks,ignore_index=True)
|
||||
|
||||
revisions_per_article = sparkdf.filter(sparkdf.namespace==0).select(['revid','articleid','date_time','title'])
|
||||
revisions_per_article = revisions_per_article.filter(f.col("date_time") >= datetime(2019,1,1))
|
||||
revisions_per_article = revisions_per_article.filter(f.col("date_time") <= datetime(2019,12,31))
|
||||
revisions_per_article = revisions_per_article.groupby(["articleid",'title']).count().toPandas()
|
||||
|
||||
revisions_per_article['title'] = revisions_per_article.title.apply(lambda s: unquote(s).strip('\"'))
|
||||
|
||||
revisions_per_article = pd.merge(revisions_per_article,tpid,left_on='articleid',right_on='articleid')
|
||||
|
||||
revisions_per_class = revisions_per_article.groupby('wp10').agg({'count':'sum'}).reset_index()
|
||||
revisions_per_class['wp10'] = revisions_per_class.wp10.apply(lambda s: WP10(s).to_string())
|
||||
|
||||
label_counts = pd.DataFrame({'wp10':map(lambda x: x.to_string(),label_counts.keys()),'n_articles':label_counts.values()})
|
||||
label_counts = pd.merge(label_counts,revisions_per_class,left_on='wp10',right_on='wp10')
|
||||
label_counts = label_counts.rename(columns={'count':'n_revisions'})
|
||||
|
||||
remember(label_counts, 'label_sample_counts')
|
||||
|
||||
sample.to_feather("data/20200301_article_labelings_sample.feather")
|
||||
|
||||
sample = pd.read_feather("data/20200301_article_labelings_sample.feather")
|
||||
sample_counts = sample.articleid.groupby(sample.wp10).count().reset_index()
|
||||
remember(sample_counts,'sample_counts')
|
||||
|
||||
sample_labels = sample.apply(ArticlePageLabel.from_row,axis=1)
|
||||
sample_labels = map(PageLabel.to_json, sample_labels)
|
||||
|
||||
with open("data/20200301_article_labelings_sample.json",'w') as of:
|
||||
of.writelines((l + '\n' for l in sample_labels))
|
||||
|
||||
pool.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
|
||||
Reference in New Issue
Block a user