#!/usr/bin/env python3
"""
prime_labeling.py  —  Workflow 00: two SEPARATE, recomputable properties per prime.

This script owns the *distinctiveness / neighborhood-density* property end to end
(it is fully computable) and the *ingestion + validation + versioning* of the
*structural<->framed* grade (the grade itself is an LLM judgment produced out of
band; see `grade-template` / `ingest-grades`).

The two properties are deliberately kept in SEPARATE derived artifacts because
they recompute on different cadences:
  - distinctiveness  : RELATIVE, drifts as the corpus grows -> recompute often.
  - structural/framed: STABLE, re-grade only when criteria change or a prime is
                       edited.

Nothing is denormalized into dist/encyclopedia.primes.jsonl. The derived index is
the source of truth and joins back to the corpus by `slug`. This keeps batch
recompute clean (never rewrites the corpus) and keeps drift inspectable.

----------------------------------------------------------------------------------
LAYOUT (all under dist/derived/)
  embeddings/
    primes_sig.<date>.<snap>.npy        frozen signature-embedding matrix
    primes_sig.<date>.<snap>.meta.json  [{slug,name,text_source}] aligned to rows
    primes_sig.latest.npy / .meta.json  copy of the most recent batch
    primes_sig.overlay.npy / .meta.json incremental additions since last batch (optional)
  distinctiveness.<date>.<snap>.jsonl   one line per prime, the computed property
  distinctiveness.latest.jsonl          copy of most recent + any incremental appends
  structural_framed.<date>.jsonl        one line per graded prime (the S/F grade)
  structural_framed.latest.jsonl
  manifest.json                         current versions, params, snapshot hashes,
                                        counts, and drift counters

Every artifact is stamped with:  computed_at (date)  +  corpus_snapshot (sha256/12
of dist/encyclopedia.primes.jsonl)  +  embedder id  +  params (k, threshold).

----------------------------------------------------------------------------------
USAGE
  # full rebuild from current dist/ (embeddings + all neighborhoods + clusters)
  python scripts/prime_labeling.py batch [--k 10] [--threshold 0.82] [--max-tokens 200]

  # score ONE newly-added prime against the frozen index (does not recompute others)
  python scripts/prime_labeling.py incremental --slug <new_slug>

  # tune k / threshold against the known crowded clusters
  python scripts/prime_labeling.py calibrate [--k 10]

  # print neighborhood diagnostics for given slugs
  python scripts/prime_labeling.py inspect symmetry recursion sovereignty ...

  # emit the per-prime grading prompt(s) for the structural/framed judgment
  python scripts/prime_labeling.py grade-template --slug sovereignty
  python scripts/prime_labeling.py grade-template --sample 18    # calibration sample

  # ingest + validate LLM-produced grades from a JSONL file into the derived index
  python scripts/prime_labeling.py ingest-grades --in grades_in.jsonl

  # validate the whole derived index against the corpus
  python scripts/prime_labeling.py verify
"""
import os, sys, json, time, argparse, hashlib, datetime, glob
import numpy as np

# ----------------------------------------------------------------------------------
# paths
def _find_repo_root(start):
    d = os.path.abspath(start)
    while True:
        if os.path.isdir(os.path.join(d, "_models")) and os.path.isdir(os.path.join(d, "dist")):
            return d
        parent = os.path.dirname(d)
        if parent == d:
            return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        d = parent

REPO   = _find_repo_root(os.path.dirname(os.path.abspath(__file__)))
MODEL  = os.path.join(REPO, "_models", "bge-small-en-v1.5-onnx-q")
DIST   = os.path.join(REPO, "dist")
PRIMES = os.path.join(DIST, "encyclopedia.primes.jsonl")
DERIVED = os.path.join(DIST, "derived")
EMB_DIR = os.path.join(DERIVED, "embeddings")
EMBEDDER_ID = "bge-small-en-v1.5-onnx-q"
QINSTR = "Represent this sentence for searching relevant passages: "

# The five diagnostic criteria from conceptual/structural-and-framed-primes.md sec 3.2.
# Each scored 0.0 = structural pole, 0.5 = mixed, 1.0 = framed pole.
CRITERIA = [
    ("vocab_travels",
     "Does the prime's home vocabulary travel with it into a new domain? "
     "(framed: imports a recognizable home-discipline lexicon; structural: just names a pattern)"),
    ("evaluative_weight",
     "Does the prime carry evaluative/normative weight by default? "
     "(framed: yes, e.g. justice/dignity/due-process; structural: neutral at definition)"),
    ("institutional_origin",
     "Does the prime have an institutional or normative referent at its origin? "
     "(framed: arose in a discipline working with human institutions/norms; structural: formal origin)"),
    ("human_practice_bound",
     "Can the prime be defined WITHOUT reference to human practices? "
     "(structural: yes; framed: no -- presupposes agents/agreements/institutions). "
     "NOTE: scored so 1.0 = framed = cannot be defined without human practices."),
    ("import_vs_recognize",
     "Does applying it feel like recognizing a pattern already there (structural) "
     "or importing a perspective that recasts the domain (framed)?"),
]
LABEL_BINS = [  # (upper_bound_inclusive, label); aggregate in [0,1]
    (0.20, "structural"),
    (0.45, "mixed-structural"),
    (0.70, "mixed-framed"),
    (1.01, "framed"),
]

