Skip to content

Adaptive all-atom tokenization of arbitrary proteins, nucleic acids, and small molecules

License

Notifications You must be signed in to change notification settings

timodonnell/struct2token

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

struct2token

Adaptive all-atom tokenization of proteins, nucleic acids, and small molecules.

Combines APT (adaptive protein tokenization via diffusion autoencoder + FSQ + nested dropout) with Bio2Token (all-atom representation). The result is APT's architecture extended to every heavy atom in the structure — not just C-alpha.

Architecture: Transformer encoder → FSQ quantization (1000 tokens) → DiT diffusion decoder with conditional flow matching. Nested dropout during training creates a coarse-to-fine hierarchy so any prefix of tokens is a valid reconstruction.

~79M parameters. Trains on a single A100/H100 in float32.

Setup

Requires uv.

# Install the package and all dependencies
uv sync

# With Flash Attention 2 (recommended for GPU training)
uv sync --extra flash

# With dev dependencies (for running tests)
uv sync --extra dev

WandB

Training logs to Weights & Biases by default. Log in before training:

uv run wandb login

To disable wandb, pass --no-wandb to the training script.

Quickstart

1. Build the data index

Scan your mmCIF files and create a parquet index:

uv run python scripts/preprocess_data.py \
    --mmcif_dir ~/tim1/helico-data/raw/mmCIF \
    --output data/index.parquet

For a quick test with a subset:

uv run python scripts/preprocess_data.py \
    --mmcif_dir ~/tim1/helico-data/raw/mmCIF \
    --output data/index.parquet \
    --max_files 100

2. Train the tokenizer

uv run python scripts/train_tokenizer.py --config configs/default.yaml

Common overrides:

# Name your wandb run
uv run python scripts/train_tokenizer.py --config configs/default.yaml \
    --wandb_run_name first-run

# Adjust batch size and learning rate
uv run python scripts/train_tokenizer.py --config configs/default.yaml \
    --batch_size 4 --lr 1e-4

# Resume from checkpoint
uv run python scripts/train_tokenizer.py --config configs/default.yaml \
    --resume checkpoints/step_10000.pt

# Train without wandb
uv run python scripts/train_tokenizer.py --config configs/default.yaml \
    --no-wandb

# Use a specific GPU
CUDA_VISIBLE_DEVICES=0 uv run python scripts/train_tokenizer.py --config configs/default.yaml

All CLI flags (--batch_size, --lr, --max_steps, --seed, --index_path, --max_atoms, --wandb_project, --wandb_run_name) override the corresponding YAML config values.

3. Evaluate

uv run python scripts/evaluate.py \
    --checkpoint checkpoints/final.pt \
    --output results.json

Options:

uv run python scripts/evaluate.py \
    --checkpoint checkpoints/final.pt \
    --max_samples 500 \
    --n_steps 100 \
    --cfg_weight 2.0 \
    --output results.json

4. Run tests

uv run pytest

Project structure

struct2token/
├── configs/
│   └── default.yaml              # master config
├── src/struct2token/
│   ├── config.py                 # dataclass configs + YAML loading
│   ├── data/
│   │   ├── tokens.py             # atom-type, residue-type vocabularies
│   │   ├── molecule_conventions.py  # per-residue canonical atom ordering
│   │   ├── mmcif_parser.py       # mmCIF → all-atom features
│   │   ├── dataset.py            # PyTorch Dataset with caching
│   │   └── collate.py            # variable-length batching
│   ├── model/
│   │   ├── embeddings.py         # coord + atom + residue + meta embeddings
│   │   ├── attention.py          # Flash Attention 2 transformer (SDPA fallback)
│   │   ├── rotary.py             # RoPE positional embeddings
│   │   ├── fsq.py                # Finite Scalar Quantization (8,5,5,5 → 1000 codes)
│   │   ├── cfm.py                # Conditional Flow Matching
│   │   ├── dit.py                # DiT decoder with adaLN
│   │   └── dae.py                # main Diffusion Autoencoder
│   ├── losses/
│   │   ├── rmsd.py               # Kabsch-aligned RMSD
│   │   ├── inter_atom_distance.py
│   │   ├── permutation.py        # symmetric sidechain resolution
│   │   └── tm.py                 # TM-score
│   ├── training/
│   │   ├── trainer.py            # training loop + wandb
│   │   ├── ema.py                # exponential moving average
│   │   └── augmentation.py       # random SO(3) rotation
│   └── inference/
│       ├── encode.py
│       ├── decode.py
│       └── metrics.py
├── scripts/
│   ├── preprocess_data.py        # build data index
│   ├── train_tokenizer.py        # training entry point
│   └── evaluate.py               # evaluation entry point
└── tests/

Config

All parameters live in configs/default.yaml. Key settings:

Parameter Default Notes
model.encoder.d_model 256 Encoder hidden dim
model.decoder.d_model 512 Decoder hidden dim
model.decoder.n_layers 12 DiT depth
model.fsq.levels [8,5,5,5] 1000-token codebook
model.n_tokens 128 Max adaptive tokens
model.max_seq_len 8192 Max atoms per structure
training.lr 3e-4 AdamW learning rate
training.batch_size 2 Per-GPU batch size
training.max_steps 500000 Total training steps
training.wandb_project struct2token WandB project name

WandB metrics

During training the following are logged:

  • train/flow_loss — flow matching MSE (main training signal)
  • train/size_loss — atom count prediction CE
  • train/total_loss — weighted sum
  • train/grad_norm — gradient norm before clipping
  • train/lr — current learning rate
  • val/flow_loss, val/size_loss — validation losses (EMA model)

Data

Training data: PDB mmCIF files (gzipped or plain). The preprocessing script scans all files and writes a parquet index with path, chain ID, entity type, and atom count per chain. The dataset lazily parses mmCIF files on access and caches parsed tensors as .pt files.

References

About

Adaptive all-atom tokenization of arbitrary proteins, nucleic acids, and small molecules

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages