Add phase 2 - download default branch commits for repositories collected in phase 1
This commit is contained in:
148
src/github_datapipe/phases/phase1_repository_sampling/service.py
Normal file
148
src/github_datapipe/phases/phase1_repository_sampling/service.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from github_datapipe.core.config import GithubConfig
|
||||
from github_datapipe.core.github_api import GithubApiClient
|
||||
from github_datapipe.core.io import read_json, write_json, write_jsonl
|
||||
from github_datapipe.core.runtime import build_run_id, require_github_token, utc_now
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SampleReposOptions:
|
||||
count: int = GithubConfig.default_repo_count
|
||||
output_root: Path = Path(GithubConfig.default_output_root)
|
||||
query: str = GithubConfig.default_query
|
||||
per_page: int = GithubConfig.default_per_page
|
||||
mode: str = "append-deduped"
|
||||
run_id: str | None = None
|
||||
|
||||
|
||||
def resolve_query(query_override: str | None) -> str:
|
||||
return query_override.strip() if query_override else GithubConfig.default_query
|
||||
|
||||
|
||||
def sample_repositories(options: SampleReposOptions) -> dict[str, Any]:
|
||||
token = require_github_token()
|
||||
if options.count <= 0:
|
||||
raise ValueError("`count` must be greater than 0.")
|
||||
|
||||
run_id = options.run_id or build_run_id()
|
||||
output_root = options.output_root.resolve()
|
||||
run_root = output_root / run_id
|
||||
phase_root = run_root / "phase1_repository_sampling"
|
||||
phase_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
seen_index_path = output_root / "seen_repo_ids.json"
|
||||
repos_path = phase_root / "repos.jsonl"
|
||||
manifest_path = phase_root / "manifest.json"
|
||||
|
||||
seen_repo_ids = set() if options.mode == "fresh" else load_seen_repo_ids(seen_index_path)
|
||||
current_run_seen_repo_ids: set[int] = set()
|
||||
|
||||
client = GithubApiClient(token=token)
|
||||
accepted_repos: list[dict[str, Any]] = []
|
||||
page = 1
|
||||
total_count: int | None = None
|
||||
|
||||
while len(accepted_repos) < options.count:
|
||||
payload = client.search_repositories(
|
||||
query=options.query,
|
||||
page=page,
|
||||
per_page=options.per_page,
|
||||
)
|
||||
items = payload.get("items", [])
|
||||
total_count = payload.get("total_count", total_count)
|
||||
|
||||
if not items:
|
||||
break
|
||||
|
||||
sampled_at = utc_now()
|
||||
for repo in items:
|
||||
repo_id = int(repo["id"])
|
||||
if repo_id in seen_repo_ids or repo_id in current_run_seen_repo_ids:
|
||||
continue
|
||||
|
||||
accepted_repos.append(
|
||||
normalize_repo_record(
|
||||
repo=repo,
|
||||
run_id=run_id,
|
||||
sample_query=options.query,
|
||||
sample_page=page,
|
||||
sampled_at=sampled_at,
|
||||
)
|
||||
)
|
||||
current_run_seen_repo_ids.add(repo_id)
|
||||
|
||||
if len(accepted_repos) >= options.count:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
write_jsonl(repos_path, accepted_repos)
|
||||
if options.mode != "fresh":
|
||||
seen_repo_ids.update(current_run_seen_repo_ids)
|
||||
write_json(seen_index_path, sorted(seen_repo_ids))
|
||||
|
||||
manifest = {
|
||||
"run_id": run_id,
|
||||
"phase_name": "phase1_repository_sampling",
|
||||
"command": "sample-repos",
|
||||
"resolved_query": options.query,
|
||||
"count_requested": options.count,
|
||||
"count_collected": len(accepted_repos),
|
||||
"per_page": options.per_page,
|
||||
"mode": options.mode,
|
||||
"page_count_scanned": max(page - 1, 0),
|
||||
"github_total_count": total_count,
|
||||
"output_files": {
|
||||
"repos_jsonl": str(repos_path),
|
||||
"seen_repo_ids_json": str(seen_index_path) if options.mode != "fresh" else None,
|
||||
},
|
||||
"sampled_at": utc_now(),
|
||||
}
|
||||
write_json(manifest_path, manifest)
|
||||
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"count_collected": len(accepted_repos),
|
||||
"repos_path": repos_path,
|
||||
"manifest_path": manifest_path,
|
||||
"seen_index_path": seen_index_path if options.mode != "fresh" else None,
|
||||
}
|
||||
|
||||
|
||||
def normalize_repo_record(
|
||||
repo: dict[str, Any],
|
||||
run_id: str,
|
||||
sample_query: str,
|
||||
sample_page: int,
|
||||
sampled_at: str,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"repo_id": repo["id"],
|
||||
"full_name": repo["full_name"],
|
||||
"html_url": repo["html_url"],
|
||||
"api_url": repo["url"],
|
||||
"default_branch": repo.get("default_branch"),
|
||||
"language": repo.get("language"),
|
||||
"description": repo.get("description"),
|
||||
"stargazers_count": repo.get("stargazers_count"),
|
||||
"size_kb": repo.get("size"),
|
||||
"fork": repo.get("fork"),
|
||||
"archived": repo.get("archived"),
|
||||
"visibility": repo.get("visibility"),
|
||||
"sample_query": sample_query,
|
||||
"sample_page": sample_page,
|
||||
"sampled_at": sampled_at,
|
||||
}
|
||||
|
||||
|
||||
def load_seen_repo_ids(path: Path) -> set[int]:
|
||||
if not path.exists():
|
||||
return set()
|
||||
payload = read_json(path)
|
||||
return {int(repo_id) for repo_id in payload}
|
||||
Reference in New Issue
Block a user