# `docs/PHASE_3_THINK.md` — Phase 3: AdaptiveKHead and REINFORCE training

> Companion to `03_CASCADE.md` Appendix C (the model THINK.md). This is the
> CPU-bounded sub-phase: build the head, the loss, and prove the
> monotone-difficulty diagnostic on a synthetic task. Real-corpus training
> requires the trained base model from Phase 1 — deferred to a GPU box.

## 1. What I understand the task to be

Build a small policy head that decides, given the hidden state at the
boundary of a new block, how many denoising steps to allocate from
`K_CHOICES = (1, 2, 4, 8, 16)`. The head is trained by REINFORCE on a
quality-minus-cost reward. The CPU-bounded sub-phase here:

- `AdaptiveKHead`: 2-layer MLP `d → hidden → n_actions` with `sample`,
  `log_prob`, `entropy` interfaces.
- `reinforce_step_count_loss`: `-(R - baseline) · log π(K | h) - β_H · H(π)`.
- A *synthetic-task smoke test* that proves:
   1. The head can learn a non-uniform policy from a noisy reward.
   2. The monotone-difficulty diagnostic from `03_CASCADE.md` Appendix A.3
      passes: bucket inputs by independent difficulty, check
      `mean(K | bucket)` is monotone increasing.

The synthetic task is designed to be a maximally favorable testbed — if
the head fails *here*, real-corpus training will also fail and the project
is dead. If it succeeds, the head works as a function approximator and the
only remaining question is whether the real reward signal is informative
enough to drive it.

## 2. Design decisions

**Discrete action space `{1, 2, 4, 8, 16}` with categorical policy.**
Geometric spacing covers a 16× cost range with only 5 actions, keeping
REINFORCE variance manageable.

**Architecture: `Linear → SiLU → Linear`.** Small enough to overfit a
synthetic task in seconds on CPU; large enough to encode the difficulty
function in the real setting. Same activation as the rest of the model.

**EMA scalar baseline (not per-cluster).** The doc proposes k-means
clusters; for the CPU smoke test a scalar EMA is sufficient and proves the
mechanism. Per-cluster baseline is a Phase 3.5 optimization for real
training.

**Reward design for the synthetic task.**

We construct fake inputs `h_i` and a "true difficulty" `d_i ∈ {0..4}`,
encoded as a one-hot signal in the first 5 dims of `h_i` plus i.i.d.
Gaussian noise in the rest. The reward is:

```
R(action_idx, d) = -|action_idx - d|
```

where `action_idx` is the chosen index into `K_CHOICES`. So action 0 (K=1)
is best for difficulty 0, action 4 (K=16) is best for difficulty 4, and
mismatches are linearly penalized. This is the simplest signal that
demands a *non-trivial* function from `h` to action.

(No `-λ·K` cost term in the synthetic test — the goal is to verify the
head can *match* a difficulty signal, not to balance speed-quality. The
λ·K term is the production deal, but it adds noise we don't need here.)

**Hyperparameters.**

- Batch size: 256 (REINFORCE wants large batches for variance reduction).
- Steps: 1000 max (typically converges in ~300).
- LR: 3e-3 on the policy head (small head, can take large LR).
- `β_H`: 0.1 for the first 30% of training, linearly annealed to 0.
- EMA baseline coefficient: 0.95 (smooth, slow-moving).

## 3. Failure modes and mitigations

1. **Policy collapse to a single action.** The classic REINFORCE failure.
   Entropy bonus + EMA baseline are the standard fix. The test asserts
   that the final policy has entropy > 0.5 nats (not deterministic).

2. **Synthetic signal too weak to extract.** The one-hot signal is large
   relative to the noise (signal magnitude 1, noise stddev 0.5), so the
   head should easily learn it. If it doesn't, the test fails fast and
   I add a difficulty diagnostic.

3. **Monotonicity test is noisy / unreliable.** The diagnostic is
   `bucket_mean[d=0] ≤ bucket_mean[d=1] ≤ ... ≤ bucket_mean[d=4]`. With
   256 samples per evaluation pass, the bucket means should be well-
   separated. Allow some slack: require *Spearman rank correlation > 0.7*
   between difficulty index and mean K, not strict monotonicity, which
   is more robust to noise at the tails.

4. **The wrong loss sign.** Easy to swap `-(R - b) · log_pi` for
   `+(R - b) · log_pi` and accidentally do gradient *descent* on reward.
   The test catches this — the loss should never blow up to dominate-K
   only.

## 4. Evidence-of-success plan

Phase 3 (CPU sub-phase) closes when:

1. `tests/test_step_predictor.py` passes:
   - `test_output_is_valid_categorical` — output is a probability simplex.
   - `test_sample_log_prob_consistent` — `log_prob` matches `log(softmax(logits)[action])`.
2. `test_reinforce_loss_basic` passes:
   - On a synthetic 2-action task where action 1 always gets reward 1
     and action 0 gets reward 0, the policy converges to deterministic
     action 1 (within entropy budget).
3. `test_synthetic_monotone_difficulty` passes:
   - 5-difficulty synthetic task: Spearman rank correlation between true
     difficulty `d` and mean(K | bucket) is > 0.7 after training.
   - Final policy entropy is > 0.5 nats (not collapsed).
   - The mean predicted K differs across at least 3 difficulty buckets
     by at least a factor of 2 (e.g. `mean K | d=0 ≤ 3` and `mean K | d=4 ≥ 6`).

Deferred to a GPU + trained model:

- Real-corpus training of the head against the *actual* CASCADE quality
  signal (NLL gap vs. K=16 reference).
- The `λ` sweep (`{0.01, 0.05, 0.1, 0.5}`) and Pareto plot.
- MMLU preservation check.
- Per-cluster (k-means) baseline.

## 5. Scope explicitly NOT covered

- Production training loop with the real base model (Phase 4+).
- PPO variant.
- Joint optimization with body weights (Phase 3.5).
- Distillation-from-AR (Phase 4).
