1
0

Initial commit

p#	new file:   runwikiq.sh
This commit is contained in:
2018-06-02 15:32:19 -07:00
commit 72633c193b
202 changed files with 21929 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from .types import Timestamp, Namespace
__version__ = "0.4.18"

View File

@@ -0,0 +1,5 @@
from . import errors
from .session import Session
from .collections import Pages, RecentChanges, Revisions, SiteInfo, \
UserContribs, DeletedRevisions

View File

@@ -0,0 +1,7 @@
from .deleted_revisions import DeletedRevisions
from .pages import Pages
from .recent_changes import RecentChanges
from .revisions import Revisions
from .site_info import SiteInfo
from .user_contribs import UserContribs
from .users import Users

View File

@@ -0,0 +1,68 @@
import re
class Collection:
"""
Represents a collection of items that can be queried via the API. This is
an abstract base class that should be extended
"""
TIMESTAMP = re.compile(r"[0-9]{4}-?[0-9]{2}-?[0-9]{2}T?" +
r"[0-9]{2}:?[0-9]{2}:?[0-9]{2}Z?")
"""
A regular expression for matching the API's timestamp format.
"""
DIRECTIONS = {'newer', 'older'}
"""
A set of potential direction names.
"""
def __init__(self, session):
"""
:Parameters:
session : `mw.api.Session`
An api session to use for post & get.
"""
self.session = session
def _check_direction(self, direction):
if direction is None:
return direction
else:
direction = str(direction)
assert direction in {None} | self.DIRECTIONS, \
"Direction must be one of {0}".format(self.DIRECTIONS)
return direction
def _check_timestamp(self, timestamp):
if timestamp is None:
return timestamp
else:
timestamp = str(timestamp)
if not self.TIMESTAMP.match(timestamp):
raise TypeError(
"{0} is not formatted like ".format(repr(timestamp)) +
"a MediaWiki timestamp."
)
return timestamp
def _items(self, items, none=True, levels=None, type=lambda val: val):
if none and items is None:
return None
else:
items = {str(type(item)) for item in items}
if levels is not None:
levels = {str(level) for level in levels}
assert len(items - levels) == 0, \
"items {0} not in levels {1}".format(
items - levels, levels)
return "|".join(items)

View File

@@ -0,0 +1,150 @@
import logging
import sys
from ...types import Timestamp
from ...util import none_or
from ..errors import MalformedResponse
from .collection import Collection
logger = logging.getLogger("mw.api.collections.deletedrevs")
class DeletedRevisions(Collection):
PROPERTIES = {'ids', 'flags', 'timestamp', 'user', 'userid', 'size',
'sha1', 'contentmodel', 'comment', 'parsedcomment', 'content',
'tags'}
# TODO:
# This is *not* the right way to do this, but it should work for all queries.
MAX_REVISIONS = 500
def get(self, rev_id, *args, **kwargs):
rev_id = int(rev_id)
revs = list(self.query(revids={rev_id}, **kwargs))
if len(revs) < 1:
raise KeyError(rev_id)
else:
return revs[0]
def query(self, *args, limit=sys.maxsize, **kwargs):
"""
Queries deleted revisions.
See https://www.mediawiki.org/wiki/API:Deletedrevs
:Parameters:
titles : set(str)
A set of page names to query (note that namespace prefix is expected)
start : :class:`mw.Timestamp`
A timestamp to start querying from
end : :class:`mw.Timestamp`
A timestamp to end querying
from_title : str
A title from which to start querying (alphabetically)
to_title : str
A title from which to stop querying (alphabetically)
prefix : str
A title prefix to match on
drcontinue : str
When more results are available, use this to continue (3) Note: may only work if drdir is set to newer.
unique : bool
List only one revision for each page
tag : str
Only list revision tagged with this tag
user : str
Only list revisions saved by this user_text
excludeuser : str
Do not list revision saved by this user_text
namespace : int
Only list pages in this namespace (id)
limit : int
Limit the number of results
direction : str
"newer" or "older"
properties : set(str)
A list of properties to include in the results:
* ids - The ID of the revision.
* flags - Revision flags (minor).
* timestamp - The timestamp of the revision.
* user - User that made the revision.
* userid - User ID of the revision creator.
* size - Length (bytes) of the revision.
* sha1 - SHA-1 (base 16) of the revision.
* contentmodel - Content model ID of the revision.
* comment - Comment by the user for the revision.
* parsedcomment - Parsed comment by the user for the revision.
* content - Text of the revision.
* tags - Tags for the revision.
"""
# `limit` means something diffent here
kwargs['limit'] = min(limit, self.MAX_REVISIONS)
revisions_yielded = 0
done = False
while not done and revisions_yielded <= limit:
rev_docs, query_continue = self._query(*args, **kwargs)
for doc in rev_docs:
yield doc
revisions_yielded += 1
if revisions_yielded >= limit:
break
if query_continue != "" and len(rev_docs) > 0:
kwargs['query_continue'] = query_continue
else:
done = True
def _query(self, titles=None, pageids=None, revids=None,
start=None, end=None, query_continue=None, unique=None, tag=None,
user=None, excludeuser=None, namespace=None, limit=None,
properties=None, direction=None):
params = {
'action': "query",
'prop': "deletedrevisions"
}
params['titles'] = self._items(titles)
params['pageids'] = self._items(pageids)
params['revids'] = self._items(revids)
params['drvprop'] = self._items(properties, levels=self.PROPERTIES)
params['drvlimit'] = none_or(limit, int)
params['drvstart'] = self._check_timestamp(start)
params['drvend'] = self._check_timestamp(end)
params['drvdir'] = self._check_direction(direction)
params['drvuser'] = none_or(user, str)
params['drvexcludeuser'] = none_or(excludeuser, int)
params['drvtag'] = none_or(tag, str)
params.update(query_continue or {'continue': ""})
doc = self.session.get(params)
doc_copy = dict(doc)
try:
if 'continue' in doc:
query_continue = doc['continue']
else:
query_continue = ''
pages = doc['query']['pages'].values()
rev_docs = []
for page_doc in pages:
page_rev_docs = page_doc.get('deletedrevisions', [])
try: del page_doc['deletedrevisions']
except KeyError: pass
for rev_doc in page_rev_docs:
rev_doc['page'] = page_doc
rev_docs.extend(page_rev_docs)
return rev_docs, query_continue
except KeyError as e:
raise MalformedResponse(str(e), doc)

View File

@@ -0,0 +1,50 @@
import logging
from ...util import none_or
from .collection import Collection
logger = logging.getLogger("mw.api.collections.pages")
class Pages(Collection):
"""
TODO
"""
def _edit(self, title=None, pageid=None, section=None, sectiontitle=None,
text=None, token=None, summary=None, minor=None,
notminor=None, bot=None, basetimestamp=None,
starttimestamp=None, recreate=None, createonly=None,
nocreate=None, watch=None, unwatch=None, watchlist=None,
md5=None, prependtext=None, appendtext=None, undo=None,
undoafter=None, redirect=None, contentformat=None,
contentmodel=None, assert_=None, nassert=None,
captchaword=None, captchaid=None):
params = {
'action': "edit"
}
params['title'] = none_or(title, str)
params['pageid'] = none_or(pageid, int)
params['section'] = none_or(section, int, levels={'new'})
params['sectiontitle'] = none_or(sectiontitle, str)
params['text'] = none_or(text, str)
params['token'] = none_or(token, str)
params['summary'] = none_or(summary, str)
params['minor'] = none_or(minor, bool)
params['notminor'] = none_or(notminor, bool)
params['bot'] = none_or(bot, bool)
params['basetimestamp'] = self._check_timestamp(basetimestamp)
params['starttimestamp'] = self._check_timestamp(starttimestamp)
params['recreate'] = none_or(recreate, bool)
params['createonly'] = none_or(createonly, bool)
params['nocreate'] = none_or(nocreate, bool)
params['watch'] = none_or(watch, bool)
params['unwatch'] = none_or(unwatch, bool)
params['watchlist'] = none_or(watchlist, bool)
params['md5'] = none_or(md5, str)
params['prependtext'] = none_or(prependtext, str)
params['appendtext'] = none_or(appendtext, str)
params['undo'] = none_or(undo, int)
params['undoafter'] = none_or(undoafter, int)
# TODO finish this

View File

@@ -0,0 +1,192 @@
import logging
import re
from ...util import none_or
from ..errors import MalformedResponse
from .collection import Collection
logger = logging.getLogger("mw.api.collections.recent_changes")
class RecentChanges(Collection):
"""
Recent changes (revisions, page creations, registrations, moves, etc.)
"""
RCCONTINUE = re.compile(r"([0-9]{4}-[0-9]{2}-[0-9]{2}T" +
r"[0-9]{2}:[0-9]{2}:[0-9]{2}Z|" +
r"[0-9]{14})" +
r"\|[0-9]+")
PROPERTIES = {'user', 'userid', 'comment', 'timestamp', 'title',
'ids', 'sizes', 'redirect', 'flags', 'loginfo',
'tags', 'sha1'}
SHOW = {'minor', '!minor', 'bot', '!bot', 'anon', '!anon',
'redirect', '!redirect', 'patrolled', '!patrolled'}
TYPES = {'edit', 'external', 'new', 'log'}
DIRECTIONS = {'newer', 'older'}
MAX_CHANGES = 50
def _check_rccontinue(self, rccontinue):
if rccontinue is None:
return None
elif self.RCCONTINUE.match(rccontinue):
return rccontinue
else:
raise TypeError(
"rccontinue {0} is not formatted correctly ".format(rccontinue) +
"'%Y-%m-%dT%H:%M:%SZ|<last_rcid>'"
)
def query(self, *args, limit=None, **kwargs):
"""
Enumerate recent changes.
See `<https://www.mediawiki.org/wiki/API:Recentchanges>`_
:Parameters:
start : :class:`mw.Timestamp`
The timestamp to start enumerating from
end : :class:`mw.Timestamp`
The timestamp to end enumerating
direction :
"newer" or "older"
namespace : int
Filter log entries to only this namespace(s)
user : str
Only list changes by this user
excludeuser : str
Don't list changes by this user
tag : str
Only list changes tagged with this tag
properties : set(str)
Include additional pieces of information
* user - Adds the user responsible for the edit and tags if they are an IP
* userid - Adds the user id responsible for the edit
* comment - Adds the comment for the edit
* parsedcomment - Adds the parsed comment for the edit
* flags - Adds flags for the edit
* timestamp - Adds timestamp of the edit
* title - Adds the page title of the edit
* ids - Adds the page ID, recent changes ID and the new and old revision ID
* sizes - Adds the new and old page length in bytes
* redirect - Tags edit if page is a redirect
* patrolled - Tags patrollable edits as being patrolled or unpatrolled
* loginfo - Adds log information (logid, logtype, etc) to log entries
* tags - Lists tags for the entry
* sha1 - Adds the content checksum for entries associated with a revision
token : set(str)
Which tokens to obtain for each change
* patrol
show : set(str)
Show only items that meet this criteria. For example, to see
only minor edits done by logged-in users, set
show={'minor', '!anon'}.
* minor
* !minor
* bot
* !bot
* anon
* !anon
* redirect
* !redirect
* patrolled
* !patrolled
* unpatrolled
limit : int
How many total changes to return
type : set(str)
Which types of changes to show
* edit
* external
* new
* log
toponly : bool
Only list changes which are the latest revision
rccontinue : str
Use this to continue loading results from where you last left off
"""
limit = none_or(limit, int)
changes_yielded = 0
done = False
while not done:
if limit is None:
kwargs['limit'] = self.MAX_CHANGES
else:
kwargs['limit'] = min(limit - changes_yielded, self.MAX_CHANGES)
rc_docs, rccontinue = self._query(*args, **kwargs)
for doc in rc_docs:
yield doc
changes_yielded += 1
if limit is not None and changes_yielded >= limit:
done = True
break
if rccontinue is not None and len(rc_docs) > 0:
kwargs['rccontinue'] = rccontinue
else:
done = True
def _query(self, start=None, end=None, direction=None, namespace=None,
user=None, excludeuser=None, tag=None, properties=None,
token=None, show=None, limit=None, type=None,
toponly=None, rccontinue=None):
params = {
'action': "query",
'list': "recentchanges"
}
params['rcstart'] = none_or(start, str)
params['rcend'] = none_or(end, str)
assert direction in {None} | self.DIRECTIONS, \
"Direction must be one of {0}".format(self.DIRECTIONS)
params['rcdir'] = direction
params['rcnamespace'] = none_or(namespace, int)
params['rcuser'] = none_or(user, str)
params['rcexcludeuser'] = none_or(excludeuser, str)
params['rctag'] = none_or(tag, str)
params['rcprop'] = self._items(properties, levels=self.PROPERTIES)
params['rctoken'] = none_or(tag, str)
params['rcshow'] = self._items(show, levels=self.SHOW)
params['rclimit'] = none_or(limit, int)
params['rctype'] = self._items(type, self.TYPES)
params['rctoponly'] = none_or(toponly, bool)
params['rccontinue'] = self._check_rccontinue(rccontinue)
doc = self.session.get(params)
try:
rc_docs = doc['query']['recentchanges']
if 'query-continue' in doc:
rccontinue = \
doc['query-continue']['recentchanges']['rccontinue']
elif len(rc_docs) > 0:
rccontinue = "|".join([rc_docs[-1]['timestamp'],
str(rc_docs[-1]['rcid'] + 1)])
else:
pass # Leave it be
except KeyError as e:
raise MalformedResponse(str(e), doc)
return rc_docs, rccontinue

View File

@@ -0,0 +1,220 @@
import logging
from ...util import none_or
from ..errors import MalformedResponse
from .collection import Collection
logger = logging.getLogger("mw.api.collections.revisions")
class Revisions(Collection):
"""
A collection of revisions indexes by title, page_id and user_text.
Note that revisions of deleted pages are queriable via
:class:`mw.api.DeletedRevs`.
"""
PROPERTIES = {'ids', 'flags', 'timestamp', 'user', 'userid', 'size',
'sha1', 'contentmodel', 'comment', 'parsedcomment',
'content', 'tags', 'flagged'}
DIFF_TO = {'prev', 'next', 'cur'}
# This is *not* the right way to do this, but it should work for all queries.
MAX_REVISIONS = 50
def get(self, rev_id, **kwargs):
"""
Get a single revision based on it's ID. Throws a :py:class:`KeyError`
if the rev_id cannot be found.
:Parameters:
rev_id : int
Revision ID
``**kwargs``
Passed to :py:meth:`query`
:Returns:
A single rev dict
"""
rev_id = int(rev_id)
revs = list(self.query(revids={rev_id}, **kwargs))
if len(revs) < 1:
raise KeyError(rev_id)
else:
return revs[0]
def query(self, *args, limit=None, **kwargs):
"""
Get revision information.
See `<https://www.mediawiki.org/wiki/API:Properties#revisions_.2F_rv>`_
:Parameters:
properties : set(str)
Which properties to get for each revision:
* ids - The ID of the revision
* flags - Revision flags (minor)
* timestamp - The timestamp of the revision
* user - User that made the revision
* userid - User id of revision creator
* size - Length (bytes) of the revision
* sha1 - SHA-1 (base 16) of the revision
* contentmodel - Content model id
* comment - Comment by the user for revision
* parsedcomment - Parsed comment by the user for the revision
* content - Text of the revision
* tags - Tags for the revision
limit : int
Limit how many revisions will be returned
No more than 500 (5000 for bots) allowed
start_id : int
From which revision id to start enumeration (enum)
end_id : int
Stop revision enumeration on this revid
start : :class:`mw.Timestamp`
From which revision timestamp to start enumeration (enum)
end : :class:`mw.Timestamp`
Enumerate up to this timestamp
direction : str
"newer" or "older"
user : str
Only include revisions made by user_text
excludeuser : bool
Exclude revisions made by user
tag : str
Only list revisions tagged with this tag
expandtemplates : bool
Expand templates in revision content (requires "content" propery)
generatexml : bool
Generate XML parse tree for revision content (requires "content" propery)
parse : bool
Parse revision content (requires "content" propery)
section : int
Only retrieve the content of this section number
token : set(str)
Which tokens to obtain for each revision
* rollback - See `<https://www.mediawiki.org/wiki/API:Edit_-_Rollback#Token>`_
rvcontinue : str
When more results are available, use this to continue
diffto : int
Revision ID to diff each revision to. Use "prev", "next" and
"cur" for the previous, next and current revision respectively
difftotext : str
Text to diff each revision to. Only diffs a limited number of
revisions. Overrides diffto. If section is set, only that
section will be diffed against this text
contentformat : str
Serialization format used for difftotext and expected for output of content
* text/x-wiki
* text/javascript
* text/css
* text/plain
* application/json
:Returns:
An iterator of rev dicts returned from the API.
"""
revisions_yielded = 0
done = False
while not done:
if limit == None:
kwargs['limit'] = self.MAX_REVISIONS
else:
kwargs['limit'] = min(limit - revisions_yielded, self.MAX_REVISIONS)
rev_docs, rvcontinue = self._query(*args, **kwargs)
for doc in rev_docs:
yield doc
revisions_yielded += 1
if limit != None and revisions_yielded >= limit:
done = True
break
if rvcontinue != None and len(rev_docs) > 0:
kwargs['rvcontinue'] = rvcontinue
else:
done = True
def _query(self, revids=None, titles=None, pageids=None, properties=None,
limit=None, start_id=None, end_id=None, start=None,
end=None, direction=None, user=None, excludeuser=None,
tag=None, expandtemplates=None, generatexml=None,
parse=None, section=None, token=None, rvcontinue=None,
diffto=None, difftotext=None, contentformat=None):
params = {
'action': "query",
'prop': "revisions",
'rawcontinue': ''
}
params['revids'] = self._items(revids, type=int)
params['titles'] = self._items(titles)
params['pageids'] = self._items(pageids, type=int)
params['rvprop'] = self._items(properties, levels=self.PROPERTIES)
if revids == None: # Can't have a limit unless revids is none
params['rvlimit'] = none_or(limit, int)
params['rvstartid'] = none_or(start_id, int)
params['rvendid'] = none_or(end_id, int)
params['rvstart'] = self._check_timestamp(start)
params['rvend'] = self._check_timestamp(end)
params['rvdir'] = self._check_direction(direction)
params['rvuser'] = none_or(user, str)
params['rvexcludeuser'] = none_or(excludeuser, int)
params['rvtag'] = none_or(tag, str)
params['rvexpandtemplates'] = none_or(expandtemplates, bool)
params['rvgeneratexml'] = none_or(generatexml, bool)
params['rvparse'] = none_or(parse, bool)
params['rvsection'] = none_or(section, int)
params['rvtoken'] = none_or(token, str)
params['rvcontinue'] = none_or(rvcontinue, str)
params['rvdiffto'] = self._check_diffto(diffto)
params['rvdifftotext'] = none_or(difftotext, str)
params['rvcontentformat'] = none_or(contentformat, str)
doc = self.session.get(params)
try:
if 'query-continue' in doc:
rvcontinue = doc['query-continue']['revisions']['rvcontinue']
else:
rvcontinue = None
pages = doc['query'].get('pages', {}).values()
rev_docs = []
for page_doc in pages:
if 'missing' in page_doc or 'revisions' not in page_doc: continue
page_rev_docs = page_doc['revisions']
del page_doc['revisions']
for rev_doc in page_rev_docs:
rev_doc['page'] = page_doc
rev_docs.extend(page_rev_docs)
return rev_docs, rvcontinue
except KeyError as e:
raise MalformedResponse(str(e), doc)
def _check_diffto(self, diffto):
if diffto == None or diffto in self.DIFF_TO:
return diffto
else:
return int(diffto)