# Known crowded clusters used to calibrate the neighborhood threshold. These SHOULD
# come out low-distinctiveness. authority/sovereignty is a crowded FRAMED pair: a
# test that distinctiveness is NOT the same axis as structural/framed.
CALIBRATION_CLUSTERS = {
    "recursion/iteration":            ["recursion", "iteration"],
    "symmetry/invariance/conserv":    ["symmetry", "invariance", "conservation_laws"],
    "equilibrium/homeostasis":        ["equilibrium", "homeostasis", "thermodynamic_equilibrium"],
    "authority/sovereignty(framed)":  ["authority", "sovereignty"],
}

# ----------------------------------------------------------------------------------
# embedder (CLS-pool + L2-normalize; query instruction only on queries)
_sess = _tok = None
def _load_model(max_tokens):
    global _sess, _tok
    if _sess is not None:
        return
    import onnxruntime as ort
    from tokenizers import Tokenizer
    so = ort.SessionOptions(); so.intra_op_num_threads = os.cpu_count() or 4
    _sess = ort.InferenceSession(os.path.join(MODEL, "model_optimized.onnx"),
                                 sess_options=so, providers=["CPUExecutionProvider"])
    _tok = Tokenizer.from_file(os.path.join(MODEL, "tokenizer.json"))
    _tok.enable_truncation(max_length=max_tokens)

def embed(texts, is_query=False, max_tokens=200, B=64):
    _load_model(max_tokens)
    out = []
    for i in range(0, len(texts), B):
        chunk = [(QINSTR + t if is_query else t) if t else "[unknown]" for t in texts[i:i+B]]
        encs = _tok.encode_batch(chunk); ml = max(len(e.ids) for e in encs)
        ids  = np.zeros((len(encs), ml), np.int64)
        mask = np.zeros((len(encs), ml), np.int64)
        for j, e in enumerate(encs):
            ids[j, :len(e.ids)] = e.ids; mask[j, :len(e.ids)] = e.attention_mask
        hs = _sess.run(None, {"input_ids": ids, "attention_mask": mask,
                              "token_type_ids": np.zeros_like(ids)})[0]
        cls = hs[:, 0, :]; cls = cls / np.linalg.norm(cls, axis=1, keepdims=True)
        out.append(cls.astype(np.float32))
    return np.vstack(out)

# ----------------------------------------------------------------------------------
# io helpers
def _iter_jsonl(path):
    with open(path) as f:
        for line in f:
            line = line.strip()
            if line:
                yield json.loads(line)

def load_primes():
    """Return aligned lists (slug, name, embed_text, text_source)."""
    slugs, names, texts, srcs = [], [], [], []
    for o in _iter_jsonl(PRIMES):
        sig  = (o.get("structural_signature") or "").strip()
        core = (o.get("core_idea") or "").strip()
        if sig:
            t, s = sig, "structural_signature"
        elif core:
            t, s = core, "core_idea(fallback)"
        else:
            t, s = o.get("slug", ""), "slug(fallback)"
        slugs.append(o.get("slug", "")); names.append(o.get("name", ""))
        texts.append(t); srcs.append(s)
    return slugs, names, texts, srcs

def corpus_snapshot():
    h = hashlib.sha256()
    with open(PRIMES, "rb") as f:
        for chunk in iter(lambda: f.read(1 << 20), b""):
            h.update(chunk)
    return h.hexdigest()[:12]

def today():
    return datetime.date.today().isoformat()

def _ensure_dirs():
    os.makedirs(EMB_DIR, exist_ok=True)

def _read_manifest():
    p = os.path.join(DERIVED, "manifest.json")
    if os.path.exists(p):
        return json.load(open(p))
    return {}

def _write_manifest(m):
    json.dump(m, open(os.path.join(DERIVED, "manifest.json"), "w"), indent=2)

# ----------------------------------------------------------------------------------
# neighborhood / distinctiveness math
def _knn_stats(S, k, threshold):
    """S: (n,n) cosine matrix with diagonal already set to -inf.
    Returns per-row dict of metrics and the order matrix (top neighbors)."""
    n = S.shape[0]
    order = np.argsort(-S, axis=1)            # neighbors by descending similarity
    rows = []
    for i in range(n):
        nbr = order[i]
        topk = nbr[:k]
        topk_sims = S[i, topk]
        above = int((S[i] >= threshold).sum())  # diagonal is -inf so self excluded
        rows.append({
            "mean_cos_knn": round(float(topk_sims.mean()), 4),
            "max_cos": round(float(S[i, nbr[0]]), 4),
            "n_neighbors_above_threshold": above,
            # nearest few for human inspection
            "_nbr_idx": topk.tolist(),
            "_nbr_sims": [round(float(x), 4) for x in topk_sims],
        })
    return rows, order

def _cluster_kmeans(E, n_clusters, seed=0):
    """K-means over L2-normalized embeddings (~spherical/cosine k-means).

    WHY NOT threshold-graph connected components: in bge-small signature space the
    corpus has one giant dense core -- ~250 primes are mutual near-neighbors above
    0.80 -- so single-link OR mutual-kNN connected components chain the whole core
    into one blob (verified during calibration). K-means partitions the dense core
    into coherent neighborhoods and yields reusable centroids, which also gives a
    clean incremental story: a new prime is assigned to its nearest saved centroid.

    Returns (labels[n], sizes{cid:count}, centroids[K,d])."""
    from sklearn.cluster import KMeans
    km = KMeans(n_clusters=n_clusters, random_state=seed, n_init=10).fit(E)
    lab = km.labels_.tolist()
    sizes = {}
    for c in lab:
        sizes[c] = sizes.get(c, 0) + 1
    return lab, sizes, km.cluster_centers_.astype(np.float32)

