Formal Verification of Transformer Architectures in Rocq
Table of Contents
- Overview
- The Problem: Runtime Shape Errors
- The Solution: Dimension-Indexed Types
- Key Invariants Enforced
- Proof-Carrying Configuration
- Type-Level Sequence Arithmetic
- Architecture
- Implementation Strategy
- Why Proofs Are "Trivial"
- Building
- FreeBSD Build Notes
- Related Work
- Future Directions
- Conclusion
- References
Overview
The rocq-transformer project demonstrates how dependent types can enforce dimensional correctness in neural network architectures. Rather than discovering tensor shape mismatches at runtime (the infamous RuntimeError: shape mismatch), the type system catches these errors at compile time.
If the code compiles, your Transformer has no dimension bugs.
The Problem: Runtime Shape Errors
In PyTorch, tensor shapes are runtime attributes:
x = torch.randn(32, 128, 512) # Hope this is [batch, seq, d_model] y = linear(x) # Did dimensions align? Runtime will tell...
Common failure modes:
- Residual connections with mismatched dimensions
d_modelnot divisible bynum_heads(discovered deep in attention code)- Cross-attention masks with wrong query/key dimensions
- Off-by-one errors in autoregressive generation
The Solution: Dimension-Indexed Types
In Rocq, shapes live in the type signature:
Definition x : Tensor3D 32 128 512 := ... Definition y : Tensor3D 32 128 256 := linearForward projection x. (* Type checker verifies: [32,128,512] -> [32,128,256] *)
The compiler becomes the shape checker.
Key Invariants Enforced
| Invariant | Paper Reference | Enforcement Mechanism |
|---|---|---|
| Q*KT requires matching inner dims | Equation (1) | matmul4D type signature |
| Multi-head splits dmodel evenly | Section 3.2.2 | (num_heads \vert d_model) proof required |
| Residual connections match shapes | Section 5.4 | add3D requires identical types |
| Each decode step adds exactly 1 token | Inference | Tensor2D batch n -> Tensor2D batch (S n) |
Proof-Carrying Configuration
Invalid configurations cannot be constructed:
Record TransformerConfig := {
d_model : nat;
num_heads : nat;
heads_divide : (num_heads | d_model) (* PROOF required *)
}.
(* This compiles: 8 divides 512 *)
Definition goodConfig := mkConfig 512 2048 8 6 5000 (exists 64; reflexivity).
(* This fails: 7 does not divide 512 *)
Definition badConfig := mkConfig 512 2048 7 6 5000 ???. (* No proof exists! *)
Type-Level Sequence Arithmetic
Autoregressive generation encodes length in the type:
Fixpoint greedyDecodeLoop
(remaining : nat) (curLen : nat)
(tgtSoFar : Tensor2D batch curLen)
: Tensor2D batch (curLen + remaining) := (* Type IS the proof! *)
match remaining with
| 0 => tgtSoFar
| S rem' =>
let next := decodeStep ... in
let extended := cat tgtSoFar next in
greedyDecodeLoop rem' (S curLen) extended
end.
The return type Tensor2D batch (curLen + remaining) proves the final length.
Architecture
The implementation mirrors the "Attention Is All You Need" paper structure:
Transformer/ ├── Tensor.v # Dimension-indexed types, ~20 structural ops ├── Config.v # Configuration with divisibility proof ├── Linear.v # Linear projection ├── Embedding.v # Token + positional embeddings ├── Attention.v # Scaled dot-product + multi-head attention ├── LayerNorm.v # Layer normalization ├── FeedForward.v # Position-wise FFN ├── Sublayer.v # Pre-norm residual connections ├── Encoder.v # N encoder layers ├── Decoder.v # N decoder layers (3 sublayers each) ├── Model.v # Complete encoder-decoder ├── Inference.v # Greedy decoding with type proofs └── Properties.v # 15 shape preservation theorems
Implementation Strategy
Implemented (~20 operations)
Structural operations that work regardless of numeric type:
zeros,ones,fill2Dtranspose2D/3D/4DviewToHeads,viewFromHeadscat_batch,cat_seqsubsequentMask,paddingMask
Axiomatized (~30 operations)
Numerical operations with precise type signatures:
matmul2D/3D/4D- dimension constraints in typessoftmax,relu,layerNormdropout,maskedFill,scaleembeddingLookup,argmax
The axioms encode dimensional behavior:
Parameter matmul2D : forall (m n k : nat), Tensor2D m k -> Tensor2D k n -> Tensor2D m n. (* Type IS specification: (m x k) @ (k x n) = (m x n) *)
Why Proofs Are "Trivial"
Most proofs in Properties.v are one-liners:
Theorem encoder_preserves_shape : forall batch seq d_model enc x mask, exists (y : Tensor3D batch seq d_model), True. Proof. intros. exists (encoderForward enc x mask). trivial. Qed.
This isn't laziness - it's the point. The type signature of encoderForward already proves shape preservation. The theorem documents the guarantee. Compilation is verification.
Building
Nix (Rocq 9.1)
nix develop make
FreeBSD (Coq 8.20)
pkg install coq gmake
The Makefile auto-detects rocq or coq_makefile and sets COQLIB for FreeBSD's non-standard library path.
FreeBSD Build Notes
FreeBSD 14.3 requires special handling:
- COQLIB detection: FreeBSD installs Coq libraries at
/usr/local/lib/ocaml/site-lib/coqrather than the expected path - Makefile fallback: Detects
rocq(Rocq 9.1) orcoq_makefile(Coq 8.x) - gmake required: FreeBSD's system make differs from GNU make
The project compiles with both Rocq 9.1 (via Nix) and Coq 8.20.1 (FreeBSD pkg).
Related Work
- hs-annotated-transformer - Executable Haskell implementation
- The Annotated Transformer - Harvard NLP PyTorch tutorial
- Attention Is All You Need - Vaswani et al. 2017
Future Directions
- Semantic axioms: Prove softmax sums to 1, layer norm produces unit variance
- Extraction: Generate executable Haskell/OCaml from structural operations
- Training correctness: Formalize gradient computation
- Quantization proofs: Verify precision bounds for int8 inference
Conclusion
Dependent types transform "works on my machine" into "works by construction." For Transformer architectures, this means:
- Dimension bugs are impossible (caught at compile time)
- Invalid configurations cannot be constructed
- Sequence length arithmetic is proven correct
The overhead is minimal: proofs are trivial because types do the heavy lifting. The 13-module implementation compiles in ~10 seconds and produces 15 shape preservation theorems automatically from type signatures.
References
- Source: https://github.com/aygp-dr/rocq-transformer
- Rocq: https://rocq-prover.org/ (formerly Coq)
- Paper: https://arxiv.org/abs/1706.03762