# `docs/PHASE_1_THINK.md` — Phase 1: pretraining-from-scratch nano

> Briefer than Phase 0's THINK because the design space is narrower: BD3-LM's recipe is well-documented and CASCADE's deltas (adaptive K, distillation) don't come in until Phase 3 and 4.

## 1. What I understand the task to be

Wire up everything between "raw token tensor" and "trained CASCADE-nano". Specifically:

- `corruption.py`: BD3-LM corruption — pick target block `b`, sample `m`, mask block `b`, fully-mask blocks `>b`.
- `losses.py`: LLaDA-style `1/m`-weighted masked CE.
- `modules/rope.py`, `modules/ffn.py`: standard RoPE and SwiGLU.
- `modules/block_causal_attn.py`: production multi-head attention with block-causal mask. **Dense fallback** in this phase (the flash-attn dispatch is gated behind a runtime check and is a no-op in CPU environments).
- `modules/cascade_block.py`: pre-norm + attn + FFN + residuals.
- `model.py:CascadeLM`: token embedding + N CascadeBlocks + output head, with logits at every position.

**Realistic exit criteria for this environment (no GPU, no data downloads):**

- All Phase 1 modules implemented (no `NotImplementedError` left in `cascade/{corruption,losses,model}.py` or `cascade/modules/{block_causal_attn,cascade_block,ffn,rope}.py`).
- `test_corruption.py`, `test_block_mask.py` and a new `test_losses.py` + `test_rope.py` all pass.
- The production `BlockCausalAttention` outputs match the dense reference oracle (`cascade/attention_reference.py`) within `1e-6` fp32 tolerance on the same inputs.
- End-to-end smoke test: feed CASCADE-nano random tokens (no real data), train for 200 steps to memorize a single batch, watch the loss decrease from ~log(vocab_size) toward near-zero. Verifies the wiring without needing FineWeb-Edu.

What's deferred to a real GPU box:

- The actual 12M / 200M-token training run.
- The "K=32 fills masked blocks coherently" sanity check (needs a trained model).
- Throughput benchmarks at production scale.

## 2. Design decisions for the new pieces

**Multi-head attention layout.** Q, K, V tensors shaped `(B, n_heads, N, d_head)` internally. Single `(d, 3*d)` qkv projection is more efficient but harder to test against the single-head reference, so use three separate `(d, d)` projections in this phase. Switch to fused qkv in Phase 5 if profiling shows it matters.

**RoPE position handling.** Apply RoPE inside the attention `forward`, after the Q and K projections, before the dot product. Positions are `arange(N)` for training; for cache reuse (Phase 2), positions of in-progress block `b` are `[b·B, b·B+1, ..., b·B+B-1]` — the cache stores already-rotated K, so prior-block K does not get rotated again.

**Dense fallback for `BlockCausalAttention`.** Build the block-causal mask as a `(N, N)` bool tensor — same code as `cascade/attention_reference.py:build_block_causal_mask` — and apply it inside softmax via `masked_fill(-inf)`. This is O(N²) memory and O(N²·d) compute; fine at nano scale (N=512) but unusable at production scale (N=4096+). The flash-attn dispatch is the production path; this fallback exists so tests work on CPU.

**Loss masking semantics.** `masked_diffusion_loss` reads `batch.loss_mask` (True at positions whose loss contributes). Per BD3-LM, these are exactly the masked positions of the target block `b` — *not* the masked positions of blocks `>b` (those are fully masked but contribute no loss; they exist only so the model sees a consistent forward shape). Tested in `test_corruption.py::test_loss_mask_matches_corruption`.

**The `1/m` factor goes outside the mean over loss positions, not inside.** I.e. the per-example loss is `(1/m_i) × mean over loss positions of -log p(target)`. This matches the ELBO derivation in `03_CASCADE.md` Appendix A.1 (the `1/m` factor comes from `(d/dt) log m(t)`, which is constant across positions within a sequence at fixed `m`).

**Smoke-test design.** Generate one random batch of shape `(4, 64)` with `vocab_size=128`, `block_size=8`. Train the model to memorize it: at every step apply the corruption, compute loss, backprop. The mean masked-CE loss should fall from `log(128) ≈ 4.85` to under `0.5` within 200 steps. If it doesn't, something in the wiring is broken — typically the loss mask, the `1/m` reweighting, or the gradient flow through the embedding/output head.

## 3. Failure modes I'm watching for

1. **Off-by-one in block boundary.** Block `b` runs from index `b·B` to `(b+1)·B - 1` inclusive. Easy to slice as `[b·B : b·B + B]` and equally easy to write `[b·B : (b+1)·B + 1]` by accident. Tested explicitly in `test_corruption.py`.

2. **Loss mask vs. token mask confusion.** Positions in blocks `>b` are masked in the *tokens* (so the model sees `[MASK]` there) but **not** in the *loss mask* (no gradient at those positions). Mixing these up gives the model a free pass to predict garbage in the future and is a silent quality drop.

3. **Dense fallback diverges from the reference oracle on multi-head.** The reference is single-head; the production path is multi-head with RoPE. Test explicitly that with `n_heads=1` and no RoPE, the production path equals the reference within `1e-6`.

4. **`1/m` exploding the gradient at low `m`.** The `m_min=1e-3` clip caps `1/m ≤ 1000`. Combined with batch averaging and AdamW's per-parameter scaling, this is fine — but if it ever isn't, watch for NaNs at the first low-`m` sample.

5. **Smoke-test doesn't converge.** If the loss plateaus near `log(vocab_size)`, the gradient isn't flowing — most likely the loss mask is empty for some `m`, or the `1/m` factor is being multiplied as a divisor instead of a multiplier. Add a guard: assert the loss mask sums to ≥ 1 per example.

## 4. Evidence-of-success plan

Phase 1 (CPU-bounded sub-phase) closes when:

1. All test files under `tests/` are at-or-above their pre-Phase-1 pass rate. New tests added in this phase pass.
2. `cascade.modules.block_causal_attn.BlockCausalAttention` with `n_heads=1` and RoPE disabled produces outputs within `1e-6` of `cascade.attention_reference.BlockCausalAttention` on identical inputs.
3. The smoke-test loss falls below `0.5` within 200 steps on a memorizable batch. (Threshold is loose because `1/m` reweighting makes "test loss" not directly comparable to CE; the qualitative thing is *monotone decrease and convergence to near-zero*.)
4. No `NotImplementedError` remains in any Phase 1 module (per `grep -r "NotImplementedError" cascade/modules cascade/{corruption,losses,model}.py`).

Phase 1 (full GPU sub-phase) — deferred — closes when nano trains stably on 200M FineWeb-Edu tokens and the block-AR equivalence (`K_b = B`) sanity check fires correctly.

## 5. What this phase explicitly does NOT cover

- `cache.py:BlockCache` (Phase 2 — the inference cache).
- `denoise.py:generate` (Phase 2).
- The flash-attn production kernel (deferred to a real GPU environment).
- Real-data training (deferred).
- The adaptive-K head (Phase 3).