def _percentile_ranks(values):
    """Map each value to its percentile rank in [0,1] (0=lowest value, 1=highest)."""
    v = np.asarray(values, float)
    order = v.argsort().argsort().astype(float)
    return order / max(1, len(v) - 1)

def _clusters_path():
    return os.path.join(DERIVED, "clusters.latest.json")

CLUSTER_CARRY_JACCARD = 0.6  # carry a family's name+description to a recomputed
                             # cluster whose membership overlaps a prior one by >= this

def _write_clusters(cid, sizes, slugs, date, snap):
    """Write cluster_id -> {members, size, name, description}.

    Family NAMES + DESCRIPTIONS are LLM-generated out of band (see `name-clusters`
    / `ingest-cluster-names`). k-means ids/membership shift on every recompute, so
    to avoid regenerating all of them each time we CARRY a prior family's name AND
    description onto the recomputed cluster it overlaps most, when the Jaccard
    overlap of their member sets is >= CLUSTER_CARRY_JACCARD. Matching is greedy and
    one-to-one (each prior family is inherited by at most one new cluster), so a
    split/merge surfaces as an un-named cluster that genuinely needs a fresh label.

    Returns (n_clusters, n_carried, n_need_label) where n_need_label counts clusters
    still missing a name or description after carry-over."""
    members = {}
    for i, c in enumerate(cid):
        members.setdefault(c, []).append(slugs[i])

    prior = []  # [{members:set, name, description}]
    if os.path.exists(_clusters_path()):
        for _, rec in json.load(open(_clusters_path())).get("clusters", {}).items():
            if rec.get("name") or rec.get("description"):
                prior.append({"members": set(rec.get("members", [])),
                              "name": rec.get("name"), "description": rec.get("description")})

    # score all (new cluster, prior family) overlaps, then greedily assign best-first
    new_ids = list(members)
    new_sets = {c: set(members[c]) for c in new_ids}
    pairs = []
    for c in new_ids:
        ms = new_sets[c]
        for pj, pr in enumerate(prior):
            union = ms | pr["members"]
            jac = len(ms & pr["members"]) / len(union) if union else 0.0
            if jac >= CLUSTER_CARRY_JACCARD:
                pairs.append((jac, c, pj))
    pairs.sort(reverse=True)
    assign = {}; used_prior = set()
    for jac, c, pj in pairs:
        if c in assign or pj in used_prior:
            continue
        assign[c] = pj; used_prior.add(pj)

    clusters = {}
    for c in new_ids:
        rec = {"size": sizes[c], "members": sorted(members[c]), "name": None, "description": None}
        if c in assign:
            pr = prior[assign[c]]
            rec["name"] = pr["name"]; rec["description"] = pr["description"]
        clusters[str(c)] = rec

    payload = {"computed_at": date, "corpus_snapshot": snap,
               "n_clusters": len(clusters), "clusters": clusters}
    json.dump(payload, open(os.path.join(DERIVED, f"clusters.{date}.{snap}.json"), "w"), indent=1)
    json.dump(payload, open(_clusters_path(), "w"), indent=1)
    n_carried = len(assign)
    n_need = sum(1 for r in clusters.values() if not (r["name"] and r["description"]))
    return len(clusters), n_carried, n_need