View File

@@ -0,0 +1,81 @@
import logging
from ..errors import MalformedResponse
from .collection import Collection
logger = logging.getLogger("mw.api.collections.site_info")
class SiteInfo(Collection):
"""
General information about the site.
"""
PROPERTIES = {'general', 'namespaces', 'namespacealiases',
'specialpagealiases', 'magicwords', 'interwikimap',
'dbrepllag', 'statistics', 'usergroups', 'extensions',
'fileextensions', 'rightsinfo', 'languages', 'skins',
'extensiontags', 'functionhooks', 'showhooks',
'variables', 'protocols'}
FILTERIW = {'local', '!local'}
def query(self, properties=None, filteriw=None, showalldb=None,
numberinggroup=None, inlanguagecode=None):
"""
General information about the site.
See `<https://www.mediawiki.org/wiki/API:Meta#siteinfo_.2F_si>`_
:Parameters:
properties: set(str)
Which sysinfo properties to get:
* general - Overall system information
* namespaces - List of registered namespaces and their canonical names
* namespacealiases - List of registered namespace aliases
* specialpagealiases - List of special page aliases
* magicwords - List of magic words and their aliases
* statistics - Returns site statistics
* interwikimap - Returns interwiki map (optionally filtered, (optionally localised by using siinlanguagecode))
* dbrepllag - Returns database server with the highest replication lag
* usergroups - Returns user groups and the associated permissions
* extensions - Returns extensions installed on the wiki
* fileextensions - Returns list of file extensions allowed to be uploaded
* rightsinfo - Returns wiki rights (license) information if available
* restrictions - Returns information on available restriction (protection) types
* languages - Returns a list of languages MediaWiki supports(optionally localised by using siinlanguagecode)
* skins - Returns a list of all enabled skins
* extensiontags - Returns a list of parser extension tags
* functionhooks - Returns a list of parser function hooks
* showhooks - Returns a list of all subscribed hooks (contents of $wgHooks)
* variables - Returns a list of variable IDs
* protocols - Returns a list of protocols that are allowed in external links.
* defaultoptions - Returns the default values for user preferences.
filteriw : str
"local" or "!local" Return only local or only nonlocal entries of the interwiki map
showalldb : bool
List all database servers, not just the one lagging the most
numberingroup : bool
Lists the number of users in user groups
inlanguagecode : bool
Language code for localised language names (best effort, use CLDR extension)
"""
siprop = self._items(properties, levels=self.PROPERTIES)
doc = self.session.get(
{
'action': "query",
'meta': "siteinfo",
'siprop': siprop,
'sifilteriw': filteriw,
'sishowalldb': showalldb,
'sinumberinggroup': numberinggroup,
'siinlanguagecode': inlanguagecode
}
)
try:
return doc['query']
except KeyError as e:
raise MalformedResponse(str(e), doc)

View File

@@ -0,0 +1,132 @@
import logging
from ...util import none_or
from ..errors import MalformedResponse
from .collection import Collection
logger = logging.getLogger("mw.api.collections.user_contribs")
class UserContribs(Collection):
"""
A collection of revisions indexes by user.
"""
PROPERTIES = {'ids', 'title', 'timestamp', 'comment', 'parsedcomment',
'size', 'sizediff', 'flags', 'patrolled', 'tags'}
SHOW = {'minor', '!minor', 'patrolled', '!patrolled'}
MAX_REVISIONS = 50
def query(self, *args, limit=None, **kwargs):
"""
Get a user's revisions.
See `<https://www.mediawiki.org/wiki/API:Usercontribs>`_
:Parameters:
limit : int
The maximum number of contributions to return.
start : :class:`mw.Timestamp`
The start timestamp to return from
end : :class:`mw.Timestamp`
The end timestamp to return to
user : set(str)
The users to retrieve contributions for. Maximum number of values 50 (500 for bots)
userprefix : set(str)
Retrieve contributions for all users whose names begin with this value.
direction : str
"newer" or "older"
namespace : int
Only list contributions in these namespaces
properties :
Include additional pieces of information
* ids - Adds the page ID and revision ID
* title - Adds the title and namespace ID of the page
* timestamp - Adds the timestamp of the edit
* comment - Adds the comment of the edit
* parsedcomment - Adds the parsed comment of the edit
* size - Adds the new size of the edit
* sizediff - Adds the size delta of the edit against its parent
* flags - Adds flags of the edit
* patrolled - Tags patrolled edits
* tags - Lists tags for the edit
show : set(str)
Show only items that meet thse criteria, e.g. non minor edits only: ucshow=!minor.
NOTE: If ucshow=patrolled or ucshow=!patrolled is set, revisions older than
$wgRCMaxAge (2592000) won't be shown
* minor
* !minor,
* patrolled,
* !patrolled,
* top,
* !top,
* new,
* !new
tag : str
Only list revisions tagged with this tag
toponly : bool
DEPRECATED! Only list changes which are the latest revision
"""
limit = none_or(limit, int)
revisions_yielded = 0
done = False
while not done:
if limit is None:
kwargs['limit'] = self.MAX_REVISIONS
else:
kwargs['limit'] = min(limit - revisions_yielded, self.MAX_REVISIONS)
uc_docs, uccontinue = self._query(*args, **kwargs)
for doc in uc_docs:
yield doc
revisions_yielded += 1
if limit is not None and revisions_yielded >= limit:
done = True
break
if uccontinue is None or len(uc_docs) == 0:
done = True
else:
kwargs['uccontinue'] = uccontinue
def _query(self, user=None, userprefix=None, limit=None, start=None,
end=None, direction=None, namespace=None, properties=None,
show=None, tag=None, toponly=None,
uccontinue=None):
params = {
'action': "query",
'list': "usercontribs"
}
params['uclimit'] = none_or(limit, int)
params['ucstart'] = self._check_timestamp(start)
params['ucend'] = self._check_timestamp(end)
if uccontinue is not None:
params.update(uccontinue)
params['ucuser'] = self._items(user, type=str)
params['ucuserprefix'] = self._items(userprefix, type=str)
params['ucdir'] = self._check_direction(direction)
params['ucnamespace'] = none_or(namespace, int)
params['ucprop'] = self._items(properties, levels=self.PROPERTIES)
params['ucshow'] = self._items(show, levels=self.SHOW)
doc = self.session.get(params)
try:
if 'query-continue' in doc:
uccontinue = doc['query-continue']['usercontribs']
else:
uccontinue = None
uc_docs = doc['query']['usercontribs']
return uc_docs, uccontinue
except KeyError as e:
raise MalformedResponse(str(e), doc)

View File

@@ -0,0 +1,83 @@
import logging
from ...util import none_or
from ..errors import MalformedResponse
from .collection import Collection
logger = logging.getLogger("mw.api.collections.users")
class Users(Collection):
"""
A collection of information about users
"""
PROPERTIES = {'blockinfo', 'implicitgroups', 'groups', 'registration',
'emailable', 'editcount', 'gender'}
SHOW = {'minor', '!minor', 'patrolled', '!patrolled'}
MAX_REVISIONS = 50
def query(self, *args, **kwargs):
"""
Get a user's metadata.
See `<https://www.mediawiki.org/wiki/API:Users>`_
:Parameters:
users : str
The usernames of the users to be retrieved.
properties : set(str)
Include additional pieces of information
blockinfo - Tags if the user is blocked, by whom, and
for what reason
groups - Lists all the groups the user(s) belongs to
implicitgroups - Lists all the groups a user is automatically
a member of
rights - Lists all the rights the user(s) has
editcount - Adds the user's edit count
registration - Adds the user's registration timestamp
emailable - Tags if the user can and wants to receive
email through [[Special:Emailuser]]
gender - Tags the gender of the user. Returns "male",
"female", or "unknown"
"""
done = False
while not done:
us_docs, query_continue = self._query(*args, **kwargs)
for doc in us_docs:
yield doc
if query_continue is None or len(us_docs) == 0:
done = True
else:
kwargs['query_continue'] = query_continue
def _query(self, users, query_continue=None, properties=None):
params = {
'action': "query",
'list': "users"
}
params['ususers'] = self._items(users, type=str)
params['usprop'] = self._items(properties, levels=self.PROPERTIES)
if query_continue is not None:
params.update(query_continue)
doc = self.session.get(params)
try:
if 'query-continue' in doc:
query_continue = doc['query-continue']['users']
else:
query_continue = None
us_docs = doc['query']['users']
return us_docs, query_continue
except KeyError as e:
raise MalformedResponse(str(e), doc)

View File

@@ -0,0 +1,48 @@
class DocError(Exception):
def __init__(self, message, doc):
super().__init__(message)
self.doc = doc
"""
The document returned by the API that brought about this error.
"""
class APIError(DocError):
def __init__(self, doc):
code = doc.get('error', {}).get('code')
message = doc.get('error', {}).get('message')
super().__init__("{0}:{1}".format(code, message), doc)
self.code = code
"""
The error code returned by the api -- if available.
"""
self.message = message
"""
The error message returned by the api -- if available.
"""
class AuthenticationError(DocError):
def __init__(self, doc):
result = doc['login']['result']
super().__init__(result, doc)
self.result = result
"""
The result code of an authentication attempt.
"""
class MalformedResponse(DocError):
def __init__(self, key, doc):
super().__init__("Expected to find '{0}' in result.".format(key), doc)
self.key = key
"""
The expected, but missing key from the API call.
"""

View File

@@ -0,0 +1,134 @@
import logging
from ..util import api
from .collections import (DeletedRevisions, Pages, RecentChanges, Revisions,
SiteInfo, UserContribs, Users)
from .errors import APIError, AuthenticationError, MalformedResponse
logger = logging.getLogger("mw.api.session")
DEFAULT_USER_AGENT = "MediaWiki-Utilities"
"""
The default User-Agent to be sent with requests to the API.
"""
class Session(api.Session):
"""
Represents a connection to a MediaWiki API.
Cookies and other session information is preserved.
:Parameters:
uri : str
The base URI for the API to use. Usually ends in "api.php"
user_agent : str
The User-Agent to be sent with requests. Will raise a warning if
left to default value.
"""
def __init__(self, uri, *args, user_agent=DEFAULT_USER_AGENT, **kwargs):
"""
Constructs a new :class:`Session`.
"""
if user_agent == DEFAULT_USER_AGENT:
logger.warning("Sending requests with default User-Agent. " +
"Set 'user_agent' on api.Session to quiet this " +
"message.")
if 'headers' in kwargs:
kwargs['headers']['User-Agent'] = str(user_agent)
else:
kwargs['headers'] = {'User-Agent': str(user_agent)}
super().__init__(uri, *args, **kwargs)
self.pages = Pages(self)
"""
An instance of :class:`mw.api.Pages`.
"""
self.revisions = Revisions(self)
"""
An instance of :class:`mw.api.Revisions`.
"""
self.recent_changes = RecentChanges(self)
"""
An instance of :class:`mw.api.RecentChanges`.
"""
self.site_info = SiteInfo(self)
"""
An instance of :class:`mw.api.SiteInfo`.
"""
self.user_contribs = UserContribs(self)
"""
An instance of :class:`mw.api.UserContribs`.
"""
self.users = Users(self)
"""
An instance of :class:`mw.api.Users`.
"""
self.deleted_revisions = DeletedRevisions(self)
"""
An instance of :class:`mw.api.DeletedRevisions`.
"""
def login(self, username, password, token=None):
"""
Performs a login operation. This method usually makes two requests to
API -- one to get a token and one to use the token to log in. If
authentication fails, this method will throw an
:class:`.errors.AuthenticationError`.
:Parameters:
username : str
Your username
password : str
Your password
:Returns:
The response in a json :py:class:`dict`
"""
doc = self.post(
{
'action': "login",
'lgname': username,
'lgpassword': password,
'lgtoken': token, # If None, we'll be getting a token
}
)
try:
if doc['login']['result'] == "Success":
return doc
elif doc['login']['result'] == "NeedToken":
if token is not None:
# Woops. We've been here before. Better error out.
raise AuthenticationError(doc)
else:
token = doc['login']['token']
return self.login(username, password, token=token)
else:
raise AuthenticationError(doc)
except KeyError as e:
raise MalformedResponse(e.message, doc)
def request(self, type, params, **kwargs):
params.update({'format': "json"})
doc = super().request(type, params, **kwargs).json()
if 'error' in doc:
raise APIError(doc)
return doc

View File

@@ -0,0 +1,4 @@
# from . import errors
from .db import DB
from .collections import Pages, RecentChanges, Revisions, Archives, \
AllRevisions, Users

View File

@@ -0,0 +1,4 @@
from .pages import Pages
from .recent_changes import RecentChanges
from .revisions import Revisions, Archives, AllRevisions
from .users import Users

View File

@@ -0,0 +1,11 @@
class Collection:
DIRECTIONS = {'newer', 'older'}
def __init__(self, db):
self.db = db
def __str__(self):
return self.__repr__()
def __repr__(self):
return "{0}({1})".format(self.__class__.__name__, repr(self.db))

View File

@@ -0,0 +1,65 @@
import logging
from ...util import none_or
from .collection import Collection
logger = logging.getLogger("mw.database.collections.pages")
class Pages(Collection):
def get(self, page_id=None, namespace_title=None, rev_id=None):
"""
Gets a single page based on a legitimate identifier of the page. Note
that namespace_title expects a tuple of namespace ID and title.
:Parameters:
page_id : int
Page ID
namespace_title : ( int, str )
the page's namespace ID and title
rev_id : int
a revision ID included in the page's history
:Returns:
iterator over result rows
"""
page_id = none_or(page_id, int)
namespace_title = none_or(namespace_title, tuple)
rev_id = none_or(rev_id, int)
query = """
SELECT page.*
FROM page
"""
values = []
if page_id is not None:
query += """
WHERE page_id = %s
"""
values.append(page_id)
if namespace_title is not None:
namespace, title = namespace_title
query += " WHERE page_namespace = %s and page_title = %s "
values.extend([int(namespace), str(title)])
elif rev_id is not None:
query += """
WHERE page_id = (SELECT rev_page FROM revision WHERE rev_id = %s)
"""
values.append(rev_id)
else:
raise TypeError("Must specify a page identifier.")
cursor = self.db.shared_connection.cursor()
cursor.execute(
query,
values
)
for row in cursor:
return row

View File

@@ -0,0 +1,128 @@
import logging
import time
from ...types import Timestamp
from ...util import none_or
from .collection import Collection
logger = logging.getLogger("mw.database.collections.pages")
class RecentChanges(Collection):
# (https://www.mediawiki.org/wiki/Manual:Recentchanges_table)
TYPES = {
'edit': 0, # edit of existing page
'new': 1, # new page
'move': 2, # Marked as obsolete
'log': 3, # log action (introduced in MediaWiki 1.2)
'move_over_redirect': 4, # Marked as obsolete
'external': 5 # An external recent change. Primarily used by Wikidata
}
def listen(self, last=None, types=None, max_wait=5):
"""
Listens to the recent changes table. Given no parameters, this function
will return an iterator over the entire recentchanges table and then
continue to "listen" for new changes to come in every 5 seconds.
:Parameters:
last : dict
a recentchanges row to pick up after
types : set ( str )
a set of recentchanges types to filter for
max_wait : float
the maximum number of seconds to wait between repeated queries
:Returns:
A never-ending iterator over change rows.
"""
while True:
if last is not None:
after = last['rc_timestamp']
after_id = last['rc_id']
else:
after = None
after_id = None
start = time.time()
rcs = self.query(after=after, after_id=after_id, direction="newer")
count = 0
for rc in rcs:
yield rc
count += 1
time.sleep(max_wait - (time.time() - start))
def query(self, before=None, after=None, before_id=None, after_id=None,
types=None, direction=None, limit=None):
"""
Queries the ``recentchanges`` table. See
`<https://www.mediawiki.org/wiki/Manual:Recentchanges_table>`_
:Parameters:
before : :class:`mw.Timestamp`
The maximum timestamp
after : :class:`mw.Timestamp`
The minimum timestamp
before_id : int
The minimum ``rc_id``
after_id : int
The maximum ``rc_id``
types : set ( str )
Which types of changes to return?
* ``edit`` -- Edits to existing pages
* ``new`` -- Edits that create new pages
* ``move`` -- (obsolete)
* ``log`` -- Log actions (introduced in MediaWiki 1.2)
* ``move_over_redirect`` -- (obsolete)
* ``external`` -- An external recent change. Primarily used by Wikidata
direction : str
"older" or "newer"
limit : int
limit the number of records returned
"""
before = none_or(before, Timestamp)
after = none_or(after, Timestamp)
before_id = none_or(before_id, int)
after_id = none_or(after_id, int)
types = none_or(types, levels=self.TYPES)
direction = none_or(direction, levels=self.DIRECTIONS)
limit = none_or(limit, int)
query = """
SELECT * FROM recentchanges
WHERE 1
"""
values = []
if before is not None:
query += " AND rc_timestamp < %s "
values.append(before.short_format())
if after is not None:
query += " AND rc_timestamp < %s "
values.append(after.short_format())
if before_id is not None:
query += " AND rc_id < %s "
values.append(before_id)
if after_id is not None:
query += " AND rc_id < %s "
values.append(after_id)
if types is not None:
query += " AND rc_type IN ({0}) ".format(
",".join(self.TYPES[t] for t in types)
)
if direction is not None:
direction = ("ASC " if direction == "newer" else "DESC ")
query += " ORDER BY rc_timestamp {0}, rc_id {0}".format(dir)
if limit is not None:
query += " LIMIT %s "
values.append(limit)
cursor.execute(query, values)
for row in cursor:
yield row

