Adaptive all-atom tokenization of proteins, nucleic acids, and small molecules.
Combines APT (adaptive protein tokenization via diffusion autoencoder + FSQ + nested dropout) with Bio2Token (all-atom representation). The result is APT's architecture extended to every heavy atom in the structure — not just C-alpha.
Architecture: Transformer encoder → FSQ quantization (1000 tokens) → DiT diffusion decoder with conditional flow matching. Nested dropout during training creates a coarse-to-fine hierarchy so any prefix of tokens is a valid reconstruction.
~79M parameters. Trains on a single A100/H100 in float32.
Requires uv.
# Install the package and all dependencies
uv sync
# With Flash Attention 2 (recommended for GPU training)
uv sync --extra flash
# With dev dependencies (for running tests)
uv sync --extra devTraining logs to Weights & Biases by default. Log in before training:
uv run wandb loginTo disable wandb, pass --no-wandb to the training script.
Scan your mmCIF files and create a parquet index:
uv run python scripts/preprocess_data.py \
--mmcif_dir ~/tim1/helico-data/raw/mmCIF \
--output data/index.parquetFor a quick test with a subset:
uv run python scripts/preprocess_data.py \
--mmcif_dir ~/tim1/helico-data/raw/mmCIF \
--output data/index.parquet \
--max_files 100uv run python scripts/train_tokenizer.py --config configs/default.yamlCommon overrides:
# Name your wandb run
uv run python scripts/train_tokenizer.py --config configs/default.yaml \
--wandb_run_name first-run
# Adjust batch size and learning rate
uv run python scripts/train_tokenizer.py --config configs/default.yaml \
--batch_size 4 --lr 1e-4
# Resume from checkpoint
uv run python scripts/train_tokenizer.py --config configs/default.yaml \
--resume checkpoints/step_10000.pt
# Train without wandb
uv run python scripts/train_tokenizer.py --config configs/default.yaml \
--no-wandb
# Use a specific GPU
CUDA_VISIBLE_DEVICES=0 uv run python scripts/train_tokenizer.py --config configs/default.yamlAll CLI flags (--batch_size, --lr, --max_steps, --seed, --index_path, --max_atoms, --wandb_project, --wandb_run_name) override the corresponding YAML config values.
uv run python scripts/evaluate.py \
--checkpoint checkpoints/final.pt \
--output results.jsonOptions:
uv run python scripts/evaluate.py \
--checkpoint checkpoints/final.pt \
--max_samples 500 \
--n_steps 100 \
--cfg_weight 2.0 \
--output results.jsonuv run pyteststruct2token/
├── configs/
│ └── default.yaml # master config
├── src/struct2token/
│ ├── config.py # dataclass configs + YAML loading
│ ├── data/
│ │ ├── tokens.py # atom-type, residue-type vocabularies
│ │ ├── molecule_conventions.py # per-residue canonical atom ordering
│ │ ├── mmcif_parser.py # mmCIF → all-atom features
│ │ ├── dataset.py # PyTorch Dataset with caching
│ │ └── collate.py # variable-length batching
│ ├── model/
│ │ ├── embeddings.py # coord + atom + residue + meta embeddings
│ │ ├── attention.py # Flash Attention 2 transformer (SDPA fallback)
│ │ ├── rotary.py # RoPE positional embeddings
│ │ ├── fsq.py # Finite Scalar Quantization (8,5,5,5 → 1000 codes)
│ │ ├── cfm.py # Conditional Flow Matching
│ │ ├── dit.py # DiT decoder with adaLN
│ │ └── dae.py # main Diffusion Autoencoder
│ ├── losses/
│ │ ├── rmsd.py # Kabsch-aligned RMSD
│ │ ├── inter_atom_distance.py
│ │ ├── permutation.py # symmetric sidechain resolution
│ │ └── tm.py # TM-score
│ ├── training/
│ │ ├── trainer.py # training loop + wandb
│ │ ├── ema.py # exponential moving average
│ │ └── augmentation.py # random SO(3) rotation
│ └── inference/
│ ├── encode.py
│ ├── decode.py
│ └── metrics.py
├── scripts/
│ ├── preprocess_data.py # build data index
│ ├── train_tokenizer.py # training entry point
│ └── evaluate.py # evaluation entry point
└── tests/
All parameters live in configs/default.yaml. Key settings:
| Parameter | Default | Notes |
|---|---|---|
model.encoder.d_model |
256 | Encoder hidden dim |
model.decoder.d_model |
512 | Decoder hidden dim |
model.decoder.n_layers |
12 | DiT depth |
model.fsq.levels |
[8,5,5,5] | 1000-token codebook |
model.n_tokens |
128 | Max adaptive tokens |
model.max_seq_len |
8192 | Max atoms per structure |
training.lr |
3e-4 | AdamW learning rate |
training.batch_size |
2 | Per-GPU batch size |
training.max_steps |
500000 | Total training steps |
training.wandb_project |
struct2token | WandB project name |
During training the following are logged:
train/flow_loss— flow matching MSE (main training signal)train/size_loss— atom count prediction CEtrain/total_loss— weighted sumtrain/grad_norm— gradient norm before clippingtrain/lr— current learning rateval/flow_loss,val/size_loss— validation losses (EMA model)
Training data: PDB mmCIF files (gzipped or plain). The preprocessing script scans all files and writes a parquet index with path, chain ID, entity type, and atom count per chain. The dataset lazily parses mmCIF files on access and caches parsed tensors as .pt files.
references/apt.pdf— Adaptive Protein Tokenizationreferences/bio2token.pdf— Bio2Token: All-atom tokenization