# ----------------------------------------------------------------------------------
# BATCH
def cmd_batch(args):
    _ensure_dirs()
    snap = corpus_snapshot(); date = today()
    slugs, names, texts, srcs = load_primes()
    n = len(slugs)
    print(f"[batch] {n} primes  snapshot={snap}  k={args.k}  threshold={args.threshold}")
    t0 = time.time()
    E = embed(texts, is_query=False, max_tokens=args.max_tokens)
    print(f"[batch] embedded {n} signatures in {time.time()-t0:.1f}s  dim={E.shape[1]}")

    # cosine matrix (E is L2-normalized so dot == cosine); mask self
    S = E @ E.T
    np.fill_diagonal(S, -np.inf)

    rows, order = _knn_stats(S, args.k, args.threshold)
    K = args.clusters or max(2, n // 10)
    cid, sizes, centroids = _cluster_kmeans(E, K)

    # distinctiveness percentile: rank by 1 - mean_cos_knn so 0 = most crowded
    # (densest neighborhood), 1 = most distinctive (sparsest). Percentile is used as
    # the headline because absolute cosine is compressed in this space.
    mean_knn = np.array([r["mean_cos_knn"] for r in rows])
    dist_pct = _percentile_ranks(1.0 - mean_knn)

    # persist versioned + latest embeddings + centroids
    base = f"primes_sig.{date}.{snap}"
    np.save(os.path.join(EMB_DIR, base + ".npy"), E)
    np.save(os.path.join(EMB_DIR, base + ".centroids.npy"), centroids)
    meta = [{"slug": slugs[i], "name": names[i], "text_source": srcs[i]} for i in range(n)]
    json.dump(meta, open(os.path.join(EMB_DIR, base + ".meta.json"), "w"))
    np.save(os.path.join(EMB_DIR, "primes_sig.latest.npy"), E)
    np.save(os.path.join(EMB_DIR, "primes_sig.latest.centroids.npy"), centroids)
    json.dump(meta, open(os.path.join(EMB_DIR, "primes_sig.latest.meta.json"), "w"))
    nclus_w, carried_w, need_w = _write_clusters(cid, sizes, slugs, date, snap)
    # reset any stale overlay from a previous incremental run. Some mounts disallow
    # unlink, so overwrite-empty rather than remove.
    _ov = os.path.join(EMB_DIR, "primes_sig.overlay.npy")
    _ovm = os.path.join(EMB_DIR, "primes_sig.overlay.meta.json")
    if os.path.exists(_ov):
        np.save(_ov, np.zeros((0, E.shape[1]), np.float32))
        json.dump([], open(_ovm, "w"))

    # write distinctiveness jsonl (versioned + latest)
    out_lines = []
    for i in range(n):
        r = rows[i]
        nearest = [[slugs[j], r["_nbr_sims"][t]] for t, j in enumerate(r["_nbr_idx"])]
        out_lines.append({
            "slug": slugs[i],
            "name": names[i],
            "property": "distinctiveness",
            "computed_at": date,
            "corpus_snapshot": snap,
            "embedder": EMBEDDER_ID,
            "mode": "batch",
            "params": {"k": args.k, "threshold": args.threshold,
                       "max_tokens": args.max_tokens, "n_clusters": K},
            "text_source": srcs[i],
            # headline: 0 = most crowded / least retrievable, 1 = most distinctive
            "distinctiveness_percentile": round(float(dist_pct[i]), 4),
            # broad-neighborhood-density signal (raw, compressed scale)
            "mean_cos_knn": r["mean_cos_knn"],
            "distinctiveness_raw": round(float(1.0 - r["mean_cos_knn"]), 4),
            # near-twin signals
            "max_cos": r["max_cos"],
            "n_neighbors_above_threshold": r["n_neighbors_above_threshold"],
            "cluster_id": cid[i],
            "cluster_size": sizes[cid[i]],
            "nearest": nearest,
        })
    vpath = os.path.join(DERIVED, f"distinctiveness.{date}.{snap}.jsonl")
    with open(vpath, "w") as f:
        for o in out_lines:
            f.write(json.dumps(o) + "\n")
    with open(os.path.join(DERIVED, "distinctiveness.latest.jsonl"), "w") as f:
        for o in out_lines:
            f.write(json.dumps(o) + "\n")

    # distribution summary
    raw = np.array([o["distinctiveness_raw"] for o in out_lines])
    singletons = sum(1 for s in sizes.values() if s == 1)
    print(f"[batch] distinctiveness_raw (1-mean_cos@{args.k}): "
          f"mean={raw.mean():.3f} min={raw.min():.3f} max={raw.max():.3f}  "
          f"(percentile-normalized as headline)")
    print(f"[batch] clusters: K={K} kmeans, sizes min={min(sizes.values())} "
          f"median={int(np.median(list(sizes.values())))} max={max(sizes.values())}; "
          f"{carried_w}/{nclus_w} families carried over by overlap, "
          f"{need_w} need a fresh name/description")

    m = _read_manifest()
    m.update({
        "distinctiveness": {
            "latest": os.path.relpath(vpath, REPO),
            "computed_at": date, "corpus_snapshot": snap, "embedder": EMBEDDER_ID,
            "n_primes": n, "params": {"k": args.k, "threshold": args.threshold,
                                      "max_tokens": args.max_tokens, "n_clusters": K},
            "n_clusters": K, "singletons": singletons,
            "incremental_adds_since_batch": 0,
        }
    })
    m.setdefault("structural_framed", m.get("structural_framed", {"latest": None}))
    _write_manifest(m)
    print(f"[batch] wrote {vpath}")
    print(f"[batch] manifest updated; incremental counter reset to 0")
    _print_next_steps(need_w, snap)

def _print_next_steps(need_w, snap):
    """Hand the human a paste-ready plan for the steps a script can't do itself
    (LLM family naming/describing + grading new primes) and the final site rebuild."""
    bar = "=" * 72
    print("\n" + bar)
    print("NEXT STEPS (for you) — recompute done; a couple of LLM steps may remain")
    print(bar)
    print("This script did the deterministic work: embeddings, neighborhoods,")
    print("distinctiveness, and k-means families. What it can't do by itself is the")
    print("LLM judgment — naming/describing families whose membership changed too much")
    print("to carry over, and grading any newly-added primes.\n")
    if need_w == 0:
        print("All families carried their name + description over from the previous run.")
        print("No LLM step needed. Just rebuild the static site:\n")
        print("    ./scripts/build_eoa_site.sh\n")
        print(bar)
        return
    print(f"{need_w} family/families need a fresh name + 1-2 sentence description")
    print("(and any new primes need a structural/framed grade). To finish, open a NEW")
    print("Claude (Cowork) window on this repo and paste the prompt between the lines:\n")
    print("    " + "-" * 64)
    print("    I just ran `python scripts/prime_labeling.py batch` on the Encyclopedia")
    print("    of Abstractions. Please read the workflow-00 memory, then finish the")
    print("    derived-label refresh:")
    print("      1. `python scripts/prime_labeling.py name-clusters` lists families")
    print("         missing a name or description. For each, write a short Title-Case")
    print("         name (2-5 words) and a 1-2 sentence description from its member")
    print("         slugs into a JSONL of {\"cluster_id\":N,\"name\":\"...\",")
    print("         \"description\":\"...\"}, then `python scripts/prime_labeling.py")
    print("         ingest-cluster-names --in <file>`.")
    print("      2. If any primes were newly ADDED, grade them with the 3-independent-")
    print("         grader protocol, ingest with `ingest-grades`, generate locked-")
    print("         template explanations, and apply via")
    print("         `scripts/apply_structural_framed_to_v2.py`.")
    print("      3. Tell me when done and I'll rebuild the site.")
    print("    " + "-" * 64)
    print("\nThen rebuild the static site:\n")
    print("    ./scripts/build_eoa_site.sh\n")
    print(bar)

# ----------------------------------------------------------------------------------
# INCREMENTAL  (score one new prime vs the frozen batch index; do NOT recompute others)
def cmd_incremental(args):
    _ensure_dirs()
    E = np.load(os.path.join(EMB_DIR, "primes_sig.latest.npy"))
    meta = json.load(open(os.path.join(EMB_DIR, "primes_sig.latest.meta.json")))
    man = _read_manifest().get("distinctiveness", {})
    k = man.get("params", {}).get("k", args.k)
    threshold = man.get("params", {}).get("threshold", args.threshold)
    max_tokens = man.get("params", {}).get("max_tokens", args.max_tokens)

    # locate the new prime in the current corpus
    rec = None
    for o in _iter_jsonl(PRIMES):
        if o.get("slug") == args.slug:
            rec = o; break
    if rec is None:
        sys.exit(f"slug not found in corpus: {args.slug}")
    sig  = (rec.get("structural_signature") or "").strip()
    core = (rec.get("core_idea") or "").strip()
    text, src = (sig, "structural_signature") if sig else \
                ((core, "core_idea(fallback)") if core else (args.slug, "slug(fallback)"))

    # include any prior incremental additions so repeated incrementals see each other
    overlay_npy = os.path.join(EMB_DIR, "primes_sig.overlay.npy")
    overlay_meta_p = os.path.join(EMB_DIR, "primes_sig.overlay.meta.json")
    if os.path.exists(overlay_npy):
        Eov = np.load(overlay_npy); meta_ov = json.load(open(overlay_meta_p))
    else:
        Eov = np.zeros((0, E.shape[1]), np.float32); meta_ov = []
    Efull = np.vstack([E, Eov]) if len(Eov) else E
    meta_full = meta + meta_ov
    # exclude any existing row with the same slug (re-scoring an existing prime should
    # not match itself; a genuinely new prime simply won't be present).
    keep = [i for i, mm in enumerate(meta_full) if mm["slug"] != args.slug]
    if len(keep) != len(meta_full):
        Efull = Efull[keep]; meta_full = [meta_full[i] for i in keep]

    v = embed([text], is_query=False, max_tokens=max_tokens)[0]
    sims = Efull @ v
    order = np.argsort(-sims)
    topk = order[:k]
    mean_cos_knn = round(float(sims[topk].mean()), 4)
    above = int((sims >= threshold).sum())
    nearest = [[meta_full[j]["slug"], round(float(sims[j]), 4)] for j in topk]

    # cluster assignment: nearest saved k-means centroid (clean + deterministic).
    cpath = os.path.join(EMB_DIR, "primes_sig.latest.centroids.npy")
    if os.path.exists(cpath):
        C = np.load(cpath)
        cluster_id = int(np.argmax(C @ v))
    else:
        cluster_id = None

    # distinctiveness percentile of the NEW prime relative to the frozen batch:
    # where 1-mean_cos_knn falls among existing distinctiveness_raw values.
    dist_latest = os.path.join(DERIVED, "distinctiveness.latest.jsonl")
    existing_raw = [o.get("distinctiveness_raw") for o in _iter_jsonl(dist_latest)
                    if o.get("mode") == "batch" and o.get("distinctiveness_raw") is not None]
    my_raw = 1.0 - mean_cos_knn
    pct = (sum(1 for x in existing_raw if x <= my_raw) / max(1, len(existing_raw))
           if existing_raw else None)

    line = {
        "slug": args.slug, "name": rec.get("name", ""),
        "property": "distinctiveness", "computed_at": today(),
        "corpus_snapshot": corpus_snapshot(), "embedder": EMBEDDER_ID,
        "mode": "incremental",
        "params": {"k": k, "threshold": threshold, "max_tokens": max_tokens},
        "text_source": src,
        "distinctiveness_percentile": round(pct, 4) if pct is not None else None,
        "mean_cos_knn": mean_cos_knn,
        "distinctiveness_raw": round(my_raw, 4),
        "max_cos": round(float(sims[order[0]]), 4),
        "n_neighbors_above_threshold": above,
        "cluster_id": cluster_id,
        "cluster_size": None,  # unknown without global recompute
        "nearest": nearest,
        "note": "incremental: scored against frozen batch index (percentile vs batch "
                "rows, cluster = nearest saved centroid); existing primes' neighborhoods "
                "are slightly stale until the next batch recompute.",
    }

    # append vector to overlay so subsequent incrementals see it
    np.save(overlay_npy, np.vstack([Eov, v[None, :]]))
    json.dump(meta_ov + [{"slug": args.slug, "name": rec.get("name", ""), "text_source": src}],
              open(overlay_meta_p, "w"))
    # append/replace this slug's line in distinctiveness.latest.jsonl
    existing = [o for o in _iter_jsonl(dist_latest) if o["slug"] != args.slug]
    with open(dist_latest, "w") as f:
        for o in existing:
            f.write(json.dumps(o) + "\n")
        f.write(json.dumps(line) + "\n")

    m = _read_manifest()
    m.setdefault("distinctiveness", {}).setdefault("incremental_adds_since_batch", 0)
    m["distinctiveness"]["incremental_adds_since_batch"] += 1
    _write_manifest(m)

    print(json.dumps(line, indent=2))
    cnt = m["distinctiveness"]["incremental_adds_since_batch"]
    if cnt >= args.batch_cadence:
        print(f"\n[drift] {cnt} incremental adds since last batch >= cadence "
              f"({args.batch_cadence}). Run `batch` to refresh all neighborhoods.")

# ----------------------------------------------------------------------------------
# CALIBRATE  (sweep thresholds against the known crowded clusters)
def cmd_calibrate(args):
    slugs, names, texts, srcs = load_primes()
    idx = {s: i for i, s in enumerate(slugs)}
    E = embed(texts, is_query=False, max_tokens=args.max_tokens)
    S = E @ E.T; np.fill_diagonal(S, -np.inf)

    # overall similarity distribution (off-diagonal upper triangle)
    iu = np.triu_indices(len(slugs), k=1)
    flat = S[iu]
    qs = np.quantile(flat, [0.5, 0.9, 0.95, 0.99, 0.995, 0.999])
    print("[calibrate] off-diagonal cosine quantiles:")
    for q, val in zip([50, 90, 95, 99, 99.5, 99.9], qs):
        print(f"    p{q:<5}= {val:.3f}")

    # within-cluster similarities for the known crowded anchors
    print("\n[calibrate] within-anchor-cluster pairwise cosine (should be HIGH):")
    anchor_pair_sims = []
    for label, members in CALIBRATION_CLUSTERS.items():
        present = [m for m in members if m in idx]
        missing = [m for m in members if m not in idx]
        pairs = []
        for a in range(len(present)):
            for b in range(a + 1, len(present)):
                s = float(S[idx[present[a]], idx[present[b]]])
                pairs.append((present[a], present[b], round(s, 3)))
                anchor_pair_sims.append(s)
        print(f"  {label}: present={present} missing={missing}")
        for a, b, s in pairs:
            print(f"      {a:30s} ~ {b:30s} {s}")
    if anchor_pair_sims:
        print(f"  --> anchor pair cosine: min={min(anchor_pair_sims):.3f} "
              f"mean={np.mean(anchor_pair_sims):.3f}")

    # distinctiveness percentile of each anchor (LOW = crowded = should be flagged)
    rows, _ = _knn_stats(S, args.k, 0.80)
    knn = np.array([r["mean_cos_knn"] for r in rows])
    pct = _percentile_ranks(1.0 - knn)
    print(f"\n[calibrate] anchor distinctiveness percentile (k={args.k}; "
          f"LOW=crowded, the calibration target):")
    for label, members in CALIBRATION_CLUSTERS.items():
        print(f"  {label}:")
        for m in members:
            if m in idx:
                i = idx[m]
                print(f"      {m:30s} mean_cos={knn[i]:.3f}  pct={pct[i]:.2f}  "
                      f"max_cos={rows[i]['max_cos']:.3f}  "
                      f"n>=0.80={rows[i]['n_neighbors_above_threshold']}")

    # k-means sizing check
    for K in (args.clusters or max(2, len(slugs) // 10),):
        cid, sizes, _ = _cluster_kmeans(E, K)
        sv = list(sizes.values())
        print(f"\n[calibrate] kmeans K={K}: sizes min={min(sv)} "
              f"median={int(np.median(sv))} max={max(sv)}")
        for label, members in CALIBRATION_CLUSTERS.items():
            present = [m for m in members if m in idx]
            cs = {cid[idx[m]] for m in present}
            print(f"  {label}: anchors fall in {len(cs)} cluster(s) "
                  f"{'(together)' if len(cs)==1 else '(split)'}")
    print("\n  Headline distinctiveness is the percentile; the dense compressed space "
          "makes absolute cosine weak, so we normalize by rank. Both broad-density "
          "(mean_cos) and near-twin (max_cos, n>=thr) signals are reported per prime.")

# ----------------------------------------------------------------------------------
# INSPECT
def cmd_inspect(args):
    slugs, names, texts, srcs = load_primes()
    idx = {s: i for i, s in enumerate(slugs)}
    E = embed(texts, is_query=False, max_tokens=args.max_tokens)
    S = E @ E.T; np.fill_diagonal(S, -np.inf)
    for slug in args.slugs:
        if slug not in idx:
            print(f"  ?? {slug} not found"); continue
        i = idx[slug]; order = np.argsort(-S[i])[:args.k]
        print(f"\n=== {slug}  (mean_cos@{args.k}={S[i, order].mean():.3f}, "
              f"distinctiveness={1-S[i,order].mean():.3f}) ===")
        for j in order:
            print(f"    {S[i,j]:.3f}  {slugs[j]}")

# ----------------------------------------------------------------------------------
# STRUCTURAL/FRAMED grade template + ingestion
def _prime_record(slug):
    for o in _iter_jsonl(PRIMES):
        if o.get("slug") == slug:
            return o
    return None

def _grade_prompt(rec):
    crit = "\n".join(f"  {i+1}. [{k}] {d}" for i, (k, d) in enumerate(CRITERIA))
    return f"""Grade this prime on the STRUCTURAL<->FRAMED spectrum using the five
diagnostic criteria from conceptual/structural-and-framed-primes.md sec 3.2.
Score EACH criterion as 0.0 (structural pole), 0.5 (mixed), or 1.0 (framed pole).
Then give a one-paragraph rationale. Do not consult the prime's existing labels.

PRIME: {rec.get('name')}  ({rec.get('slug')})
ORIGIN_DOMAIN: {rec.get('origin_domain')}
CORE_IDEA: {(rec.get('core_idea') or '')[:1200]}
STRUCTURAL_SIGNATURE: {(rec.get('structural_signature') or '')[:800]}
WHAT_IT_IS_NOT: {(rec.get('what_it_is_not') or '')[:500]}

CRITERIA (score each 0.0 / 0.5 / 1.0):
{crit}

Return JSON ONLY:
{{"slug":"{rec.get('slug')}","criteria":{{"vocab_travels":_,"evaluative_weight":_,
"institutional_origin":_,"human_practice_bound":_,"import_vs_recognize":_}},
"rationale":"..."}}"""

def cmd_grade_template(args):
    if args.slug:
        slugs = [args.slug]
    else:
        allslugs, *_ = load_primes()
        # a spread across the corpus for calibration
        n = args.sample or 18
        step = max(1, len(allslugs) // n)
        slugs = allslugs[::step][:n]
    for s in slugs:
        rec = _prime_record(s)
        if rec is None:
            print(f"# slug not found: {s}"); continue
        print("\n" + "=" * 88)
        print(_grade_prompt(rec))

def aggregate_and_label(criteria):
    vals = [float(criteria[k]) for k, _ in CRITERIA]
    agg = round(sum(vals) / len(vals), 4)
    for ub, lab in LABEL_BINS:
        if agg <= ub:
            return agg, lab
    return agg, "framed"

def cmd_ingest_grades(args):
    """Read JSONL of {slug, criteria{...}, rationale, [self_consistency]} produced by
    the LLM grader, validate, compute aggregate+label, write the derived S/F file."""
    _ensure_dirs()
    valid_keys = {k for k, _ in CRITERIA}
    out = []
    errors = []
    for o in _iter_jsonl(args.infile):
        slug = o.get("slug")
        crit = o.get("criteria", {})
        if set(crit.keys()) != valid_keys:
            errors.append((slug, f"criteria keys {set(crit.keys())} != {valid_keys}")); continue
        if any(crit[k] not in (0.0, 0.5, 1.0) for k in valid_keys):
            errors.append((slug, f"marks must be 0.0/0.5/1.0: {crit}")); continue
        if not (o.get("rationale") or "").strip():
            errors.append((slug, "missing rationale")); continue
        agg, label = aggregate_and_label(crit)
        rec = {
            "slug": slug, "name": (_prime_record(slug) or {}).get("name", ""),
            "property": "structural_framed",
            "graded_at": today(), "corpus_snapshot": corpus_snapshot(),
            "grader": o.get("grader", "claude (LLM-assisted)"),
            "criteria": {k: float(crit[k]) for k in valid_keys},
            "aggregate": agg, "label": label,
            "rationale": o["rationale"].strip(),
        }
        if "self_consistency" in o:
            rec["self_consistency"] = o["self_consistency"]
        out.append(rec)
    if errors:
        print(f"[ingest] {len(errors)} rejected:")
        for s, e in errors:
            print(f"    {s}: {e}")
    # merge with any existing latest (replace by slug)
    latest = os.path.join(DERIVED, "structural_framed.latest.jsonl")
    existing = {o["slug"]: o for o in _iter_jsonl(latest)} if os.path.exists(latest) else {}
    for r in out:
        existing[r["slug"]] = r
    vpath = os.path.join(DERIVED, f"structural_framed.{today()}.jsonl")
    with open(vpath, "w") as f:
        for r in out:
            f.write(json.dumps(r) + "\n")
    with open(latest, "w") as f:
        for r in existing.values():
            f.write(json.dumps(r) + "\n")
    m = _read_manifest()
    m["structural_framed"] = {"latest": os.path.relpath(vpath, REPO),
                              "graded_at": today(), "n_graded_total": len(existing),
                              "n_graded_this_run": len(out)}
    _write_manifest(m)
    print(f"[ingest] accepted {len(out)} grades ({len(existing)} total); wrote {vpath}")

# ----------------------------------------------------------------------------------
# VERIFY
def cmd_verify(args):
    slugs, *_ = load_primes()
    sset = set(slugs)
    ok = True
    # distinctiveness coverage
    dl = os.path.join(DERIVED, "distinctiveness.latest.jsonl")
    if os.path.exists(dl):
        dslugs = [o["slug"] for o in _iter_jsonl(dl)]
        missing = sset - set(dslugs)
        dup = len(dslugs) != len(set(dslugs))
        print(f"[verify] distinctiveness: {len(dslugs)} rows, "
              f"{len(missing)} corpus primes missing, duplicates={dup}")
        if missing:
            ok = False; print(f"    missing (up to 10): {list(missing)[:10]}")
    else:
        ok = False; print("[verify] distinctiveness.latest.jsonl MISSING")
    # structural/framed validity
    sf = os.path.join(DERIVED, "structural_framed.latest.jsonl")
    if os.path.exists(sf):
        rows = list(_iter_jsonl(sf))
        bad = 0
        for o in rows:
            agg, label = aggregate_and_label(o["criteria"])
            if abs(agg - o["aggregate"]) > 1e-6 or label != o["label"]:
                bad += 1
        print(f"[verify] structural_framed: {len(rows)} graded, "
              f"{len(sset - {o['slug'] for o in rows})} ungraded, {bad} internally inconsistent")
        if bad:
            ok = False
    else:
        print("[verify] structural_framed.latest.jsonl absent (grades not yet ingested)")
    print(f"[verify] {'PASS' if ok else 'ISSUES FOUND'}")

# ----------------------------------------------------------------------------------
# CLUSTER NAMING  (LLM-assisted, like the structural/framed grade)
def cmd_name_clusters(args):
    """Print each cluster's membership so an LLM can propose a short family name +
    description. Filter to specific ids with --ids 1,13,29; by default prints only
    clusters MISSING a name or a description (the ones a recompute didn't carry over)."""
    data = json.load(open(_clusters_path()))
    want = set(args.ids.split(",")) if args.ids else None
    shown = 0
    for cid, rec in sorted(data["clusters"].items(), key=lambda kv: int(kv[0])):
        if want and cid not in want:
            continue
        if not want and rec.get("name") and rec.get("description"):
            continue  # already fully labeled
        shown += 1
        print(f"\n--- cluster {cid}  (size {rec['size']}, name={rec.get('name')!r}, "
              f"has_description={bool(rec.get('description'))}) ---")
        print("  members:", ", ".join(rec["members"]))
    if not want:
        print(f"\n{shown} cluster(s) need a name and/or description. After generating, "
              f"ingest with: python scripts/prime_labeling.py ingest-cluster-names --in <file.jsonl>")

def cmd_ingest_cluster_names(args):
    """Read JSONL of {cluster_id, name?, description?} and write into
    clusters.latest.json. Either field may be present; both are merged."""
    data = json.load(open(_clusters_path()))
    n = nd = 0
    for o in _iter_jsonl(args.infile):
        cid = str(o["cluster_id"])
        if cid not in data["clusters"]:
            continue
        if o.get("name"):
            data["clusters"][cid]["name"] = o["name"].strip(); n += 1
        if o.get("description"):
            data["clusters"][cid]["description"] = o["description"].strip(); nd += 1
    json.dump(data, open(_clusters_path(), "w"), indent=1)
    named = sum(1 for r in data["clusters"].values() if r.get("name"))
    desc = sum(1 for r in data["clusters"].values() if r.get("description"))
    print(f"[cluster-names] set {n} names, {nd} descriptions; "
          f"{named}/{len(data['clusters'])} named, {desc}/{len(data['clusters'])} described")

# ----------------------------------------------------------------------------------
def main():
    ap = argparse.ArgumentParser(description=__doc__,
                                 formatter_class=argparse.RawDescriptionHelpFormatter)
    sub = ap.add_subparsers(dest="cmd", required=True)

    b = sub.add_parser("batch"); b.set_defaults(fn=cmd_batch)
    b.add_argument("--k", type=int, default=10)
    b.add_argument("--threshold", type=float, default=0.80)
    b.add_argument("--max-tokens", type=int, default=200)
    b.add_argument("--clusters", type=int, default=0, help="k for k-means; 0 => n//10")

    inc = sub.add_parser("incremental"); inc.set_defaults(fn=cmd_incremental)
    inc.add_argument("--slug", required=True)
    inc.add_argument("--k", type=int, default=10)
    inc.add_argument("--threshold", type=float, default=0.80)
    inc.add_argument("--max-tokens", type=int, default=200)
    inc.add_argument("--batch-cadence", type=int, default=20)

    c = sub.add_parser("calibrate"); c.set_defaults(fn=cmd_calibrate)
    c.add_argument("--k", type=int, default=10)
    c.add_argument("--max-tokens", type=int, default=200)
    c.add_argument("--clusters", type=int, default=0)

    ins = sub.add_parser("inspect"); ins.set_defaults(fn=cmd_inspect)
    ins.add_argument("slugs", nargs="+")
    ins.add_argument("--k", type=int, default=10)
    ins.add_argument("--max-tokens", type=int, default=200)

    gt = sub.add_parser("grade-template"); gt.set_defaults(fn=cmd_grade_template)
    gt.add_argument("--slug", default=None)
    gt.add_argument("--sample", type=int, default=0)

    ig = sub.add_parser("ingest-grades"); ig.set_defaults(fn=cmd_ingest_grades)
    ig.add_argument("--in", dest="infile", required=True)

    nc = sub.add_parser("name-clusters"); nc.set_defaults(fn=cmd_name_clusters)
    nc.add_argument("--ids", default=None, help="comma-separated cluster ids; default = unnamed")

    icn = sub.add_parser("ingest-cluster-names"); icn.set_defaults(fn=cmd_ingest_cluster_names)
    icn.add_argument("--in", dest="infile", required=True)

    v = sub.add_parser("verify"); v.set_defaults(fn=cmd_verify)

    a = ap.parse_args()
    a.fn(a)

if __name__ == "__main__":
    main()