View File

@@ -0,0 +1,410 @@
import logging
import time
from itertools import chain
from ...types import Timestamp
from ...util import iteration, none_or
from .collection import Collection
logger = logging.getLogger("mw.database.collections.revisions")
class AllRevisions(Collection):
def get(self, rev_id, include_page=False):
"""
Gets a single revisions by ID. Checks both the ``revision`` and
``archive`` tables. This method throws a :class:`KeyError` if a
revision cannot be found.
:Parameters:
rev_id : int
Revision ID
include_page : bool
Join revision returned against ``page``
:Returns:
A revision row
"""
rev_id = int(rev_id)
try:
rev_row = self.db.revisions.get(rev_id, include_page=include_page)
except KeyError as e:
rev_row = self.db.archives.get(rev_id)
return rev_row
def query(self, *args, **kwargs):
"""
Queries revisions (excludes revisions to deleted pages)
:Parameters:
page_id : int
Page identifier. Filter revisions to this page.
user_id : int
User identifier. Filter revisions to those made by this user.
user_text : str
User text (user_name or IP address). Filter revisions to those
made by this user.
before : :class:`mw.Timestamp`
Filter revisions to those made before this timestamp.
after : :class:`mw.Timestamp`
Filter revisions to those made after this timestamp.
before_id : int
Filter revisions to those with an ID before this ID
after_id : int
Filter revisions to those with an ID after this ID
direction : str
"newer" or "older"
limit : int
Limit the number of results
include_page : bool
Join revisions returned against ``page``
:Returns:
An iterator over revision rows.
"""
revisions = self.db.revisions.query(*args, **kwargs)
archives = self.db.archives.query(*args, **kwargs)
if 'direction' in kwargs:
direction = kwargs['direction']
if direction not in self.DIRECTIONS:
raise TypeError("direction must be in {0}".format(self.DIRECTIONS))
if direction == "newer":
collated_revisions = iteration.sequence(
revisions,
archives,
compare=lambda r1, r2:\
(r1['rev_timestamp'], r1['rev_id']) <=
(r2['rev_timestamp'], r2['rev_id'])
)
else: # direction == "older"
collated_revisions = iteration.sequence(
revisions,
archives,
compare=lambda r1, r2:\
(r1['rev_timestamp'], r1['rev_id']) >=
(r2['rev_timestamp'], r2['rev_id'])
)
else:
collated_revisions = chain(revisions, archives)
if 'limit' in kwargs:
limit = kwargs['limit']
for i, rev in enumerate(collated_revisions):
yield rev
if i >= limit:
break
else:
for rev in collated_revisions:
yield rev
class Revisions(Collection):
def get(self, rev_id, include_page=False):
"""
Gets a single revisions by ID. Checks the ``revision`` table. This
method throws a :class:`KeyError` if a revision cannot be found.
:Parameters:
rev_id : int
Revision ID
include_page : bool
Join revision returned against ``page``
:Returns:
A revision row
"""
rev_id = int(rev_id)
query = """
SELECT *, FALSE AS archived FROM revision
"""
if include_page:
query += """
INNER JOIN page ON page_id = rev_page
"""
query += " WHERE rev_id = %s"
cursor.execute(query, [rev_id])
for row in cursor:
return row
raise KeyError(rev_id)
def query(self, page_id=None, user_id=None, user_text=None,
before=None, after=None, before_id=None, after_id=None,
direction=None, limit=None, include_page=False):
"""
Queries revisions (excludes revisions to deleted pages)
:Parameters:
page_id : int
Page identifier. Filter revisions to this page.
user_id : int
User identifier. Filter revisions to those made by this user.
user_text : str
User text (user_name or IP address). Filter revisions to those
made by this user.
before : :class:`mw.Timestamp`
Filter revisions to those made before this timestamp.
after : :class:`mw.Timestamp`
Filter revisions to those made after this timestamp.
before_id : int
Filter revisions to those with an ID before this ID
after_id : int
Filter revisions to those with an ID after this ID
direction : str
"newer" or "older"
limit : int
Limit the number of results
include_page : bool
Join revisions returned against ``page``
:Returns:
An iterator over revision rows.
"""
start_time = time.time()
page_id = none_or(page_id, int)
user_id = none_or(user_id, int)
user_text = none_or(user_text, str)
before = none_or(before, Timestamp)
after = none_or(after, Timestamp)
before_id = none_or(before_id, int)
after_id = none_or(after_id, int)
direction = none_or(direction, levels=self.DIRECTIONS)
include_page = bool(include_page)
query = """
SELECT *, FALSE AS archived FROM revision
"""
if include_page:
query += """
INNER JOIN page ON page_id = rev_page
"""
query += """
WHERE 1
"""
values = []
if page_id is not None:
query += " AND rev_page = %s "
values.append(page_id)
if user_id is not None:
query += " AND rev_user = %s "
values.append(user_id)
if user_text is not None:
query += " AND rev_user_text = %s "
values.append(user_text)
if before is not None:
query += " AND rev_timestamp < %s "
values.append(before.short_format())
if after is not None:
query += " AND rev_timestamp > %s "
values.append(after.short_format())
if before_id is not None:
query += " AND rev_id < %s "
values.append(before_id)
if after_id is not None:
query += " AND rev_id > %s "
values.append(after_id)
if direction is not None:
direction = ("ASC " if direction == "newer" else "DESC ")
if before_id != None or after_id != None:
query += " ORDER BY rev_id {0}, rev_timestamp {0}".format(direction)
else:
query += " ORDER BY rev_timestamp {0}, rev_id {0}".format(direction)
if limit is not None:
query += " LIMIT %s "
values.append(limit)
cursor = self.db.shared_connection.cursor()
cursor.execute(query, values)
count = 0
for row in cursor:
yield row
count += 1
logger.debug("%s revisions read in %s seconds" % (count, time.time() - start_time))
class Archives(Collection):
def get(self, rev_id):
"""
Gets a single revisions by ID. Checks the ``archive`` table. This
method throws a :class:`KeyError` if a revision cannot be found.
:Parameters:
rev_id : int
Revision ID
:Returns:
A revision row
"""
rev_id = int(rev_id)
query = """
SELECT
ar_id,
ar_rev_id AS rev_id,
ar_page_id AS rev_page,
ar_page_id AS page_id,
ar_title AS page_title,
ar_namespace AS page_namespace,
ar_text_id AS rev_text_id,
ar_comment AS rev_comment,
ar_user AS rev_user,
ar_user_text AS rev_user_text,
ar_timestamp AS rev_timestamp,
ar_minor_edit AS rev_minor_edit,
ar_deleted AS rev_deleted,
ar_len AS rev_len,
ar_parent_id AS rev_parent_id,
ar_sha1 AS rev_sha1,
TRUE AS archived
FROM archive
WHERE ar_rev_id = %s
"""
cursor.execute(query, [rev_id])
for row in cursor:
return row
raise KeyError(rev_id)
def query(self, page_id=None, user_id=None, user_text=None,
before=None, after=None, before_id=None, after_id=None,
before_ar_id=None, after_ar_id=None,
direction=None, limit=None, include_page=True):
"""
Queries archived revisions (revisions of deleted pages)
:Parameters:
page_id : int
Page identifier. Filter revisions to this page.
user_id : int
User identifier. Filter revisions to those made by this user.
user_text : str
User text (user_name or IP address). Filter revisions to those
made by this user.
before : :class:`mw.Timestamp`
Filter revisions to those made before this timestamp.
after : :class:`mw.Timestamp`
Filter revisions to those made after this timestamp.
before_id : int
Filter revisions to those with an ID before this ID
after_id : int
Filter revisions to those with an ID after this ID
direction : str
"newer" or "older"
limit : int
Limit the number of results
include_page : bool
This field is ignored. It's only here for compatibility with
:class:`mw.database.Revision`.
:Returns:
An iterator over revision rows.
"""
page_id = none_or(page_id, int)
user_id = none_or(user_id, int)
before = none_or(before, Timestamp)
after = none_or(after, Timestamp)
before_id = none_or(before_id, int)
after_id = none_or(after_id, int)
direction = none_or(direction, levels=self.DIRECTIONS)
limit = none_or(limit, int)
start_time = time.time()
cursor = self.db.shared_connection.cursor()
query = """
SELECT
ar_id,
ar_rev_id AS rev_id,
ar_page_id AS rev_page,
ar_page_id AS page_id,
ar_title AS page_title,
ar_namespace AS page_namespace,
ar_text_id AS rev_text_id,
ar_comment AS rev_comment,
ar_user AS rev_user,
ar_user_text AS rev_user_text,
ar_timestamp AS rev_timestamp,
ar_minor_edit AS rev_minor_edit,
ar_deleted AS rev_deleted,
ar_len AS rev_len,
ar_parent_id AS rev_parent_id,
ar_sha1 AS rev_sha1,
TRUE AS archived
FROM archive
"""
query += """
WHERE 1
"""
values = []
if page_id is not None:
query += " AND ar_page_id = %s "
values.append(page_id)
if user_id is not None:
query += " AND ar_user = %s "
values.append(user_id)
if user_text is not None:
query += " AND ar_user_text = %s "
values.append(user_text)
if before is not None:
query += " AND ar_timestamp < %s "
values.append(before.short_format())
if after is not None:
query += " AND ar_timestamp > %s "
values.append(after.short_format())
if before_id is not None:
query += " AND ar_rev_id < %s "
values.append(before_id)
if after_id is not None:
query += " AND ar_rev_id > %s "
values.append(after_id)
if before_ar_id is not None:
query += " AND ar_id < ? "
values.append(before_ar_id)
if after_ar_id is not None:
query += " AND ar_id > ? "
values.append(after_ar_id)
if direction is not None:
dir = ("ASC " if direction == "newer" else "DESC ")
if before is not None or after is not None:
query += " ORDER BY ar_timestamp {0}, ar_rev_id {0}".format(dir)
elif before_id is not None or after_id is not None:
query += " ORDER BY ar_rev_id {0}, ar_timestamp {0}".format(dir)
else:
query += " ORDER BY ar_id {0}".format(dir)
if limit is not None:
query += " LIMIT %s "
values.append(limit)
cursor.execute(query, values)
count = 0
for row in cursor:
yield row
count += 1
logger.debug("%s revisions read in %s seconds" % (count, time.time() - start_time))

View File

@@ -0,0 +1,154 @@
import logging
import time
from ...types import Timestamp
from ...util import none_or
from .collection import Collection
logger = logging.getLogger("mw.database.collections.users")
class Users(Collection):
CREATION_ACTIONS = {'newusers', 'create', 'create2', 'autocreate',
'byemail'}
def get(self, user_id=None, user_name=None):
"""
Gets a single user row from the database. Raises a :class:`KeyError`
if a user cannot be found.
:Parameters:
user_id : int
User ID
user_name : str
User's name
:Returns:
A user row.
"""
user_id = none_or(user_id, int)
user_name = none_or(user_name, str)
query = """
SELECT user.*
FROM user
"""
values = []
if user_id is not None:
query += """
WHERE user_id = %s
"""
values.append(user_id)
elif user_name is not None:
query += """
WHERE user_name = %s
"""
values.append(user_name)
else:
raise TypeError("Must specify a user identifier.")
cursor = self.db.shared_connection.cursor()
cursor.execute(
query,
values
)
for row in cursor:
return row
raise KeyError(user_id if user_id is not None else user_name)
def query(self, registered_before=None, registered_after=None,
before_id=None, after_id=None, limit=None,
direction=None, self_created_only=False):
"""
Queries users based on various filtering parameters.
:Parameters:
registered_before : :class:`mw.Timestamp`
A timestamp to search before (inclusive)
registered_after : :class:`mw.Timestamp`
A timestamp to search after (inclusive)
before_id : int
A user_id to search before (inclusive)
after_id : int
A user_ud to search after (inclusive)
direction : str
"newer" or "older"
limit : int
Limit the results to at most this number
self_creations_only : bool
limit results to self_created user accounts
:Returns:
an iterator over ``user`` table rows
"""
start_time = time.time()
registered_before = none_or(registered_before, Timestamp)
registered_after = none_or(registered_after, Timestamp)
before_id = none_or(before_id, str)
after_id = none_or(after_id, str)
direction = none_or(direction, levels=self.DIRECTIONS)
limit = none_or(limit, int)
self_created_only = bool(self_created_only)
query = """
SELECT user.*
FROM user
"""
values = []
if self_created_only:
query += """
INNER JOIN logging ON
log_user = user_id
log_type = "newusers" AND
log_action = "create"
"""
query += "WHERE 1 "
if registered_before is not None:
query += "AND user_registration <= %s "
values.append(registered_before.short_format())
if registered_after is not None:
query += "AND user_registration >= %s "
values.append(registered_after.short_format())
if before_id is not None:
query += "AND user_id <= %s "
values.append(before_id)
if after_id is not None:
query += "AND user_id >= %s "
values.append(after_id)
query += "GROUP BY user_id " # In case of duplicate log events
if direction is not None:
if registered_before is not None or registered_after is not None:
if direction == "newer":
query += "ORDER BY user_registration ASC "
else:
query += "ORDER BY user_registration DESC "
else:
if direction == "newer":
query += "ORDER BY user_id ASC "
else:
query += "ORDER BY user_id DESC "
if limit is not None:
query += "LIMIT %s "
values.append(limit)
cursor = self.db.shared_connection.cursor()
cursor.execute(query, values)
count = 0
for row in cursor:
yield row
count += 1
logger.debug("%s users queried in %s seconds" % (count, time.time() - start_time))

View File

@@ -0,0 +1,134 @@
import getpass
import logging
import os
import pymysql
import pymysql.cursors
from .collections import AllRevisions, Archives, Pages, Revisions, Users
logger = logging.getLogger("mw.database.db")
class DB:
"""
Represents a connection to a MySQL database.
:Parameters:
connection = :class:`oursql.Connection`
A connection to a MediaWiki database
"""
def __init__(self, connection):
self.shared_connection = connection
self.shared_connection.cursorclass = pymysql.cursors.DictCursor
self.revisions = Revisions(self)
"""
An instance of :class:`mw.database.Revisions`.
"""
self.archives = Archives(self)
"""
An instance of :class:`mw.database.Archives`.
"""
self.all_revisions = AllRevisions(self)
"""
An instance of :class:`mw.database.AllRevisions`.
"""
self.pages = Pages(self)
"""
An instance of :class:`mw.database.Pages`.
"""
self.users = Users(self)
"""
An instance of :class:`mw.database.Users`.
"""
def __repr__(self):
return "%s(%s)" % (
self.__class__.__name__,
", ".join(
[repr(arg) for arg in self.args] +
["%s=%r" % (k, v) for k, v in self.kwargs.items()]
)
)
def __str__(self):
return self.__repr__()
@classmethod
def add_arguments(cls, parser, defaults=None):
"""
Adds the arguments to an :class:`argparse.ArgumentParser` in order to
create a database connection.
"""
defaults = defaults if defaults is not None else defaults
default_host = defaults.get('host', "localhost")
parser.add_argument(
'--host', '-h',
help="MySQL database host to connect to (defaults to {0})".format(default_host),
default=default_host
)
default_database = defaults.get('database', getpass.getuser())
parser.add_argument(
'--database', '-d',
help="MySQL database name to connect to (defaults to {0})".format(default_database),
default=default_database
)
default_defaults_file = defaults.get('defaults-file', os.path.expanduser("~/.my.cnf"))
parser.add_argument(
'--defaults-file',
help="MySQL defaults file (defaults to {0})".format(default_defaults_file),
default=default_defaults_file
)
default_user = defaults.get('user', getpass.getuser())
parser.add_argument(
'--user', '-u',
help="MySQL user (defaults to %s)".format(default_user),
default=default_user
)
return parser
@classmethod
def from_arguments(cls, args):
"""
Constructs a :class:`~mw.database.DB`.
Consumes :class:`argparse.ArgumentParser` arguments given by
:meth:`add_arguments` in order to create a :class:`DB`.
:Parameters:
args : :class:`argparse.Namespace`
A collection of argument values returned by :class:`argparse.ArgumentParser`'s :meth:`parse_args()`
"""
connection = pymysql.connect(
args.host,
args.user,
database=args.database,
read_default_file=args.defaults_file
)
return cls(connection)
@classmethod
def from_params(cls, *args, **kwargs):
"""
Constructs a :class:`~mw.database.DB`. Passes `*args` and `**kwargs`
to :meth:`oursql.connect` and configures the connection.
:Parameters:
args : :class:`argparse.Namespace`
A collection of argument values returned by :class:`argparse.ArgumentParser`'s :meth:`parse_args()`
"""
kwargs['cursorclass'] = pymysql.cursors.DictCursor
if kwargs['db']:
kwargs['database'] = kwargs['db']
del kwargs['db']
connection = pymysql.connect(*args, **kwargs)
return cls(connection)

