""" Phase 1: Repository Sampling Service. This service is responsible for searching GitHub for repositories based on a query, filtering out previously seen repositories (if requested), and saving the results. """ from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Any from github_datapipe.core.config import GithubConfig from github_datapipe.core.github_api import GithubApiClient from github_datapipe.core.io import read_json, write_json, write_jsonl from github_datapipe.core.runtime import build_run_id, require_github_token, utc_now @dataclass(frozen=True) class SampleReposOptions: """Options for the repository sampling phase.""" count: int = GithubConfig.default_repo_count output_root: Path = Path(GithubConfig.default_output_root) query: str = GithubConfig.default_query per_page: int = GithubConfig.default_per_page mode: str = "append-deduped" run_id: str | None = None def resolve_query(query_override: str | None) -> str: """Resolve the search query, using the override if provided, else the default.""" return query_override.strip() if query_override else GithubConfig.default_query def sample_repositories(options: SampleReposOptions) -> dict[str, Any]: """ Execute Phase 1: Sample repositories from GitHub. Searches for repositories, handles pagination and deduplication, and saves the collected repository data and a run manifest. Args: options: Configuration options for sampling. Returns: A dictionary containing the run ID and paths to generated files. """ token = require_github_token() if options.count <= 0: raise ValueError("`count` must be greater than 0.") run_id = options.run_id or build_run_id() output_root = options.output_root.resolve() run_root = output_root / run_id phase_root = run_root / "phase1_repository_sampling" phase_root.mkdir(parents=True, exist_ok=True) seen_index_path = output_root / "seen_repo_ids.json" repos_path = phase_root / "repos.jsonl" manifest_path = phase_root / "manifest.json" seen_repo_ids = set() if options.mode == "fresh" else load_seen_repo_ids(seen_index_path) current_run_seen_repo_ids: set[int] = set() client = GithubApiClient(token=token) accepted_repos: list[dict[str, Any]] = [] page = 1 total_count: int | None = None while len(accepted_repos) < options.count: payload = client.search_repositories( query=options.query, page=page, per_page=options.per_page, ) items = payload.get("items", []) total_count = payload.get("total_count", total_count) if not items: break sampled_at = utc_now() for repo in items: repo_id = int(repo["id"]) if repo_id in seen_repo_ids or repo_id in current_run_seen_repo_ids: continue accepted_repos.append( normalize_repo_record( repo=repo, run_id=run_id, sample_query=options.query, sample_page=page, sampled_at=sampled_at, ) ) current_run_seen_repo_ids.add(repo_id) if len(accepted_repos) >= options.count: break page += 1 write_jsonl(repos_path, accepted_repos) if options.mode != "fresh": seen_repo_ids.update(current_run_seen_repo_ids) write_json(seen_index_path, sorted(seen_repo_ids)) manifest = { "run_id": run_id, "phase_name": "phase1_repository_sampling", "command": "sample-repos", "resolved_query": options.query, "count_requested": options.count, "count_collected": len(accepted_repos), "per_page": options.per_page, "mode": options.mode, "page_count_scanned": max(page - 1, 0), "github_total_count": total_count, "output_files": { "repos_jsonl": str(repos_path), "seen_repo_ids_json": str(seen_index_path) if options.mode != "fresh" else None, }, "sampled_at": utc_now(), } write_json(manifest_path, manifest) return { "run_id": run_id, "count_collected": len(accepted_repos), "repos_path": repos_path, "manifest_path": manifest_path, "seen_index_path": seen_index_path if options.mode != "fresh" else None, } def normalize_repo_record( repo: dict[str, Any], run_id: str, sample_query: str, sample_page: int, sampled_at: str, ) -> dict[str, Any]: """Normalize a raw GitHub repository payload into our internal schema.""" return { "run_id": run_id, "repo_id": repo["id"], "full_name": repo["full_name"], "html_url": repo["html_url"], "api_url": repo["url"], "default_branch": repo.get("default_branch"), "language": repo.get("language"), "description": repo.get("description"), "stargazers_count": repo.get("stargazers_count"), "size_kb": repo.get("size"), "fork": repo.get("fork"), "archived": repo.get("archived"), "visibility": repo.get("visibility"), "sample_query": sample_query, "sample_page": sample_page, "sampled_at": sampled_at, } def load_seen_repo_ids(path: Path) -> set[int]: """Load the set of repository IDs that have already been sampled.""" if not path.exists(): return set() payload = read_json(path) return {int(repo_id) for repo_id in payload}