Formal Verification of Transformer Architectures in Rocq

Table of Contents

Overview

The rocq-transformer project demonstrates how dependent types can enforce dimensional correctness in neural network architectures. Rather than discovering tensor shape mismatches at runtime (the infamous RuntimeError: shape mismatch), the type system catches these errors at compile time.

If the code compiles, your Transformer has no dimension bugs.

The Problem: Runtime Shape Errors

In PyTorch, tensor shapes are runtime attributes:

x = torch.randn(32, 128, 512)    # Hope this is [batch, seq, d_model]
y = linear(x)                     # Did dimensions align? Runtime will tell...

Common failure modes:

  • Residual connections with mismatched dimensions
  • d_model not divisible by num_heads (discovered deep in attention code)
  • Cross-attention masks with wrong query/key dimensions
  • Off-by-one errors in autoregressive generation

The Solution: Dimension-Indexed Types

In Rocq, shapes live in the type signature:

Definition x : Tensor3D 32 128 512 := ...
Definition y : Tensor3D 32 128 256 := linearForward projection x.
(* Type checker verifies: [32,128,512] -> [32,128,256] *)

The compiler becomes the shape checker.

Key Invariants Enforced

Invariant Paper Reference Enforcement Mechanism
Q*KT requires matching inner dims Equation (1) matmul4D type signature
Multi-head splits dmodel evenly Section 3.2.2 (num_heads \vert d_model) proof required
Residual connections match shapes Section 5.4 add3D requires identical types
Each decode step adds exactly 1 token Inference Tensor2D batch n -> Tensor2D batch (S n)

Proof-Carrying Configuration

Invalid configurations cannot be constructed:

Record TransformerConfig := {
  d_model   : nat;
  num_heads : nat;
  heads_divide : (num_heads | d_model)   (* PROOF required *)
}.

(* This compiles: 8 divides 512 *)
Definition goodConfig := mkConfig 512 2048 8 6 5000 (exists 64; reflexivity).

(* This fails: 7 does not divide 512 *)
Definition badConfig := mkConfig 512 2048 7 6 5000 ???.  (* No proof exists! *)

Type-Level Sequence Arithmetic

Autoregressive generation encodes length in the type:

Fixpoint greedyDecodeLoop
  (remaining : nat) (curLen : nat)
  (tgtSoFar : Tensor2D batch curLen)
  : Tensor2D batch (curLen + remaining) :=   (* Type IS the proof! *)
  match remaining with
  | 0 => tgtSoFar
  | S rem' =>
      let next := decodeStep ... in
      let extended := cat tgtSoFar next in
      greedyDecodeLoop rem' (S curLen) extended
  end.

The return type Tensor2D batch (curLen + remaining) proves the final length.

Architecture

The implementation mirrors the "Attention Is All You Need" paper structure:

Transformer/
├── Tensor.v      # Dimension-indexed types, ~20 structural ops
├── Config.v      # Configuration with divisibility proof
├── Linear.v      # Linear projection
├── Embedding.v   # Token + positional embeddings
├── Attention.v   # Scaled dot-product + multi-head attention
├── LayerNorm.v   # Layer normalization
├── FeedForward.v # Position-wise FFN
├── Sublayer.v    # Pre-norm residual connections
├── Encoder.v     # N encoder layers
├── Decoder.v     # N decoder layers (3 sublayers each)
├── Model.v       # Complete encoder-decoder
├── Inference.v   # Greedy decoding with type proofs
└── Properties.v  # 15 shape preservation theorems

Implementation Strategy

Implemented (~20 operations)

Structural operations that work regardless of numeric type:

  • zeros, ones, fill2D
  • transpose2D/3D/4D
  • viewToHeads, viewFromHeads
  • cat_batch, cat_seq
  • subsequentMask, paddingMask

Axiomatized (~30 operations)

Numerical operations with precise type signatures:

  • matmul2D/3D/4D - dimension constraints in types
  • softmax, relu, layerNorm
  • dropout, maskedFill, scale
  • embeddingLookup, argmax

The axioms encode dimensional behavior:

Parameter matmul2D : forall (m n k : nat),
  Tensor2D m k -> Tensor2D k n -> Tensor2D m n.
(* Type IS specification: (m x k) @ (k x n) = (m x n) *)

Why Proofs Are "Trivial"

Most proofs in Properties.v are one-liners:

Theorem encoder_preserves_shape : forall batch seq d_model enc x mask,
  exists (y : Tensor3D batch seq d_model), True.
Proof. intros. exists (encoderForward enc x mask). trivial. Qed.

This isn't laziness - it's the point. The type signature of encoderForward already proves shape preservation. The theorem documents the guarantee. Compilation is verification.

Building

Nix (Rocq 9.1)

nix develop
make

FreeBSD (Coq 8.20)

pkg install coq
gmake

The Makefile auto-detects rocq or coq_makefile and sets COQLIB for FreeBSD's non-standard library path.

FreeBSD Build Notes

FreeBSD 14.3 requires special handling:

  1. COQLIB detection: FreeBSD installs Coq libraries at /usr/local/lib/ocaml/site-lib/coq rather than the expected path
  2. Makefile fallback: Detects rocq (Rocq 9.1) or coq_makefile (Coq 8.x)
  3. gmake required: FreeBSD's system make differs from GNU make

The project compiles with both Rocq 9.1 (via Nix) and Coq 8.20.1 (FreeBSD pkg).

Related Work

Future Directions

  1. Semantic axioms: Prove softmax sums to 1, layer norm produces unit variance
  2. Extraction: Generate executable Haskell/OCaml from structural operations
  3. Training correctness: Formalize gradient computation
  4. Quantization proofs: Verify precision bounds for int8 inference

Conclusion

Dependent types transform "works on my machine" into "works by construction." For Transformer architectures, this means:

  • Dimension bugs are impossible (caught at compile time)
  • Invalid configurations cannot be constructed
  • Sequence length arithmetic is proven correct

The overhead is minimal: proofs are trivial because types do the heavy lifting. The 13-module implementation compiles in ~10 seconds and produces 15 shape preservation theorems automatically from type signatures.

References

Author: Jason Walsh

j@wal.sh

Last Updated: 2025-12-22 23:21:33

build: 2026-01-11 18:39 | sha: eb805a8