View File

@@ -0,0 +1,14 @@
"""
A package with utilities for managing the persistent word analysis across text
versions of a document. `PersistenceState` is the highest level of the
interface and the part of the system that's most interesting externally. `Word`s
are also very important. The current implementation of `Word` only accounts for
how the number of revisions in which a Word is visible. If persistent word
views (or something similar) is intended to be kept, refactoring will be
necessary.
"""
from .state import State
from .tokens import Tokens, Token
from . import defaults
from . import api

View File

@@ -0,0 +1,85 @@
from .. import reverts
from ...util import none_or
from .state import State
def track(session, rev_id, page_id=None, revert_radius=reverts.defaults.RADIUS,
future_revisions=reverts.defaults.RADIUS, properties=None):
"""
Computes a persistence score for a revision by processing the revisions
that took place around it.
:Parameters:
session : :class:`mw.api.Session`
An API session to make use of
rev_id : int
the ID of the revision to check
page_id : int
the ID of the page the revision occupies (slower if not provided)
revert_radius : int
a positive integer indicating the maximum number of revisions that can be reverted
"""
if not hasattr(session, "revisions"):
raise TypeError("session is wrong type. Expected a mw.api.Session.")
rev_id = int(rev_id)
page_id = none_or(page_id, int)
revert_radius = int(revert_radius)
if revert_radius < 1:
raise TypeError("invalid radius. Expected a positive integer.")
properties = set(properties) if properties is not None else set()
# If we don't have the page_id, we're going to need to look them up
if page_id is None:
rev = session.revisions.get(rev_id, properties={'ids'})
page_id = rev['page']['pageid']
# Load history and current rev
current_and_past_revs = list(session.revisions.query(
pageids={page_id},
limit=revert_radius + 1,
start_id=rev_id,
direction="older",
properties={'ids', 'timestamp', 'content', 'sha1'} | properties
))
try:
# Extract current rev and reorder history
current_rev, past_revs = (
current_and_past_revs[0], # Current rev is the first one returned
reversed(current_and_past_revs[1:]) # The rest are past revs, but they are in the wrong order
)
except IndexError:
# Only way to get here is if there isn't enough history. Couldn't be
# reverted. Just return None.
return None
# Load future revisions
future_revs = session.revisions.query(
pageids={page_id},
limit=future_revisions,
start_id=rev_id + 1, # Ensures that we skip the current revision
direction="newer",
properties={'ids', 'timestamp', 'content', 'sha1'} | properties
)
state = State(revert_radius=revert_radius)
# Process old revisions
for rev in past_revs:
state.process(rev.get('*', ""), rev, rev.get('sha1'))
# Process current revision
_, tokens_added, _ = state.process(current_rev.get('*'), current_rev,
current_rev.get('sha1'))
# Process new revisions
future_revs = list(future_revs)
for rev in future_revs:
state.process(rev.get('*', ""), rev, rev.get('sha1'))
return current_rev, tokens_added, future_revs
score = track

View File

@@ -0,0 +1,11 @@
from . import tokenization, difference
TOKENIZE = tokenization.wikitext_split
"""
The standard tokenizing function.
"""
DIFF = difference.sequence_matcher
"""
The standard diff function
"""

View File

@@ -0,0 +1,49 @@
from difflib import SequenceMatcher
def sequence_matcher(old, new):
"""
Generates a sequence of operations using :class:`difflib.SequenceMatcher`.
:Parameters:
old : list( `hashable` )
Old tokens
new : list( `hashable` )
New tokens
Returns:
Minimal operations needed to convert `old` to `new`
"""
sm = SequenceMatcher(None, list(old), list(new))
return sm.get_opcodes()
def apply(ops, old, new):
"""
Applies operations (delta) to copy items from `old` to `new`.
:Parameters:
ops : list((op, a1, a2, b1, b2))
Operations to perform
old : list( `hashable` )
Old tokens
new : list( `hashable` )
New tokens
:Returns:
An iterator over elements matching `new` but copied from `old`
"""
for code, a_start, a_end, b_start, b_end in ops:
if code == "insert":
for t in new[b_start:b_end]:
yield t
elif code == "replace":
for t in new[b_start:b_end]:
yield t
elif code == "equal":
for t in old[a_start:a_end]:
yield t
elif code == "delete":
pass
else:
assert False, \
"encounted an unrecognized operation code: " + repr(code)

View File

@@ -0,0 +1,149 @@
from hashlib import sha1
from . import defaults
from .. import reverts
from .tokens import Token, Tokens
class Version:
__slots__ = ('tokens')
def __init__(self):
self.tokens = None
class State:
"""
Represents the state of word persistence in a page.
See `<https://meta.wikimedia.org/wiki/Research:Content_persistence>`_
:Parameters:
tokenize : function( `str` ) --> list( `str` )
A tokenizing function
diff : function(list( `str` ), list( `str` )) --> list( `ops` )
A function to perform a difference between token lists
revert_radius : int
a positive integer indicating the maximum revision distance that a revert can span.
revert_detector : :class:`mw.lib.reverts.Detector`
a revert detector to start process with
:Example:
>>> from pprint import pprint
>>> from mw.lib import persistence
>>>
>>> state = persistence.State()
>>>
>>> pprint(state.process("Apples are red.", revision=1))
([Token(text='Apples', revisions=[1]),
Token(text=' ', revisions=[1]),
Token(text='are', revisions=[1]),
Token(text=' ', revisions=[1]),
Token(text='red', revisions=[1]),
Token(text='.', revisions=[1])],
[Token(text='Apples', revisions=[1]),
Token(text=' ', revisions=[1]),
Token(text='are', revisions=[1]),
Token(text=' ', revisions=[1]),
Token(text='red', revisions=[1]),
Token(text='.', revisions=[1])],
[])
>>> pprint(state.process("Apples are blue.", revision=2))
([Token(text='Apples', revisions=[1, 2]),
Token(text=' ', revisions=[1, 2]),
Token(text='are', revisions=[1, 2]),
Token(text=' ', revisions=[1, 2]),
Token(text='blue', revisions=[2]),
Token(text='.', revisions=[1, 2])],
[Token(text='blue', revisions=[2])],
[Token(text='red', revisions=[1])])
>>> pprint(state.process("Apples are red.", revision=3)) # A revert!
([Token(text='Apples', revisions=[1, 2, 3]),
Token(text=' ', revisions=[1, 2, 3]),
Token(text='are', revisions=[1, 2, 3]),
Token(text=' ', revisions=[1, 2, 3]),
Token(text='red', revisions=[1, 3]),
Token(text='.', revisions=[1, 2, 3])],
[],
[])
"""
def __init__(self, tokenize=defaults.TOKENIZE, diff=defaults.DIFF,
revert_radius=reverts.defaults.RADIUS,
revert_detector=None):
self.tokenize = tokenize
self.diff = diff
# Either pass a detector or the revert radius so I can make one
if revert_detector is None:
self.revert_detector = reverts.Detector(int(revert_radius))
else:
self.revert_detector = revert_detector
# Stores the last tokens
self.last = None
def process(self, text, revision=None, checksum=None):
"""
Modifies the internal state based a change to the content and returns
the sets of words added and removed.
:Parameters:
text : str
The text content of a revision
revision : `mixed`
Revision meta data
checksum : str
A checksum hash of the text content (will be generated if not provided)
:Returns:
Three :class:`~mw.lib.persistence.Tokens` lists
current_tokens : :class:`~mw.lib.persistence.Tokens`
A sequence of :class:`~mw.lib.persistence.Token` for the
processed revision
tokens_added : :class:`~mw.lib.persistence.Tokens`
A set of tokens that were inserted by the processed revision
tokens_removed : :class:`~mw.lib.persistence.Tokens`
A sequence of :class:`~mw.lib.persistence.Token` removed by the
processed revision
"""
if checksum is None:
checksum = sha1(bytes(text, 'utf8')).hexdigest()
version = Version()
revert = self.revert_detector.process(checksum, version)
if revert is not None: # Revert
# Empty words.
tokens_added = Tokens()
tokens_removed = Tokens()
# Extract reverted_to revision
_, _, reverted_to = revert
version.tokens = reverted_to.tokens
else:
if self.last is None: # First version of the page!
version.tokens = Tokens(Token(t) for t in self.tokenize(text))
tokens_added = version.tokens
tokens_removed = Tokens()
else:
# NOTICE: HEAVY COMPUTATION HERE!!!
#
# OK. It's not that heavy. It's just performing a diff,
# but you're still going to spend most of your time here.
# Diffs usually run in O(n^2) -- O(n^3) time and most tokenizers
# produce a lot of tokens.
version.tokens, tokens_added, tokens_removed = \
self.last.tokens.compare(self.tokenize(text), self.diff)
version.tokens.persist(revision)
self.last = version
return version.tokens, tokens_added, tokens_removed

View File

@@ -0,0 +1,12 @@
from nose.tools import eq_
from .. import difference
def test_sequence_matcher():
t1 = "foobar derp hepl derpl"
t2 = "fooasldal 3 hepl asl a derpl"
ops = difference.sequence_matcher(t1, t2)
eq_("".join(difference.apply(ops, t1, t2)), t2)

View File

@@ -0,0 +1,25 @@
from nose.tools import eq_
from ..state import State
def test_state():
contents_revisions = [
("Apples are red.", 0),
("Apples are blue.", 1),
("Apples are red.", 2),
("Apples are tasty and red.", 3),
("Apples are tasty and blue.", 4)
]
state = State()
token_sets = [state.process(c, r) for c, r in contents_revisions]
for i, (content, revision) in enumerate(contents_revisions):
eq_("".join(token_sets[i][0].texts()), content)
eq_(token_sets[0][0][0].text, "Apples")
eq_(len(token_sets[0][0][0].revisions), 5)
eq_(token_sets[0][0][4].text, "red")
eq_(len(token_sets[0][0][4].revisions), 3)

View File

@@ -0,0 +1,10 @@
from nose.tools import eq_
from .. import tokenization
def test_wikitext_split():
eq_(
list(tokenization.wikitext_split("foo bar herp {{derp}}")),
["foo", " ", "bar", " ", "herp", " ", "{{", "derp", "}}"]
)

View File

@@ -0,0 +1,16 @@
import re
def wikitext_split(text):
"""
Performs the simplest possible split of latin character-based languages
and wikitext.
:Parameters:
text : str
Text to split.
"""
return re.findall(
r"[\w]+|\[\[|\]\]|\{\{|\}\}|\n+| +|&\w+;|'''|''|=+|\{\||\|\}|\|\-|.",
text
)

View File

@@ -0,0 +1,98 @@
class Token:
"""
Represents a chunk of text and the revisions of a page that it survived.
"""
__slots__ = ('text', 'revisions')
def __init__(self, text, revisions=None):
self.text = text
"""
The text of the token.
"""
self.revisions = revisions if revisions is not None else []
"""
The meta data for the revisions that the token has appeared within.
"""
def persist(self, revision):
self.revisions.append(revision)
def __repr__(self):
return "{0}({1})".format(
self.__class__.__name__,
", ".join([
"text={0}".format(repr(self.text)),
"revisions={0}".format(repr(self.revisions))
])
)
class Tokens(list):
"""
Represents a :class:`list` of :class:`~mw.lib.persistence.Token` with some
useful helper functions.
:Example:
>>> from mw.lib.persistence import Token, Tokens
>>>
>>> tokens = Tokens()
>>> tokens.append(Token("foo"))
>>> tokens.extend([Token(" "), Token("bar")])
>>>
>>> tokens[0]
Token(text='foo', revisions=[])
>>>
>>> "".join(tokens.texts())
'foo bar'
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def persist(self, revision):
for token in self:
token.persist(revision)
def texts(self):
for token in self:
yield token.text
def compare(self, new, diff):
old = self.texts()
return self.apply_diff(diff(old, new), self, new)
@classmethod
def apply_diff(cls, ops, old, new):
tokens = cls()
tokens_added = cls()
tokens_removed = cls()
for code, a_start, a_end, b_start, b_end in ops:
if code == "insert":
for token_text in new[b_start:b_end]:
token = Token(token_text)
tokens.append(token)
tokens_added.append(token)
elif code == "replace":
for token_text in new[b_start:b_end]:
token = Token(token_text)
tokens.append(token)
tokens_added.append(token)
tokens_removed.extend(t for t in old[a_start:a_end])
elif code == "equal":
tokens.extend(old[a_start:a_end])
elif code == "delete":
tokens_removed.extend(old[a_start:a_end])
else:
assert False, \
"encounted an unrecognized operation code: " + repr(code)
return (tokens, tokens_added, tokens_removed)

View File

@@ -0,0 +1,24 @@
"""
This module provides a set of utilities for detecting identity reverts in
revisioned content.
To detect reverts in a stream of revisions to a single page, you can use
:func:`detect`. If you'll be detecting reverts in a collection of pages or
would, for some other reason, prefer to process revisions one at a time,
:class:`Detector` and it's :meth:`~Detector.process` will allow you to do so.
To detect reverts one-at-time and arbitrarily, you can user the `check()`
functions:
* :func:`database.check` and :func:`database.check_row` use a :class:`mw.database.DB`
* :func:`api.check` and :func:`api.check_rev` use a :class:`mw.api.Session`
Note that these functions are less performant than detecting reverts in a
stream of page revisions. This can be practical when trying to identify
reverted revisions in a user's contribution history.
"""
from .detector import Detector, Revert
from .functions import detect, reverts
from . import database
from . import api
from . import defaults

View File

@@ -0,0 +1,134 @@
from itertools import chain
from . import defaults
from ...types import Timestamp
from ...util import none_or
from .dummy_checksum import DummyChecksum
from .functions import detect
def check_rev(session, rev, **kwargs):
"""
Checks whether a revision (database row) was reverted (identity) and returns
a named tuple of Revert(reverting, reverteds, reverted_to).
:Parameters:
session : :class:`mw.api.Session`
An API session to make use of
rev : dict
a revision dict containing 'revid' and 'page.id'
radius : int
a positive integer indicating the maximum number of revisions that can be reverted
before : :class:`mw.Timestamp`
if set, limits the search for *reverting* revisions to those which were saved before this timestamp
properties : set( str )
a set of properties to include in revisions (see :class:`mw.api.Revisions`)
"""
# extract rev_id, sha1, page_id
if 'revid' in rev:
rev_id = rev['revid']
else:
raise TypeError("rev must have 'rev_id'")
if 'page' in rev:
page_id = rev['page']['id']
elif 'pageid' in rev:
page_id = rev['pageid']
else:
raise TypeError("rev must have 'page' or 'pageid'")
# run the regular check
return check(session, rev_id, page_id=page_id, **kwargs)
def check(session, rev_id, page_id=None, radius=defaults.RADIUS,
before=None, window=None, properties=None):
"""
Checks whether a revision was reverted (identity) and returns a named tuple
of Revert(reverting, reverteds, reverted_to).
:Parameters:
session : :class:`mw.api.Session`
An API session to make use of
rev_id : int
the ID of the revision to check
page_id : int
the ID of the page the revision occupies (slower if not provided)
radius : int
a positive integer indicating the maximum number of revisions
that can be reverted
before : :class:`mw.Timestamp`
if set, limits the search for *reverting* revisions to those which
were saved before this timestamp
window : int
if set, limits the search for *reverting* revisions to those which
were saved within `window` seconds after the reverted edit
properties : set( str )
a set of properties to include in revisions (see :class:`mw.api.Revisions`)
"""
if not hasattr(session, "revisions"):
raise TypeError("session wrong type. Expected a mw.api.Session.")
rev_id = int(rev_id)
radius = int(radius)
if radius < 1:
raise TypeError("invalid radius. Expected a positive integer.")
page_id = none_or(page_id, int)
before = none_or(before, Timestamp)
properties = set(properties) if properties is not None else set()
# If we don't have the page_id, we're going to need to look them up
if page_id is None:
rev = session.revisions.get(rev_id, properties={'ids'})
page_id = rev['page']['pageid']
# Load history and current rev
current_and_past_revs = list(session.revisions.query(
pageids={page_id},
limit=radius + 1,
start_id=rev_id,
direction="older",
properties={'ids', 'timestamp', 'sha1'} | properties
))
try:
# Extract current rev and reorder history
current_rev, past_revs = (
current_and_past_revs[0], # Current rev is the first one returned
reversed(current_and_past_revs[1:]) # The rest are past revs, but they are in the wrong order
)
except IndexError:
# Only way to get here is if there isn't enough history. Couldn't be
# reverted. Just return None.
return None
if window is not None and before is None:
before = Timestamp(current_rev['timestamp']) + window
# Load future revisions
future_revs = session.revisions.query(
pageids={page_id},
limit=radius,
start_id=rev_id + 1, # Ensures that we skip the current revision
end=before,
direction="newer",
properties={'ids', 'timestamp', 'sha1'} | properties
)
# Convert to an iterable of (checksum, rev) pairs for detect() to consume
checksum_revisions = chain(
((rev['sha1'] if 'sha1' in rev else DummyChecksum(), rev)
for rev in past_revs),
[(current_rev.get('sha1', DummyChecksum()), current_rev)],
((rev['sha1'] if 'sha1' in rev else DummyChecksum(), rev)
for rev in future_revs),
)
for revert in detect(checksum_revisions, radius=radius):
# Check that this is a relevant revert
if rev_id in [rev['revid'] for rev in revert.reverteds]:
return revert
return None

View File

@@ -0,0 +1,148 @@
import random
from itertools import chain
from . import defaults
from ...types import Timestamp
from ...util import none_or
from .dummy_checksum import DummyChecksum
from .functions import detect
HEX = "1234567890abcdef"
def random_sha1():
return ''.join(random.choice(HEX) for i in range(40))
"""
Simple constant used in order to not do weird things with a dummy revision.
"""
def check_row(db, rev_row, **kwargs):
"""
Checks whether a revision (database row) was reverted (identity) and returns
a named tuple of Revert(reverting, reverteds, reverted_to).
:Parameters:
db : :class:`mw.database.DB`
A database connection to make use of.
rev_row : dict
a revision row containing 'rev_id' and 'rev_page' or 'page_id'
radius : int
a positive integer indicating the the maximum number of revisions that can be reverted
check_archive : bool
should the archive table be checked for reverting revisions?
before : `Timestamp`
if set, limits the search for *reverting* revisions to those which were saved before this timestamp
"""
# extract rev_id, sha1, page_id
if 'rev_id' in rev_row:
rev_id = rev_row['rev_id']
else:
raise TypeError("rev_row must have 'rev_id'")
if 'page_id' in rev_row:
page_id = rev_row['page_id']
elif 'rev_page' in rev_row:
page_id = rev_row['rev_page']
else:
raise TypeError("rev_row must have 'page_id' or 'rev_page'")
# run the regular check
return check(db, rev_id, page_id=page_id, **kwargs)
def check(db, rev_id, page_id=None, radius=defaults.RADIUS, check_archive=False,
before=None, window=None):
"""
Checks whether a revision was reverted (identity) and returns a named tuple
of Revert(reverting, reverteds, reverted_to).
:Parameters:
db : `mw.database.DB`
A database connection to make use of.
rev_id : int
the ID of the revision to check
page_id : int
the ID of the page the revision occupies (slower if not provided)
radius : int
a positive integer indicating the maximum number of revisions that can be reverted
check_archive : bool
should the archive table be checked for reverting revisions?
before : `Timestamp`
if set, limits the search for *reverting* revisions to those which were saved before this timestamp
window : int
if set, limits the search for *reverting* revisions to those which
were saved within `window` seconds after the reverted edit
"""
if not hasattr(db, "revisions") and hasattr(db, "all_revisions"):
raise TypeError("db wrong type. Expected a mw.database.DB.")
rev_id = int(rev_id)
radius = int(radius)
if radius < 1:
raise TypeError("invalid radius. Expected a positive integer.")
page_id = none_or(page_id, int)
check_archive = bool(check_archive)
before = none_or(before, Timestamp)
# If we are searching the archive, we'll need to use `all_revisions`.
if check_archive:
dbrevs = db.all_revisions
else:
dbrevs = db.revisions
# If we don't have the sha1 or page_id, we're going to need to look them up
if page_id is None:
row = dbrevs.get(rev_id=rev_id)
page_id = row['rev_page']
# Load history and current rev
current_and_past_revs = list(dbrevs.query(
page_id=page_id,
limit=radius + 1,
before_id=rev_id + 1, # Ensures that we capture the current revision
direction="older"
))
try:
# Extract current rev and reorder history
current_rev, past_revs = (
current_and_past_revs[0], # Current rev is the first one returned
reversed(current_and_past_revs[1:]) # The rest are past revs, but they are in the wrong order
)
except IndexError:
# Only way to get here is if there isn't enough history. Couldn't be
# reverted. Just return None.
return None
if window is not None and before is None:
before = Timestamp(current_rev['rev_timestamp']) + window
# Load future revisions
future_revs = dbrevs.query(
page_id=page_id,
limit=radius,
after_id=rev_id,
before=before,
direction="newer"
)
# Convert to an iterable of (checksum, rev) pairs for detect() to consume
checksum_revisions = chain(
((rev['rev_sha1'] if rev['rev_sha1'] is not None \
else DummyChecksum(), rev)
for rev in past_revs),
[(current_rev['rev_sha1'] or DummyChecksum(), current_rev)],
((rev['rev_sha1'] if rev['rev_sha1'] is not None \
else DummyChecksum(), rev)
for rev in future_revs)
)
for revert in detect(checksum_revisions, radius=radius):
# Check that this is a relevant revert
if rev_id in [rev['rev_id'] for rev in revert.reverteds]:
return revert
return None

View File

@@ -0,0 +1,24 @@
RADIUS = 15
"""
TODO: Better documentation here. For the time being, see:
Priedhorsky, R., Chen, J., Lam, S. T. K., Panciera, K., Terveen, L., &
Riedl, J. (2007, November). Creating, destroying, and restoring value in
Wikipedia. In Proceedings of the 2007 international ACM conference on
Supporting group work (pp. 259-268). ACM.
"""
class DUMMY_SHA1: pass
"""
Used in when checking for reverts when the checksum of the revision of interest
is unknown.
>>> DUMMY_SHA1 in {"aaa", "bbb"} # or any 40 character hex
False
>>>
>>> DUMMY_SHA1 == DUMMY_SHA1
True
>>> {DUMMY_SHA1, DUMMY_SHA1}
{<class '__main__.DUMMY_SHA1'>}
"""

View File

@@ -0,0 +1,83 @@
from collections import namedtuple
from ...util import ordered
from . import defaults
Revert = namedtuple("Revert", ['reverting', 'reverteds', 'reverted_to'])
"""
Represents a revert event. This class behaves like
:class:`collections.namedtuple`. Note that the datatypes of `reverting`,
`reverteds` and `reverted_to` is not specified since those types will depend
on the revision data provided during revert detection.
:Members:
**reverting**
The reverting revision data : `mixed`
**reverteds**
The reverted revision data (ordered chronologically) : list( `mixed` )
**reverted_to**
The reverted-to revision data : `mixed`
"""
class Detector(ordered.HistoricalMap):
"""
Detects revert events in a stream of revisions (to the same page) based on
matching checksums. To detect reverts, construct an instance of this class and call
:meth:`process` in chronological order (``direction == "newer"``).
See `<https://meta.wikimedia.org/wiki/R:Identity_revert>`_
:Parameters:
radius : int
a positive integer indicating the maximum revision distance that a revert can span.
:Example:
>>> from mw.lib import reverts
>>> detector = reverts.Detector()
>>>
>>> detector.process("aaa", {'rev_id': 1})
>>> detector.process("bbb", {'rev_id': 2})
>>> detector.process("aaa", {'rev_id': 3})
Revert(reverting={'rev_id': 3}, reverteds=[{'rev_id': 2}], reverted_to={'rev_id': 1})
>>> detector.process("ccc", {'rev_id': 4})
"""
def __init__(self, radius=defaults.RADIUS):
"""
:Parameters:
radius : int
a positive integer indicating the maximum revision distance that a revert can span.
"""
if radius < 1:
raise TypeError("invalid radius. Expected a positive integer.")
super().__init__(maxlen=radius + 1)
def process(self, checksum, revision=None):
"""
Process a new revision and detect a revert if it occurred. Note that
you can pass whatever you like as `revision` and it will be returned in
the case that a revert occurs.
:Parameters:
checksum : str
Any identity-machable string-based hash of revision content
revision : `mixed`
Revision meta data. Note that any data will just be returned in the
case of a revert.
:Returns:
a :class:`~mw.lib.reverts.Revert` if one occured or `None`
"""
revert = None
if checksum in self: # potential revert
reverteds = list(self.up_to(checksum))
if len(reverteds) > 0: # If no reverted revisions, this is a noop
revert = Revert(revision, reverteds, self[checksum])
self.insert(checksum, revision)
return revert

View File

@@ -0,0 +1,24 @@
class DummyChecksum():
"""
Used in when checking for reverts when the checksum of the revision of interest
is unknown. DummyChecksums won't match eachother or anything else, but they
will match themselves and they are hashable.
>>> dummy1 = DummyChecksum()
>>> dummy1
<#140687347334280>
>>> dummy1 == dummy1
True
>>>
>>> dummy2 = DummyChecksum()
>>> dummy2
<#140687347334504>
>>> dummy1 == dummy2
False
>>>
>>> {"foo", "bar", dummy1, dummy1, dummy2}
{<#140687347334280>, 'foo', <#140687347334504>, 'bar'}
"""
def __str__(self): repr(self)
def __repr__(self): return "<#" + str(id(self)) + ">"

View File

@@ -0,0 +1,46 @@
from .detector import Detector
from . import defaults
def detect(checksum_revisions, radius=defaults.RADIUS):
"""
Detects reverts that occur in a sequence of revisions. Note that,
`revision` data meta will simply be returned in the case of a revert.
This function serves as a convenience wrapper around calls to
:class:`Detector`'s :meth:`~Detector.process`
method.
:Parameters:
checksum_revisions : iter( ( checksum : str, revision : `mixed` ) )
an iterable over tuples of checksum and revision meta data
radius : int
a positive integer indicating the maximum revision distance that a revert can span.
:Return:
a iterator over :class:`Revert`
:Example:
>>> from mw.lib import reverts
>>>
>>> checksum_revisions = [
... ("aaa", {'rev_id': 1}),
... ("bbb", {'rev_id': 2}),
... ("aaa", {'rev_id': 3}),
... ("ccc", {'rev_id': 4})
... ]
>>>
>>> list(reverts.detect(checksum_revisions))
[Revert(reverting={'rev_id': 3}, reverteds=[{'rev_id': 2}], reverted_to={'rev_id': 1})]
"""
revert_detector = Detector(radius)
for checksum, revision in checksum_revisions:
revert = revert_detector.process(checksum, revision)
if revert is not None:
yield revert
# For backwards compatibility
reverts = detect

View File

@@ -0,0 +1,33 @@
from nose.tools import eq_
from ..detector import Detector
def test_detector():
detector = Detector(2)
eq_(detector.process("a", {'id': 1}), None)
# Check noop
eq_(detector.process("a", {'id': 2}), None)
# Short revert
eq_(detector.process("b", {'id': 3}), None)
eq_(
detector.process("a", {'id': 4}),
({'id': 4}, [{'id': 3}], {'id': 2})
)
# Medium revert
eq_(detector.process("c", {'id': 5}), None)
eq_(detector.process("d", {'id': 6}), None)
eq_(
detector.process("a", {'id': 7}),
({'id': 7}, [{'id': 6}, {'id': 5}], {'id': 4})
)
# Long (undetected) revert
eq_(detector.process("e", {'id': 8}), None)
eq_(detector.process("f", {'id': 9}), None)
eq_(detector.process("g", {'id': 10}), None)
eq_(detector.process("a", {'id': 11}), None)

View File

@@ -0,0 +1,23 @@
from nose.tools import eq_
from ..functions import reverts
def test_reverts():
checksum_revisions = [
("a", {'id': 1}),
("b", {'id': 2}),
("c", {'id': 3}),
("a", {'id': 4}),
("d", {'id': 5}),
("b", {'id': 6}),
("a", {'id': 7})
]
expected = [
({'id': 4}, [{'id': 3}, {'id': 2}], {'id': 1}),
({'id': 7}, [{'id': 6}, {'id': 5}], {'id': 4})
]
for revert in reverts(checksum_revisions, radius=2):
eq_(revert, expected.pop(0))

View File

@@ -0,0 +1,4 @@
from .functions import cluster, sessions
from .event import Event
from .cache import Cache, Session
from . import defaults

View File

@@ -0,0 +1,121 @@
import logging
from collections import namedtuple
from ...util import Heap
from ...types import Timestamp
from . import defaults
from .event import Event, unpack_events
logger = logging.getLogger("mw.lib.sessions.cache")
Session = namedtuple("Session", ["user", "events"])
"""
Represents a user session (a cluster over events for a user). This class
behaves like :class:`collections.namedtuple`. Note that the datatypes of
`events`, is not specified since those types will depend on the revision data
provided during revert detection.
:Members:
**user**
A hashable user identifier : `hashable`
**events**
A list of event data : list( `mixed` )
"""
class Cache:
"""
A cache of recent user session. Since sessions expire once activities stop
for at least `cutoff` seconds, this class manages a cache of *active*
sessions.
:Parameters:
cutoff : int
Maximum amount of time in seconds between session events
:Example:
>>> from mw.lib import sessions
>>>
>>> cache = sessions.Cache(cutoff=3600)
>>>
>>> list(cache.process("Willy on wheels", 100000, {'rev_id': 1}))
[]
>>> list(cache.process("Walter", 100001, {'rev_id': 2}))
[]
>>> list(cache.process("Willy on wheels", 100001, {'rev_id': 3}))
[]
>>> list(cache.process("Walter", 100035, {'rev_id': 4}))
[]
>>> list(cache.process("Willy on wheels", 103602, {'rev_id': 5}))
[Session(user='Willy on wheels', events=[{'rev_id': 1}, {'rev_id': 3}])]
>>> list(cache.get_active_sessions())
[Session(user='Walter', events=[{'rev_id': 2}, {'rev_id': 4}]), Session(user='Willy on wheels', events=[{'rev_id': 5}])]
"""
def __init__(self, cutoff=defaults.CUTOFF):
self.cutoff = int(cutoff)
self.recently_active = Heap()
self.active_users = {}
def process(self, user, timestamp, data=None):
"""
Processes a user event.
:Parameters:
user : `hashable`
A hashable value to identify a user (`int` or `str` are OK)
timestamp : :class:`mw.Timestamp`
The timestamp of the event
data : `mixed`
Event meta data
:Returns:
A generator of :class:`~mw.lib.sessions.Session` expired after
processing the user event.
"""
event = Event(user, Timestamp(timestamp), data)
for user, events in self._clear_expired(event.timestamp):
yield Session(user, unpack_events(events))
# Apply revision
if event.user in self.active_users:
events = self.active_users[event.user]
else:
events = []
self.active_users[event.user] = events
self.recently_active.push((event.timestamp, events))
events.append(event)
def get_active_sessions(self):
"""
Retrieves the active, unexpired sessions.
:Returns:
A generator of :class:`~mw.lib.sessions.Session`
"""
for last_timestamp, events in self.recently_active:
yield Session(events[-1].user, unpack_events(events))
def _clear_expired(self, timestamp):
# Cull old sessions
while (len(self.recently_active) > 0 and
timestamp - self.recently_active.peek()[0] >= self.cutoff):
_, events = self.recently_active.pop()
if timestamp - events[-1].timestamp >= self.cutoff:
del self.active_users[events[-1].user]
yield events[-1].user, events
else:
self.recently_active.push((events[-1].timestamp, events))
def __repr__(self):
return "%s(%s)".format(self.__class__.__name__, repr(self.cutoff))

View File

@@ -0,0 +1,6 @@
CUTOFF = 60 * 60
"""
TODO: Better documentation here.
For the time being, see
`<https://meta.wikimedia.org/wiki/Research:Edit_session>`_
"""

View File

@@ -0,0 +1,19 @@
import logging
from collections import namedtuple
logger = logging.getLogger("mw.lib.sessions.event")
# class Event:
# __slots__ = ('user', 'timestamp', 'data')
#
# def __init__(self, user, timestamp, data=None):
# self.user = user
# self.timestamp = Timestamp(timestamp)
# self.data = data
Event = namedtuple("Event", ['user', 'timestamp', 'data'])
def unpack_events(events):
return list(e.data for e in events)

View File

@@ -0,0 +1,68 @@
import logging
from .cache import Cache
from . import defaults
logger = logging.getLogger("mw.lib.sessions.functions")
def cluster(user_events, cutoff=defaults.CUTOFF):
"""
Clusters user sessions from a sequence of user events. Note that,
`event` data will simply be returned in the case of a revert.
This function serves as a convenience wrapper around calls to
:class:`~mw.lib.sessions.Cache`'s :meth:`~mw.lib.sessions.Cache.process`
method.
:Parameters:
user_events : iter( (user, timestamp, event) )
an iterable over tuples of user, timestamp and event data.
* user : `hashable`
* timestamp : :class:`mw.Timestamp`
* event : `mixed`
cutoff : int
the maximum time between events within a user session
:Returns:
a iterator over :class:`~mw.lib.sessions.Session`
:Example:
>>> from mw.lib import sessions
>>>
>>> user_events = [
... ("Willy on wheels", 100000, {'rev_id': 1}),
... ("Walter", 100001, {'rev_id': 2}),
... ("Willy on wheels", 100001, {'rev_id': 3}),
... ("Walter", 100035, {'rev_id': 4}),
... ("Willy on wheels", 103602, {'rev_id': 5})
... ]
>>>
>>> for user, events in sessions.cluster(user_events):
... (user, events)
...
('Willy on wheels', [{'rev_id': 1}, {'rev_id': 3}])
('Walter', [{'rev_id': 2}, {'rev_id': 4}])
('Willy on wheels', [{'rev_id': 5}])
"""
# Construct the session manager
cache = Cache(cutoff)
# Apply the events
for user, timestamp, event in user_events:
for session in cache.process(user, timestamp, event):
yield session
# Yield the left-overs
for session in cache.get_active_sessions():
yield session
# For backwards compatibility
sessions = cluster

View File

@@ -0,0 +1,22 @@
from nose.tools import eq_
from ..cache import Cache
def test_session_manager():
cache = Cache(cutoff=2)
user_sessions = list(cache.process("foo", 1))
eq_(user_sessions, [])
user_sessions = list(cache.process("bar", 2))
eq_(user_sessions, [])
user_sessions = list(cache.process("foo", 2))
eq_(user_sessions, [])
user_sessions = list(cache.process("bar", 10))
eq_(len(user_sessions), 2)
user_sessions = list(cache.get_active_sessions())
eq_(len(user_sessions), 1)

View File

@@ -0,0 +1,50 @@
from itertools import chain
from nose.tools import eq_
from .. import defaults
from ..functions import sessions
EVENTS = {
"foo": [
[
("foo", 1234567890, 1),
("foo", 1234567892, 2),
("foo", 1234567894, 3)
],
[
("foo", 1234567894 + defaults.CUTOFF, 4),
("foo", 1234567897 + defaults.CUTOFF, 5)
]
],
"bar": [
[
("bar", 1234567891, 6),
("bar", 1234567892, 7),
("bar", 1234567893, 8)
],
[
("bar", 1234567895 + defaults.CUTOFF, 9),
("bar", 1234567898 + defaults.CUTOFF, 0)
]
]
}
def test_group_events():
events = []
events.extend(chain(*EVENTS['foo']))
events.extend(chain(*EVENTS['bar']))
events.sort()
user_sessions = sessions(events)
counts = {
'foo': 0,
'bar': 0
}
for user, session in user_sessions:
eq_(list(e[2] for e in EVENTS[user][counts[user]]), list(session))
counts[user] += 1

View File

@@ -0,0 +1,2 @@
from .functions import normalize
from .parser import Parser

View File

@@ -0,0 +1,25 @@
def normalize(title):
"""
Normalizes a page title to the database format. E.g. spaces are converted
to underscores and the first character in the title is converted to
upper-case.
:Parameters:
title : str
A page title
:Returns:
The normalized title.
:Example:
>>> from mw.lib import title
>>>
>>> title.normalize("foo bar")
'Foo_bar'
"""
if title is None:
return title
else:
if len(title) > 0:
return (title[0].upper() + title[1:]).replace(" ", "_")
else:
return ""

View File

@@ -0,0 +1,171 @@
from ...types import Namespace
from ...util import autovivifying, none_or
from .functions import normalize
class Parser:
"""
Constructs a page name parser from a set of :class:`mw.Namespace`. Such a
parser can be used to convert a full page name (namespace included with a
colon; e.g, ``"Talk:Foo"``) into a namespace ID and
:func:`mw.lib.title.normalize`'d page title (e.g., ``(1, "Foo")``).
:Parameters:
namespaces : set( :class:`mw.Namespace` )
:Example:
>>> from mw import Namespace
>>> from mw.lib import title
>>>
>>> parser = title.Parser(
... [
... Namespace(0, "", case="first-letter"),
... Namespace(1, "Discuss\u00e3o", canonical="Talk", case="first-letter"),
... Namespace(2, "Usu\u00e1rio(a)", canonical="User", aliases={"U"}, case="first-letter")
... ]
... )
>>>
>>> parser.parse("Discuss\u00e3o:Foo") # Using the standard name
(1, 'Foo')
>>> parser.parse("Talk:Foo bar") # Using the cannonical name
(1, 'Foo_bar')
>>> parser.parse("U:Foo bar") # Using an alias
(2, 'Foo_bar')
>>> parser.parse("Herpderp:Foo bar") # Psuedo namespace
(0, 'Herpderp:Foo_bar')
"""
def __init__(self, namespaces=None):
namespaces = none_or(namespaces, set)
self.ids = {}
self.names = {}
if namespaces is not None:
for namespace in namespaces:
self.add_namespace(namespace)
def parse(self, page_name):
"""
Parses a page name to extract the namespace.
:Parameters:
page_name : str
A page name including the namespace prefix and a colon (if not Main)
:Returns:
A tuple of (namespace : `int`, title : `str`)
"""
parts = page_name.split(":", 1)
if len(parts) == 1:
ns_id = 0
title = normalize(page_name)
else:
ns_name, title = parts
ns_name, title = normalize(ns_name), normalize(title)
if self.contains_name(ns_name):
ns_id = self.get_namespace(name=ns_name).id
else:
ns_id = 0
title = normalize(page_name)
return ns_id, title
def add_namespace(self, namespace):
"""
Adds a namespace to the parser.
:Parameters:
namespace : :class:`mw.Namespace`
A namespace
"""
self.ids[namespace.id] = namespace
self.names[namespace.name] = namespace
for alias in namespace.aliases:
self.names[alias] = namespace
if namespace.canonical is not None:
self.names[namespace.canonical] = namespace
def contains_name(self, name):
return normalize(name) in self.names
def get_namespace(self, id=None, name=None):
"""
Gets a namespace from the parser. Throws a :class:`KeyError` if a
namespace cannot be found.
:Parameters:
id : int
A namespace ID
name : str
A namespace name (standard, cannonical names and aliases
will be searched)
:Returns:
A :class:`mw.Namespace`.
"""
if id is not None:
return self.ids[int(id)]
else:
return self.names[normalize(name)]
@classmethod
def from_site_info(cls, si_doc):
"""
Constructs a parser from the result of a :meth:`mw.api.SiteInfo.query`.
:Parameters:
si_doc : dict
The result of a site_info request.
:Returns:
An initialized :class:`mw.lib.title.Parser`
"""
aliases = autovivifying.Dict(vivifier=lambda k: [])
# get aliases
if 'namespacealiases' in si_doc:
for alias_doc in si_doc['namespacealiases']:
aliases[alias_doc['id']].append(alias_doc['*'])
namespaces = []
for ns_doc in si_doc['namespaces'].values():
namespaces.append(
Namespace.from_doc(ns_doc, aliases)
)
return Parser(namespaces)
@classmethod
def from_api(cls, session):
"""
Constructs a parser from a :class:`mw.api.Session`
:Parameters:
session : :class:`mw.api.Session`
An open API session
:Returns:
An initialized :class:`mw.lib.title.Parser`
"""
si_doc = session.site_info.query(
properties={'namespaces', 'namespacealiases'}
)
return cls.from_site_info(si_doc)
@classmethod
def from_dump(cls, dump):
"""
Constructs a parser from a :class:`mw.xml_dump.Iterator`. Note that
XML database dumps do not include namespace aliases or cannonical names
so the parser that will be constructed will only work in common cases.
:Parameters:
dump : :class:`mw.xml_dump.Iterator`
An XML dump iterator
:Returns:
An initialized :class:`mw.lib.title.Parser`
"""
return cls(dump.namespaces)

View File

@@ -0,0 +1,10 @@
from nose.tools import eq_
from ..functions import normalize
def test_normalize():
eq_("Foobar", normalize("Foobar")) # Same
eq_("Foobar", normalize("foobar")) # Capitalize
eq_("FooBar", normalize("fooBar")) # Late capital
eq_("Foo_bar", normalize("Foo bar")) # Space

View File

@@ -0,0 +1,58 @@
from nose.tools import eq_
from ....types import Namespace
from ..parser import Parser
def test_simple():
parser = Parser(
[
Namespace(0, "", case="first-letter"),
Namespace(1, "Discuss\u00e3o", canonical="Talk", case="first-letter"),
Namespace(2, "Usu\u00e1rio(a)", canonical="User", case="first-letter")
]
)
eq_((1, "Foo"), parser.parse("Discuss\u00e3o:Foo"))
eq_((1, "Foo_bar"), parser.parse("Discuss\u00e3o:Foo bar"))
eq_((0, "Herpderp:Foo_bar"), parser.parse("Herpderp:Foo bar"))
def test_from_site_info():
parser = Parser.from_site_info(
{
"namespaces": {
"0": {
"id": 0,
"case": "first-letter",
"*": "",
"content": ""
},
"1": {
"id": 1,
"case": "first-letter",
"*": "Discuss\u00e3o",
"subpages": "",
"canonical": "Talk"
},
"2": {
"id": 2,
"case": "first-letter",
"*": "Usu\u00e1rio(a)",
"subpages": "",
"canonical": "User"
}
},
"namespacealiases": [
{
"id": 1,
"*": "WAFFLES"
}
]
}
)
eq_((1, "Foo"), parser.parse("Discuss\u00e3o:Foo"))
eq_((1, "Foo_bar"), parser.parse("Discuss\u00e3o:Foo bar"))
eq_((0, "Herpderp:Foo_bar"), parser.parse("Herpderp:Foo bar"))
eq_((1, "Foo_bar"), parser.parse("WAFFLES:Foo bar"))

View File

@@ -0,0 +1,2 @@
from .timestamp import Timestamp
from .namespace import Namespace

View File

@@ -0,0 +1,61 @@
from . import serializable
from ..util import none_or
class Namespace(serializable.Type):
"""
Namespace meta data.
"""
__slots__ = ('id', 'name', 'aliases', 'case', 'canonical')
def __init__(self, id, name, canonical=None, aliases=None, case=None,
content=False):
self.id = int(id)
"""
Namespace ID : `int`
"""
self.name = none_or(name, str)
"""
Namespace name : `str`
"""
self.aliases = serializable.Set.deserialize(aliases or [], str)
"""
Alias names : set( `str` )
"""
self.case = none_or(case, str)
"""
Case sensitivity : `str` | `None`
"""
self.canonical = none_or(canonical, str)
"""
Canonical name : `str` | `None`
"""
self.content = bool(content)
"""
Is considered a content namespace : `bool`
"""
def __hash__(self):
return self.id
@classmethod
def from_doc(cls, doc, aliases={}):
"""
Constructs a namespace object from a namespace doc returned by the API
site_info call.
"""
return cls(
doc['id'],
doc['*'],
canonical=doc.get('canonical'),
aliases=set(aliases.get(doc['id'], [])),
case=doc['case'],
content='content' in doc
)

View File

@@ -0,0 +1,97 @@
from itertools import chain
class Type:
def __eq__(self, other):
if other is None:
return False
try:
for key in self.keys():
if getattr(self, key) != getattr(other, key):
return False
return True
except KeyError:
return False
def __neq__(self, other):
return not self.__eq__(other)
def __str__(self):
return self.__repr__()
def __repr__(self):
return "%s(%s)" % (
self.__class__.__name__,
", ".join(
"%s=%r" % (k, v) for k, v in self.items()
)
)
def items(self):
for key in self.keys():
yield key, getattr(self, key)
def keys(self):
return (
key for key in
chain(getattr(self, "__slots__", []), self.__dict__.keys())
if key[:2] != "__"
)
def serialize(self):
return dict(
(k, self._serialize(v))
for k, v in self.items()
)
def _serialize(self, value):
if hasattr(value, "serialize"):
return value.serialize()
else:
return value
@classmethod
def deserialize(cls, doc_or_instance):
if isinstance(doc_or_instance, cls):
return doc_or_instance
else:
return cls(**doc_or_instance)
class Dict(dict, Type):
def serialize(self):
return {k: self._serialize(v) for k, v in self.items()}
@staticmethod
def deserialize(d, value_deserializer=lambda v: v):
if isinstance(d, Dict):
return d
else:
return Dict((k, value_deserializer(v)) for k, v in d.items())
class Set(set, Type):
def serialize(self):
return [self._serialize(v) for v in self]
@staticmethod
def deserialize(s, value_deserializer=lambda v: v):
if isinstance(s, Set):
return s
else:
return Set(value_deserializer(v) for v in s)
class List(list, Type):
def serialize(self):
return list(self._serialize(v) for v in self)
@staticmethod
def deserialize(l, value_deserializer=lambda v: v):
if isinstance(l, List):
return l
else:
return List(value_deserializer(v) for v in l)

View File

@@ -0,0 +1,32 @@
from nose.tools import eq_
from ..namespace import Namespace
def test_namespace():
namespace = Namespace(10, "Foo", canonical="Bar", aliases={'WT'},
case="foobar", content=False)
eq_(namespace.id, 10)
eq_(namespace.name, "Foo")
eq_(namespace.canonical, "Bar")
eq_(namespace.aliases, {'WT'})
eq_(namespace.case, "foobar")
eq_(namespace.content, False)
def test_namespace_from_doc():
doc = {
"id": 0,
"case": "first-letter",
"*": "",
"content": ""
}
namespace = Namespace.from_doc(doc)
eq_(namespace.id, 0)
eq_(namespace.name, "")
eq_(namespace.canonical, None)
eq_(namespace.aliases, set())
eq_(namespace.case, "first-letter")
eq_(namespace.content, True)

View File

@@ -0,0 +1,25 @@
from nose.tools import eq_
from .. import serializable
def test_type():
class Foo(serializable.Type):
def __init__(self, foo, bar):
self.foo = foo
self.bar = bar
foo = Foo(1, "bar")
eq_(foo, Foo.deserialize(foo))
eq_(foo, Foo.deserialize(foo.serialize()))
def test_dict():
d = serializable.Dict()
d['foo'] = "bar"
d['derp'] = "herp"
eq_(d['foo'], "bar")
assert 'derp' in d
eq_(d, serializable.Dict.deserialize(d.serialize(), str))

View File

@@ -0,0 +1,87 @@
import pickle
from nose.tools import eq_
from ..timestamp import LONG_MW_TIME_STRING, Timestamp
def test_self():
t1 = Timestamp(1234567890)
# Unix timestamp
eq_(t1, Timestamp(int(t1)))
# Short format
eq_(t1, Timestamp(t1.short_format()))
# Long format
eq_(t1, Timestamp(t1.long_format()))
def test_comparison():
t1 = Timestamp(1234567890)
t2 = Timestamp(1234567891)
assert t1 < t2, "Less than comparison failed"
assert t2 > t1, "Greater than comparison failed"
assert not t2 < t1, "Not less than comparison failed"
assert not t1 > t2, "Not greater than comparison failed"
assert t1 <= t2, "Less than or equal to comparison failed"
assert t1 <= t1, "Less than or equal to comparison failed"
assert t2 >= t1, "Greater than or equal to comparison failed"
assert t2 >= t2, "Greater than or equal to comparison failed"
assert not t2 <= t1, "Not less than or equal to comparison failed"
assert not t1 >= t2, "Not greater than or equal to comparison failed"
def test_subtraction():
t1 = Timestamp(1234567890)
t2 = Timestamp(1234567891)
eq_(t2 - t1, 1)
eq_(t1 - t2, -1)
eq_(t2 - 1, t1)
def test_strptime():
eq_(
Timestamp("2009-02-13T23:31:30Z"),
Timestamp.strptime("2009-02-13T23:31:30Z", LONG_MW_TIME_STRING)
)
eq_(
Timestamp.strptime(
"expires 03:20, 21 November 2013 (UTC)",
"expires %H:%M, %d %B %Y (UTC)"
),
Timestamp("2013-11-21T03:20:00Z")
)
def test_strftime():
eq_(
Timestamp("2009-02-13T23:31:30Z").strftime(LONG_MW_TIME_STRING),
"2009-02-13T23:31:30Z"
)
eq_(
Timestamp("2009-02-13T23:31:30Z").strftime("expires %H:%M, %d %B %Y (UTC)"),
"expires 23:31, 13 February 2009 (UTC)"
)
def test_serialization():
timestamp = Timestamp(1234567890)
eq_(
timestamp,
Timestamp.deserialize(timestamp.serialize())
)
def test_pickling():
timestamp = Timestamp(1234567890)
eq_(
timestamp,
pickle.loads(pickle.dumps(timestamp))
)

View File

@@ -0,0 +1,318 @@
import calendar
import datetime
import time
from . import serializable
LONG_MW_TIME_STRING = '%Y-%m-%dT%H:%M:%SZ'
"""
The longhand version of MediaWiki time strings.
"""
SHORT_MW_TIME_STRING = '%Y%m%d%H%M%S'
"""
The shorthand version of MediaWiki time strings.
"""
class Timestamp(serializable.Type):
"""
An immutable type for working with MediaWiki timestamps in their various
forms.
:Parameters:
time_thing : :class:`mw.Timestamp` | :py:class:`~time.time_struct` | :py:class:`~datetime.datetime` | :py:class:`str` | :py:class:`int` | :py:class:`float`
The timestamp type from which to construct the timestamp class.
:Returns:
:class:`mw.Timestamp`
You can make use of a lot of different *time things* to initialize a
:class:`mw.Timestamp`.
* If a :py:class:`~time.time_struct` or :py:class:`~datetime.datetime` are provided, a `Timestamp` will be constructed from their values.
* If an `int` or `float` are provided, they will be assumed to a unix timestamp in seconds since Jan. 1st, 1970 UTC.
* If a `str` is provided, it will be be checked against known MediaWiki timestamp formats. E.g., ``'%Y%m%d%H%M%S'`` and ``'%Y-%m-%dT%H:%M:%SZ'``.
* If a :class:`mw.Timestamp` is provided, the same `Timestamp` will be returned.
For example::
>>> import datetime, time
>>> from mw import Timestamp
>>> Timestamp(1234567890)
Timestamp('2009-02-13T23:31:30Z')
>>> Timestamp(1234567890) == Timestamp("2009-02-13T23:31:30Z")
True
>>> Timestamp(1234567890) == Timestamp("20090213233130")
True
>>> Timestamp(1234567890) == Timestamp(datetime.datetime.utcfromtimestamp(1234567890))
True
>>> Timestamp(1234567890) == Timestamp(time.strptime("2009-02-13T23:31:30Z", "%Y-%m-%dT%H:%M:%SZ"))
True
>>> Timestamp(1234567890) == Timestamp(Timestamp(1234567890))
True
You can also do math and comparisons of timestamps.::
>>> from mw import Timestamp
>>> t = Timestamp(1234567890)
>>> t
Timestamp('2009-02-13T23:31:30Z')
>>> t2 = t + 10
>>> t2
Timestamp('2009-02-13T23:31:40Z')
>>> t += 1
>>> t
Timestamp('2009-02-13T23:31:31Z')
>>> t2 - t
9
>>> t < t2
True
"""
def __new__(cls, time_thing):
if isinstance(time_thing, cls):
return time_thing
elif isinstance(time_thing, time.struct_time):
return cls.from_time_struct(time_thing)
elif isinstance(time_thing, datetime.datetime):
return cls.from_datetime(time_thing)
elif type(time_thing) in (int, float):
return cls.from_unix(time_thing)
else:
return cls.from_string(time_thing)
def __init__(self, time_thing):
# Important that this does nothing in order to allow __new__ to behave
# as expected. User `initialize()` instead.
pass
def initialize(self, time_struct):
self.__time = time_struct
def short_format(self):
"""
Constructs a long, ``'%Y%m%d%H%M%S'`` formatted string common to the
database. This method is roughly equivalent to calling
``strftime('%Y%m%d%H%M%S')``.
:Parameters:
format : str
The string format
:Returns:
A formatted string
"""
return self.strftime(SHORT_MW_TIME_STRING)
def long_format(self):
"""
Constructs a long, ``'%Y-%m-%dT%H:%M:%SZ'`` formatted string common to the
API. This method is roughly equivalent to calling
``strftime('%Y-%m-%dT%H:%M:%SZ')``.
:Parameters:
format : str
The string format
:Returns:
A formatted string
"""
return self.strftime(LONG_MW_TIME_STRING)
def strftime(self, format):
"""
Constructs a formatted string.
See `<https://docs.python.org/3/library/time.html#time.strftime>`_ for a
discussion of formats descriptors.
:Parameters:
format : str
The format description
:Returns:
A formatted string
"""
return time.strftime(format, self.__time)
@classmethod
def strptime(cls, string, format):
"""
Constructs a :class:`mw.Timestamp` from an explicitly formatted string.
See `<https://docs.python.org/3/library/time.html#time.strftime>`_ for a
discussion of formats descriptors.
:Parameters:
string : str
A formatted timestamp
format : str
The format description
:Returns:
:class:`mw.Timestamp`
"""
return cls.from_time_struct(time.strptime(string, format))
@classmethod
def from_time_struct(cls, time_struct):
"""
Constructs a :class:`mw.Timestamp` from a :class:`time.time_struct`.
:Parameters:
time_struct : :class:`time.time_struct`
A time structure
:Returns:
:class:`mw.Timestamp`
"""
instance = super().__new__(cls)
instance.initialize(time_struct)
return instance
@classmethod
def from_datetime(cls, dt):
"""
Constructs a :class:`mw.Timestamp` from a :class:`datetime.datetime`.
:Parameters:
dt : :class:`datetime.datetime``
A datetime.
:Returns:
:class:`mw.Timestamp`
"""
time_struct = dt.timetuple()
return cls.from_time_struct(time_struct)
@classmethod
def from_unix(cls, seconds):
"""
Constructs a :class:`mw.Timestamp` from a unix timestamp (in seconds
since Jan. 1st, 1970 UTC).
:Parameters:
seconds : int
A unix timestamp
:Returns:
:class:`mw.Timestamp`
"""
time_struct = datetime.datetime.utcfromtimestamp(seconds).timetuple()
return cls.from_time_struct(time_struct)
@classmethod
def from_string(cls, string):
"""
Constructs a :class:`mw.Timestamp` from a MediaWiki formatted string.
This method is provides a convenient way to construct from common
MediaWiki timestamp formats. E.g., ``%Y%m%d%H%M%S`` and
``%Y-%m-%dT%H:%M:%SZ``.
:Parameters:
string : str
A formatted timestamp
:Returns:
:class:`mw.Timestamp`
"""
if type(string) == bytes:
string = str(string, 'utf8')
else:
string = str(string)
try:
return cls.strptime(string, SHORT_MW_TIME_STRING)
except ValueError as e:
try:
return cls.strptime(string, LONG_MW_TIME_STRING)
except ValueError as e:
raise ValueError(
"{0} is not a valid Wikipedia date format".format(
repr(string)
)
)
return cls.from_time_struct(time_struct)
def __format__(self, format):
return self.strftime(format)
def __str__(self):
return self.short_format()
def serialize(self):
return self.unix()
@classmethod
def deserialize(cls, time_thing):
return Timestamp(time_thing)
def __repr__(self):
return "{0}({1})".format(
self.__class__.__name__,
repr(self.long_format())
)
def __int__(self):
return self.unix()
def __float__(self):
return float(self.unix())
def unix(self):
"""
:Returns:
the number of seconds since Jan. 1st, 1970 UTC.
"""
return int(calendar.timegm(self.__time))
def __sub__(self, other):
if isinstance(other, Timestamp):
return self.unix() - other.unix()
else:
return self + (other * -1)
def __add__(self, seconds):
return Timestamp(self.unix() + seconds)
def __eq__(self, other):
try:
return self.__time == other.__time
except AttributeError:
return False
def __lt__(self, other):
try:
return self.__time < other.__time
except AttributeError:
return NotImplemented
def __gt__(self, other):
try:
return self.__time > other.__time
except AttributeError:
return NotImplemented
def __le__(self, other):
try:
return self.__time <= other.__time
except AttributeError:
return NotImplemented
def __ge__(self, other):
try:
return self.__time >= other.__time
except AttributeError:
return NotImplemented
def __ne__(self, other):
try:
return not self.__time == other.__time
except AttributeError:
return NotImplemented
def __getnewargs__(self):
return (self.__time,)

View File

@@ -0,0 +1,2 @@
from .functions import none_or
from .heap import Heap

View File

@@ -0,0 +1,53 @@
import logging
import time
import requests
from .functions import none_or
logger = logging.getLogger("mw.util.api.session")
FAILURE_THRESHOLD = 5
TIMEOUT = 20
class Session:
def __init__(self, uri, headers=None, timeout=None,
failure_threshold=None, wait_step=2):
if uri is None:
raise TypeError("uri must not be None")
self.uri = str(uri)
self.headers = headers if headers is not None else {}
self.session = requests.Session()
self.failure_threshold = int(failure_threshold or FAILURE_THRESHOLD)
self.timeout = float(TIMEOUT)
self.wait_step = float(wait_step)
self.failed = 0
def __sleep(self):
time.sleep(self.failed * (self.wait_step ** self.failed))
def get(self, params, **kwargs):
return self.request('GET', params, **kwargs)
def post(self, params, **kwargs):
return self.request('POST', params, **kwargs)
def request(self, type, params):
try:
result = self.session.request(type, self.uri, params=params,
timeout=self.timeout)
self.failed = 0
return result
except (requests.HTTPError, requests.ConnectionError):
self.failed += 1
if self.failed > self.failure_threshold:
self.failed = 0
raise
else:
self.__sleep()
self.request(type, params)

View File

@@ -0,0 +1,11 @@
class Dict(dict):
def __init__(self, *args, vivifier=lambda k: None, **kwargs):
self.vivifier = vivifier
dict.__init__(self, *args, **kwargs)
def __getitem__(self, key):
if key not in self:
dict.__setitem__(self, key, self.vivifier(key))
return dict.__getitem__(self, key)

View File

@@ -0,0 +1,21 @@
def none_or(val, func=None, levels=None):
if val is None:
return None
else:
if levels is not None:
if val not in set(levels):
raise KeyError(val)
return val
else:
return func(val)
def try_keys(dictionary, keys):
attempted_keys = []
for key in keys:
if key in dictionary:
return dictionary[key]
attempted_keys.append(key)
raise KeyError("|".join(str(k) for k in attempted_keys))

View File

@@ -0,0 +1,22 @@
import heapq
class Heap(list):
def __init__(self, *args, **kwargs):
list.__init__(self, *args, **kwargs)
heapq.heapify(self)
def pop(self):
return heapq.heappop(self)
def push(self, item):
heapq.heappush(self, item)
def peek(self):
return self[0]
def pushpop(self, item):
return heapq.heappushpop(self, item)
def poppush(self, itemp):
return heapq.replace(self, item)

View File

@@ -0,0 +1,3 @@
from .aggregate import aggregate, group
from .peekable import Peekable
from .sequence import sequence

View File

@@ -0,0 +1,20 @@
from .peekable import Peekable
def group(it, by=lambda i: i):
return aggregate(it, by)
def aggregate(it, by=lambda i: i):
it = Peekable(it)
def chunk(it, by):
identifier = by(it.peek())
while not it.empty():
if identifier == by(it.peek()):
yield next(it)
else:
break
while not it.empty():
yield (by(it.peek()), chunk(it, by))

View File

@@ -0,0 +1,8 @@
def count(iterable):
"""
Consumes all items in an iterable and returns a count.
"""
n = 0
for item in iterable:
n += 1
return n

View File

@@ -0,0 +1,37 @@
def Peekable(it):
if isinstance(it, PeekableType):
return it
else:
return PeekableType(it)
class PeekableType:
class EMPTY:
pass
def __init__(self, it):
self.it = iter(it)
self.__cycle()
def __iter__(self):
return self
def __cycle(self):
try:
self.lookahead = next(self.it)
except StopIteration:
self.lookahead = self.EMPTY
def __next__(self):
item = self.peek()
self.__cycle()
return item
def peek(self):
if self.empty():
raise StopIteration()
else:
return self.lookahead
def empty(self):
return self.lookahead == self.EMPTY

View File

@@ -0,0 +1,27 @@
from .peekable import Peekable
def sequence(*iterables, by=None, compare=None):
if compare is not None:
compare = compare
elif by is not None:
compare = lambda i1, i2: by(i1) <= by(i2)
else:
compare = lambda i1, i2: i1 <= i2
iterables = [Peekable(it) for it in iterables]
done = False
while not done:
next_i = None
for i, it in enumerate(iterables):
if not it.empty():
if next_i is None or compare(it.peek(), iterables[next_i].peek()):
next_i = i
if next_i is None:
done = True
else:
yield next(iterables[next_i])

View File

@@ -0,0 +1,13 @@
from nose.tools import eq_
from ..aggregate import aggregate
def test_group():
l = [0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 14]
expected = [[0, 1, 2, 3, 4, 5], [10, 11, 12, 13, 14]]
result = []
for identifier, group in aggregate(l, lambda item: int(item / 10)):
result.append(list(group))
eq_(result, expected)

View File

@@ -0,0 +1,20 @@
from nose.tools import eq_
from ..peekable import Peekable
def test_peekable():
iterable = range(0, 100)
iterable = Peekable(iterable)
expected = list(range(0, 100))
result = []
assert not iterable.empty()
eq_(iterable.peek(), expected[0])
result.append(next(iterable))
eq_(iterable.peek(), expected[1])
result.append(next(iterable))
result.extend(list(iterable))
eq_(result, expected)

View File

@@ -0,0 +1,11 @@
from nose.tools import eq_
from ..sequence import sequence
def test_sequence():
foo = [{'val': 3}, {'val': 5}]
bar = [{'val': 1}, {'val': 10}, {'val': 15}]
expected = [{'val': 1}, {'val': 3}, {'val': 5}, {'val': 10}, {'val': 15}]
result = list(sequence(foo, bar, compare=lambda i1, i2: i1['val'] < i2['val']))
eq_(expected, result)

View File

@@ -0,0 +1,116 @@
from . import autovivifying
class Circle(list):
def __init__(self, maxsize, iterable=None):
self._maxsize = int(maxsize)
list.__init__(self, [None] * maxsize)
self._size = 0
self._pointer = 0
if iterable is not None:
self.extend(iterable)
def state(self):
return list(list.__iter__(self))
def _internalize(self, index):
if self._size < self._maxsize:
return index
else:
return (self._pointer + index) % self._maxsize
def __iter__(self):
for i in range(0, self._size):
yield list.__getitem__(self, self._internalize(i))
def __reversed__(self):
for i in range(self._size - 1, -1, -1):
yield list.__getitem__(self, self._internalize(i))
def pop(self, index=None):
raise NotImplementedError()
def __len__(self):
return self._size
def __getitem__(self, index):
return list.__getitem__(self, self._internalize(index))
def append(self, value):
# Get the old value
old_value = list.__getitem__(self, self._pointer)
# Update internal list
list.__setitem__(self, self._pointer, value)
# Update state
self._pointer = (self._pointer + 1) % self._maxsize
self._size = min(self._maxsize, self._size + 1)
# If we overwrote a value, yield it.
return old_value
def extend(self, values):
for value in values:
expectorate = self.append(value)
if expectorate is not None or self._size == self._maxsize:
yield expectorate
class HistoricalMap(autovivifying.Dict):
'''
A datastructure for efficiently storing and retrieving a
limited number of historical records.
TODO: Rename this to FIFOCache
'''
def __init__(self, *args, maxlen, **kwargs):
'''Maxlen specifies the maximum amount of history to keep'''
super().__init__(self, *args, vivifier=lambda k: [], **kwargs)
self._circle = Circle(maxlen) # List to preserve order for history
def __iter__(self):
return iter(self._circle)
def __setitem__(self, key, value):
'''Adds a new key-value pair. Returns any discarded values.'''
# Add to history circle and catch expectorate
expectorate = self._circle.append((key, value))
autovivifying.Dict.__getitem__(self, key).append(value)
if expectorate is not None:
old_key, old_value = expectorate
autovivifying.Dict.__getitem__(self, old_key).pop(0)
if len(autovivifying.Dict.__getitem__(self, old_key)) == 0:
autovivifying.Dict.__delitem__(self, old_key)
return (old_key, old_value)
def insert(self, key, value):
return self.__setitem__(key, value)
def __getitem__(self, key):
if key in self:
return autovivifying.Dict.__getitem__(self, key)[-1]
else:
raise KeyError(key)
def get(self, key):
'''Gets the most recently added value for a key'''
return self.__getitem__(key)
def up_to(self, key):
'''Gets the recently inserted values up to a key'''
for okey, ovalue in reversed(self._circle):
if okey == key:
break
else:
yield ovalue
def last(self):
return self.circle[-1]

View File

@@ -0,0 +1,28 @@
from nose.tools import eq_
from .. import autovivifying
def test_word_count():
words = """
I am a little teapot short and stout. Here is my handle and here is my
spout. The red fox jumps over the lazy brown dog. She sells sea shells
by the sea shore.
""".replace(".", " ").split()
# Lame way
lame_counts = {}
for word in words:
if word not in lame_counts:
lame_counts[word] = 0
lame_counts[word] += 1
# Awesome way
awesome_counts = autovivifying.Dict( # Autovivifies entries with zero.
vivifier=lambda k: 0 # Useful for counting.
)
for word in words:
awesome_counts[word] += 1
eq_(lame_counts, awesome_counts)

View File

@@ -0,0 +1,14 @@
from nose.tools import eq_
from ..functions import none_or
def test_none_or():
eq_(10, none_or("10", int))
eq_(10.75, none_or("10.75", float))
eq_(None, none_or(None, int))
assert none_or("", str) is not None
assert none_or([], list) is not None
assert none_or({}, dict) is not None
assert none_or(0, int) is not None
assert none_or(-1, int) is not None

View File

@@ -0,0 +1,28 @@
from nose.tools import eq_
from ..heap import Heap
def test_heap():
h = Heap([5, 4, 7, 8, 2])
eq_(h.pop(), 2)
eq_(h.pop(), 4)
eq_(h.pop(), 5)
eq_(h.pop(), 7)
eq_(h.pop(), 8)
eq_(len(h), 0)
h = Heap([10, 20, 100])
eq_(h.pop(), 10)
h.push(30)
eq_(len(h), 3)
eq_(h.pop(), 20)
eq_(h.pop(), 30)
eq_(h.pop(), 100)
eq_(len(h), 0)
h = Heap([(1, 7), (2, 4), (10, -100)])
eq_(h.peek(), (1, 7))
h.pop()
eq_(h.pop(), (2, 4))
eq_(h.pop(), (10, -100))

View File

@@ -0,0 +1,41 @@
from nose.tools import eq_
from .. import ordered
def test_circle():
circle = ordered.Circle(3)
eq_(0, len(circle))
print(circle.state())
eq_(None, circle.append(5))
eq_(1, len(circle))
print(circle.state())
eq_(None, circle.append(6))
eq_(2, len(circle))
print(circle.state())
eq_(None, circle.append(7))
eq_(3, len(circle))
print(circle.state())
eq_(5, circle.append(8))
eq_(3, len(circle))
print(circle.state())
eq_([6, 7, 8], list(circle))
print(circle.state())
eq_([8, 7, 6], list(reversed(circle)))
def test_historical_map():
hist = ordered.HistoricalMap(maxlen=2)
assert "foo" not in hist
eq_(None, hist.insert('foo', "bar1"))
assert "foo" in hist
eq_(None, hist.insert('foo', "bar2"))
eq_(('foo', "bar1"), hist.insert('not_foo', "not_bar"))

View File

@@ -0,0 +1,69 @@
"""
This is a failed attempt. See
https://github.com/halfak/Mediawiki-Utilities/issues/13 for more details.
"""
'''
import os
import py7zlib
class SevenZFileError(py7zlib.ArchiveError):
pass
class SevenZFile(object):
@classmethod
def is_7zfile(cls, filepath):
""" Determine if filepath points to a valid 7z archive. """
is7z = False
fp = None
try:
fp = open(filepath, 'rb')
archive = py7zlib.Archive7z(fp)
n = len(archive.getnames())
is7z = True
finally:
if fp: fp.close()
return is7z
def __init__(self, filepath):
fp = open(filepath, 'rb')
self.filepath = filepath
self.archive = py7zlib.Archive7z(fp)
def __contains__(self, name):
return name in self.archive.getnames()
def bytestream(self, name):
""" Iterate stream of bytes from an archive member. """
if name not in self:
raise SevenZFileError('member %s not found in %s' %
(name, self.filepath))
else:
member = self.archive.getmember(name)
for byte in member.read():
if not byte: break
yield byte
def readlines(self, name):
""" Iterate lines from an archive member. """
linesep = os.linesep[-1]
line = ''
for ch in self.bytestream(name):
line += ch
if ch == linesep:
yield line
line = ''
if line: yield line
import os
import py7zlib
with open("/mnt/data/xmldatadumps/public/simplewiki/20141122/simplewiki-20141122-pages-meta-history.xml.7z", "rb") as f:
a = py7zlib.Archive7z(f)
print(a.getmember(a.getnames()[0]).read())
'''

View File

@@ -0,0 +1,27 @@
"""
This module is a collection of utilities for efficiently processing MediaWiki's
XML database dumps. There are two important concerns that this module intends
to address: *performance* and the *complexity* of streaming XML parsing.
Performance
Performance is a serious concern when processing large database XML dumps.
Regretfully, the Global Intepreter Lock prevents us from running threads on
multiple CPUs. This library provides a :func:`map`, a function
that maps a dump processing over a set of dump files using
:class:`multiprocessing` to distribute the work over multiple CPUS
Complexity
Streaming XML parsing is gross. XML dumps are (1) some site meta data, (2)
a collection of pages that contain (3) collections of revisions. The
module allows you to think about dump files in this way and ignore the
fact that you're streaming XML. An :class:`Iterator` contains
site meta data and an iterator of :class:`Page`'s. A
:class:`Page` contains page meta data and an iterator of
:class:`Revision`'s. A :class:`Revision` contains revision meta data
including a :class:`Contributor` (if one a contributor was specified in the
XML).
"""
from .map import map
from .iteration import *
from .functions import file, open_file

View File

@@ -0,0 +1,104 @@
try:
import xml.etree.cElementTree as etree
except ImportError:
import xml.etree.ElementTree as etree
from xml.etree.ElementTree import ParseError
from .errors import MalformedXML
def trim_ns(tag):
return tag[tag.find("}") + 1:]
class EventPointer:
def __init__(self, etree_events):
self.tag_stack = []
self.etree_events = etree_events
def __next__(self):
event, element = next(self.etree_events)
tag = trim_ns(element.tag)
if event == "start":
self.tag_stack.append(tag)
else:
if self.tag_stack[-1] == tag:
self.tag_stack.pop()
else:
raise MalformedXML("Expected {0}, but saw {1}.".format(
self.tag_stack[-1],
tag)
)
return event, element
def depth(self):
return len(self.tag_stack)
@classmethod
def from_file(cls, f):
return EventPointer(etree.iterparse(f, events=("start", "end")))
class ElementIterator:
def __init__(self, element, pointer):
self.pointer = pointer
self.element = element
self.depth = pointer.depth() - 1
self.done = False
def __iter__(self):
while not self.done and self.pointer.depth() > self.depth:
event, element = next(self.pointer)
if event == "start":
sub_iterator = ElementIterator(element, self.pointer)
yield sub_iterator
sub_iterator.clear()
self.done = True
def complete(self):
while not self.done and self.pointer.depth() > self.depth:
event, element = next(self.pointer)
if self.pointer.depth() > self.depth:
element.clear()
self.done = True
def clear(self):
self.complete()
self.element.clear()
def attr(self, key, alt=None):
return self.element.attrib.get(key, alt)
def __getattr__(self, attr):
if attr == "tag":
return trim_ns(self.element.tag)
elif attr == "text":
self.complete()
return self.element.text
else:
raise AttributeError("%s has no attribute %r" % (self.__class__.__name__, attr))
@classmethod
def from_file(cls, f):
try:
pointer = EventPointer.from_file(f)
event, element = next(pointer)
return cls(element, pointer)
except ParseError as e:
raise ParseError(
"{0}: {1}...".format(str(e),
str(f.read(500), 'utf-8', 'replace')))

View File

@@ -0,0 +1,12 @@
class FileTypeError(Exception):
"""
Thrown when an XML dump file is not of an expected type.
"""
pass
class MalformedXML(Exception):
"""
Thrown when an XML dump file is not formatted as expected.
"""
pass

View File

@@ -0,0 +1,76 @@
import os
import re
import subprocess
from .errors import FileTypeError
EXTENSIONS = {
'xml': ["cat"],
'gz': ["zcat"],
'bz2': ["bzcat"],
'7z': ["7z", "e", "-so"],
'lzma': ["lzcat"]
}
"""
A map from file extension to the command to run to extract the data to standard out.
"""
EXT_RE = re.compile(r'\.([^\.]+)$')
"""
A regular expression for extracting the final extension of a file.
"""
def file(path_or_f):
"""
Verifies that a file exists at a given path and that the file has a
known extension type.
:Parameters:
path : `str`
the path to a dump file
"""
if hasattr(path_or_f, "readline"):
return path_or_f
else:
path = path_or_f
path = os.path.expanduser(path)
if not os.path.isfile(path):
raise FileTypeError("Can't find file %s" % path)
match = EXT_RE.search(path)
if match is None:
raise FileTypeError("No extension found for %s." % path)
elif match.groups()[0] not in EXTENSIONS:
raise FileTypeError("File type %r is not supported." % path)
else:
return path
def open_file(path_or_f):
"""
Turns a path to a dump file into a file-like object of (decompressed)
XML data.
:Parameters:
path : `str`
the path to the dump file to read
"""
if hasattr(path_or_f, "read"):
return path_or_f
else:
path = path_or_f
match = EXT_RE.search(path)
ext = match.groups()[0]
p = subprocess.Popen(
EXTENSIONS[ext] + [path],
stdout=subprocess.PIPE,
stderr=open(os.devnull, "w")
)
# sys.stderr.write("\n%s %s\n" % (EXTENSIONS[ext], path))
# sys.stderr.write(p.stdout.read(1000))
# return False
return p.stdout

View File

@@ -0,0 +1,10 @@
"""
*... iteration ...*
"""
from .iterator import Iterator
from .page import Page
from .redirect import Redirect
from .revision import Revision
from .comment import Comment
from .contributor import Contributor
from .text import Text

View File

@@ -0,0 +1,54 @@
from ...types import serializable
class Comment(str, serializable.Type):
"""
A revision comment. This class behaves identically to
:class:`str` except that it takes and stores an additional parameter
recording whether the comment was deleted or not.
>>> from mw.xml_dump import Comment
>>>
>>> c = Comment("foo")
>>> c == "foo"
True
>>> c.deleted
False
**deleted**
Was the comment deleted? | `bool`
"""
def __new__(cls, string_or_comment="", deleted=False):
if isinstance(string_or_comment, cls):
return string_or_comment
else:
inst = super().__new__(cls, string_or_comment)
inst.initialize(string_or_comment, deleted)
return inst
def initialize(self, string, deleted):
self.deleted = bool(deleted)
def __str__(self):
return str.__str__(self)
def __repr__(self):
return "{0}({1})".format(
self.__class__.__name__,
", ".join([
str.__repr__(self),
"deleted={0}".format(repr(self.deleted))
])
)
def serialize(self):
return {
"string_or_comment": str(self),
"deleted": self.deleted
}
@classmethod
def from_element(cls, e):
return cls(e.text, e.attr('deleted', False))

View File

@@ -0,0 +1,45 @@
from ...types import serializable
from ...util import none_or
from .util import consume_tags
class Contributor(serializable.Type):
"""
Contributor meta data.
"""
__slots__ = ('id', 'user_text')
TAG_MAP = {
'id': lambda e: int(e.text),
'username': lambda e: str(e.text),
'ip': lambda e: str(e.text)
}
def __init__(self, id, user_text):
self.id = none_or(id, int)
"""
User ID : int | `None` (if not specified in the XML)
User ID of a user if the contributor is signed into an account
in the while making the contribution and `None` when
contributors are not signed in.
"""
self.user_text = none_or(user_text, str)
"""
User name or IP address : str | `None` (if not specified in the XML)
If a user is logged in, this will reflect the users accout
name. If the user is not logged in, this will usually be
recorded as the IPv4 or IPv6 address in the XML.
"""
@classmethod
def from_element(cls, element):
values = consume_tags(cls.TAG_MAP, element)
return cls(
values.get('id'),
values.get('username', values.get('ip'))
)

View File

@@ -0,0 +1,221 @@
import io
from ...types import serializable
from ...util import none_or
from ..element_iterator import ElementIterator
from ..errors import MalformedXML
from .namespace import Namespace
from .page import Page
class ConcatinatingTextReader(io.TextIOBase):
def __init__(self, *items):
self.items = [io.StringIO(i) if isinstance(i, str) else i
for i in items]
def read(self, size=-1):
return "".join(self._read(size))
def readline(self):
if len(self.items) > 0:
line = self.items[0].readline()
if line == "": self.items.pop(0)
else:
line = ""
return line
def _read(self, size):
if size > 0:
while len(self.items) > 0:
byte_vals = self.items[0].read(size)
yield byte_vals
if len(byte_vals) < size:
size = size - len(byte_vals) # Decrement bytes
self.items.pop(0)
else:
break
else:
for item in self.items:
yield item.read()
def concat(*stream_items):
return ConcatinatingTextReader(*stream_items)
class Iterator(serializable.Type):
"""
XML Dump Iterator. Dump file meta data and a
:class:`~mw.xml_dump.Page` iterator. Instances of this class can be
called as an iterator directly. E.g.::
from mw.xml_dump import Iterator
# Construct dump file iterator
dump = Iterator.from_file(open("example/dump.xml"))
# Iterate through pages
for page in dump:
# Iterate through a page's revisions
for revision in page:
print(revision.id)
"""
__slots__ = ('site_name', 'base', 'generator', 'case', 'namespaces',
'__pages')
def __init__(self, site_name=None, dbname=None, base=None, generator=None,
case=None, namespaces=None, pages=None):
self.site_name = none_or(site_name, str)
"""
The name of the site. : str | `None` (if not specified in the XML)
"""
self.dbname = none_or(dbname, str)
"""
The database name of the site. : str | `None` (if not specified in the
XML)
"""
self.base = none_or(base, str)
"""
TODO: ??? : str | `None` (if not specified in the XML)
"""
self.generator = none_or(generator, str)
"""
TODO: ??? : str | `None` (if not specified in the XML)
"""
self.case = none_or(case, str)
"""
TODO: ??? : str | `None` (if not specified in the XML)
"""
self.namespaces = none_or(namespaces, list)
"""
A list of :class:`mw.Namespace` | `None` (if not specified in the XML)
"""
# Should be a lazy generator of page info
self.__pages = pages
def __iter__(self):
return self.__pages
def __next__(self):
return next(self.__pages)
@classmethod
def load_namespaces(cls, element):
namespaces = []
for sub_element in element:
tag = sub_element.tag
if tag == "namespace":
namespace = Namespace.from_element(sub_element)
namespaces.append(namespace)
else:
assert False, "This should never happen"
return namespaces
@classmethod
def load_site_info(cls, element):
site_name = None
dbname = None
base = None
generator = None
case = None
namespaces = {}
for sub_element in element:
if sub_element.tag == 'sitename':
site_name = sub_element.text
if sub_element.tag == 'dbname':
dbname = sub_element.text
elif sub_element.tag == 'base':
base = sub_element.text
elif sub_element.tag == 'generator':
generator = sub_element.text
elif sub_element.tag == 'case':
case = sub_element.text
elif sub_element.tag == 'namespaces':
namespaces = cls.load_namespaces(sub_element)
return site_name, dbname, base, generator, case, namespaces
@classmethod
def load_pages(cls, element):
for sub_element in element:
tag = sub_element.tag
if tag == "page":
yield Page.from_element(sub_element)
else:
assert MalformedXML("Expected to see 'page'. " +
"Instead saw '{0}'".format(tag))
@classmethod
def from_element(cls, element):
site_name = None
base = None
generator = None
case = None
namespaces = None
# Consume <siteinfo>
for sub_element in element:
tag = sub_element.tag
if tag == "siteinfo":
site_name, dbname, base, generator, case, namespaces = \
cls.load_site_info(sub_element)
break
# Consume all <page>
pages = cls.load_pages(element)
return cls(site_name, dbname, base, generator, case, namespaces, pages)
@classmethod
def from_file(cls, f):
element = ElementIterator.from_file(f)
assert element.tag == "mediawiki"
return cls.from_element(element)
@classmethod
def from_string(cls, string):
f = io.StringIO(string)
element = ElementIterator.from_file(f)
assert element.tag == "mediawiki"
return cls.from_element(element)
@classmethod
def from_page_xml(cls, page_xml):
header = """
<mediawiki xmlns="http://www.mediawiki.org/xml/export-0.5/"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://www.mediawiki.org/xml/export-0.5/
http://www.mediawiki.org/xml/export-0.5.xsd" version="0.5"
xml:lang="en">
<siteinfo>
<namespaces>
</namespaces>
</siteinfo>
"""
footer = "</mediawiki>"
return cls.from_file(concat(header, page_xml, footer))

View File

@@ -0,0 +1,11 @@
from ... import types
class Namespace(types.Namespace):
@classmethod
def from_element(cls, element):
return cls(
element.attr('key'),
element.text,
case=element.attr('case')
)

View File

@@ -0,0 +1,118 @@
from ...types import serializable
from ...util import none_or
from ..errors import MalformedXML
from .redirect import Redirect
from .revision import Revision
class Page(serializable.Type):
"""
Page meta data and a :class:`~mw.xml_dump.Revision` iterator. Instances of
this class can be called as iterators directly. E.g.
.. code-block:: python
page = mw.xml_dump.Page( ... )
for revision in page:
print("{0} {1}".format(revision.id, page_id))
"""
__slots__ = (
'id',
'title',
'namespace',
'redirect',
'restrictions'
)
def __init__(self, id, title, namespace, redirect, restrictions, revisions=None):
self.id = none_or(id, int)
"""
Page ID : `int`
"""
self.title = none_or(title, str)
"""
Page title (namespace excluded) : `str`
"""
self.namespace = none_or(namespace, int)
"""
Namespace ID : `int`
"""
self.redirect = none_or(redirect, Redirect)
"""
Page is currently redirect? : :class:`~mw.xml_dump.Redirect` | `None`
"""
self.restrictions = serializable.List.deserialize(restrictions)
"""
A list of page editing restrictions (empty unless restrictions are specified) : list( `str` )
"""
# Should be a lazy generator
self.__revisions = revisions or []
def __iter__(self):
return self.__revisions
def __next__(self):
return next(self.__revisions)
@classmethod
def load_revisions(cls, first_revision, element):
yield Revision.from_element(first_revision)
for sub_element in element:
tag = sub_element.tag
if tag == "revision":
yield Revision.from_element(sub_element)
else:
raise MalformedXML("Expected to see 'revision'. " +
"Instead saw '{0}'".format(tag))
@classmethod
def from_element(cls, element):
title = None
namespace = None
id = None
redirect = None
restrictions = []
first_revision = None
# Consume each of the elements until we see <id> which should come last.
for sub_element in element:
tag = sub_element.tag
if tag == "title":
title = sub_element.text
elif tag == "ns":
namespace = sub_element.text
elif tag == "id":
id = int(sub_element.text)
elif tag == "redirect":
redirect = Redirect.from_element(sub_element)
elif tag == "restrictions":
restrictions.append(sub_element.text)
elif tag == "DiscussionThreading":
continue
elif tag == "sha1":
continue
elif tag == "revision":
first_revision = sub_element
break
# Assuming that the first revision seen marks the end of page
# metadata. I'm not too keen on this assumption, so I'm leaving
# this long comment to warn whoever ends up maintaining this.
else:
raise MalformedXML("Unexpected tag found when processing " +
"a <page>: '{0}'".format(tag))
# Assuming that I got here by seeing a <revision> tag. See verbose
# comment above.
revisions = cls.load_revisions(first_revision, element)
return cls(id, title, namespace, redirect, restrictions, revisions)

View File

@@ -0,0 +1,29 @@
from ...types import serializable
from ...util import none_or
class Redirect(serializable.Type):
"""
Represents a redirect tag.
**title**
Full page name that this page is redirected to : `str`
"""
def __new__(cls, redirect_or_title):
if isinstance(redirect_or_title, cls):
return redirect_or_title
else:
inst = super().__new__(cls)
inst.initialize(redirect_or_title)
return inst
def __init__(self, *args, **kwargs):
pass
def initialize(self, title):
self.title = none_or(title, str)
@classmethod
def from_element(cls, e):
return cls(e.attr('title'))

View File

@@ -0,0 +1,116 @@
from ...types import serializable, Timestamp
from ...util import none_or
from .comment import Comment
from .contributor import Contributor
from .text import Text
from .util import consume_tags
class Revision(serializable.Type):
"""
Revision meta data.
"""
__slots__ = ('id', 'timestamp', 'contributor', 'minor', 'comment', 'text',
'bytes', 'sha1', 'parent_id', 'model', 'format',
'beginningofpage')
TAG_MAP = {
'id': lambda e: int(e.text),
'timestamp': lambda e: Timestamp(e.text),
'contributor': lambda e: Contributor.from_element(e),
'minor': lambda e: True,
'comment': lambda e: Comment.from_element(e),
'text': lambda e: Text.from_element(e),
'sha1': lambda e: str(e.text),
'parentid': lambda e: int(e.text),
'model': lambda e: str(e.text),
'format': lambda e: str(e.text)
}
def __init__(self, id, timestamp, contributor=None, minor=None,
comment=None, text=None, bytes=None, sha1=None,
parent_id=None, model=None, format=None,
beginningofpage=False):
self.id = none_or(id, int)
"""
Revision ID : `int`
"""
self.timestamp = none_or(timestamp, Timestamp)
"""
Revision timestamp : :class:`mw.Timestamp`
"""
self.contributor = none_or(contributor, Contributor.deserialize)
"""
Contributor meta data : :class:`~mw.xml_dump.Contributor` | `None`
"""
self.minor = False or none_or(minor, bool)
"""
Is revision a minor change? : `bool`
"""
self.comment = none_or(comment, Comment)
"""
Comment left with revision : :class:`~mw.xml_dump.Comment` (behaves like `str`, with additional members)
"""
self.text = none_or(text, Text)
"""
Content of text : :class:`~mw.xml_dump.Text` (behaves like `str`, with additional members)
"""
self.bytes = none_or(bytes, int)
"""
Number of bytes of content : `str`
"""
self.sha1 = none_or(sha1, str)
"""
sha1 hash of the content : `str`
"""
self.parent_id = none_or(parent_id, int)
"""
Revision ID of preceding revision : `int` | `None`
"""
self.model = none_or(model, str)
"""
TODO: ??? : `str`
"""
self.format = none_or(format, str)
"""
TODO: ??? : `str`
"""
self.beginningofpage = bool(beginningofpage)
"""
Is the first revision of a page : `bool`
Used to identify the first revision of a page when using Wikihadoop
revision pairs. Otherwise is always set to False. Do not expect to use
this when processing an XML dump directly.
"""
@classmethod
def from_element(cls, element):
values = consume_tags(cls.TAG_MAP, element)
return cls(
values.get('id'),
values.get('timestamp'),
values.get('contributor'),
values.get('minor') is not None,
values.get('comment'),
values.get('text'),
values.get('bytes'),
values.get('sha1'),
values.get('parentid'),
values.get('model'),
values.get('format'),
element.attr('beginningofpage') is not None
# For Wikihadoop.
# Probably never used by anything, ever.
)

Some files were not shown because too many files have changed in this diff Show More