From 3fabc4fb342fd5c1057a969e5b22922c6bb7af2c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Feb 2026 13:03:16 +0000 Subject: [PATCH 1/3] Initial plan From ab06d3bd049d0269bc087cda6abb4f66e2b59903 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Feb 2026 13:09:55 +0000 Subject: [PATCH 2/3] Restrict Python version to <3.12 due to dependency incompatibilities Co-authored-by: julian-parker <19472441+julian-parker@users.noreply.github.com> --- README.md | 9 + build/lib/stable_codec/__init__.py | 1 + build/lib/stable_codec/ctc_loss.py | 236 ++++++++ build/lib/stable_codec/fsq.py | 134 +++++ build/lib/stable_codec/model.py | 159 +++++ build/lib/stable_codec/residual_fsq.py | 63 ++ build/lib/stable_codec/training_demo.py | 157 +++++ build/lib/stable_codec/training_module.py | 644 +++++++++++++++++++++ dist/stable_codec-0.1.3-py3-none-any.whl | Bin 0 -> 19930 bytes setup.py | 4 +- stable_codec.egg-info/PKG-INFO | 223 +++++++ stable_codec.egg-info/SOURCES.txt | 16 + stable_codec.egg-info/dependency_links.txt | 1 + stable_codec.egg-info/requires.txt | 7 + stable_codec.egg-info/top_level.txt | 1 + 15 files changed, 1653 insertions(+), 2 deletions(-) create mode 100644 build/lib/stable_codec/__init__.py create mode 100644 build/lib/stable_codec/ctc_loss.py create mode 100644 build/lib/stable_codec/fsq.py create mode 100644 build/lib/stable_codec/model.py create mode 100644 build/lib/stable_codec/residual_fsq.py create mode 100644 build/lib/stable_codec/training_demo.py create mode 100644 build/lib/stable_codec/training_module.py create mode 100644 dist/stable_codec-0.1.3-py3-none-any.whl create mode 100644 stable_codec.egg-info/PKG-INFO create mode 100644 stable_codec.egg-info/SOURCES.txt create mode 100644 stable_codec.egg-info/dependency_links.txt create mode 100644 stable_codec.egg-info/requires.txt create mode 100644 stable_codec.egg-info/top_level.txt diff --git a/README.md b/README.md index fe6b114..91f31a7 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,9 @@ Model weights: https://huggingface.co/stabilityai/stable-codec-speech-16k ## Changelog +### [v0.1.3] TBD +- __Fix__ restricted Python version to <3.12 due to dependency incompatibilities +- __Fix__ clarified installation instructions regarding Python version requirements ### [v0.1.2] 14-01-25 - __New__ added hooks for `stable-codec-speech-16k-base`. - __Fix__ fixed major issue with precision in FSQ token calculation, which was degrading results. Fix is currently local, will be upstreamed to `stable-audio-tools` later. @@ -34,6 +37,12 @@ In addition to the training described in the paper, the weights for `stable-code The model itself is defined in [stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools) package. +### Python Version Compatibility + +**Important:** This package currently requires **Python 3.9, 3.10, or 3.11**. Python 3.12 and later are not supported due to incompatibilities in the `stable-audio-tools` dependency chain (specifically `PyWavelets==1.4.1` and `pandas==2.0.2`). + +If you attempt to install on Python 3.12+, you will encounter build errors. Please use Python 3.11 or earlier. + To install `stable-codec`: ```bash diff --git a/build/lib/stable_codec/__init__.py b/build/lib/stable_codec/__init__.py new file mode 100644 index 0000000..a2cc1b3 --- /dev/null +++ b/build/lib/stable_codec/__init__.py @@ -0,0 +1 @@ +from stable_codec.model import StableCodec \ No newline at end of file diff --git a/build/lib/stable_codec/ctc_loss.py b/build/lib/stable_codec/ctc_loss.py new file mode 100644 index 0000000..a6c9e02 --- /dev/null +++ b/build/lib/stable_codec/ctc_loss.py @@ -0,0 +1,236 @@ +import torch + +from torch.nn import functional as F +from torch import nn + +from stable_audio_tools.training.losses import LossModule + +# https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html +class CTCLossModule(LossModule): + def __init__( + self, + name: str, + input_key: str, + target_key: str, + weight: float = 1.0, + decay: float = 1.0, + blank_idx: int = 0, + padding_idx: int = None, + input_lengths_key: str = None, + ): + super().__init__(name=name, weight=weight, decay=decay) + self.ctc_loss = nn.CTCLoss(blank=blank_idx, reduction='mean', zero_infinity=True) + self.input_key = input_key + self.target_key = target_key + self.input_lengths_key = input_lengths_key + self.blank_idx = blank_idx + self.padding_idx = padding_idx if padding_idx is not None else blank_idx + 1 + + def forward(self, info): + """ + Computes the CTC loss. + + Args: + info (dict): Dictionary containing model outputs and other relevant data. + - info[self.input_key]: Model logits of shape (batch_size, sequence_length, num_classes). + - info[self.target_key]: Target data (list of dicts with 'phone' key). + - info[self.input_lengths_key]: (Optional) Actual lengths of the input sequences. + + Returns: + loss (Tensor): The computed CTC loss, scaled by the weight. + """ + # Build targets and target lengths + padded_targets, target_lengths = build_target(info[self.target_key], self.padding_idx) + + # Get logits from the model output + logits = info[self.input_key] # Expected shape: (batch_size, sequence_length, num_classes) + + # Move logits to the device of phonemes + device = padded_targets.device + logits = logits.to(device) + + # Apply log_softmax to obtain log probabilities + log_probs = F.log_softmax(logits, dim=-1) # Shape: (batch_size, seq_length, num_classes) + + # Transpose log_probs to match (seq_length, batch_size, num_classes) + log_probs = log_probs.permute(1, 0, 2) # Now shape is (seq_length, batch_size, num_classes) + + # Determine input lengths + if self.input_lengths_key and self.input_lengths_key in info: + input_lengths = info[self.input_lengths_key].to(device) + else: + # Assume all input sequences have the same length + input_lengths = torch.full( + (log_probs.size(1),), # batch_size + log_probs.size(0), # seq_length + dtype=torch.long, + device=device + ) + + # Compute the CTC loss + loss = self.ctc_loss(log_probs, padded_targets, input_lengths, target_lengths) + + loss = self.weight * loss + + return loss + +class PERModule(nn.Module): + def __init__( + self, + input_key: str, + target_key: str, + blank_idx: int = 0, + padding_idx: int = None, + ): + super().__init__() + self.input_key = input_key + self.target_key = target_key + self.blank_idx = blank_idx + self.padding_idx = padding_idx if padding_idx is not None else blank_idx + 1 + + def decode_predictions(self, predicted_ids): + """ + Decodes the model predictions by collapsing repeats and removing blanks. + + Args: + predicted_ids (Tensor): Tensor of shape (seq_length,) containing predicted token IDs. + + Returns: + List[int]: Decoded sequence of token IDs. + """ + predicted_sequence = [] + previous_id = None + for id in predicted_ids: + id = id.item() + if id != self.blank_idx and id != previous_id: + predicted_sequence.append(id) + previous_id = id + return predicted_sequence + + def forward(self, info): + """ + Computes the CTC loss. + + Args: + info (dict): Dictionary containing model outputs and other relevant data. + - info[self.input_key]: Model logits of shape (batch_size, sequence_length, num_classes). + - info[self.target_key]: Target data (list of dicts with 'phone' key). + - info[self.input_lengths_key]: (Optional) Actual lengths of the input sequences. + + Returns: + loss (Tensor): The computed CTC loss, scaled by the weight. + """ + with torch.no_grad(): + # Build targets and target lengths + padded_targets, target_lengths = build_target(info[self.target_key], self.padding_idx) + + # Get logits from the model output + logits = info[self.input_key] # Expected shape: (batch_size, sequence_length, num_classes) + + # Move logits to the device of phonemes + device = padded_targets.device + logits = logits.to(device) + + # Apply log_softmax to obtain log probabilities + log_probs = F.log_softmax(logits, dim=-1) # Shape: (batch_size, seq_length, num_classes) + + # Transpose log_probs to match (seq_length, batch_size, num_classes) + log_probs = log_probs.permute(1, 0, 2) # Now shape is (seq_length, batch_size, num_classes) + + # Get predictions via greedy decoding + predicted_ids = torch.argmax(logits, dim=-1) # Shape: (batch_size, seq_length) + + batch_size = predicted_ids.size(0) + pers = [] + + for i in range(batch_size): + # Decode predictions + pred_ids = predicted_ids[i] # Tensor of shape (seq_length,) + pred_sequence = self.decode_predictions(pred_ids) + + # Get target sequence + target_ids = padded_targets[i] # Tensor of shape (max_target_length,) + target_length = target_lengths[i] + target_sequence = target_ids[:target_length].tolist() + + # Remove padding tokens from target sequence + target_sequence = [id for id in target_sequence if id != self.padding_idx] + + # Compute edit distance using the editdistance package + # distance = editdistance.eval(pred_sequence, target_sequence) + distance = edit_distance(pred_sequence, target_sequence) + + # Compute PER + per = distance / max(len(target_sequence), 1) + pers.append(per) + + # Compute average PER over the batch + average_per = sum(pers) / len(pers) + + return average_per + +def edit_distance(seq1, seq2): + """ + Computes the edit distance between two sequences. + + Args: + seq1 (List[int]): First sequence. + seq2 (List[int]): Second sequence. + + Returns: + int: The edit distance between seq1 and seq2. + """ + m = len(seq1) + n = len(seq2) + # Create a DP table + dp = [[0] * (n + 1) for _ in range(m + 1)] + # Initialize + for i in range(m + 1): + dp[i][0] = i + for j in range(n + 1): + dp[0][j] = j + # Compute dp table + for i in range(1, m + 1): + for j in range(1, n + 1): + if seq1[i - 1] == seq2[j - 1]: + cost = 0 + else: + cost = 1 + dp[i][j] = min( + dp[i - 1][j] + 1, # deletion + dp[i][j - 1] + 1, # insertion + dp[i - 1][j - 1] + cost # substitution + ) + return dp[m][n] + +def build_target(batch, padding_idx): + """ + Builds padded targets and computes target lengths. + + Args: + batch (list): A list of dictionaries, each containing a 'phone' key with tensor values. + + Returns: + padded_targets (Tensor): Padded target sequences of shape (batch_size, max_target_length). + target_lengths (Tensor): Lengths of each target sequence in the batch. + """ + # Extract phoneme sequences + phoneme_sequences = [item['phone'] for item in batch] + + # Determine device from the phoneme sequences + device = phoneme_sequences[0].device + + # Ensure phoneme sequences are 1D tensors + phoneme_sequences = [seq.view(-1) if seq.ndim > 1 else seq for seq in phoneme_sequences] + + # Compute target lengths + target_lengths = torch.tensor([seq.size(0) for seq in phoneme_sequences], dtype=torch.long, device=device) + + # Pad sequences + padded_targets = nn.utils.rnn.pad_sequence( + phoneme_sequences, + batch_first=True, + padding_value=padding_idx + ).to(device) + + return padded_targets, target_lengths diff --git a/build/lib/stable_codec/fsq.py b/build/lib/stable_codec/fsq.py new file mode 100644 index 0000000..920fa42 --- /dev/null +++ b/build/lib/stable_codec/fsq.py @@ -0,0 +1,134 @@ +""" +Dithered Finite Scalar Quantization +Code adapted from https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/finite_scalar_quantization.py +""" + +from typing import List, Tuple +import random + +import torch +import torch.nn as nn +from torch.nn import Module +from torch import Tensor, int32 +from torch.amp import autocast + +from einops import rearrange + + +def leaky_hard_clip(x: Tensor, alpha: float = 1e-3) -> Tensor: + return (1-alpha) * torch.clamp(x, -1, 1) + alpha * x + +def round_ste(z: Tensor) -> Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + +class DitheredFSQ(Module): + def __init__( + self, + levels: List[int], + dither_inference: bool = False, + num_codebooks: int = 1, + noise_dropout: float = 0.5, + scale: float = 1.0, + ): + super().__init__() + self.levels = levels + + _levels = torch.tensor(levels, dtype=torch.int64) + self.register_buffer("_levels", _levels, persistent = False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int64) + self.register_buffer("_basis", _basis, persistent = False) + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + self.codebook_size = _levels.prod().item() + + self.num_codebooks = num_codebooks + + self.dim = codebook_dim * num_codebooks + + self.dither_inference = dither_inference + + self.scale = scale + + half_l = self.scale * 2 / (self._levels - 1) + self.register_buffer("half_l", half_l, persistent = False) + + self.allowed_dtypes = (torch.float32, torch.float64) + + self.noise_dropout = noise_dropout + + def quantize(self, z, skip_tanh: bool = False): + if not skip_tanh: z = torch.tanh(z) + + if not self.training: + quantized = self._scale_and_shift_inverse(round_ste(self._scale_and_shift(z))) + else: + quantized = z + mask = torch.bernoulli(torch.full([z.shape[0],1,1,1], self.noise_dropout, device = z.device)).bool().expand_as(z) + quantized = torch.where(mask, quantized, self._scale_and_shift_inverse(round_ste(self._scale_and_shift(quantized)))) + mask = torch.bernoulli(torch.full([z.shape[0],1,1,1], self.noise_dropout, device = z.device)).bool().expand_as(z) + quantized = torch.where(mask, quantized, z + (torch.rand_like(z) - 0.5) * self.half_l) + + return quantized + + def _scale_and_shift(self, z): + level_indices = (z + 1 * self.scale) / self.half_l + return level_indices + + def _scale_and_shift_inverse(self, level_indices): + z = level_indices * self.half_l - 1 * self.scale + return z + + def _indices_to_codes(self, indices): + level_indices = self._indices_to_level_indices(indices) + codes = self._scale_and_shift_inverse(level_indices) + return codes + + def _codes_to_indices(self, zhat): + zhat = self._scale_and_shift(zhat) + zhat = zhat.round().to(torch.int64) + out = (zhat * self._basis).sum(dim=-1) + return out + + def _indices_to_level_indices(self, indices): + indices = rearrange(indices, '... -> ... 1') + codes_non_centered = (indices // self._basis) % self._levels + return codes_non_centered + + def indices_to_codes(self, indices): + # Expects input of batch x sequence x num_codebooks + assert indices.shape[-1] == self.num_codebooks, f'expected last dimension of {self.num_codebooks} but found last dimension of {indices.shape[-1]}' + codes = self._indices_to_codes(indices.to(torch.int64)) + codes = rearrange(codes, '... c d -> ... (c d)') + return codes + + @autocast(device_type="cuda", enabled = False) + def forward(self, z, skip_tanh: bool = False): + + orig_dtype = z.dtype + + assert z.shape[-1] == self.dim, f'expected dimension of {self.num_codebooks * self.dim} but found dimension of {z.shape[-1]}' + + z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks) + + # make sure allowed dtype before quantizing + + if z.dtype not in self.allowed_dtypes: + z = z.to(torch.float64) + + codes = self.quantize(z, skip_tanh=skip_tanh) + indices = self._codes_to_indices(codes) + codes = rearrange(codes, 'b n c d -> b n (c d)') + + # cast codes back to original dtype + + if codes.dtype != orig_dtype: + codes = codes.type(orig_dtype) + + # return quantized output and indices + + return codes, indices diff --git a/build/lib/stable_codec/model.py b/build/lib/stable_codec/model.py new file mode 100644 index 0000000..14f541a --- /dev/null +++ b/build/lib/stable_codec/model.py @@ -0,0 +1,159 @@ +import json +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torchaudio +from einops import rearrange +from stable_audio_tools import get_pretrained_model +from stable_audio_tools.data.utils import VolumeNorm +from stable_audio_tools.models import create_model_from_config +from stable_audio_tools.models.fsq import DitheredFSQ +from stable_audio_tools.models.utils import copy_state_dict, load_ckpt_state_dict + +from .residual_fsq import ResidualFSQBottleneck + + +class StableCodec(nn.Module): + def __init__(self, + model_config_path: Optional[str] = None, ckpt_path: Optional[str] = None, pretrained_model: Optional[str] = None, device = torch.device("cpu"), + ): + super().__init__() + self.device = device + + if pretrained_model is not None: + print(f"Loading pretrained model `{pretrained_model}`.\n") + self.model, model_config = get_pretrained_model(pretrained_model) + else: + if model_config_path is None: + raise ValueError("Either `model_config_path` or `pretrained_model` should be provided.") + print(f"Loading config from `{model_config_path}`.\n") + with open(model_config_path) as f: + model_config = json.load(f) + self.model = create_model_from_config(model_config) + if ckpt_path is not None: + print(f"Loading weights from `{ckpt_path}`.\n") + state = load_ckpt_state_dict(ckpt_path) + copy_state_dict(self.model, state) + + self.model = self.model.to(self.device).eval().requires_grad_(False) + + self.residual_fsq: Optional[ResidualFSQBottleneck] = None + + self.sample_rate = model_config["sample_rate"] + self.volume_norm = VolumeNorm([-20, 0], self.sample_rate) + + self.preset_bottleneck_configs = { + "1x46656_400bps": [ + ([6, 6, 6, 6, 6, 6], 1.0) + ], + "2x15625_700bps": [ + ([5, 5, 5, 5, 5, 5], 1.0), + ([5, 5, 5, 5, 5, 5], 0.25), + ], + "4x729_1000bps": [ + ([3, 3, 3, 3, 3, 3], 1.0), + ([3, 3, 3, 3, 3, 3], 0.5), + ([3, 3, 3, 3, 3, 3], 0.25), + ([3, 3, 3, 3, 3, 3], 0.125), + ] + } + + def set_posthoc_bottleneck(self, stages): + if isinstance(stages,str): + if stages in self.preset_bottleneck_configs: + stages = self.preset_bottleneck_configs[stages] + else: + raise ValueError(f"Unsupported preset bottleneck configuration `{stages}`.") + + self.residual_fsq = ResidualFSQBottleneck(stages).to(self.device).eval().requires_grad_(False) + + def encode(self, audio: Union[str, torch.Tensor], posthoc_bottleneck: bool = False, normalize: bool = True,**kwargs): + """ + Encode audio into latents and tokens. + + Args: + + audio : Union[str, torch.Tensor] + Path to an audio file or a `Tensor` of the eaudio itself. + posthoc_bottleneck : bool + Whether to inject a posthoc FSQ instead of the FSQ used during training. + If `True`, its configuration should've been passed in with the `self.set_posthoc_bottleneck` method. + normalize : bool + Whether to normalize the audio to -20 LUFS before encoding (recommended). + Other `kwargs` are the same as in `AudioAutoencoder.encode_audio` method. + + Returns: + + Tuple of `(continuous_latents, tokens)`. + + continuous_latents : torch.Tensor + Pre-bottleneck latents in the `(B, H, S)` shape. + tokens : torch.Tensor + Bottleneck tokens in the `(B, S, 1)` shape. + + Where `B` is the batch size, `H` is the hidden dimension and `S` is the sequence length. + """ + if isinstance(audio, str): + audio, sample_rate = torchaudio.load(audio) + audio = self.model.preprocess_audio_for_encoder(audio.to(self.device), sample_rate) + if normalize: + audio = self.volume_norm(audio.squeeze(0)).unsqueeze(0) + + latents, info = self.model.encode_audio(audio, + return_info=True, skip_bottleneck=posthoc_bottleneck, **kwargs) + if posthoc_bottleneck: + tokens = self.residual_fsq.encode(latents) + else: + tokens = info["quantizer_indices"] + + return info["pre_bottleneck_latents"], tokens + + def decode(self, tokens: torch.Tensor, posthoc_bottleneck: bool = False, **kwargs): + """ + Decode audio from tokens. + + Args: + + tokens : torch.Tensor + Integer tokens produced by `encode` stage in `(B, S, 1)` shape. + posthoc_bottleneck : bool + Whether to inject a posthoc FSQ instead of the FSQ used during training. + If `True`, its configuration should've been passed in with `self.set_posthoc_bottleneck` method. + + Returns: + + Decoded audio in the `(B, C, L)` shape. + Where `B` is the batch size, `C` is the number of channels and `L` is the number of frames. + """ + if posthoc_bottleneck: + latents = self.residual_fsq.decode(tokens) + else: + latents = self.model.bottleneck.decode_tokens(tokens) + latents = rearrange(latents, "b c n -> b n c") + + audio = self.model.decode_audio(latents, **kwargs) + return audio + +def main(): + sc = StableCodec( + pretrained_model="stabilityai/stable-codec-speech-16k", + device = torch.device("cuda") + ) + + sc.set_posthoc_bottleneck("2x15625_700bps") + + wavfile = "test.wav" + + posthoc_bottleneck = False + latents, tokens = sc.encode(wavfile, posthoc_bottleneck=posthoc_bottleneck) + decoded = sc.decode(tokens, posthoc_bottleneck=posthoc_bottleneck) + torchaudio.save("decode.wav", decoded.squeeze(0).cpu(), 16000) + + posthoc_bottleneck = True + latents, tokens = sc.encode(wavfile, posthoc_bottleneck=posthoc_bottleneck) + decoded = sc.decode(tokens, posthoc_bottleneck=posthoc_bottleneck) + torchaudio.save("decode-res.wav", decoded.squeeze(0).cpu(), 16000) + +if __name__ == "__main__": + main() diff --git a/build/lib/stable_codec/residual_fsq.py b/build/lib/stable_codec/residual_fsq.py new file mode 100644 index 0000000..b83b6b6 --- /dev/null +++ b/build/lib/stable_codec/residual_fsq.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn + +from typing import List, Tuple +from einops import rearrange +from .fsq import DitheredFSQ + +class ResidualFSQBottleneck(nn.Module): + def __init__(self, stages: List[Tuple[List[int], float]]): + super().__init__() + + # 1st for single_tokens, others - residuals. + self.quantizers = nn.ModuleList([ + DitheredFSQ(levels=levels, scale=scale).eval().requires_grad_(False) + for (levels, scale) in stages]) + + self.n_codebooks = len(stages) + self.codebook_size = sum(map(len, stages)) * self.n_codebooks + + def encode(self, x): + input_dtype = x.dtype + z = torch.tanh(x.to(torch.float64)) + z = rearrange(z, "b c n -> b n c") + + r = z + res_ids = [] + for quantizer in self.quantizers: + q, ids = quantizer(r, skip_tanh=True) + r = r - q.to(torch.float64) + res_ids.append(ids) + + return res_ids + + def decode(self, res_ids): + z = sum([ + q.indices_to_codes(res_ids[i]) + for (i, q) in enumerate(self.quantizers) + ]) + return rearrange(z, "b n c -> b c n") + +if __name__ == "__main__": + fsq = DitheredFSQ([17, 17, 17, 17, 17, 17]).eval().requires_grad_(False) + # res_fsq = ResidualFSQBottleneck(stages=[ + # ([5, 5, 5, 5, 5, 5], 1.0), + # ([5, 5, 5, 5, 5, 5], 0.25), + # ]).eval().requires_grad_(False) + res_fsq = ResidualFSQBottleneck(stages=[ + ([3, 3, 3, 3, 3, 3], 1.0), + ([3, 3, 3, 3, 3, 3], 0.5), + ([3, 3, 3, 3, 3, 3], 0.25), + ([3, 3, 3, 3, 3, 3], 0.125), + ]).eval().requires_grad_(False) + + x = torch.rand(1, 6, 1) + + z1 = res_fsq.decode(res_fsq.encode(x)) + + _, y2 = fsq(rearrange(x, "b c n -> b n c")) + z2 = rearrange(fsq.indices_to_codes(y2), "b n c -> b c n") + + print(z1) + print(z2) + assert (z1 == z2).all() diff --git a/build/lib/stable_codec/training_demo.py b/build/lib/stable_codec/training_demo.py new file mode 100644 index 0000000..c7fa248 --- /dev/null +++ b/build/lib/stable_codec/training_demo.py @@ -0,0 +1,157 @@ +import os +import torch +import torchaudio +import pytorch_lightning as pl + +from einops import rearrange +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +from stable_audio_tools.models.autoencoders import ( + fold_channels_into_batch, unfold_channels_from_batch, +) +from stable_audio_tools.training.utils import ( + log_image, log_point_cloud, logger_project_name, log_audio, +) +from stable_audio_tools.interface.aeiou import ( + audio_spectrogram_image, tokens_spectrogram_image, +) + +def trim_to_shortest(a, b): + """Trim the longer of two tensors to the length of the shorter one.""" + if a.shape[-1] > b.shape[-1]: + return a[:,:,:b.shape[-1]], b + elif b.shape[-1] > a.shape[-1]: + return a, b[:,:,:a.shape[-1]] + return a, b + +class AutoencoderDemoCallback(pl.Callback): + def __init__( + self, + demo_dl, + demo_every = 2000, + sample_size = 65536, + sample_rate = 16000, + max_demos = 8, + ): + super().__init__() + self.demo_every = demo_every + self.demo_samples = sample_size + self.demo_dl = demo_dl + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.max_demos = max_demos + + @rank_zero_only + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + if ( + (trainer.global_step - 1) % self.demo_every != 0 or + self.last_demo_step == trainer.global_step + ): + return + + self.last_demo_step = trainer.global_step + module.eval() + + try: + demo_iter = iter(self.demo_dl) + demo_reals, _ = next(demo_iter) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + # Limit the number of demo samples + if demo_reals.shape[0] > self.max_demos: + demo_reals = demo_reals[:self.max_demos,...] + + encoder_input = demo_reals + encoder_input = encoder_input.to(module.device) + + if module.force_input_mono: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + demo_reals = demo_reals.to(module.device) + + with torch.no_grad(): + if module.use_ema: + latents = module.autoencoder_ema.ema_model.encode(encoder_input) + fakes = module.autoencoder_ema.ema_model.decode(latents) + else: + latents = module.autoencoder.encode(encoder_input) + fakes = module.autoencoder.decode(latents) + + #Trim output to remove post-padding. + fakes, demo_reals = trim_to_shortest(fakes.detach(), demo_reals) + + # Visualize discriminator sensitivity. + if module.discriminator is not None: + window = torch.kaiser_window(512).to(fakes.device) + stft_kwargs = { + "n_fft": 512, + "hop_length": 128, + "win_length": 512, + "window": window, + "center": True, + } + + fakes_stft = torch.stft(fold_channels_into_batch(fakes), + return_complex=True, **stft_kwargs) + fakes_stft.requires_grad = True + fakes_signal = unfold_channels_from_batch( + torch.istft(fakes_stft, **stft_kwargs), fakes.shape[1]) + + real_stft = torch.stft(fold_channels_into_batch(demo_reals), + return_complex=True, **stft_kwargs) + reals_signal = unfold_channels_from_batch( + torch.istft(real_stft, **stft_kwargs), demo_reals.shape[1]) + + _, loss, _ = module.discriminator.loss(reals_signal, fakes_signal) + fakes_stft.retain_grad() + loss.backward() + grads = unfold_channels_from_batch(fakes_stft.grad.detach().abs(), fakes.shape[1]) + + log_image(trainer.logger, 'disciminator_sensitivity', + tokens_spectrogram_image(grads.mean(dim=1).log10(), + title='Discriminator Sensitivity', symmetric=False)) + opts = module.optimizers() + opts[0].zero_grad() + opts[1].zero_grad() + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + data_dir = os.path.join( + trainer.logger.save_dir, logger_project_name(trainer.logger), + trainer.logger.experiment.id, "media") + os.makedirs(data_dir, exist_ok=True) + filename = os.path.join(data_dir, f'recon_{trainer.global_step:08}.wav') + + reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_audio(trainer.logger, 'recon', filename, self.sample_rate) + log_point_cloud(trainer.logger, 'embeddings_3dpca', latents) + log_image(trainer.logger, 'embeddings_spec', tokens_spectrogram_image(latents)) + log_image(trainer.logger, 'recon_melspec_left', audio_spectrogram_image(reals_fakes)) + except Exception as e: + print(f'{type(e).__name__}: {e}') + raise e + finally: + module.train() + +def create_demo_callback_from_config(model_config, **kwargs): + model_type = model_config.get('model_type', None) + assert model_type is not None, 'model_type must be specified in model config' + + training_config = model_config.get('training', None) + assert training_config is not None, 'training config must be specified in model config' + + demo_config = training_config.get("demo", {}) + return AutoencoderDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) diff --git a/build/lib/stable_codec/training_module.py b/build/lib/stable_codec/training_module.py new file mode 100644 index 0000000..4518ea5 --- /dev/null +++ b/build/lib/stable_codec/training_module.py @@ -0,0 +1,644 @@ +import torch +import torch.nn as nn +import pytorch_lightning as pl + +from typing import Optional, Literal +from ema_pytorch import EMA +from torch.nn import Parameter +from einops import rearrange + +from stable_audio_tools.models import create_model_from_config +from stable_audio_tools.models.autoencoders import AudioAutoencoder +from stable_audio_tools.models.discriminators import ( + EncodecDiscriminator, OobleckDiscriminator, DACGANLoss, +) +from stable_audio_tools.models.bottleneck import ( + VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, + RVQVAEBottleneck, WassersteinBottleneck, +) +from stable_audio_tools.training.losses import ( + MelSpectrogramLoss, MultiLoss, AuralossLoss, ValueLoss, L1Loss, + LossWithTarget, MSELoss, HubertLoss, + # PESQMetric, # TODO move PESQ here? +) +from stable_audio_tools.training.losses import auraloss as auraloss +from stable_audio_tools.training.utils import ( + create_optimizer_from_config, create_scheduler_from_config, log_metric, +) + +from .ctc_loss import CTCLossModule, PERModule + +def trim_to_shortest(a, b): + """Trim the longer of two tensors to the length of the shorter one.""" + if a.shape[-1] > b.shape[-1]: + return a[:,:,:b.shape[-1]], b + elif b.shape[-1] > a.shape[-1]: + return a, b[:,:,:a.shape[-1]] + return a, b + +class ProjectionHead(nn.Module): + def __init__(self, latent_dim, proj_head_dim, mid_dim=256): + super(ProjectionHead, self).__init__() + self.proj_head = nn.Sequential( + nn.Tanh(), + nn.Linear(latent_dim, mid_dim), + nn.ReLU(), + nn.Linear(mid_dim, mid_dim), + nn.ReLU(), + nn.Linear(mid_dim, proj_head_dim) + ) + + def forward(self, x): + return self.proj_head(x) + +class AutoencoderTrainingWrapper(pl.LightningModule): + def __init__(self, + autoencoder: AudioAutoencoder, + loss_config: dict, + eval_loss_config: dict, + optimizer_configs: dict, + sample_rate: int = 16000, + lr: float = 1e-4, + warmup_steps: int = 0, + warmup_mode: Literal["adv", "full"] = "adv", + encoder_freeze_on_warmup: bool = False, + use_ema: bool = True, + ema_copy = None, + force_input_mono = False, + latent_mask_ratio = 0.0, + teacher_model: Optional[AudioAutoencoder] = None, + clip_grad_norm = 0.0, + encoder_mask_ratio = 0.0, + use_ctc: bool = False, + proj_head_dim: Optional[int] = None, + detach_proj_head: bool = False, + ): + super().__init__() + + self.automatic_optimization = False + self.autoencoder = autoencoder + + self.warmed_up = False + self.warmup_steps = warmup_steps + self.warmup_mode = warmup_mode + self.encoder_freeze_on_warmup = encoder_freeze_on_warmup + self.lr = lr + self.clip_grad_norm = clip_grad_norm + + self.force_input_mono = force_input_mono + self.teacher_model = teacher_model + + self.use_ctc = use_ctc + self.proj_head_dim = proj_head_dim + self.detach_proj_head = detach_proj_head + self.projection_head = ( + ProjectionHead(self.autoencoder.latent_dim, self.proj_head_dim) + if self.use_ctc and self.proj_head_dim is not None else + nn.Identity() + ) + + self.optimizer_configs = optimizer_configs + self.loss_config = loss_config + + # Spectral reconstruction loss + self.sdstft = auraloss.MultiResolutionSTFTLoss( + sample_rate=sample_rate, **loss_config['spectral']['config']) + + # Discriminator + self.use_disc = True if 'discriminator' in loss_config else False + self.discriminator = None + if self.use_disc: + if loss_config['discriminator']['type'] == 'oobleck': + self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'encodec': + self.discriminator = EncodecDiscriminator( + in_channels=self.autoencoder.out_channels, + **loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'dac': + self.discriminator = DACGANLoss( + channels=self.autoencoder.out_channels, + sample_rate=sample_rate, + **loss_config['discriminator']['config']) + + gen_loss_modules = [] + if self.use_disc: + # Discriminator loss. + self.losses_disc = MultiLoss([ + ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'), + ]) + + # Adversarial and feature matching losses. + gen_loss_modules += [ + ValueLoss( + key='loss_adv', + weight=self.loss_config['discriminator']['weights']['adversarial'], + name='loss_adv'), + ValueLoss( + key='feature_matching_distance', + weight=self.loss_config['discriminator']['weights']['feature_matching'], + name='feature_matching_loss'), + ] + + # Reconstruction loss + gen_loss_modules += [AuralossLoss(self.sdstft, + target_key='reals', input_key='decoded', name='mrstft_loss', + weight=self.loss_config['spectral']['weights']['mrstft'], + decay=self.loss_config['spectral'].get('decay', 1.0), + )] + + if "mrmel" in loss_config: + mrmel_weight = loss_config["mrmel"]["weights"]["mrmel"] + if mrmel_weight > 0: + mrmel_config = loss_config["mrmel"]["config"] + self.mrmel = MelSpectrogramLoss(sample_rate, + n_mels=mrmel_config["n_mels"], + window_lengths=mrmel_config["window_lengths"], + pow=mrmel_config["pow"], + log_weight=mrmel_config["log_weight"], + mag_weight=mrmel_config["mag_weight"], + ) + gen_loss_modules.append(LossWithTarget( + self.mrmel, "reals", "decoded", + name="mrmel_loss", weight=mrmel_weight, + )) + + if "hubert" in loss_config: + hubert_weight = loss_config["hubert"]["weights"]["hubert"] + if hubert_weight > 0: + hubert_cfg = ( + loss_config["hubert"]["config"] + if "config" in loss_config["hubert"] else + dict() + ) + self.hubert = HubertLoss(weight=1.0, **hubert_cfg) + + gen_loss_modules.append(LossWithTarget( + self.hubert, target_key = "reals", input_key = "decoded", + name="hubert_loss", weight=hubert_weight, + decay = loss_config["hubert"].get("decay", 1.0) + )) + + if "l1" in loss_config["time"]["weights"]: + if self.loss_config['time']['weights']['l1'] > 0.0: + gen_loss_modules.append(L1Loss( + key_a='reals', key_b='decoded', + weight=self.loss_config['time']['weights']['l1'], + name='l1_time_loss', + decay = self.loss_config['time'].get('decay', 1.0), + )) + + if "l2" in loss_config["time"]["weights"]: + if self.loss_config['time']['weights']['l2'] > 0.0: + gen_loss_modules.append(MSELoss( + key_a='reals', key_b='decoded', + weight=self.loss_config['time']['weights']['l2'], + name='l2_time_loss', + decay = self.loss_config['time'].get('decay', 1.0), + )) + + if self.autoencoder.bottleneck is not None: + gen_loss_modules += create_loss_modules_from_bottleneck( + self.autoencoder.bottleneck, self.loss_config) + + self.encoder_mask_ratio = encoder_mask_ratio + if encoder_mask_ratio > 0.0: + gen_loss_modules.append(L1Loss( + key_a='detached_latents', key_b='masked_latents', + weight=1.0, + name='encoder_mask_loss', + decay = 1.0, + )) + + if "ctc" in loss_config: + ctc_weight = loss_config["ctc"]["weights"]["ctc"] + if ctc_weight > 0: + gen_loss_modules.append(CTCLossModule( + name = "ctc_loss", + target_key = "ctc_tgt", + input_key = "log_probs", + weight = ctc_weight, + decay = loss_config["ctc"].get("decay", 1.0), + blank_idx = loss_config["ctc"].get("blank_idx", 80) + )) + + self.losses_gen = MultiLoss(gen_loss_modules) + + # Set up EMA for model weights + self.autoencoder_ema = None + self.use_ema = use_ema + if self.use_ema: + self.autoencoder_ema = EMA( + self.autoencoder, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1 + ) + + self.latent_mask_ratio = latent_mask_ratio + + # evaluation losses & metrics + self.eval_losses = torch.nn.ModuleDict() + if eval_loss_config is not None: + # if "pesq" in eval_loss_config: + # self.eval_losses["pesq"] = PESQMetric(sample_rate) + if "stft"in eval_loss_config: + self.eval_losses["stft"] = auraloss.STFTLoss(**eval_loss_config["stft"]) + if "sisdr" in eval_loss_config: + self.eval_losses["sisdr"] = auraloss.SISDRLoss(**eval_loss_config["sisdr"]) + if "mel" in eval_loss_config: + self.eval_losses["mel"] = auraloss.MelSTFTLoss( + sample_rate, **eval_loss_config["mel"]) + if "per" in eval_loss_config: + self.eval_losses["per"] = PERModule( + target_key = "ctc_tgt", + input_key = "log_probs", + blank_idx = loss_config["ctc"].get("blank_idx", 80)) + + self.validation_step_outputs = [] + + def configure_optimizers(self): + gen_params = list(self.autoencoder.parameters()) + + if not self.use_disc: + opt_gen = create_optimizer_from_config( + self.optimizer_configs['autoencoder']['optimizer'], gen_params) + if "scheduler" in self.optimizer_configs['autoencoder']: + sched_gen = create_scheduler_from_config( + self.optimizer_configs['autoencoder']['scheduler'], opt_gen) + return [opt_gen], [sched_gen] + return [opt_gen] + + # Using discriminator. + opt_gen = create_optimizer_from_config( + self.optimizer_configs['autoencoder']['optimizer'], gen_params) + opt_disc = create_optimizer_from_config( + self.optimizer_configs['discriminator']['optimizer'], + self.discriminator.parameters()) + + use_scheduler = ( + "scheduler" in self.optimizer_configs['autoencoder'] and + "scheduler" in self.optimizer_configs['discriminator'] + ) + if use_scheduler: + sched_gen = create_scheduler_from_config( + self.optimizer_configs['autoencoder']['scheduler'], opt_gen) + sched_disc = create_scheduler_from_config( + self.optimizer_configs['discriminator']['scheduler'], opt_disc) + return [opt_gen, opt_disc], [sched_gen, sched_disc] + return [opt_gen, opt_disc] + + def forward(self, reals): + latents, encoder_info = self.autoencoder.encode(reals, return_info=True) + decoded = self.autoencoder.decode(latents) + return decoded + + def validation_step(self, batch, batch_idx): + reals, _ = batch + # Remove extra dimension added by WebDataset + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + if len(reals.shape) == 2: + reals = reals.unsqueeze(1) + + loss_info = {} + loss_info["reals"] = reals + + encoder_input = reals + if self.force_input_mono and encoder_input.shape[1] > 1: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + loss_info["encoder_input"] = encoder_input + + with torch.no_grad(): + if self.use_ctc: + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + continuous_latents = encoder_info["pre_bottleneck_latents"] + proj_features = rearrange(continuous_latents, "b c n -> b n c") + proj_features = self.projection_head( + proj_features.detach() + if self.detach_proj_head else + proj_features + ) + + loss_info['log_probs'] = proj_features + loss_info['ctc_tgt'] = batch[1] + else: + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + + loss_info["latents"] = latents + loss_info.update(encoder_info) + + decoded = self.autoencoder.decode(latents) + #Trim output to remove post-padding. + decoded, reals = trim_to_shortest(decoded, reals) + + # Run evaluation metrics. + val_loss_dict = {} + for eval_key, eval_fn in self.eval_losses.items(): + if eval_key == 'per': + loss_value = eval_fn(loss_info) + else: + loss_value = eval_fn(decoded, reals) + if eval_key == "sisdr": loss_value = -loss_value + + if isinstance(loss_value, torch.Tensor): + loss_value = loss_value.item() + + val_loss_dict[eval_key] = loss_value + + self.validation_step_outputs.append(val_loss_dict) + return val_loss_dict + + def on_validation_epoch_end(self): + sum_loss_dict = {} + for loss_dict in self.validation_step_outputs: + for key, value in loss_dict.items(): + if key not in sum_loss_dict: + sum_loss_dict[key] = value + else: + sum_loss_dict[key] += value + + for key, value in sum_loss_dict.items(): + val_loss = value / len(self.validation_step_outputs) + val_loss = self.all_gather(val_loss).mean().item() + log_metric(self.logger, f"val/{key}", val_loss) + + self.validation_step_outputs.clear() # free memory + + def training_step(self, batch, batch_idx): + reals, _ = batch + + log_dict = {} + # Remove extra dimension added by WebDataset + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + if len(reals.shape) == 2: + reals = reals.unsqueeze(1) + + if self.global_step >= self.warmup_steps: + self.warmed_up = True + + loss_info = {} + loss_info["reals"] = reals + encoder_input = reals + + if self.force_input_mono and encoder_input.shape[1] > 1: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + loss_info["encoder_input"] = encoder_input + data_std = encoder_input.std() + + if self.warmed_up and self.encoder_freeze_on_warmup: + with torch.no_grad(): + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + else: + if self.use_ctc: + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + continuous_latents = encoder_info["pre_bottleneck_latents"] + proj_features = rearrange(continuous_latents, "b c n -> b n c") + proj_features = self.projection_head( + proj_features.detach() + if self.detach_proj_head else + proj_features + ) + + loss_info['log_probs'] = proj_features + loss_info['ctc_tgt'] = batch[1] + else: + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + + if self.encoder_mask_ratio > 0.0: + masked_latents = self.autoencoder.encode( + encoder_input, return_info=False, encoder_mask_ratio=self.encoder_mask_ratio) + detached_latents = latents.detach() + loss_info["masked_latents"] = masked_latents + loss_info["detached_latents"] = detached_latents + + loss_info["latents"] = latents + loss_info.update(encoder_info) + + # Encode with teacher model for distillation + if self.teacher_model is not None: + with torch.no_grad(): + teacher_latents = self.teacher_model.encode(encoder_input, return_info=False) + loss_info['teacher_latents'] = teacher_latents + + # Optionally mask out some latents for noise resistance + if self.latent_mask_ratio > 0.0: + mask = torch.rand_like(latents) < self.latent_mask_ratio + latents = torch.where(mask, torch.zeros_like(latents), latents) + + decoded = self.autoencoder.decode(latents) + #Trim output to remove post-padding + decoded, reals = trim_to_shortest(decoded, reals) + + loss_info["decoded"] = decoded + loss_info["reals"] = reals + + if self.autoencoder.out_channels == 2: + loss_info["decoded_left"] = decoded[:, 0:1, :] + loss_info["decoded_right"] = decoded[:, 1:2, :] + loss_info["reals_left"] = reals[:, 0:1, :] + loss_info["reals_right"] = reals[:, 1:2, :] + + # Distillation + if self.teacher_model is not None: + with torch.no_grad(): + teacher_decoded = self.teacher_model.decode(teacher_latents) + own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher + teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model + + loss_info['teacher_decoded'] = teacher_decoded + loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded + loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded + + if self.use_disc: + if self.warmed_up: + loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals=reals, fakes=decoded) + else: + loss_adv = torch.tensor(0.).to(reals) + feature_matching_distance = torch.tensor(0.).to(reals) + + if self.warmup_mode == "adv": + loss_dis, _, _ = self.discriminator.loss(reals=reals, fakes=decoded) + else: + loss_dis = torch.tensor(0.0).to(reals) + + loss_info["loss_dis"] = loss_dis + loss_info["loss_adv"] = loss_adv + loss_info["feature_matching_distance"] = feature_matching_distance + + opt_gen = None + opt_disc = None + if self.use_disc: + opt_gen, opt_disc = self.optimizers() + else: + opt_gen = self.optimizers() + + lr_schedulers = self.lr_schedulers() + sched_gen = None + sched_disc = None + + if lr_schedulers is not None: + if self.use_disc: + sched_gen, sched_disc = lr_schedulers + else: + sched_gen = lr_schedulers + + # Train the discriminator + use_disc = ( + self.use_disc + and self.global_step % 2 + # Check warmup mode and if it is time to use discriminator. + and ( + (self.warmup_mode == "full" and self.warmed_up) + or self.warmup_mode == "adv") + ) + if use_disc: + loss, losses = self.losses_disc(loss_info) + log_dict['train/disc_lr'] = opt_disc.param_groups[0]['lr'] + opt_disc.zero_grad() + self.manual_backward(loss) + + if self.clip_grad_norm > 0.0: + torch.nn.utils.clip_grad_norm_( + self.discriminator.parameters(), self.clip_grad_norm) + + opt_disc.step() + if sched_disc is not None: + # sched step every step + sched_disc.step() + + # Train the generator + else: + loss, losses = self.losses_gen(loss_info) + if self.use_ema: + self.autoencoder_ema.update() + + opt_gen.zero_grad() + self.manual_backward(loss) + if self.clip_grad_norm > 0.0: + torch.nn.utils.clip_grad_norm_( + self.autoencoder.parameters(), self.clip_grad_norm) + + opt_gen.step() + if sched_gen is not None: + # scheduler step every step + sched_gen.step() + + log_dict['train/loss'] = loss.detach().item() + log_dict['train/latent_std'] = latents.std().detach().item() + log_dict['train/data_std'] = data_std.detach().item() + log_dict['train/gen_lr'] = opt_gen.param_groups[0]['lr'] + + for loss_name, loss_value in losses.items(): + log_dict[f'train/{loss_name}'] = loss_value.detach().item() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def export_model(self, path, use_safetensors=False): + if self.autoencoder_ema is not None: + model = self.autoencoder_ema.ema_model + else: + model = self.autoencoder + + if use_safetensors: + save_model(model, path) + else: + torch.save({"state_dict": model.state_dict()}, path) + +def create_loss_modules_from_bottleneck(bottleneck, loss_config): + losses = [] + + if ( + isinstance(bottleneck, VAEBottleneck) or + isinstance(bottleneck, DACRVQVAEBottleneck) or + isinstance(bottleneck, RVQVAEBottleneck) + ): + try: + kl_weight = loss_config['bottleneck']['weights']['kl'] + except: + kl_weight = 1e-6 + + kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss') + losses.append(kl_loss) + + if ( + isinstance(bottleneck, RVQBottleneck) or + isinstance(bottleneck, RVQVAEBottleneck) + ): + quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss') + losses.append(quantizer_loss) + + if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck): + codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss') + commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss') + losses.append(codebook_loss) + losses.append(commitment_loss) + + if isinstance(bottleneck, WassersteinBottleneck): + try: + mmd_weight = loss_config['bottleneck']['weights']['mmd'] + except: + mmd_weight = 100 + + mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss') + losses.append(mmd_loss) + + return losses + +def create_training_wrapper_from_config(model_config, model): + model_type = model_config.get('model_type', None) + assert model_type is not None, 'model_type must be specified in model config' + + training_config = model_config.get('training', None) + assert training_config is not None, 'training config must be specified in model config' + + ema_copy = None + if training_config.get("use_ema", False): + ema_copy = create_model_from_config(model_config) + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + use_ema = training_config.get("use_ema", False) + latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0) + + teacher_model = training_config.get("teacher_model", None) + if teacher_model is not None: + teacher_model = create_model_from_config(teacher_model) + teacher_model = teacher_model.eval().requires_grad_(False) + + teacher_model_ckpt = training_config.get("teacher_model_ckpt", None) + if teacher_model_ckpt is not None: + teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"]) + else: + raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") + + return AutoencoderTrainingWrapper( + model, + lr=training_config.get("learning_rate", None), + warmup_steps=training_config.get("warmup_steps", 0), + encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), + sample_rate=model_config["sample_rate"], + loss_config=training_config.get("loss_configs", None), + eval_loss_config=training_config.get("eval_loss_configs", None), + optimizer_configs=training_config.get("optimizer_configs", None), + use_ema=use_ema, + ema_copy=ema_copy if use_ema else None, + force_input_mono=training_config.get("force_input_mono", False), + latent_mask_ratio=latent_mask_ratio, + teacher_model=teacher_model, + encoder_mask_ratio=training_config.get("encoder_mask_ratio", 0.0), + use_ctc=training_config.get("use_ctc", False), + proj_head_dim=model_config["model"].get("proj_head_dim", False), + detach_proj_head=model_config["model"].get("detach_proj_head", None), + ) diff --git a/dist/stable_codec-0.1.3-py3-none-any.whl b/dist/stable_codec-0.1.3-py3-none-any.whl new file mode 100644 index 0000000000000000000000000000000000000000..ac31d4f20a874f124e1fbcf949997970adb5c1b2 GIT binary patch literal 19930 zcmZ^~LzHM?5;R!0ZQHhO+qP}ner4OX@yfPs+g0=W%wqoTnYs7m;@jjZPex>9q=Gas z2nqlI00e-Wg^Z3VY2rHRzZccN!1x!=E`~-nruxSACZ@*p`udi3mM;4GbPk@<)DyF+ z6m*g@6Oys=GxCzME0dr}oKP7ZpzorjPy+^-<_RDG3jdW_lZ@us83F)ckMRE^wXut_ zzKy-J^S{*Bn$z){Y-oL-b$evtW4fxyv-1IHV`iD6fyJ;39hhMP(#_K%M5>9I<4<+F zuM$!uq_NDBS=-XC>4_4(@V&G6mOIUjiz1V?JauZLa%?ZW$UzR=?pLA}9VpFYDv9g^ zg@&OBaVnNx<(h(@QIkT_npHCHDbRFsnzX4#Rs^2V`nM*_t2QlOax+?OR^G15d>Z?W z+`6ibD(Y$b%aM>Piqn-!I4qI(r1?_b78iI*g5zIh1ah4k-Rd0S+C`6EDv(;7d7q28`4ygvuzsx_%#}l`g)k$ zC@9}o>@A%P)nCb_b00JG;F<}ID) zZTeR%bFllbCN4KB{Q>nt4e zb-b^K+Hv|Jqt|TC@C^-A)z|NahSJab-qK=fFRPc@8UAp3q6o#+j}V&$`F$67-iwT^ zo~|Nkm!Z%IY`z5)zVOk6hF)I$xV0G~%jqVOJOxi>zNV7E9I zmh>r6G3w)C0>u-!&rX{<&=?8;TP1w<5r{|aGlgg*SkY@3vhv)lmB1VdG19Iw=epfW zcVS#G+fa9DqBe)>5td^`5rOc&oSe+!# z1fDlmQ=2RrOA|p0OhuHi@18z0p=nZK=?Xm@=3{G0%C9kgm^7ne$8KSG{F%qcApuEo zT7`J5J~V1C7Wijx+kH^|7rq`u0_6iL1gMH8TM>@E0}%p7uRQN_V3F2T zcTtUy?m}a-{IUwy`Un>etAtLtRwBNBKkE+w5@?A7E>L44VDT3@+1M<*{xF(7;MWxa z0YK{-t(tKOqs7RhW@FT<2g$^^9sbO&<2vx1S>>N`k43ha!rG$ zoO%b*(Ni4K1LMd*xTwk@+Ee@quaYY(A_u=1;?;Up>!#R~$J9`3D$UIYu#pg(fkt#g zSmry6afk)9LFA~?;yVT@=(2ZT@tQ*`(5>Jj=r`rV8EpATW8~qXgY=~Q6N~zNN7)o{ z6pVm)j)-Hj9BS^Th-4m$6#%I%1%^yjOt8# zV%c8IPG2ugx0Vg^fIbhDZjJWE>%J-4n}db%JW{s0$xlYf{3+Lr^ltQS#^Fg2Js@sjtgtI()WrGN-f=Y5}Cj52PH{az8z`RH_N`{52~ zWmN(xv=+>WFw|(r*=D$~!ml_2o?$H#vD>crGH2hNe55L-O8(k7`T4RDu1`frF=nm% z4UK%M*=QIq4Y<}vyD%UTXodA$D^~Vwii?bd8f}X*k=qckP8^nK`72~z`X#P_7gquF zm#^Or@09w7T*_uh4pkwarW1lsWe9cPVdv0JG|-8J!~98dyuhc?8h!zn{E7-0Wp@?> z^~gdJGgNd?E7{MO$qBI04lkXLPc=(>XC#1-GrS$r1!*gr1YgB8f{05^UQ9y2U&+Io zb|Y();m4csM$^QKXYVX3!JEO?6J)*>gRWy^^P0}gXCC(ZqlCPZ5W{NWf2zpF6PD%Is4 zteH4x@6QyA#9=d4+?wDt{O)6yZZd^u!p@76yBi7aO%0tMRP_(^e-jn9X=wBwPym20 zL;wKf|K&ht&W`^$%Cfqw{UIAd&!u|39!gqkBM!MOJeUrOMV7Mn+RaXoKt9O?5lymI zq!bSAOP~82X<}m!XexmffF(OXVbRCBSwYl@=a+$hd!yC}Zb6bNp8 zel}>|eGq}po>+k1`U7!gh(+hMS5Nr1Z`#pUd+ld78f|H{Hh-edo!hhKCPCE+Xqsos zj8$$;d^>V6D_6y`WQ|;fL2*uy;Rux2LN^Gs01gGnK}ms4{8__(je!4x-0+$5?h*+@ zoz@hTAIz1K<|`VcF@$>rqw!oXGnfB;A-Hw&6ciug$~~J{N@SO|J3Swp8u)9_se=M= zn!?|#87S=_ZGkp&K?Iay5!Qk5Kh1--i6^YKgNtrYftRI&+A1qJ(9Ocwjh)eR7wE*h zg#ZKk(jvP|D0XU6yh69I`Q2a=`;=l;vN3FkAA*~}?~>kjp0iOX2Z(ShW!8F{;1j=T zR%e(-v0Q+JJL&?4pjoe|8*T0RA(3>64T;!cQiixlro1saF~9~8se@3kib@E9k5DQ= z0UNGkWQX6Q@C7IQ3j`!W)4ZHQgqk!gIv>+ zX(u@90FQWUxtY!ze$r18t_XMoNEwJE-H)=}NNBZo?(}2~J0$%wT})1p$W6v$(=w7N zQhNz)F-Q+1iEbq#`5XRcaJ++itr3Z3|;xWz4?99AwnA!{11tWn}hm;ECa7h));|y6g504S9?pK(%odF!! z`+5fCP1bn*SvLz75SNy2Nva%lOCP{Eq;BVFwTcva8ZuggI`5?tnkgZY{5(h{CctW~ ze`KOmNJB`*ixWCo+}jwP(wL;pV7l5An67Gr8}!yz3gJ+S2*tJo*(&<0A8z}cnJpn# zLaI1q{=ly6?*%kWASKg8!(c6#Nd0I+gpz7KvnCm7x>9RFL2M$|iw;&G=`R8I&7*)r zz;xVL+>y-L!5<wq{}k*I7*`ql0)(B+v1!jo2dB>nwO%$Zj>T6p5;>b;285}u*8l!c#5E~+ zl@7M~B&M@7Z_4=qFoOsS2mDuvs-3DV$I29c+GEyh)I**Vitn|+Vs{KJ@{>)p!P4W#Z( zuk2O}h3seaeN73s$TW$>i6NCg>p12k#WC$<&Th#a2B58U#TuUH1Fxg!+6aP2d(m+x zTl#hmzYL7e)?m#Xc5X~M#Sjo5p8W2+L!I}N#rocjUZx*}4?adwIn^LQsB1fPo=EIC zi9+AOB0V=WyU_q7n4a|fdGbB3jsjSu1nY#}hU4+m)su8&cCB999MaWKv~;eLovfad z%Em$1BQ9;X6R&T>zDnWGz6xw=bT>W26hX*-wVI{@7(DhsQrqlZ=yRH(K_@uv_9dq1 zAK!PBsUcP!2o!|N(9@Q&HEl32yM!;>V06qn6S?NWfiz5`OmWFdo&HCruwVL;#6{U1 z!8RU*H!jP4YkheOFB5T0z>kP)W`+|F;RHi5Sm6}i%)2i`Qw5v24kx~2&vmvqN4qJv zLuHqOfHN9i>Vk_AZKnkXv-2j-a|F`IS&T6Uzkrk4q);ZEBF#(yL8O6F4 zi7IJNjtZ}vQ(yfRJx$yEH*Kt$Fc4=B>L2j`?)FT$xh$yQ008X&#$D9^?e@0+=3bls z^m`3W`%Ml6zt=i~r$(g$iyWI|BVdofT#G=PuFWTpF|Z^WMKft*5@*I|fpz z#%(?|2!!aPE#}=A2{zb6BwwVM8LAmJi5>QkS_H;RbdfQP9!nBKmOA67T3aieqF0n1 zb7pS|CcWyOKOssf&>|IWnQf$oVo+o$tVYzKhuM$Wwe{;LO!_Mv%cIKU8rN{O^)lFT zq_dDCrk#eQweIsIrgV?Rsl~*<{TFV*`pmuSlEs=_K|n0YYh#kAl3EN}iOYZw#YlN= zG`WZ!e{ufM)BW;_9G`2|pYVk?5%iLPf)to}W}vahUe(ONOokW>POyRFau3*FV?p>0 z7FRQlj7;0GQD5jlxKy-wKm|o7EF9UZwX(*|m_y_usZ+>8k}}B(2VPRGOkObNGRR3} z>D12ZI_W@;*UbaO>NOth(QqoK~5E`6(#5Ah9Of8QMCs|lxv++#QH0|*p6L!r+tYBj~N%S zE_5}Vsw>!2%?|Agjlw3G=M$|lL!~lV*u;1RRP~<;>$Sacp2ma^IOnubU^9**1Robr zVIo(0_dnrl@R6a#IZuqC7Sr)dCqCDd>qq4Z8;({W9+E%4esA9*6Rv_w-rXhgdL z;!+*=lKgz?g7)<7FXMKuk0#|qxkbdjui3nK*rz+avpVc5 zfp+N$)@n~cc3wqPxu|G8bRZf6AO-bu1*#d>7@1lugBDHU5WERW#wZOb(ZC_Tt)A~xaS?&(l}3GWUya)V7}2(!&_h(lpKWqXal^BoKY0Zc;@1W254&ib__U^d-SlT1TaT^p!CpRV7H4p?J`y4YLSRS{~?#`X(5A_EHyO0+vM{%_NMi zke-?&fIb0y;(D3_bI7a_78XFa)7K7SOZB&5oW)dY%yEDXTgGrLlFQM4>d54QM3|HP zD?izQwI~30-FW2Fkvd~sD>wc3BrBh}zriQkbkdSv^~pW`!4oa$k)(juTGT3;wlxFI zlo#*NcrVV+>wSIXUXh-@6(zs>#s7^&jNaGl{%JrCp$#4VLTk4|W6eus;f9PMqVhKuDA#k= zAdk6EI~XP-5O?rDQA3_6q_o#NoeJ$F#LZag^YX@*?DBl8p4n7`p<&7-b`p!!J0tD6 z#zzdRpIX?&?c@`sD@I!Z)oBGS#=Gh%#)T}sP<(BiOLOW|13(xB4FMGkIqa?L$q2K4 z%&mn%m$s&EzRv?Ur|!s3+)y1LN95yfDAp| z;u($(W@ad`9)}Y<5AUp5bMdVCyl=KbN}ZAg!2M;6ofut~{6D;|lg-;S1M+*|0)$e_ zs%c{V(UHig;b|mpLVE>Jn=c_1NKKlyn{&BP*bIye3zE$-N<$XIv!{D@h{E8C14yV5 zg++aHNI$6G11>`aF z&JSH3Xgo?-DB^_!?Hc&FdD|dxg6_H@Ev9z%kkzh*Yl#{Aci!_zTgPo)9y<8~Muf;7 z@KZA)u3B>$&p!CCbFFKGRi<2^`p6XcngT4U1@9i5Ps%o4yc$Psd&EOho2d4oPh7f6 zEVlnT*BcCz%IO=5TY-!?ZXrupsk58$piuRiX|AAAE5_*P?^6?e9$|GJG<-FE# zTi+%c5!<9fjy-@{rQZ+Szm8X=%G^i8VV@kmpEm$4)-lg}v9x)}tmh0hf*f;Pk}E4$ z{e5gw_67ky>N90`YiunsW);nrTk|$0^=fC+1}4^R{F%L!zce+~cz(Y!NsVYT1N7%A z%TZY2y;!%_6{a1eT21(+R@RcY$o*>LtPZu!5824%m#$DV(XlSCyjk7k$**}*pG2Xq zwjuOUhF%u6w#B3=v33?r8?^@KJk_zF;f^aw-c6|+B?dtH;Hwxa#pi3|gbN0rY8G{! z022l#xDg3HB4KQhIZF?sD|FHHF~uk?Qy=~lN73bpC8?ik`-MKnq9fzLwKV5bh6d7R zbQBc4V@fhXm!4L7nNob1O|T-0rbvMJR44n79os}|P|G-?ALrGCRJ7E@8OQF)oxex%nWltNMG8 z`o~5{-J&+`7X|s1P*6uB92g-ui)+#x@>kz7FTPOoLl6XLCF_-8;3w-vw!cy;`34Ok zFQ<&tj)^$Pvc#^Ucg|ShPc_TpQ<67Qaz}X8$n#x<;g4Ed%q7&SL>pOr;c+m^&Z;`) zOsRvt^mG-QIn9HkpyXINWU?pm*{2nQFe564GeVUrYa3cuB&P|c^?N?#DJMRNLlC3l zY@2-V56)1heFwN%rA+x5%W5nk|pzeebP zzS9S!UXgwK!)1 z$CTZ*2ytvqmGwiRLj+WtHD+aE4xX)$dg@O`p<(Z_BhjHQ(p=XyM)0?ck!{$JU-#{r ziey$xgzx`T|Be%emGm2n@6jQ^9@NZSwHGUz zm*}o6VG_AOebpDMyyC#BCp&G-s<_?8p?%-C06_;IWT}{3u2g6w>JP1Ewe&<~C`soC zAfFLg{CkaX(gYW)G3a%I{DNhRvJTe%u`Eu9aCz^atSRbvA87S1XbH0F@+^*A53 zSC?QaHxEge5btD2vdx@~i@XrQIGt%>S@l*dScGvqAa9fUAmBLym(V_|Rmho{+Bm*F z42V*b`i0IU3|gf{SN262T@Io$z$Ag`upT4pC8k>>M+zUMiVHwoCwK-5rKNSOfKcQS zidN27$Tml21i3!?NAA+A=O@6#C@uiQ5fn9%!jU>sums82rkFg4BfW_DV5~|=50J9b zq0&Q8qcXlaEZ%&R`XTNrJdxeZA4;YO#1Nkz`3t>Xhmc|vl+1_2}bJv7U=G zDX1&<%B?jNQT&z`B0Dq{p@u*v(rdJh%r-`>fVsrUvN_RviV%JPik(KJ0`wEX;g+f%lBJ?ZH6UEU-G)dNgwMLM^@eC%?7J?G^v~%sk9LtOg6fElaBtV#mqK` zxQAL0`s%iQ`Cn{xC2Z!bc6OU0HD+N+_eryzD5D`^u#<6$I8L({6T3R>)y?GAGp=D# z?FQ+{?Vt_)7=A<3zFtIKqd+>ygs|_ppP8P;_8nu#HK7*^*O`v7_fJjpm)DbPCsyGb zWfv-;Y@r3KZ8Esklkm}1^=Qo$Qy)qB6DApl6xJ8mOeUMXKHfpF1+X?{EuQcRBmUWY z6a8*wW2Ub+*{6=ISxyD{E603hZU{M6OTJeJrk_7=MH$VgZ!@R~A>AH`ATD^rjCX)nAl;eh5>o9A$N|^F$)IM6Y}=rDYVz_NBa(t?KtY4s7FnBz6NcZ(ak= zAH8x6Uj8Qy>9{%3fCVo=C7XE0mWQQjIH3v2^HzRj>@l`D-|LLV&6&#H^a5kjF4U!Y za0&*KPOq%Arrx-`Zk^)vu=HRjRHLIIn_hV(BkEg+}=G1rjt>&sux(l$kfLkt;V_c zg7;OlhNKKU*dRpw+SCVodoZjMvy!#EfHV5O7 z&*F#JQvOwp4*(oWvyE6&MJ+pOb~TY6%(g}K{-vC4JzYzd(^HJ!>Q0Qi#fLJ3RhzIH z9V4&Nk{h%~!y?#Q`z1>+uVlxbLMy08*)^krf*z<`p=_>81#)+G0j$9jm#v~Opipt) zn1vagY!K+!p}tZ1@X)LIlaFoj=d@iJU^Zwkb|&)9`Xf-Vw2oGqVn}GbrO$UdpM)I>=e0|nUzqv zH78Jj_I$Oji%+V0<}j}GZl&@|r2?U}99fI@s=mhWHrsv$)YZ-9&|TSs`5(ykCi`em zH2u0@Erk4C{5N=&Ig;1jzDnt{C|T2ISXuOiV3=u3Icqg2T@U`%2X zu5EZNN*lac3moRYi~@j#mMw%|1Z)gE2kvWULeWWf=Khvhz8Y<0>WeJi(->w3?ZV3Upwe}}TmumYOUVn5ypS^J`9N)jB7U^PQCwDy zoDP@*N<)Hn!raG)R{aA0H{m>r-SG(c=Uh-V0RX`N{{`HC3XZFd>3?hmJl6Ku;z_&D z)L<+(98YnUl7C=dLX6?~Hodz7dd4R7K^*{X*VH;aiAuqyZE@$pvc zOB0VZ!^DtpH}(`;WfjebS{w~^6`+V^QMz%8v9?=%NuWwa&3oN>R90YdsmVyL6i#{3z9t2ctg9O&^QR%DmO0aBwrH>g(-^YvR-`-VA*V^>0&6ZCcL%>P>_h*c2Gl*Grv(A3qOuwT=)jN4rh#^6MgQg(w^g97fvoTZGVTCMt994yUh&H&?K*UWF8spWL zslHU)8(&LtJTq%QO?{AF5Xu0X4IC+3r9(E#(_i~)?O|ly^+C3gsZYfC?R>5Nwx5-j zp6SgzPzP8UCrVYoQj`MKI0Zl+p((#s9w#C>*%kC9$7i+W7hDZ;=UTpctyq>BJu~S- z!1b8ZmG!m3c_LZ?}Q21lAN1DQlYb zIs-5~s)H<>rr&gyzdqjEqM4>rWoRF~vNr03EyvXEzN-;1FLPr`>c(eH`sK#n5eeyf z@;@qIwM~Wn{50c2Y|>YMJn&Um{Yv4!3yLFAc`LAYd0KOE zIFy|qg8|3M$v}@J%ZWFg>ZPLmA<>KZMb!5e=t;B(_|JkMxE_fqa{ytsudkPvRSR6V zZB^@@>eYyPDOifetvhG?jc#3ne2-POCf!m z_?@mS-Ixvu^_@Gw?`YvgM&Rx%P57IDy1uJ-6&i84YnOLGRA0a>1F8X#HCRsLCVaoa z;zvk{*{s09iyzM2#?zzTghnHIq2 z(Wh>DjBOZz;V%r3Q8*^Ym*{{I83acIl~lfvV`UtgSB@MsP9pHjIhrDfQ7nM0ANO#{ zH$ci4S_aCTA3(%cX@2c9K(qGIBHChdP#JS9vwWu+QyV!}-a2Lh=prX#n*|#optMF! zGcZ0H0@t;FVN*`$vVpzRgMf7jBIV5ptM|qRq}S0V08#`KpwftF zq--A*WwIS%vXY_{>42!|R1){sa~6wa^Ti9on=ZpZzBGP^slil+Po=IJUu_dz6yT@n zMqCT}RBX{i_-;$^RmP!h)BFu9`9@MSK0JkfE}Xa5ATI366i+!*Pg@nw>}ICuusd~m zq811E<0=YU-8iw9~P ztqVo8^ujHrek#Gvdsi4`pm$BB)-@c35p6Q-D`v4HJ|yy_0j=rJ6`>P1>oDNoT+d$0 z1h>QnmOnhwk;e_T1f@|6QS65BA&PmKR>`p2CWIWIB}GxqvkOGi-oD#3=0jv7J`Rwy z3dF&LmEchZ|{F!^3eqWdJ4DgezDQ6@3{qa;x!o2tIEP7T_ey%1R7PHVqeRJ~P zLE{)1Ch3}(ClidUu8s5#2Ya4<0_6Qz_)II;6L;KdFw6#C`@jQv$cV!!%H?$w=Ebtd z070>?LmL*aF=%G`gV|J8ay~kbh&yhqx9sBx)J(hWoyZn;HeE6Av#-_5g2&O_V4iwS}U%D<08&}Mr-W_qRSsu!57X^mkOZ1HeFa;p8 zr!_!2TJhqeH8S*d=E;b9&aSN@16eS=oO-xH&tZ`iPY;tV!w$TaLu8&CfiQD-krHKY zddKj3Q38&NQ5Bh1uAxJe=RoGv4boT{RhL@fy?wKY{v=EU<6+&^TBDE_y_lHz^qcn9 zuMRpHLsisiT|GO#c!hZefpd6pGFfh9KwbePM7#Ke?n7`kl22Sb@PRK6qQ>0?FbY>| zhQqMLEWV{k6h-nrP-WWo?W0pgpO5|}!_k9F^w;Wk9OxGB3+WUr%khA1NR;a9#G}? zgH92~=)kK6xv?|BwwLeTk4Sv!+^^0%g>(#Y{un=w)UhG*3~+Y=V-HiZ`(!(@C$xP6 zVkEqq!$pl^f#}LSoNv=5hBvvz%=_NnDc}*QoASS9d2zd%)yd}s0EFL3L}6>CE{?1z zmPH`Hu$`qTM^nPYu!0Parr`t@l*}kNCUZu4^~ZTchz&#LPGf1x;iE|^izF6*d{t=T z1#3zH0O(rXScn`8ZS?-w^O9Xl3auaAnmf7qwjG zJ|s!YGqP_fctRnHb0vqhic=b{LdNFWEcK-eLZeyu`xC4 zC_LHJ!=@u=vwyWu61_*=@g7UBx|9VMfmdt@$a&K zkMiPCNzKxT_nX_VtMY$eq0^hzK55M}mX{V89;_^*MejR{QE&;SQth1@;=;DeDs*R} z%0`W<%rb2EyS=HdDM-KtO&$}Hf}Qr_*e!eIAD@;jXmTi)#f^IoeKZ6IIxw+OEP>;U zD6oCH?N%*8C4wJs3(lE1NkP%nllH??kN2op(jCAFU$iBWCaX5eWgd;Nk3;=blk?fr z^~^O?J46UF9M9MFs|9a^?!4EZ^@0-KoI-u|fE#lwCkK{XAy>9A#btz<=zW_5WS5ezvnpAb+zK_UG-NSj07ui zXS~?$kODyM;{tVTvAZP+oPb#ndr5|6=`JxJfFld~@PbX2$oOosag8;oSv@)#xZZZc ze+7J7A$@ft(wqaG#=6iG45fc2$Ch+jf9)f(RZ@_u>k^8j{w7RKl)QqHz72jofFmnV z5Ne@|izxwTNXmPU(3^sw9N*jeIb4?pCQ5hOS8()^jg8WZt2(7AJORNeSr5YHngNHkJ?Xj-5@a|g3F+^Zl0gU zEmI-Ox3|;yL%IoR(Dwy-^zNB-E*2+~IJJtZC!SHjJY;;ZH+a$N&EQsMXT!_7#@n#| z8vF383sY?rnN_JPsGh%)WRLbU|0^K3)5{UyeEkPx1&Rbm{r3@U1qxKA13=WS`y`(R z)+KzthEE;8r%O158>inH(ys<%+uh|?i~h^H=T5#uPeF_O<_=){65p^t9`*nyQrjr- z5IP#dy!m%K8Pco-xfOk~PSLU5xJZ;yq)j&N0u4;}d@bl|ZDftv6JEGe&{10YM%x%6 z^t5%d^1v~4b^a+y-6i^w1HLMP6HIV}W00%^^sQ(Ei(O9G=ZnL%)o;mn$>~_()x*b<&^HdjAnZL3~`iUBp)$j3qXT~7L#k;bhil-B+~EFbL1$k zokl~?Wc(a-SQvdZT!GtBDf1v#H#>gc2~(zsXoGjFSb!PUif1|7^x59;(eXAN8~1f# zFWc|}{_&+(G4iNpykZXvoRr^p3%)B*>6NYB1KGBD|;| zvUSq*jkN^Q`J%A5+H)^_!CUQqxXa9V&P~-A4P<`hHKsiz{zBoGrF^{qGwNA+pr|oBI4QF$D2C zYi)~;31F`Zt>g0A^t)4z7-*9)NhII=InNs-R7{e!!5boQ+zY^H+zOW5bR@9T-7FW( zUBu_`j$HKx#wEejTIM3dp>f80Gc)!w7awuxi#Oeh95?T}#jXD%R#Gvctw>p#qiw)jtE81Ib%FtS6s8=b3B=p}Yfm!VsE9JdUv ze9N4RKd&Lio7hx@kjQW^v(h2cRdQ20PT3K}^a$C_pSWf)hOBmZ)cZyh3Ncxx7>Djg zx-8yjrK%0Rho^YiC^aeZ!Cdn8A~$Ad%%xwXwBVHZoUW++3(27y?&(Vl4=(k;L7|w^ z^8=B2=0HsQW|On84|R$ArtBCvY79Rb;$bQLTH@6pbH-v!UP)S)OPsPbXv=`xg39uY zDG1VijXA?{Jb2cm6zA!Q`LjHK#oh`U53h!;oLq_@z41AwMyP?Q+>tk32+NHex8aFU zRL}8qUHwo9Q~OU$mg2<~_t;@$x`g(JcBGW&xdS|2iJRPoq+R=YyPr$4b|zZ=7fk$e z7itvbL71B+WvS&TRrTMXE!Yiiac9xZBDeCLoMxSzQ8pR9m9Ln4FUixaoUEqD99WefxyP`IGxn>#&erdV@V2d!%j&1R_PbQSN zok5z+5k~;Jb;Aoj*>?b$QB}+C=c`+{i%8}%rtf7-b0}+z z%wd-aDD5S$o?jM5K(e-?5sI_$yx5Dog2rFOJxW~C7pcR46=l8~sxSn!^u4P<11BM) z1!WeB$3jYdVH7hL8CLAE+U$x5$>ZxKK@h~G9G)uFT|HJf1Yr@%>Bsmedwkr@0;5on zi4IBRKp=EeSRFuj83ohVwa4TB066ZyBNmyJlb^mD%B=TFtn`EvLY|XvkQgrKLJd$! zWG!pR<^^Ehti5QE zrePu3a~p0lg5>j6vXzF=FT6)k%Jz6rH=X*}UCN|f z#^|^&@%J;C#?cr&TI#CDc0qLIMf?Hlp254qu5ZFTnDIp*lSBzw5DWgiS-@llZq&+U z&+%!hn;CDco?m*&s7M?h*ifK@s>;b@Pe;H%MHDpGB{^L+KwQ|N1Hu0kV`gD1atYJW zKN3i9wt)xTXTs6&rUh@Am1qY-=NiVyJ3s)uDRjYM+nrLF1Y&NOO~1KxC=1_?Sdp11 zpoyZl2^`9hV@+(7?`9GXy{%nLx}7hz-@SttoTLa`B<-1Ku;qmPGLV0oZbUh% zv~u2=jU{;Fs_{Ji*aGwZz^gt2?a`WG8AAa^GoJmxYC3xBWUw3)>kG5|$`P(&FyNUlZF`$GF%X9W zpwJF!1p_JkR2fOmywK0{uG7lu!7>m$FSb9E%U9$ngE*Z^zA_JZ+np3qoe=$4T?2XU zDDCKD&5rh*p^qkJLL^V%WqRh_;6pA0APj<&yiM%#>5l z{@a$*g8YmB5ie;O=osmk=}au0U1%-s% z&wk*3Zv`9Q;7g^V^cQb@v%qjhl6WKF780ye&S> zxYj@l!_B$}$K?P4{m`-5JMRHtn<<7Jm|i51zxebr#}6F9Wg&W9VIKrfHy`KY8X}9^ z+=F57FubARpK4mS+t)swL&n26I?_0h9E`cMxgYL!@V4f+Q6c!s z;9(?Z0g)4CLE?Iyy+1ZMSI)b!bxIk=7)dl&pa{3Op@a^V9k#X}=Oo(*oW!EfrVuS&~Il zr2|0R((0&zsZ-5? zqM-w4-iXlIDPdAibkHD8y7%gg#3)VXWB%IR7&IZN5b?(KElH><#i*o;9WoE3Z1@PU zH=;p0&{}JaDOC{}Hp_{WXMi}Cv^UIE?&n4CLG-M)yb*-R+~IUs0Vxl^?wh~-AFj#& z)URgu+A(?bUyUfm2F3EVgSIwBWTtM$J*zHPkUPdSTldSP>h*dFHb;2#kzV-+<;zQSDst@Cl$Wb8b>qrwb z(ac2DgW4@Wb&refm}NFdlLBQfM3$H!TC994%ccG!BsXB4b_%_?yPb34qb3u?kaLS< zihJpEj<0{kX1e|y7wwZ~mc~hsIErpJr|w*F*!!k4tVK2;Waduw`G*o6Ov(Fp*&jNG z@=%|?0a*Ow{#nmq>Bo@%?4?`TbWn!gDSMe;GFgYTgI|Y^z^E4QQ5WZfP6{9Cwgyhs zD8;|lc&ev|gR=`Od%0M?W&G2KaEP08dECYKz_3)Mn=idglf(Nv6Q~%?IG3J(D_5O( zkybc1xS?3p8ctM=`jnJ=dAL#8v9?7A-n5h>aYUD)$ADW6d5o*ttarmfZ7zGZWw34Z zlEs-Rr376CGVucpzXCjWm7JYwX_5k@H_+wuOoJ=YLxxttrBS}D+qhHV13(?xVJZA4 z8^+ho{^0}14N3-ebBb_AefnZ`VE5%e4<;Xw{3~Epp8O8GrvsGY=UPN++OevLI+0&EI z))w=J57JPc9It+i*y)EOOAx4R?m0tvsf|RmceEM=&VV6pq0qc=F&b5*!6fMuOoHZ^ zP7^`Bdawb;q|B~WE7)+fv83H3ye8s;zf)GlltkAhPY(9X{$44e>uvJc&1X*SF1LZ4 z`UPON2}3Yjo|sfSvdv#I+)<-Kk%j=Ye}-BS zEL?ujyaoXLRM6NVcgMFlULM^2o+#9r!$eiyTdL646XuL$ordxIyfkUk&h&(gK`jdW2br|)z0wSH1mH){#JS+1D|v1IU21o7M=n9Mx$)YeF9ST=rn(=f zyhM{nL(;=2tYcn*B*h@1FGAE;J7Ih*Pdpp};lZ>=(J*1LdauF#b*CrDz>tG96=Cm)Q&c%@Xe*3N+A`b!%am8{Wr!YUa z$3X`-B=R9-M&%@7#0U^#4dt4sMgTs*1QEKjzfmA5cjT6FS4rh}Gac~VC&ae28WQI~ ziYr=mDAoe#0GLuQR|}Al4LUSqS9D+bBv{QQ1i*=g7-FENf*?eg0{YdrDHw3zwt%%t z>6{GcT@cVpiE@0@0XMww7)&mq14gwFFR?e!AZtV8tdRK29Pu_)_N+j}RJ2FpM1 z@|BO#Q1H09(VdxCpO$ym8`xYwu^|-E7>vwE6N@@g^ta`- z`(^LTx|)z9tdJsOasG(}feey}k0U`EiQpT>uOy49^Yw^b*px`DW8-M`|=mL1{9p-ig|h>SIAd?6dbj+3F6z zfO;MHxn_#z1n5#tEy4P_PeKIim)F)wODHl%G8<#rt1+uVp4x9c5AQJ)*dEcm6lt~2 z?K6>oz#XxTTj1xB_dg{snpk?-Dr5^p;%mG-zi+C4xl8Wb=1?zHg2&DE|VNa3BSMe`8(+z_&OsUf1WeWDi5=>+nPW{_N zZZTo-Q&HJr2H|a#oghCW z5dsSgbV)of3l>b^fGQ~+ONu#LLgR}7A~+=$ikPnpz)UA5q;MW^_%f_?>C6^zPT3B% zJ{WN8{nfj4Cx(u)lRX5;HE;Rt}e;`ChYn^cu^l ztE-9+Ry_!nKu6o!ECKtgE`eq&O`!Bzo3I_JpwCuE*uB%;*V9AXoHiNFf z;wrw$W)#l$I$ROMmYO7*JA=53cD|-&ZxiYf2Zi{J6B7z~stF4Ox&4P_druRZ=++4- zwEQ^J-%oh`U49lXfEjk-fCg5S*jKv>J!6FI$v`68!G!K^qLWt;YU;W`L3f1dXro4v z_Ps+aQUF&txG@^UWvoM9M!jq^NIODSR1K64g{T%hlI7I4`9R^=9>xV(>Z0|2BzZWP zMMF^4sXD=|QjJ?W;iQ}=ldA-KG%!7cx#q^!P^6Vwo{bIa#(tT#vIGyGV}gh1OmzLK(ERq@}eqDwbeqmDbXh5Mv#Kl#FW8 z#ZqnTOH#`u!dR+BrKGmnYF`>uq8K_2#`4C@=}VvY%rrUA$*26z^Wn+;pL=ufmEhY( zGnRgbq29C$snN1IdTb0Wu=f>C;J0m#!+eL>q zL>=!;|aKlE(Q*iiSvJx)vuu~cYu?t?ZJSrJ~Bx(6;6(tKvyi*_kgueBj zHr8D~3t@h|i3*CkCfkvn3 zP=(YYw<2;yn3}*^ebVZfvB(YM$HK(<%Tfh80lfz2e~Ymf=R6&d;9q)&4+cMsl+k#L zZT7mSBI<>*JstlfS7hJq+v)YsifxEj4(`{u-!aMsI!aoGf!w`2_2S~hG-PS|aTX6% zrmNyY?c9^!n-cjhX5wjmR}*}h%mCMo+v1NFp&xB!k2tj->MJQw=E2!3+Zg&!Pe`G< z18RiM=3m@X z-6f=oRz8Q&_ovL^CHWK04uzt?EtJkHN(F1)1_b!%wql$9MO#m-ovii2q=bJYHPz z9<9%nLy?&HUS7kWVM&8=>-p>vSA$)*h(} zpJ*ny1^HN6o4F!ipUF`zDZLxN`jpqnU)@u1<)%wp#>y=dC|x!xhQ}|kJ}rATWqi>! zUC&wkJ`n?zHl(aiT`F|32A!N1iy6hAWUPxMZg=#CYzgabQ&Y9-0<)w%uLPJ|Xm^Ot zHL^S8v?5%4r14SaEak1RP&aZ=#OFw7y)tfYM$x3zcmZdeUv7{g-iZ|IZ884`=dGza z%G{#!wV#t1xfzdZIFO2R$N`pv-H8`)ilg>Z@#U5f<+9q7DswfITR{mwD^QNiG9@eG zJl=2fUr-C?>+>q5!iKG?d`i@HXc!9gIZ3*kHlwo`gspmR3EBK{X#uxzD!=|(0XU3y zz@wY^9)SyV>?6naiG9#r^rFPr{1*zf#&GdiY2Pe2;F$5zkLKK4Q z7Iee!d`v_B==HLWrB7Hi#iiT`e0?TgNqmYek~Ch`|X)bu#xw*IrS23>;Uh_LIz?@z6k7+pd^F=N@8 zqx|~t?=ijNZcu}nN}jSq`MFl-u9YDWjE8}G1eWpF2Oh;vJDu5={I}9iC5r=J>q z(he-d-->l)hNs_9E)36qrA8l|);Z+24H1&XSNg}@UD3GCR~#y~xu-8}CpESZ?Q#)j z-;*Hw+F7IXucysX_R?Q);h!&RD zl_N)NQ8}8$i2LEU-EAcm1ZLZ_8fW7j)wbW!^I9Hwq>d&p#L&)JmR~9v=9eubUFfN` zefvpjJ2X8Vj=TJPN%he^1vQ^Ux9aS>|C#GmIW4Z-VPdS9_p_6>9u*c^1(LKK4P@3w>{UUI#9A@y;y0KkmKvAy6i=0CU)eCQ81k-xKe(-2(|3%K>k7Aj zHM$`w7eYN%bCW(od);^UTn$wteqM>eM~81{tMF{)T1kDd_%!uv9!4@~h8h~_1B&VY z5b96`d1DPbQ8J0zK;d%(1m;V7dpnPfO&m$fsh;CRUFJTr+H37K3dM268xPKyHDZnl$NqU#`s|>H?;jI3YP~Bjt_WcU zmN_Uay{i&Y?l0Wy1Ajw2E@dWc38dEWTJ~BWJ8QlypvM(~OuA4B$@J55xL)2Kt(4U2 zfPxW2HDSZyX-;4fRf+aUAFivMULM2*!$e?T^$8f zfP10-Hb5Y9^v&7bGe|%HI0w87EVJueU){!+sbC-h9D&*; z^dSHOFc1ZlfWstvq=NtnI9dXvfNiq76hZEvsQ)$1?!|V35U^~$i+SsOjs0H{5Lm

<*oAE98 a??7*3$;G|%B?0z3oqay2vYTB%p#K7ISnm-4 literal 0 HcmV?d00001 diff --git a/setup.py b/setup.py index cdbab2b..38d311a 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='stable-codec', - version='0.1.2', + version='0.1.3', author='Stability AI', author_email='julian.parker@stability.ai', description='Stable Codec: A series of codec models for speech and audio', @@ -10,7 +10,7 @@ long_description_content_type='text/markdown', url='https://github.com/Stability-AI/stable-codec/', packages=find_packages(), - python_requires='>=3.9', + python_requires='>=3.9,<3.12', install_requires=['packaging', 'wheel', 'torch==2.4', diff --git a/stable_codec.egg-info/PKG-INFO b/stable_codec.egg-info/PKG-INFO new file mode 100644 index 0000000..3f60642 --- /dev/null +++ b/stable_codec.egg-info/PKG-INFO @@ -0,0 +1,223 @@ +Metadata-Version: 2.4 +Name: stable-codec +Version: 0.1.3 +Summary: Stable Codec: A series of codec models for speech and audio +Home-page: https://github.com/Stability-AI/stable-codec/ +Author: Stability AI +Author-email: julian.parker@stability.ai +Requires-Python: >=3.9,<3.12 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: packaging +Requires-Dist: wheel +Requires-Dist: torch==2.4 +Requires-Dist: torchaudio==2.4 +Requires-Dist: stable-audio-tools==0.0.19 +Requires-Dist: pytorch-lightning==2.1 +Requires-Dist: prefigure==0.0.9 +Dynamic: author +Dynamic: author-email +Dynamic: description +Dynamic: description-content-type +Dynamic: home-page +Dynamic: license-file +Dynamic: requires-dist +Dynamic: requires-python +Dynamic: summary + +# Stable Codec + +This repository contains training and inference scripts for models in the Stable Codec series, starting with `stable-codec-speech-16k` - introduced in the paper titled Scaling Transformers for Low-bitrate High-Quality Speech Coding. + +Paper: https://arxiv.org/abs/2411.19842 + +Sound demos: https://stability-ai.github.io/stable-codec-demo/ + +Model weights: https://huggingface.co/stabilityai/stable-codec-speech-16k + +## Changelog + +### [v0.1.3] TBD +- __Fix__ restricted Python version to <3.12 due to dependency incompatibilities +- __Fix__ clarified installation instructions regarding Python version requirements +### [v0.1.2] 14-01-25 +- __New__ added hooks for `stable-codec-speech-16k-base`. +- __Fix__ fixed major issue with precision in FSQ token calculation, which was degrading results. Fix is currently local, will be upstreamed to `stable-audio-tools` later. +### [v0.1.1] 10-01-25 +- Release + + +## + +Note that whilst this code is MIT licensed, the model weights are covered by the [Stability AI Community License](https://huggingface.co/stabilityai/stable-codec-speech-16k/blob/main/LICENSE.md) + +## Variants +The model is currently available in two variants: +- `stable-codec-speech-16k` is an improved finetune, with boosted latent semantics. __It should be used in 99% of use-cases.__ +- `stable-codec-speech-16k-base` is the weights corresponding to the results in our [publication](https://arxiv.org/abs/2411.19842), provided for reproducibility. + +### Additional Training + +In addition to the training described in the paper, the weights for `stable-codec-speech-16k` have undergone 500k steps of finetuning with force-aligned data from LibriLight and the English portion Multilingual LibriSpeech. This was performed by using a CTC head to regress the force-aligned phoneme tags from pre-bottleneck latents. We found that this additional training significantly boosted the applicability of the codec tokens to downstream tasks like TTS, at a small cost to objective reconstruction metrics. + +## Install + +The model itself is defined in [stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools) package. + +### Python Version Compatibility + +**Important:** This package currently requires **Python 3.9, 3.10, or 3.11**. Python 3.12 and later are not supported due to incompatibilities in the `stable-audio-tools` dependency chain (specifically `PyWavelets==1.4.1` and `pandas==2.0.2`). + +If you attempt to install on Python 3.12+, you will encounter build errors. Please use Python 3.11 or earlier. + +To install `stable-codec`: + +```bash +pip install stable-codec +pip install -U flash-attn --no-build-isolation +``` + +**IMPORTANT NOTE:** This model currently has a hard requirement for FlashAttention due to its use of sliding window attention. Inference without FlashAttention will likely be greatly degraded. This also means that the model currently does not support CPU inference. We will relax the dependency on FlashAttention in the future. + +## Encoding and decoding + +To encode audio or decode tokens, the `StableCodec` class provides a convenient wrapper for the model. It can be used with a local checkpoint and config as follows: + +```python +import torch +import torchaudio +from stable_codec import StableCodec + +model = StableCodec( + model_config_path="", + ckpt_path="", # optional, can be `None`, + device = torch.device("cuda") +) + +audiopath = "audio.wav" + +latents, tokens = model.encode(audiopath) +decoded_audio = model.decode(tokens) + +torchaudio.save("decoded.wav", decoded_audio, model.sample_rate) +``` + +To download the model weights automatically from HuggingFace, simply provide the model name: + +```python +model = StableCodec( + pretrained_model = 'stabilityai/stable-codec-speech-16k' +) +``` +### Posthoc bottleneck configuration + +Most usecases will benefit from replacing the training-time FSQ bottleneck with a post-hoc FSQ bottleneck, as described in the paper. This allows token dictionary size to be reduced to a reasonable level for modern language models. This is achieved by calling the `set_posthoc_bottleneck` function, and setting a flag to the encode/decode calls: + +```python +model.set_posthoc_bottleneck("2x15625_700bps") +latents, tokens = model.encode(audiopath, posthoc_bottleneck = True) +decoded_audio = model.decode(tokens, posthoc_bottleneck = True) +``` +`set_posthoc_bottleneck` can take a string as argument, which allows selection a number of recommended preset settings for the bottleneck: + +| Bottleneck Preset | Number of Tokens per step | Dictionary Size | Bits Per Second (bps) | +|-------------------|------------------|-----------------|-----------------------| +| `1x46656_400bps` | 1 | 46656 | 400 | +| `2x15625_700bps` | 2 | 15625 | 700 | +| `4x729_1000bps` | 4 | 729 | 1000 | + +Alternatively, the bottleneck stages can be specified directly. The format for specifying this can be seen in the definition of the `StableCodec` class in `model.py`. + +### Normalization + +The model is trained with utterances normalized to -20 +-5 LUFS. The `encode` function normalizes to -20 LUFS by default, but it can be disabled by setting `normalize = False` when calling the function. + +## Finetune + +To finetune a model given its config and checkpoint, execute `train.py` file: + +```bash +python train.py \ + --project "stable-codec" \ + --name "finetune" \ + --config-file "defaults.ini" \ + --save-dir "" \ + --model-config "" \ + --dataset-config "" \ + --val-dataset-config "" \ + --pretrained-ckpt-path "" \ + --ckpt-path "$CKPT_PATH" \ + --num-nodes $SLURM_JOB_NUM_NODES \ + --num-workers 16 --batch-size 10 --precision "16-mixed" \ + --checkpoint-every 10000 \ + --logger "wandb" +``` + +For dataset configuration, refer to `stable-audio-tools` [dataset docs](https://github.com/Stability-AI/stable-audio-tools/blob/main/docs/datasets.md). + + +### Using CTC loss + +To use [CTC loss](https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html) +during training you have to enable it in the training configuration file +and in the training dataset configuration. + +1. Modifying training configuration: + - Enable CTC projection head and set its hidden dimension: + ```python + config["model"]["use_proj_head"] = True + config["model"]["proj_head_dim"] = 81 + ``` + - Enable CTC in the training part of the config: + ```python + config["training"]["use_ctc"] = True + ``` + - And set its loss config: + ```python + config["training"]["loss_configs"]["ctc"] = { + "blank_idx": 80, + "decay": 1.0, + "weights": {"ctc": 1.0} + } + ``` + - Optionally, you can enable computation of the Phone-Error-Rate (PER) during validation: + ```python + config["training"]["eval_loss_configs"]["per"] = {} + ``` + +2. Configuring dataset (only WebDataset format is supported for CTC): + - The dataset configuration should have one additional field set to it (see [dataset docs](https://github.com/Stability-AI/stable-audio-tools/blob/main/docs/datasets.md) for other options): + ```python + config["force_align_text"] = True + ``` + - And the JSON metadata file for each sample should contain force aligned transcript under `force_aligned_text` entry in the format specified below (besides other metadata). + Where `transcript` is a list of word-level alignments with `start` and `end` fields specifying range **in seconds** of each word. + ```json + "normalized_text":"and i feel" + "force_aligned_text":{ + "transcript":[ + { + "word":"and", + "start":0.2202, + "end":0.3403 + }, + { + "word":"i", + "start":0.4604, + "end":0.4804 + }, + { + "word":"feel", + "start":0.5204, + "end":0.7006 + } + ] + } + ``` +## Objective Metrics + +| Model | SI-SDR | Mel Dis | STFT Dis | PESQ | STOI | +|---------------------------|-------:|--------:|---------:|-----:|-----:| +| `stable-codec-speech-16k-base` | 4.73 | 0.86 | 1.26 | 3.09 | 0.92 | +| `stable-codec-speech-16k` | 3.58 | 0.90 | 1.30 | 3.01 | 0.90 | + diff --git a/stable_codec.egg-info/SOURCES.txt b/stable_codec.egg-info/SOURCES.txt new file mode 100644 index 0000000..974d85b --- /dev/null +++ b/stable_codec.egg-info/SOURCES.txt @@ -0,0 +1,16 @@ +LICENSE +README.md +pyproject.toml +setup.py +stable_codec/__init__.py +stable_codec/ctc_loss.py +stable_codec/fsq.py +stable_codec/model.py +stable_codec/residual_fsq.py +stable_codec/training_demo.py +stable_codec/training_module.py +stable_codec.egg-info/PKG-INFO +stable_codec.egg-info/SOURCES.txt +stable_codec.egg-info/dependency_links.txt +stable_codec.egg-info/requires.txt +stable_codec.egg-info/top_level.txt \ No newline at end of file diff --git a/stable_codec.egg-info/dependency_links.txt b/stable_codec.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/stable_codec.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/stable_codec.egg-info/requires.txt b/stable_codec.egg-info/requires.txt new file mode 100644 index 0000000..959f1d4 --- /dev/null +++ b/stable_codec.egg-info/requires.txt @@ -0,0 +1,7 @@ +packaging +wheel +torch==2.4 +torchaudio==2.4 +stable-audio-tools==0.0.19 +pytorch-lightning==2.1 +prefigure==0.0.9 diff --git a/stable_codec.egg-info/top_level.txt b/stable_codec.egg-info/top_level.txt new file mode 100644 index 0000000..e12fb46 --- /dev/null +++ b/stable_codec.egg-info/top_level.txt @@ -0,0 +1 @@ +stable_codec From 8b77ab79fcc8319c8387aac86d8b3eb7fe5c0993 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Feb 2026 13:10:14 +0000 Subject: [PATCH 3/3] Add build artifacts to .gitignore Co-authored-by: julian-parker <19472441+julian-parker@users.noreply.github.com> --- .gitignore | 5 + build/lib/stable_codec/__init__.py | 1 - build/lib/stable_codec/ctc_loss.py | 236 -------- build/lib/stable_codec/fsq.py | 134 ----- build/lib/stable_codec/model.py | 159 ----- build/lib/stable_codec/residual_fsq.py | 63 -- build/lib/stable_codec/training_demo.py | 157 ----- build/lib/stable_codec/training_module.py | 644 --------------------- dist/stable_codec-0.1.3-py3-none-any.whl | Bin 19930 -> 0 bytes stable_codec.egg-info/PKG-INFO | 223 ------- stable_codec.egg-info/SOURCES.txt | 16 - stable_codec.egg-info/dependency_links.txt | 1 - stable_codec.egg-info/requires.txt | 7 - stable_codec.egg-info/top_level.txt | 1 - 14 files changed, 5 insertions(+), 1642 deletions(-) delete mode 100644 build/lib/stable_codec/__init__.py delete mode 100644 build/lib/stable_codec/ctc_loss.py delete mode 100644 build/lib/stable_codec/fsq.py delete mode 100644 build/lib/stable_codec/model.py delete mode 100644 build/lib/stable_codec/residual_fsq.py delete mode 100644 build/lib/stable_codec/training_demo.py delete mode 100644 build/lib/stable_codec/training_module.py delete mode 100644 dist/stable_codec-0.1.3-py3-none-any.whl delete mode 100644 stable_codec.egg-info/PKG-INFO delete mode 100644 stable_codec.egg-info/SOURCES.txt delete mode 100644 stable_codec.egg-info/dependency_links.txt delete mode 100644 stable_codec.egg-info/requires.txt delete mode 100644 stable_codec.egg-info/top_level.txt diff --git a/.gitignore b/.gitignore index 927da94..39c8981 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,8 @@ __pycache__/ venv/ *.wav + +# Build artifacts +build/ +dist/ +*.egg-info/ diff --git a/build/lib/stable_codec/__init__.py b/build/lib/stable_codec/__init__.py deleted file mode 100644 index a2cc1b3..0000000 --- a/build/lib/stable_codec/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from stable_codec.model import StableCodec \ No newline at end of file diff --git a/build/lib/stable_codec/ctc_loss.py b/build/lib/stable_codec/ctc_loss.py deleted file mode 100644 index a6c9e02..0000000 --- a/build/lib/stable_codec/ctc_loss.py +++ /dev/null @@ -1,236 +0,0 @@ -import torch - -from torch.nn import functional as F -from torch import nn - -from stable_audio_tools.training.losses import LossModule - -# https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html -class CTCLossModule(LossModule): - def __init__( - self, - name: str, - input_key: str, - target_key: str, - weight: float = 1.0, - decay: float = 1.0, - blank_idx: int = 0, - padding_idx: int = None, - input_lengths_key: str = None, - ): - super().__init__(name=name, weight=weight, decay=decay) - self.ctc_loss = nn.CTCLoss(blank=blank_idx, reduction='mean', zero_infinity=True) - self.input_key = input_key - self.target_key = target_key - self.input_lengths_key = input_lengths_key - self.blank_idx = blank_idx - self.padding_idx = padding_idx if padding_idx is not None else blank_idx + 1 - - def forward(self, info): - """ - Computes the CTC loss. - - Args: - info (dict): Dictionary containing model outputs and other relevant data. - - info[self.input_key]: Model logits of shape (batch_size, sequence_length, num_classes). - - info[self.target_key]: Target data (list of dicts with 'phone' key). - - info[self.input_lengths_key]: (Optional) Actual lengths of the input sequences. - - Returns: - loss (Tensor): The computed CTC loss, scaled by the weight. - """ - # Build targets and target lengths - padded_targets, target_lengths = build_target(info[self.target_key], self.padding_idx) - - # Get logits from the model output - logits = info[self.input_key] # Expected shape: (batch_size, sequence_length, num_classes) - - # Move logits to the device of phonemes - device = padded_targets.device - logits = logits.to(device) - - # Apply log_softmax to obtain log probabilities - log_probs = F.log_softmax(logits, dim=-1) # Shape: (batch_size, seq_length, num_classes) - - # Transpose log_probs to match (seq_length, batch_size, num_classes) - log_probs = log_probs.permute(1, 0, 2) # Now shape is (seq_length, batch_size, num_classes) - - # Determine input lengths - if self.input_lengths_key and self.input_lengths_key in info: - input_lengths = info[self.input_lengths_key].to(device) - else: - # Assume all input sequences have the same length - input_lengths = torch.full( - (log_probs.size(1),), # batch_size - log_probs.size(0), # seq_length - dtype=torch.long, - device=device - ) - - # Compute the CTC loss - loss = self.ctc_loss(log_probs, padded_targets, input_lengths, target_lengths) - - loss = self.weight * loss - - return loss - -class PERModule(nn.Module): - def __init__( - self, - input_key: str, - target_key: str, - blank_idx: int = 0, - padding_idx: int = None, - ): - super().__init__() - self.input_key = input_key - self.target_key = target_key - self.blank_idx = blank_idx - self.padding_idx = padding_idx if padding_idx is not None else blank_idx + 1 - - def decode_predictions(self, predicted_ids): - """ - Decodes the model predictions by collapsing repeats and removing blanks. - - Args: - predicted_ids (Tensor): Tensor of shape (seq_length,) containing predicted token IDs. - - Returns: - List[int]: Decoded sequence of token IDs. - """ - predicted_sequence = [] - previous_id = None - for id in predicted_ids: - id = id.item() - if id != self.blank_idx and id != previous_id: - predicted_sequence.append(id) - previous_id = id - return predicted_sequence - - def forward(self, info): - """ - Computes the CTC loss. - - Args: - info (dict): Dictionary containing model outputs and other relevant data. - - info[self.input_key]: Model logits of shape (batch_size, sequence_length, num_classes). - - info[self.target_key]: Target data (list of dicts with 'phone' key). - - info[self.input_lengths_key]: (Optional) Actual lengths of the input sequences. - - Returns: - loss (Tensor): The computed CTC loss, scaled by the weight. - """ - with torch.no_grad(): - # Build targets and target lengths - padded_targets, target_lengths = build_target(info[self.target_key], self.padding_idx) - - # Get logits from the model output - logits = info[self.input_key] # Expected shape: (batch_size, sequence_length, num_classes) - - # Move logits to the device of phonemes - device = padded_targets.device - logits = logits.to(device) - - # Apply log_softmax to obtain log probabilities - log_probs = F.log_softmax(logits, dim=-1) # Shape: (batch_size, seq_length, num_classes) - - # Transpose log_probs to match (seq_length, batch_size, num_classes) - log_probs = log_probs.permute(1, 0, 2) # Now shape is (seq_length, batch_size, num_classes) - - # Get predictions via greedy decoding - predicted_ids = torch.argmax(logits, dim=-1) # Shape: (batch_size, seq_length) - - batch_size = predicted_ids.size(0) - pers = [] - - for i in range(batch_size): - # Decode predictions - pred_ids = predicted_ids[i] # Tensor of shape (seq_length,) - pred_sequence = self.decode_predictions(pred_ids) - - # Get target sequence - target_ids = padded_targets[i] # Tensor of shape (max_target_length,) - target_length = target_lengths[i] - target_sequence = target_ids[:target_length].tolist() - - # Remove padding tokens from target sequence - target_sequence = [id for id in target_sequence if id != self.padding_idx] - - # Compute edit distance using the editdistance package - # distance = editdistance.eval(pred_sequence, target_sequence) - distance = edit_distance(pred_sequence, target_sequence) - - # Compute PER - per = distance / max(len(target_sequence), 1) - pers.append(per) - - # Compute average PER over the batch - average_per = sum(pers) / len(pers) - - return average_per - -def edit_distance(seq1, seq2): - """ - Computes the edit distance between two sequences. - - Args: - seq1 (List[int]): First sequence. - seq2 (List[int]): Second sequence. - - Returns: - int: The edit distance between seq1 and seq2. - """ - m = len(seq1) - n = len(seq2) - # Create a DP table - dp = [[0] * (n + 1) for _ in range(m + 1)] - # Initialize - for i in range(m + 1): - dp[i][0] = i - for j in range(n + 1): - dp[0][j] = j - # Compute dp table - for i in range(1, m + 1): - for j in range(1, n + 1): - if seq1[i - 1] == seq2[j - 1]: - cost = 0 - else: - cost = 1 - dp[i][j] = min( - dp[i - 1][j] + 1, # deletion - dp[i][j - 1] + 1, # insertion - dp[i - 1][j - 1] + cost # substitution - ) - return dp[m][n] - -def build_target(batch, padding_idx): - """ - Builds padded targets and computes target lengths. - - Args: - batch (list): A list of dictionaries, each containing a 'phone' key with tensor values. - - Returns: - padded_targets (Tensor): Padded target sequences of shape (batch_size, max_target_length). - target_lengths (Tensor): Lengths of each target sequence in the batch. - """ - # Extract phoneme sequences - phoneme_sequences = [item['phone'] for item in batch] - - # Determine device from the phoneme sequences - device = phoneme_sequences[0].device - - # Ensure phoneme sequences are 1D tensors - phoneme_sequences = [seq.view(-1) if seq.ndim > 1 else seq for seq in phoneme_sequences] - - # Compute target lengths - target_lengths = torch.tensor([seq.size(0) for seq in phoneme_sequences], dtype=torch.long, device=device) - - # Pad sequences - padded_targets = nn.utils.rnn.pad_sequence( - phoneme_sequences, - batch_first=True, - padding_value=padding_idx - ).to(device) - - return padded_targets, target_lengths diff --git a/build/lib/stable_codec/fsq.py b/build/lib/stable_codec/fsq.py deleted file mode 100644 index 920fa42..0000000 --- a/build/lib/stable_codec/fsq.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Dithered Finite Scalar Quantization -Code adapted from https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/finite_scalar_quantization.py -""" - -from typing import List, Tuple -import random - -import torch -import torch.nn as nn -from torch.nn import Module -from torch import Tensor, int32 -from torch.amp import autocast - -from einops import rearrange - - -def leaky_hard_clip(x: Tensor, alpha: float = 1e-3) -> Tensor: - return (1-alpha) * torch.clamp(x, -1, 1) + alpha * x - -def round_ste(z: Tensor) -> Tensor: - """Round with straight through gradients.""" - zhat = z.round() - return z + (zhat - z).detach() - -class DitheredFSQ(Module): - def __init__( - self, - levels: List[int], - dither_inference: bool = False, - num_codebooks: int = 1, - noise_dropout: float = 0.5, - scale: float = 1.0, - ): - super().__init__() - self.levels = levels - - _levels = torch.tensor(levels, dtype=torch.int64) - self.register_buffer("_levels", _levels, persistent = False) - - _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int64) - self.register_buffer("_basis", _basis, persistent = False) - - codebook_dim = len(levels) - self.codebook_dim = codebook_dim - - self.codebook_size = _levels.prod().item() - - self.num_codebooks = num_codebooks - - self.dim = codebook_dim * num_codebooks - - self.dither_inference = dither_inference - - self.scale = scale - - half_l = self.scale * 2 / (self._levels - 1) - self.register_buffer("half_l", half_l, persistent = False) - - self.allowed_dtypes = (torch.float32, torch.float64) - - self.noise_dropout = noise_dropout - - def quantize(self, z, skip_tanh: bool = False): - if not skip_tanh: z = torch.tanh(z) - - if not self.training: - quantized = self._scale_and_shift_inverse(round_ste(self._scale_and_shift(z))) - else: - quantized = z - mask = torch.bernoulli(torch.full([z.shape[0],1,1,1], self.noise_dropout, device = z.device)).bool().expand_as(z) - quantized = torch.where(mask, quantized, self._scale_and_shift_inverse(round_ste(self._scale_and_shift(quantized)))) - mask = torch.bernoulli(torch.full([z.shape[0],1,1,1], self.noise_dropout, device = z.device)).bool().expand_as(z) - quantized = torch.where(mask, quantized, z + (torch.rand_like(z) - 0.5) * self.half_l) - - return quantized - - def _scale_and_shift(self, z): - level_indices = (z + 1 * self.scale) / self.half_l - return level_indices - - def _scale_and_shift_inverse(self, level_indices): - z = level_indices * self.half_l - 1 * self.scale - return z - - def _indices_to_codes(self, indices): - level_indices = self._indices_to_level_indices(indices) - codes = self._scale_and_shift_inverse(level_indices) - return codes - - def _codes_to_indices(self, zhat): - zhat = self._scale_and_shift(zhat) - zhat = zhat.round().to(torch.int64) - out = (zhat * self._basis).sum(dim=-1) - return out - - def _indices_to_level_indices(self, indices): - indices = rearrange(indices, '... -> ... 1') - codes_non_centered = (indices // self._basis) % self._levels - return codes_non_centered - - def indices_to_codes(self, indices): - # Expects input of batch x sequence x num_codebooks - assert indices.shape[-1] == self.num_codebooks, f'expected last dimension of {self.num_codebooks} but found last dimension of {indices.shape[-1]}' - codes = self._indices_to_codes(indices.to(torch.int64)) - codes = rearrange(codes, '... c d -> ... (c d)') - return codes - - @autocast(device_type="cuda", enabled = False) - def forward(self, z, skip_tanh: bool = False): - - orig_dtype = z.dtype - - assert z.shape[-1] == self.dim, f'expected dimension of {self.num_codebooks * self.dim} but found dimension of {z.shape[-1]}' - - z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks) - - # make sure allowed dtype before quantizing - - if z.dtype not in self.allowed_dtypes: - z = z.to(torch.float64) - - codes = self.quantize(z, skip_tanh=skip_tanh) - indices = self._codes_to_indices(codes) - codes = rearrange(codes, 'b n c d -> b n (c d)') - - # cast codes back to original dtype - - if codes.dtype != orig_dtype: - codes = codes.type(orig_dtype) - - # return quantized output and indices - - return codes, indices diff --git a/build/lib/stable_codec/model.py b/build/lib/stable_codec/model.py deleted file mode 100644 index 14f541a..0000000 --- a/build/lib/stable_codec/model.py +++ /dev/null @@ -1,159 +0,0 @@ -import json -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torchaudio -from einops import rearrange -from stable_audio_tools import get_pretrained_model -from stable_audio_tools.data.utils import VolumeNorm -from stable_audio_tools.models import create_model_from_config -from stable_audio_tools.models.fsq import DitheredFSQ -from stable_audio_tools.models.utils import copy_state_dict, load_ckpt_state_dict - -from .residual_fsq import ResidualFSQBottleneck - - -class StableCodec(nn.Module): - def __init__(self, - model_config_path: Optional[str] = None, ckpt_path: Optional[str] = None, pretrained_model: Optional[str] = None, device = torch.device("cpu"), - ): - super().__init__() - self.device = device - - if pretrained_model is not None: - print(f"Loading pretrained model `{pretrained_model}`.\n") - self.model, model_config = get_pretrained_model(pretrained_model) - else: - if model_config_path is None: - raise ValueError("Either `model_config_path` or `pretrained_model` should be provided.") - print(f"Loading config from `{model_config_path}`.\n") - with open(model_config_path) as f: - model_config = json.load(f) - self.model = create_model_from_config(model_config) - if ckpt_path is not None: - print(f"Loading weights from `{ckpt_path}`.\n") - state = load_ckpt_state_dict(ckpt_path) - copy_state_dict(self.model, state) - - self.model = self.model.to(self.device).eval().requires_grad_(False) - - self.residual_fsq: Optional[ResidualFSQBottleneck] = None - - self.sample_rate = model_config["sample_rate"] - self.volume_norm = VolumeNorm([-20, 0], self.sample_rate) - - self.preset_bottleneck_configs = { - "1x46656_400bps": [ - ([6, 6, 6, 6, 6, 6], 1.0) - ], - "2x15625_700bps": [ - ([5, 5, 5, 5, 5, 5], 1.0), - ([5, 5, 5, 5, 5, 5], 0.25), - ], - "4x729_1000bps": [ - ([3, 3, 3, 3, 3, 3], 1.0), - ([3, 3, 3, 3, 3, 3], 0.5), - ([3, 3, 3, 3, 3, 3], 0.25), - ([3, 3, 3, 3, 3, 3], 0.125), - ] - } - - def set_posthoc_bottleneck(self, stages): - if isinstance(stages,str): - if stages in self.preset_bottleneck_configs: - stages = self.preset_bottleneck_configs[stages] - else: - raise ValueError(f"Unsupported preset bottleneck configuration `{stages}`.") - - self.residual_fsq = ResidualFSQBottleneck(stages).to(self.device).eval().requires_grad_(False) - - def encode(self, audio: Union[str, torch.Tensor], posthoc_bottleneck: bool = False, normalize: bool = True,**kwargs): - """ - Encode audio into latents and tokens. - - Args: - - audio : Union[str, torch.Tensor] - Path to an audio file or a `Tensor` of the eaudio itself. - posthoc_bottleneck : bool - Whether to inject a posthoc FSQ instead of the FSQ used during training. - If `True`, its configuration should've been passed in with the `self.set_posthoc_bottleneck` method. - normalize : bool - Whether to normalize the audio to -20 LUFS before encoding (recommended). - Other `kwargs` are the same as in `AudioAutoencoder.encode_audio` method. - - Returns: - - Tuple of `(continuous_latents, tokens)`. - - continuous_latents : torch.Tensor - Pre-bottleneck latents in the `(B, H, S)` shape. - tokens : torch.Tensor - Bottleneck tokens in the `(B, S, 1)` shape. - - Where `B` is the batch size, `H` is the hidden dimension and `S` is the sequence length. - """ - if isinstance(audio, str): - audio, sample_rate = torchaudio.load(audio) - audio = self.model.preprocess_audio_for_encoder(audio.to(self.device), sample_rate) - if normalize: - audio = self.volume_norm(audio.squeeze(0)).unsqueeze(0) - - latents, info = self.model.encode_audio(audio, - return_info=True, skip_bottleneck=posthoc_bottleneck, **kwargs) - if posthoc_bottleneck: - tokens = self.residual_fsq.encode(latents) - else: - tokens = info["quantizer_indices"] - - return info["pre_bottleneck_latents"], tokens - - def decode(self, tokens: torch.Tensor, posthoc_bottleneck: bool = False, **kwargs): - """ - Decode audio from tokens. - - Args: - - tokens : torch.Tensor - Integer tokens produced by `encode` stage in `(B, S, 1)` shape. - posthoc_bottleneck : bool - Whether to inject a posthoc FSQ instead of the FSQ used during training. - If `True`, its configuration should've been passed in with `self.set_posthoc_bottleneck` method. - - Returns: - - Decoded audio in the `(B, C, L)` shape. - Where `B` is the batch size, `C` is the number of channels and `L` is the number of frames. - """ - if posthoc_bottleneck: - latents = self.residual_fsq.decode(tokens) - else: - latents = self.model.bottleneck.decode_tokens(tokens) - latents = rearrange(latents, "b c n -> b n c") - - audio = self.model.decode_audio(latents, **kwargs) - return audio - -def main(): - sc = StableCodec( - pretrained_model="stabilityai/stable-codec-speech-16k", - device = torch.device("cuda") - ) - - sc.set_posthoc_bottleneck("2x15625_700bps") - - wavfile = "test.wav" - - posthoc_bottleneck = False - latents, tokens = sc.encode(wavfile, posthoc_bottleneck=posthoc_bottleneck) - decoded = sc.decode(tokens, posthoc_bottleneck=posthoc_bottleneck) - torchaudio.save("decode.wav", decoded.squeeze(0).cpu(), 16000) - - posthoc_bottleneck = True - latents, tokens = sc.encode(wavfile, posthoc_bottleneck=posthoc_bottleneck) - decoded = sc.decode(tokens, posthoc_bottleneck=posthoc_bottleneck) - torchaudio.save("decode-res.wav", decoded.squeeze(0).cpu(), 16000) - -if __name__ == "__main__": - main() diff --git a/build/lib/stable_codec/residual_fsq.py b/build/lib/stable_codec/residual_fsq.py deleted file mode 100644 index b83b6b6..0000000 --- a/build/lib/stable_codec/residual_fsq.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -import torch.nn as nn - -from typing import List, Tuple -from einops import rearrange -from .fsq import DitheredFSQ - -class ResidualFSQBottleneck(nn.Module): - def __init__(self, stages: List[Tuple[List[int], float]]): - super().__init__() - - # 1st for single_tokens, others - residuals. - self.quantizers = nn.ModuleList([ - DitheredFSQ(levels=levels, scale=scale).eval().requires_grad_(False) - for (levels, scale) in stages]) - - self.n_codebooks = len(stages) - self.codebook_size = sum(map(len, stages)) * self.n_codebooks - - def encode(self, x): - input_dtype = x.dtype - z = torch.tanh(x.to(torch.float64)) - z = rearrange(z, "b c n -> b n c") - - r = z - res_ids = [] - for quantizer in self.quantizers: - q, ids = quantizer(r, skip_tanh=True) - r = r - q.to(torch.float64) - res_ids.append(ids) - - return res_ids - - def decode(self, res_ids): - z = sum([ - q.indices_to_codes(res_ids[i]) - for (i, q) in enumerate(self.quantizers) - ]) - return rearrange(z, "b n c -> b c n") - -if __name__ == "__main__": - fsq = DitheredFSQ([17, 17, 17, 17, 17, 17]).eval().requires_grad_(False) - # res_fsq = ResidualFSQBottleneck(stages=[ - # ([5, 5, 5, 5, 5, 5], 1.0), - # ([5, 5, 5, 5, 5, 5], 0.25), - # ]).eval().requires_grad_(False) - res_fsq = ResidualFSQBottleneck(stages=[ - ([3, 3, 3, 3, 3, 3], 1.0), - ([3, 3, 3, 3, 3, 3], 0.5), - ([3, 3, 3, 3, 3, 3], 0.25), - ([3, 3, 3, 3, 3, 3], 0.125), - ]).eval().requires_grad_(False) - - x = torch.rand(1, 6, 1) - - z1 = res_fsq.decode(res_fsq.encode(x)) - - _, y2 = fsq(rearrange(x, "b c n -> b n c")) - z2 = rearrange(fsq.indices_to_codes(y2), "b n c -> b c n") - - print(z1) - print(z2) - assert (z1 == z2).all() diff --git a/build/lib/stable_codec/training_demo.py b/build/lib/stable_codec/training_demo.py deleted file mode 100644 index c7fa248..0000000 --- a/build/lib/stable_codec/training_demo.py +++ /dev/null @@ -1,157 +0,0 @@ -import os -import torch -import torchaudio -import pytorch_lightning as pl - -from einops import rearrange -from pytorch_lightning.utilities.rank_zero import rank_zero_only - -from stable_audio_tools.models.autoencoders import ( - fold_channels_into_batch, unfold_channels_from_batch, -) -from stable_audio_tools.training.utils import ( - log_image, log_point_cloud, logger_project_name, log_audio, -) -from stable_audio_tools.interface.aeiou import ( - audio_spectrogram_image, tokens_spectrogram_image, -) - -def trim_to_shortest(a, b): - """Trim the longer of two tensors to the length of the shorter one.""" - if a.shape[-1] > b.shape[-1]: - return a[:,:,:b.shape[-1]], b - elif b.shape[-1] > a.shape[-1]: - return a, b[:,:,:a.shape[-1]] - return a, b - -class AutoencoderDemoCallback(pl.Callback): - def __init__( - self, - demo_dl, - demo_every = 2000, - sample_size = 65536, - sample_rate = 16000, - max_demos = 8, - ): - super().__init__() - self.demo_every = demo_every - self.demo_samples = sample_size - self.demo_dl = demo_dl - self.sample_rate = sample_rate - self.last_demo_step = -1 - self.max_demos = max_demos - - @rank_zero_only - def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): - if ( - (trainer.global_step - 1) % self.demo_every != 0 or - self.last_demo_step == trainer.global_step - ): - return - - self.last_demo_step = trainer.global_step - module.eval() - - try: - demo_iter = iter(self.demo_dl) - demo_reals, _ = next(demo_iter) - - # Remove extra dimension added by WebDataset - if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: - demo_reals = demo_reals[0] - - # Limit the number of demo samples - if demo_reals.shape[0] > self.max_demos: - demo_reals = demo_reals[:self.max_demos,...] - - encoder_input = demo_reals - encoder_input = encoder_input.to(module.device) - - if module.force_input_mono: - encoder_input = encoder_input.mean(dim=1, keepdim=True) - - demo_reals = demo_reals.to(module.device) - - with torch.no_grad(): - if module.use_ema: - latents = module.autoencoder_ema.ema_model.encode(encoder_input) - fakes = module.autoencoder_ema.ema_model.decode(latents) - else: - latents = module.autoencoder.encode(encoder_input) - fakes = module.autoencoder.decode(latents) - - #Trim output to remove post-padding. - fakes, demo_reals = trim_to_shortest(fakes.detach(), demo_reals) - - # Visualize discriminator sensitivity. - if module.discriminator is not None: - window = torch.kaiser_window(512).to(fakes.device) - stft_kwargs = { - "n_fft": 512, - "hop_length": 128, - "win_length": 512, - "window": window, - "center": True, - } - - fakes_stft = torch.stft(fold_channels_into_batch(fakes), - return_complex=True, **stft_kwargs) - fakes_stft.requires_grad = True - fakes_signal = unfold_channels_from_batch( - torch.istft(fakes_stft, **stft_kwargs), fakes.shape[1]) - - real_stft = torch.stft(fold_channels_into_batch(demo_reals), - return_complex=True, **stft_kwargs) - reals_signal = unfold_channels_from_batch( - torch.istft(real_stft, **stft_kwargs), demo_reals.shape[1]) - - _, loss, _ = module.discriminator.loss(reals_signal, fakes_signal) - fakes_stft.retain_grad() - loss.backward() - grads = unfold_channels_from_batch(fakes_stft.grad.detach().abs(), fakes.shape[1]) - - log_image(trainer.logger, 'disciminator_sensitivity', - tokens_spectrogram_image(grads.mean(dim=1).log10(), - title='Discriminator Sensitivity', symmetric=False)) - opts = module.optimizers() - opts[0].zero_grad() - opts[1].zero_grad() - - #Interleave reals and fakes - reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') - # Put the demos together - reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') - - data_dir = os.path.join( - trainer.logger.save_dir, logger_project_name(trainer.logger), - trainer.logger.experiment.id, "media") - os.makedirs(data_dir, exist_ok=True) - filename = os.path.join(data_dir, f'recon_{trainer.global_step:08}.wav') - - reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() - torchaudio.save(filename, reals_fakes, self.sample_rate) - - log_audio(trainer.logger, 'recon', filename, self.sample_rate) - log_point_cloud(trainer.logger, 'embeddings_3dpca', latents) - log_image(trainer.logger, 'embeddings_spec', tokens_spectrogram_image(latents)) - log_image(trainer.logger, 'recon_melspec_left', audio_spectrogram_image(reals_fakes)) - except Exception as e: - print(f'{type(e).__name__}: {e}') - raise e - finally: - module.train() - -def create_demo_callback_from_config(model_config, **kwargs): - model_type = model_config.get('model_type', None) - assert model_type is not None, 'model_type must be specified in model config' - - training_config = model_config.get('training', None) - assert training_config is not None, 'training config must be specified in model config' - - demo_config = training_config.get("demo", {}) - return AutoencoderDemoCallback( - demo_every=demo_config.get("demo_every", 2000), - sample_size=model_config["sample_size"], - sample_rate=model_config["sample_rate"], - **kwargs - ) diff --git a/build/lib/stable_codec/training_module.py b/build/lib/stable_codec/training_module.py deleted file mode 100644 index 4518ea5..0000000 --- a/build/lib/stable_codec/training_module.py +++ /dev/null @@ -1,644 +0,0 @@ -import torch -import torch.nn as nn -import pytorch_lightning as pl - -from typing import Optional, Literal -from ema_pytorch import EMA -from torch.nn import Parameter -from einops import rearrange - -from stable_audio_tools.models import create_model_from_config -from stable_audio_tools.models.autoencoders import AudioAutoencoder -from stable_audio_tools.models.discriminators import ( - EncodecDiscriminator, OobleckDiscriminator, DACGANLoss, -) -from stable_audio_tools.models.bottleneck import ( - VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, - RVQVAEBottleneck, WassersteinBottleneck, -) -from stable_audio_tools.training.losses import ( - MelSpectrogramLoss, MultiLoss, AuralossLoss, ValueLoss, L1Loss, - LossWithTarget, MSELoss, HubertLoss, - # PESQMetric, # TODO move PESQ here? -) -from stable_audio_tools.training.losses import auraloss as auraloss -from stable_audio_tools.training.utils import ( - create_optimizer_from_config, create_scheduler_from_config, log_metric, -) - -from .ctc_loss import CTCLossModule, PERModule - -def trim_to_shortest(a, b): - """Trim the longer of two tensors to the length of the shorter one.""" - if a.shape[-1] > b.shape[-1]: - return a[:,:,:b.shape[-1]], b - elif b.shape[-1] > a.shape[-1]: - return a, b[:,:,:a.shape[-1]] - return a, b - -class ProjectionHead(nn.Module): - def __init__(self, latent_dim, proj_head_dim, mid_dim=256): - super(ProjectionHead, self).__init__() - self.proj_head = nn.Sequential( - nn.Tanh(), - nn.Linear(latent_dim, mid_dim), - nn.ReLU(), - nn.Linear(mid_dim, mid_dim), - nn.ReLU(), - nn.Linear(mid_dim, proj_head_dim) - ) - - def forward(self, x): - return self.proj_head(x) - -class AutoencoderTrainingWrapper(pl.LightningModule): - def __init__(self, - autoencoder: AudioAutoencoder, - loss_config: dict, - eval_loss_config: dict, - optimizer_configs: dict, - sample_rate: int = 16000, - lr: float = 1e-4, - warmup_steps: int = 0, - warmup_mode: Literal["adv", "full"] = "adv", - encoder_freeze_on_warmup: bool = False, - use_ema: bool = True, - ema_copy = None, - force_input_mono = False, - latent_mask_ratio = 0.0, - teacher_model: Optional[AudioAutoencoder] = None, - clip_grad_norm = 0.0, - encoder_mask_ratio = 0.0, - use_ctc: bool = False, - proj_head_dim: Optional[int] = None, - detach_proj_head: bool = False, - ): - super().__init__() - - self.automatic_optimization = False - self.autoencoder = autoencoder - - self.warmed_up = False - self.warmup_steps = warmup_steps - self.warmup_mode = warmup_mode - self.encoder_freeze_on_warmup = encoder_freeze_on_warmup - self.lr = lr - self.clip_grad_norm = clip_grad_norm - - self.force_input_mono = force_input_mono - self.teacher_model = teacher_model - - self.use_ctc = use_ctc - self.proj_head_dim = proj_head_dim - self.detach_proj_head = detach_proj_head - self.projection_head = ( - ProjectionHead(self.autoencoder.latent_dim, self.proj_head_dim) - if self.use_ctc and self.proj_head_dim is not None else - nn.Identity() - ) - - self.optimizer_configs = optimizer_configs - self.loss_config = loss_config - - # Spectral reconstruction loss - self.sdstft = auraloss.MultiResolutionSTFTLoss( - sample_rate=sample_rate, **loss_config['spectral']['config']) - - # Discriminator - self.use_disc = True if 'discriminator' in loss_config else False - self.discriminator = None - if self.use_disc: - if loss_config['discriminator']['type'] == 'oobleck': - self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config']) - elif loss_config['discriminator']['type'] == 'encodec': - self.discriminator = EncodecDiscriminator( - in_channels=self.autoencoder.out_channels, - **loss_config['discriminator']['config']) - elif loss_config['discriminator']['type'] == 'dac': - self.discriminator = DACGANLoss( - channels=self.autoencoder.out_channels, - sample_rate=sample_rate, - **loss_config['discriminator']['config']) - - gen_loss_modules = [] - if self.use_disc: - # Discriminator loss. - self.losses_disc = MultiLoss([ - ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'), - ]) - - # Adversarial and feature matching losses. - gen_loss_modules += [ - ValueLoss( - key='loss_adv', - weight=self.loss_config['discriminator']['weights']['adversarial'], - name='loss_adv'), - ValueLoss( - key='feature_matching_distance', - weight=self.loss_config['discriminator']['weights']['feature_matching'], - name='feature_matching_loss'), - ] - - # Reconstruction loss - gen_loss_modules += [AuralossLoss(self.sdstft, - target_key='reals', input_key='decoded', name='mrstft_loss', - weight=self.loss_config['spectral']['weights']['mrstft'], - decay=self.loss_config['spectral'].get('decay', 1.0), - )] - - if "mrmel" in loss_config: - mrmel_weight = loss_config["mrmel"]["weights"]["mrmel"] - if mrmel_weight > 0: - mrmel_config = loss_config["mrmel"]["config"] - self.mrmel = MelSpectrogramLoss(sample_rate, - n_mels=mrmel_config["n_mels"], - window_lengths=mrmel_config["window_lengths"], - pow=mrmel_config["pow"], - log_weight=mrmel_config["log_weight"], - mag_weight=mrmel_config["mag_weight"], - ) - gen_loss_modules.append(LossWithTarget( - self.mrmel, "reals", "decoded", - name="mrmel_loss", weight=mrmel_weight, - )) - - if "hubert" in loss_config: - hubert_weight = loss_config["hubert"]["weights"]["hubert"] - if hubert_weight > 0: - hubert_cfg = ( - loss_config["hubert"]["config"] - if "config" in loss_config["hubert"] else - dict() - ) - self.hubert = HubertLoss(weight=1.0, **hubert_cfg) - - gen_loss_modules.append(LossWithTarget( - self.hubert, target_key = "reals", input_key = "decoded", - name="hubert_loss", weight=hubert_weight, - decay = loss_config["hubert"].get("decay", 1.0) - )) - - if "l1" in loss_config["time"]["weights"]: - if self.loss_config['time']['weights']['l1'] > 0.0: - gen_loss_modules.append(L1Loss( - key_a='reals', key_b='decoded', - weight=self.loss_config['time']['weights']['l1'], - name='l1_time_loss', - decay = self.loss_config['time'].get('decay', 1.0), - )) - - if "l2" in loss_config["time"]["weights"]: - if self.loss_config['time']['weights']['l2'] > 0.0: - gen_loss_modules.append(MSELoss( - key_a='reals', key_b='decoded', - weight=self.loss_config['time']['weights']['l2'], - name='l2_time_loss', - decay = self.loss_config['time'].get('decay', 1.0), - )) - - if self.autoencoder.bottleneck is not None: - gen_loss_modules += create_loss_modules_from_bottleneck( - self.autoencoder.bottleneck, self.loss_config) - - self.encoder_mask_ratio = encoder_mask_ratio - if encoder_mask_ratio > 0.0: - gen_loss_modules.append(L1Loss( - key_a='detached_latents', key_b='masked_latents', - weight=1.0, - name='encoder_mask_loss', - decay = 1.0, - )) - - if "ctc" in loss_config: - ctc_weight = loss_config["ctc"]["weights"]["ctc"] - if ctc_weight > 0: - gen_loss_modules.append(CTCLossModule( - name = "ctc_loss", - target_key = "ctc_tgt", - input_key = "log_probs", - weight = ctc_weight, - decay = loss_config["ctc"].get("decay", 1.0), - blank_idx = loss_config["ctc"].get("blank_idx", 80) - )) - - self.losses_gen = MultiLoss(gen_loss_modules) - - # Set up EMA for model weights - self.autoencoder_ema = None - self.use_ema = use_ema - if self.use_ema: - self.autoencoder_ema = EMA( - self.autoencoder, - ema_model=ema_copy, - beta=0.9999, - power=3/4, - update_every=1, - update_after_step=1 - ) - - self.latent_mask_ratio = latent_mask_ratio - - # evaluation losses & metrics - self.eval_losses = torch.nn.ModuleDict() - if eval_loss_config is not None: - # if "pesq" in eval_loss_config: - # self.eval_losses["pesq"] = PESQMetric(sample_rate) - if "stft"in eval_loss_config: - self.eval_losses["stft"] = auraloss.STFTLoss(**eval_loss_config["stft"]) - if "sisdr" in eval_loss_config: - self.eval_losses["sisdr"] = auraloss.SISDRLoss(**eval_loss_config["sisdr"]) - if "mel" in eval_loss_config: - self.eval_losses["mel"] = auraloss.MelSTFTLoss( - sample_rate, **eval_loss_config["mel"]) - if "per" in eval_loss_config: - self.eval_losses["per"] = PERModule( - target_key = "ctc_tgt", - input_key = "log_probs", - blank_idx = loss_config["ctc"].get("blank_idx", 80)) - - self.validation_step_outputs = [] - - def configure_optimizers(self): - gen_params = list(self.autoencoder.parameters()) - - if not self.use_disc: - opt_gen = create_optimizer_from_config( - self.optimizer_configs['autoencoder']['optimizer'], gen_params) - if "scheduler" in self.optimizer_configs['autoencoder']: - sched_gen = create_scheduler_from_config( - self.optimizer_configs['autoencoder']['scheduler'], opt_gen) - return [opt_gen], [sched_gen] - return [opt_gen] - - # Using discriminator. - opt_gen = create_optimizer_from_config( - self.optimizer_configs['autoencoder']['optimizer'], gen_params) - opt_disc = create_optimizer_from_config( - self.optimizer_configs['discriminator']['optimizer'], - self.discriminator.parameters()) - - use_scheduler = ( - "scheduler" in self.optimizer_configs['autoencoder'] and - "scheduler" in self.optimizer_configs['discriminator'] - ) - if use_scheduler: - sched_gen = create_scheduler_from_config( - self.optimizer_configs['autoencoder']['scheduler'], opt_gen) - sched_disc = create_scheduler_from_config( - self.optimizer_configs['discriminator']['scheduler'], opt_disc) - return [opt_gen, opt_disc], [sched_gen, sched_disc] - return [opt_gen, opt_disc] - - def forward(self, reals): - latents, encoder_info = self.autoencoder.encode(reals, return_info=True) - decoded = self.autoencoder.decode(latents) - return decoded - - def validation_step(self, batch, batch_idx): - reals, _ = batch - # Remove extra dimension added by WebDataset - if reals.ndim == 4 and reals.shape[0] == 1: - reals = reals[0] - - if len(reals.shape) == 2: - reals = reals.unsqueeze(1) - - loss_info = {} - loss_info["reals"] = reals - - encoder_input = reals - if self.force_input_mono and encoder_input.shape[1] > 1: - encoder_input = encoder_input.mean(dim=1, keepdim=True) - - loss_info["encoder_input"] = encoder_input - - with torch.no_grad(): - if self.use_ctc: - latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) - continuous_latents = encoder_info["pre_bottleneck_latents"] - proj_features = rearrange(continuous_latents, "b c n -> b n c") - proj_features = self.projection_head( - proj_features.detach() - if self.detach_proj_head else - proj_features - ) - - loss_info['log_probs'] = proj_features - loss_info['ctc_tgt'] = batch[1] - else: - latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) - - loss_info["latents"] = latents - loss_info.update(encoder_info) - - decoded = self.autoencoder.decode(latents) - #Trim output to remove post-padding. - decoded, reals = trim_to_shortest(decoded, reals) - - # Run evaluation metrics. - val_loss_dict = {} - for eval_key, eval_fn in self.eval_losses.items(): - if eval_key == 'per': - loss_value = eval_fn(loss_info) - else: - loss_value = eval_fn(decoded, reals) - if eval_key == "sisdr": loss_value = -loss_value - - if isinstance(loss_value, torch.Tensor): - loss_value = loss_value.item() - - val_loss_dict[eval_key] = loss_value - - self.validation_step_outputs.append(val_loss_dict) - return val_loss_dict - - def on_validation_epoch_end(self): - sum_loss_dict = {} - for loss_dict in self.validation_step_outputs: - for key, value in loss_dict.items(): - if key not in sum_loss_dict: - sum_loss_dict[key] = value - else: - sum_loss_dict[key] += value - - for key, value in sum_loss_dict.items(): - val_loss = value / len(self.validation_step_outputs) - val_loss = self.all_gather(val_loss).mean().item() - log_metric(self.logger, f"val/{key}", val_loss) - - self.validation_step_outputs.clear() # free memory - - def training_step(self, batch, batch_idx): - reals, _ = batch - - log_dict = {} - # Remove extra dimension added by WebDataset - if reals.ndim == 4 and reals.shape[0] == 1: - reals = reals[0] - - if len(reals.shape) == 2: - reals = reals.unsqueeze(1) - - if self.global_step >= self.warmup_steps: - self.warmed_up = True - - loss_info = {} - loss_info["reals"] = reals - encoder_input = reals - - if self.force_input_mono and encoder_input.shape[1] > 1: - encoder_input = encoder_input.mean(dim=1, keepdim=True) - - loss_info["encoder_input"] = encoder_input - data_std = encoder_input.std() - - if self.warmed_up and self.encoder_freeze_on_warmup: - with torch.no_grad(): - latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) - else: - if self.use_ctc: - latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) - continuous_latents = encoder_info["pre_bottleneck_latents"] - proj_features = rearrange(continuous_latents, "b c n -> b n c") - proj_features = self.projection_head( - proj_features.detach() - if self.detach_proj_head else - proj_features - ) - - loss_info['log_probs'] = proj_features - loss_info['ctc_tgt'] = batch[1] - else: - latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) - - if self.encoder_mask_ratio > 0.0: - masked_latents = self.autoencoder.encode( - encoder_input, return_info=False, encoder_mask_ratio=self.encoder_mask_ratio) - detached_latents = latents.detach() - loss_info["masked_latents"] = masked_latents - loss_info["detached_latents"] = detached_latents - - loss_info["latents"] = latents - loss_info.update(encoder_info) - - # Encode with teacher model for distillation - if self.teacher_model is not None: - with torch.no_grad(): - teacher_latents = self.teacher_model.encode(encoder_input, return_info=False) - loss_info['teacher_latents'] = teacher_latents - - # Optionally mask out some latents for noise resistance - if self.latent_mask_ratio > 0.0: - mask = torch.rand_like(latents) < self.latent_mask_ratio - latents = torch.where(mask, torch.zeros_like(latents), latents) - - decoded = self.autoencoder.decode(latents) - #Trim output to remove post-padding - decoded, reals = trim_to_shortest(decoded, reals) - - loss_info["decoded"] = decoded - loss_info["reals"] = reals - - if self.autoencoder.out_channels == 2: - loss_info["decoded_left"] = decoded[:, 0:1, :] - loss_info["decoded_right"] = decoded[:, 1:2, :] - loss_info["reals_left"] = reals[:, 0:1, :] - loss_info["reals_right"] = reals[:, 1:2, :] - - # Distillation - if self.teacher_model is not None: - with torch.no_grad(): - teacher_decoded = self.teacher_model.decode(teacher_latents) - own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher - teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model - - loss_info['teacher_decoded'] = teacher_decoded - loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded - loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded - - if self.use_disc: - if self.warmed_up: - loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals=reals, fakes=decoded) - else: - loss_adv = torch.tensor(0.).to(reals) - feature_matching_distance = torch.tensor(0.).to(reals) - - if self.warmup_mode == "adv": - loss_dis, _, _ = self.discriminator.loss(reals=reals, fakes=decoded) - else: - loss_dis = torch.tensor(0.0).to(reals) - - loss_info["loss_dis"] = loss_dis - loss_info["loss_adv"] = loss_adv - loss_info["feature_matching_distance"] = feature_matching_distance - - opt_gen = None - opt_disc = None - if self.use_disc: - opt_gen, opt_disc = self.optimizers() - else: - opt_gen = self.optimizers() - - lr_schedulers = self.lr_schedulers() - sched_gen = None - sched_disc = None - - if lr_schedulers is not None: - if self.use_disc: - sched_gen, sched_disc = lr_schedulers - else: - sched_gen = lr_schedulers - - # Train the discriminator - use_disc = ( - self.use_disc - and self.global_step % 2 - # Check warmup mode and if it is time to use discriminator. - and ( - (self.warmup_mode == "full" and self.warmed_up) - or self.warmup_mode == "adv") - ) - if use_disc: - loss, losses = self.losses_disc(loss_info) - log_dict['train/disc_lr'] = opt_disc.param_groups[0]['lr'] - opt_disc.zero_grad() - self.manual_backward(loss) - - if self.clip_grad_norm > 0.0: - torch.nn.utils.clip_grad_norm_( - self.discriminator.parameters(), self.clip_grad_norm) - - opt_disc.step() - if sched_disc is not None: - # sched step every step - sched_disc.step() - - # Train the generator - else: - loss, losses = self.losses_gen(loss_info) - if self.use_ema: - self.autoencoder_ema.update() - - opt_gen.zero_grad() - self.manual_backward(loss) - if self.clip_grad_norm > 0.0: - torch.nn.utils.clip_grad_norm_( - self.autoencoder.parameters(), self.clip_grad_norm) - - opt_gen.step() - if sched_gen is not None: - # scheduler step every step - sched_gen.step() - - log_dict['train/loss'] = loss.detach().item() - log_dict['train/latent_std'] = latents.std().detach().item() - log_dict['train/data_std'] = data_std.detach().item() - log_dict['train/gen_lr'] = opt_gen.param_groups[0]['lr'] - - for loss_name, loss_value in losses.items(): - log_dict[f'train/{loss_name}'] = loss_value.detach().item() - - self.log_dict(log_dict, prog_bar=True, on_step=True) - return loss - - def export_model(self, path, use_safetensors=False): - if self.autoencoder_ema is not None: - model = self.autoencoder_ema.ema_model - else: - model = self.autoencoder - - if use_safetensors: - save_model(model, path) - else: - torch.save({"state_dict": model.state_dict()}, path) - -def create_loss_modules_from_bottleneck(bottleneck, loss_config): - losses = [] - - if ( - isinstance(bottleneck, VAEBottleneck) or - isinstance(bottleneck, DACRVQVAEBottleneck) or - isinstance(bottleneck, RVQVAEBottleneck) - ): - try: - kl_weight = loss_config['bottleneck']['weights']['kl'] - except: - kl_weight = 1e-6 - - kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss') - losses.append(kl_loss) - - if ( - isinstance(bottleneck, RVQBottleneck) or - isinstance(bottleneck, RVQVAEBottleneck) - ): - quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss') - losses.append(quantizer_loss) - - if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck): - codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss') - commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss') - losses.append(codebook_loss) - losses.append(commitment_loss) - - if isinstance(bottleneck, WassersteinBottleneck): - try: - mmd_weight = loss_config['bottleneck']['weights']['mmd'] - except: - mmd_weight = 100 - - mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss') - losses.append(mmd_loss) - - return losses - -def create_training_wrapper_from_config(model_config, model): - model_type = model_config.get('model_type', None) - assert model_type is not None, 'model_type must be specified in model config' - - training_config = model_config.get('training', None) - assert training_config is not None, 'training config must be specified in model config' - - ema_copy = None - if training_config.get("use_ema", False): - ema_copy = create_model_from_config(model_config) - # Copy each weight to the ema copy - for name, param in model.state_dict().items(): - if isinstance(param, Parameter): - # backwards compatibility for serialized parameters - param = param.data - ema_copy.state_dict()[name].copy_(param) - - use_ema = training_config.get("use_ema", False) - latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0) - - teacher_model = training_config.get("teacher_model", None) - if teacher_model is not None: - teacher_model = create_model_from_config(teacher_model) - teacher_model = teacher_model.eval().requires_grad_(False) - - teacher_model_ckpt = training_config.get("teacher_model_ckpt", None) - if teacher_model_ckpt is not None: - teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"]) - else: - raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") - - return AutoencoderTrainingWrapper( - model, - lr=training_config.get("learning_rate", None), - warmup_steps=training_config.get("warmup_steps", 0), - encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), - sample_rate=model_config["sample_rate"], - loss_config=training_config.get("loss_configs", None), - eval_loss_config=training_config.get("eval_loss_configs", None), - optimizer_configs=training_config.get("optimizer_configs", None), - use_ema=use_ema, - ema_copy=ema_copy if use_ema else None, - force_input_mono=training_config.get("force_input_mono", False), - latent_mask_ratio=latent_mask_ratio, - teacher_model=teacher_model, - encoder_mask_ratio=training_config.get("encoder_mask_ratio", 0.0), - use_ctc=training_config.get("use_ctc", False), - proj_head_dim=model_config["model"].get("proj_head_dim", False), - detach_proj_head=model_config["model"].get("detach_proj_head", None), - ) diff --git a/dist/stable_codec-0.1.3-py3-none-any.whl b/dist/stable_codec-0.1.3-py3-none-any.whl deleted file mode 100644 index ac31d4f20a874f124e1fbcf949997970adb5c1b2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 19930 zcmZ^~LzHM?5;R!0ZQHhO+qP}ner4OX@yfPs+g0=W%wqoTnYs7m;@jjZPex>9q=Gas z2nqlI00e-Wg^Z3VY2rHRzZccN!1x!=E`~-nruxSACZ@*p`udi3mM;4GbPk@<)DyF+ z6m*g@6Oys=GxCzME0dr}oKP7ZpzorjPy+^-<_RDG3jdW_lZ@us83F)ckMRE^wXut_ zzKy-J^S{*Bn$z){Y-oL-b$evtW4fxyv-1IHV`iD6fyJ;39hhMP(#_K%M5>9I<4<+F zuM$!uq_NDBS=-XC>4_4(@V&G6mOIUjiz1V?JauZLa%?ZW$UzR=?pLA}9VpFYDv9g^ zg@&OBaVnNx<(h(@QIkT_npHCHDbRFsnzX4#Rs^2V`nM*_t2QlOax+?OR^G15d>Z?W z+`6ibD(Y$b%aM>Piqn-!I4qI(r1?_b78iI*g5zIh1ah4k-Rd0S+C`6EDv(;7d7q28`4ygvuzsx_%#}l`g)k$ zC@9}o>@A%P)nCb_b00JG;F<}ID) zZTeR%bFllbCN4KB{Q>nt4e zb-b^K+Hv|Jqt|TC@C^-A)z|NahSJab-qK=fFRPc@8UAp3q6o#+j}V&$`F$67-iwT^ zo~|Nkm!Z%IY`z5)zVOk6hF)I$xV0G~%jqVOJOxi>zNV7E9I zmh>r6G3w)C0>u-!&rX{<&=?8;TP1w<5r{|aGlgg*SkY@3vhv)lmB1VdG19Iw=epfW zcVS#G+fa9DqBe)>5td^`5rOc&oSe+!# z1fDlmQ=2RrOA|p0OhuHi@18z0p=nZK=?Xm@=3{G0%C9kgm^7ne$8KSG{F%qcApuEo zT7`J5J~V1C7Wijx+kH^|7rq`u0_6iL1gMH8TM>@E0}%p7uRQN_V3F2T zcTtUy?m}a-{IUwy`Un>etAtLtRwBNBKkE+w5@?A7E>L44VDT3@+1M<*{xF(7;MWxa z0YK{-t(tKOqs7RhW@FT<2g$^^9sbO&<2vx1S>>N`k43ha!rG$ zoO%b*(Ni4K1LMd*xTwk@+Ee@quaYY(A_u=1;?;Up>!#R~$J9`3D$UIYu#pg(fkt#g zSmry6afk)9LFA~?;yVT@=(2ZT@tQ*`(5>Jj=r`rV8EpATW8~qXgY=~Q6N~zNN7)o{ z6pVm)j)-Hj9BS^Th-4m$6#%I%1%^yjOt8# zV%c8IPG2ugx0Vg^fIbhDZjJWE>%J-4n}db%JW{s0$xlYf{3+Lr^ltQS#^Fg2Js@sjtgtI()WrGN-f=Y5}Cj52PH{az8z`RH_N`{52~ zWmN(xv=+>WFw|(r*=D$~!ml_2o?$H#vD>crGH2hNe55L-O8(k7`T4RDu1`frF=nm% z4UK%M*=QIq4Y<}vyD%UTXodA$D^~Vwii?bd8f}X*k=qckP8^nK`72~z`X#P_7gquF zm#^Or@09w7T*_uh4pkwarW1lsWe9cPVdv0JG|-8J!~98dyuhc?8h!zn{E7-0Wp@?> z^~gdJGgNd?E7{MO$qBI04lkXLPc=(>XC#1-GrS$r1!*gr1YgB8f{05^UQ9y2U&+Io zb|Y();m4csM$^QKXYVX3!JEO?6J)*>gRWy^^P0}gXCC(ZqlCPZ5W{NWf2zpF6PD%Is4 zteH4x@6QyA#9=d4+?wDt{O)6yZZd^u!p@76yBi7aO%0tMRP_(^e-jn9X=wBwPym20 zL;wKf|K&ht&W`^$%Cfqw{UIAd&!u|39!gqkBM!MOJeUrOMV7Mn+RaXoKt9O?5lymI zq!bSAOP~82X<}m!XexmffF(OXVbRCBSwYl@=a+$hd!yC}Zb6bNp8 zel}>|eGq}po>+k1`U7!gh(+hMS5Nr1Z`#pUd+ld78f|H{Hh-edo!hhKCPCE+Xqsos zj8$$;d^>V6D_6y`WQ|;fL2*uy;Rux2LN^Gs01gGnK}ms4{8__(je!4x-0+$5?h*+@ zoz@hTAIz1K<|`VcF@$>rqw!oXGnfB;A-Hw&6ciug$~~J{N@SO|J3Swp8u)9_se=M= zn!?|#87S=_ZGkp&K?Iay5!Qk5Kh1--i6^YKgNtrYftRI&+A1qJ(9Ocwjh)eR7wE*h zg#ZKk(jvP|D0XU6yh69I`Q2a=`;=l;vN3FkAA*~}?~>kjp0iOX2Z(ShW!8F{;1j=T zR%e(-v0Q+JJL&?4pjoe|8*T0RA(3>64T;!cQiixlro1saF~9~8se@3kib@E9k5DQ= z0UNGkWQX6Q@C7IQ3j`!W)4ZHQgqk!gIv>+ zX(u@90FQWUxtY!ze$r18t_XMoNEwJE-H)=}NNBZo?(}2~J0$%wT})1p$W6v$(=w7N zQhNz)F-Q+1iEbq#`5XRcaJ++itr3Z3|;xWz4?99AwnA!{11tWn}hm;ECa7h));|y6g504S9?pK(%odF!! z`+5fCP1bn*SvLz75SNy2Nva%lOCP{Eq;BVFwTcva8ZuggI`5?tnkgZY{5(h{CctW~ ze`KOmNJB`*ixWCo+}jwP(wL;pV7l5An67Gr8}!yz3gJ+S2*tJo*(&<0A8z}cnJpn# zLaI1q{=ly6?*%kWASKg8!(c6#Nd0I+gpz7KvnCm7x>9RFL2M$|iw;&G=`R8I&7*)r zz;xVL+>y-L!5<wq{}k*I7*`ql0)(B+v1!jo2dB>nwO%$Zj>T6p5;>b;285}u*8l!c#5E~+ zl@7M~B&M@7Z_4=qFoOsS2mDuvs-3DV$I29c+GEyh)I**Vitn|+Vs{KJ@{>)p!P4W#Z( zuk2O}h3seaeN73s$TW$>i6NCg>p12k#WC$<&Th#a2B58U#TuUH1Fxg!+6aP2d(m+x zTl#hmzYL7e)?m#Xc5X~M#Sjo5p8W2+L!I}N#rocjUZx*}4?adwIn^LQsB1fPo=EIC zi9+AOB0V=WyU_q7n4a|fdGbB3jsjSu1nY#}hU4+m)su8&cCB999MaWKv~;eLovfad z%Em$1BQ9;X6R&T>zDnWGz6xw=bT>W26hX*-wVI{@7(DhsQrqlZ=yRH(K_@uv_9dq1 zAK!PBsUcP!2o!|N(9@Q&HEl32yM!;>V06qn6S?NWfiz5`OmWFdo&HCruwVL;#6{U1 z!8RU*H!jP4YkheOFB5T0z>kP)W`+|F;RHi5Sm6}i%)2i`Qw5v24kx~2&vmvqN4qJv zLuHqOfHN9i>Vk_AZKnkXv-2j-a|F`IS&T6Uzkrk4q);ZEBF#(yL8O6F4 zi7IJNjtZ}vQ(yfRJx$yEH*Kt$Fc4=B>L2j`?)FT$xh$yQ008X&#$D9^?e@0+=3bls z^m`3W`%Ml6zt=i~r$(g$iyWI|BVdofT#G=PuFWTpF|Z^WMKft*5@*I|fpz z#%(?|2!!aPE#}=A2{zb6BwwVM8LAmJi5>QkS_H;RbdfQP9!nBKmOA67T3aieqF0n1 zb7pS|CcWyOKOssf&>|IWnQf$oVo+o$tVYzKhuM$Wwe{;LO!_Mv%cIKU8rN{O^)lFT zq_dDCrk#eQweIsIrgV?Rsl~*<{TFV*`pmuSlEs=_K|n0YYh#kAl3EN}iOYZw#YlN= zG`WZ!e{ufM)BW;_9G`2|pYVk?5%iLPf)to}W}vahUe(ONOokW>POyRFau3*FV?p>0 z7FRQlj7;0GQD5jlxKy-wKm|o7EF9UZwX(*|m_y_usZ+>8k}}B(2VPRGOkObNGRR3} z>D12ZI_W@;*UbaO>NOth(QqoK~5E`6(#5Ah9Of8QMCs|lxv++#QH0|*p6L!r+tYBj~N%S zE_5}Vsw>!2%?|Agjlw3G=M$|lL!~lV*u;1RRP~<;>$Sacp2ma^IOnubU^9**1Robr zVIo(0_dnrl@R6a#IZuqC7Sr)dCqCDd>qq4Z8;({W9+E%4esA9*6Rv_w-rXhgdL z;!+*=lKgz?g7)<7FXMKuk0#|qxkbdjui3nK*rz+avpVc5 zfp+N$)@n~cc3wqPxu|G8bRZf6AO-bu1*#d>7@1lugBDHU5WERW#wZOb(ZC_Tt)A~xaS?&(l}3GWUya)V7}2(!&_h(lpKWqXal^BoKY0Zc;@1W254&ib__U^d-SlT1TaT^p!CpRV7H4p?J`y4YLSRS{~?#`X(5A_EHyO0+vM{%_NMi zke-?&fIb0y;(D3_bI7a_78XFa)7K7SOZB&5oW)dY%yEDXTgGrLlFQM4>d54QM3|HP zD?izQwI~30-FW2Fkvd~sD>wc3BrBh}zriQkbkdSv^~pW`!4oa$k)(juTGT3;wlxFI zlo#*NcrVV+>wSIXUXh-@6(zs>#s7^&jNaGl{%JrCp$#4VLTk4|W6eus;f9PMqVhKuDA#k= zAdk6EI~XP-5O?rDQA3_6q_o#NoeJ$F#LZag^YX@*?DBl8p4n7`p<&7-b`p!!J0tD6 z#zzdRpIX?&?c@`sD@I!Z)oBGS#=Gh%#)T}sP<(BiOLOW|13(xB4FMGkIqa?L$q2K4 z%&mn%m$s&EzRv?Ur|!s3+)y1LN95yfDAp| z;u($(W@ad`9)}Y<5AUp5bMdVCyl=KbN}ZAg!2M;6ofut~{6D;|lg-;S1M+*|0)$e_ zs%c{V(UHig;b|mpLVE>Jn=c_1NKKlyn{&BP*bIye3zE$-N<$XIv!{D@h{E8C14yV5 zg++aHNI$6G11>`aF z&JSH3Xgo?-DB^_!?Hc&FdD|dxg6_H@Ev9z%kkzh*Yl#{Aci!_zTgPo)9y<8~Muf;7 z@KZA)u3B>$&p!CCbFFKGRi<2^`p6XcngT4U1@9i5Ps%o4yc$Psd&EOho2d4oPh7f6 zEVlnT*BcCz%IO=5TY-!?ZXrupsk58$piuRiX|AAAE5_*P?^6?e9$|GJG<-FE# zTi+%c5!<9fjy-@{rQZ+Szm8X=%G^i8VV@kmpEm$4)-lg}v9x)}tmh0hf*f;Pk}E4$ z{e5gw_67ky>N90`YiunsW);nrTk|$0^=fC+1}4^R{F%L!zce+~cz(Y!NsVYT1N7%A z%TZY2y;!%_6{a1eT21(+R@RcY$o*>LtPZu!5824%m#$DV(XlSCyjk7k$**}*pG2Xq zwjuOUhF%u6w#B3=v33?r8?^@KJk_zF;f^aw-c6|+B?dtH;Hwxa#pi3|gbN0rY8G{! z022l#xDg3HB4KQhIZF?sD|FHHF~uk?Qy=~lN73bpC8?ik`-MKnq9fzLwKV5bh6d7R zbQBc4V@fhXm!4L7nNob1O|T-0rbvMJR44n79os}|P|G-?ALrGCRJ7E@8OQF)oxex%nWltNMG8 z`o~5{-J&+`7X|s1P*6uB92g-ui)+#x@>kz7FTPOoLl6XLCF_-8;3w-vw!cy;`34Ok zFQ<&tj)^$Pvc#^Ucg|ShPc_TpQ<67Qaz}X8$n#x<;g4Ed%q7&SL>pOr;c+m^&Z;`) zOsRvt^mG-QIn9HkpyXINWU?pm*{2nQFe564GeVUrYa3cuB&P|c^?N?#DJMRNLlC3l zY@2-V56)1heFwN%rA+x5%W5nk|pzeebP zzS9S!UXgwK!)1 z$CTZ*2ytvqmGwiRLj+WtHD+aE4xX)$dg@O`p<(Z_BhjHQ(p=XyM)0?ck!{$JU-#{r ziey$xgzx`T|Be%emGm2n@6jQ^9@NZSwHGUz zm*}o6VG_AOebpDMyyC#BCp&G-s<_?8p?%-C06_;IWT}{3u2g6w>JP1Ewe&<~C`soC zAfFLg{CkaX(gYW)G3a%I{DNhRvJTe%u`Eu9aCz^atSRbvA87S1XbH0F@+^*A53 zSC?QaHxEge5btD2vdx@~i@XrQIGt%>S@l*dScGvqAa9fUAmBLym(V_|Rmho{+Bm*F z42V*b`i0IU3|gf{SN262T@Io$z$Ag`upT4pC8k>>M+zUMiVHwoCwK-5rKNSOfKcQS zidN27$Tml21i3!?NAA+A=O@6#C@uiQ5fn9%!jU>sums82rkFg4BfW_DV5~|=50J9b zq0&Q8qcXlaEZ%&R`XTNrJdxeZA4;YO#1Nkz`3t>Xhmc|vl+1_2}bJv7U=G zDX1&<%B?jNQT&z`B0Dq{p@u*v(rdJh%r-`>fVsrUvN_RviV%JPik(KJ0`wEX;g+f%lBJ?ZH6UEU-G)dNgwMLM^@eC%?7J?G^v~%sk9LtOg6fElaBtV#mqK` zxQAL0`s%iQ`Cn{xC2Z!bc6OU0HD+N+_eryzD5D`^u#<6$I8L({6T3R>)y?GAGp=D# z?FQ+{?Vt_)7=A<3zFtIKqd+>ygs|_ppP8P;_8nu#HK7*^*O`v7_fJjpm)DbPCsyGb zWfv-;Y@r3KZ8Esklkm}1^=Qo$Qy)qB6DApl6xJ8mOeUMXKHfpF1+X?{EuQcRBmUWY z6a8*wW2Ub+*{6=ISxyD{E603hZU{M6OTJeJrk_7=MH$VgZ!@R~A>AH`ATD^rjCX)nAl;eh5>o9A$N|^F$)IM6Y}=rDYVz_NBa(t?KtY4s7FnBz6NcZ(ak= zAH8x6Uj8Qy>9{%3fCVo=C7XE0mWQQjIH3v2^HzRj>@l`D-|LLV&6&#H^a5kjF4U!Y za0&*KPOq%Arrx-`Zk^)vu=HRjRHLIIn_hV(BkEg+}=G1rjt>&sux(l$kfLkt;V_c zg7;OlhNKKU*dRpw+SCVodoZjMvy!#EfHV5O7 z&*F#JQvOwp4*(oWvyE6&MJ+pOb~TY6%(g}K{-vC4JzYzd(^HJ!>Q0Qi#fLJ3RhzIH z9V4&Nk{h%~!y?#Q`z1>+uVlxbLMy08*)^krf*z<`p=_>81#)+G0j$9jm#v~Opipt) zn1vagY!K+!p}tZ1@X)LIlaFoj=d@iJU^Zwkb|&)9`Xf-Vw2oGqVn}GbrO$UdpM)I>=e0|nUzqv zH78Jj_I$Oji%+V0<}j}GZl&@|r2?U}99fI@s=mhWHrsv$)YZ-9&|TSs`5(ykCi`em zH2u0@Erk4C{5N=&Ig;1jzDnt{C|T2ISXuOiV3=u3Icqg2T@U`%2X zu5EZNN*lac3moRYi~@j#mMw%|1Z)gE2kvWULeWWf=Khvhz8Y<0>WeJi(->w3?ZV3Upwe}}TmumYOUVn5ypS^J`9N)jB7U^PQCwDy zoDP@*N<)Hn!raG)R{aA0H{m>r-SG(c=Uh-V0RX`N{{`HC3XZFd>3?hmJl6Ku;z_&D z)L<+(98YnUl7C=dLX6?~Hodz7dd4R7K^*{X*VH;aiAuqyZE@$pvc zOB0VZ!^DtpH}(`;WfjebS{w~^6`+V^QMz%8v9?=%NuWwa&3oN>R90YdsmVyL6i#{3z9t2ctg9O&^QR%DmO0aBwrH>g(-^YvR-`-VA*V^>0&6ZCcL%>P>_h*c2Gl*Grv(A3qOuwT=)jN4rh#^6MgQg(w^g97fvoTZGVTCMt994yUh&H&?K*UWF8spWL zslHU)8(&LtJTq%QO?{AF5Xu0X4IC+3r9(E#(_i~)?O|ly^+C3gsZYfC?R>5Nwx5-j zp6SgzPzP8UCrVYoQj`MKI0Zl+p((#s9w#C>*%kC9$7i+W7hDZ;=UTpctyq>BJu~S- z!1b8ZmG!m3c_LZ?}Q21lAN1DQlYb zIs-5~s)H<>rr&gyzdqjEqM4>rWoRF~vNr03EyvXEzN-;1FLPr`>c(eH`sK#n5eeyf z@;@qIwM~Wn{50c2Y|>YMJn&Um{Yv4!3yLFAc`LAYd0KOE zIFy|qg8|3M$v}@J%ZWFg>ZPLmA<>KZMb!5e=t;B(_|JkMxE_fqa{ytsudkPvRSR6V zZB^@@>eYyPDOifetvhG?jc#3ne2-POCf!m z_?@mS-Ixvu^_@Gw?`YvgM&Rx%P57IDy1uJ-6&i84YnOLGRA0a>1F8X#HCRsLCVaoa z;zvk{*{s09iyzM2#?zzTghnHIq2 z(Wh>DjBOZz;V%r3Q8*^Ym*{{I83acIl~lfvV`UtgSB@MsP9pHjIhrDfQ7nM0ANO#{ zH$ci4S_aCTA3(%cX@2c9K(qGIBHChdP#JS9vwWu+QyV!}-a2Lh=prX#n*|#optMF! zGcZ0H0@t;FVN*`$vVpzRgMf7jBIV5ptM|qRq}S0V08#`KpwftF zq--A*WwIS%vXY_{>42!|R1){sa~6wa^Ti9on=ZpZzBGP^slil+Po=IJUu_dz6yT@n zMqCT}RBX{i_-;$^RmP!h)BFu9`9@MSK0JkfE}Xa5ATI366i+!*Pg@nw>}ICuusd~m zq811E<0=YU-8iw9~P ztqVo8^ujHrek#Gvdsi4`pm$BB)-@c35p6Q-D`v4HJ|yy_0j=rJ6`>P1>oDNoT+d$0 z1h>QnmOnhwk;e_T1f@|6QS65BA&PmKR>`p2CWIWIB}GxqvkOGi-oD#3=0jv7J`Rwy z3dF&LmEchZ|{F!^3eqWdJ4DgezDQ6@3{qa;x!o2tIEP7T_ey%1R7PHVqeRJ~P zLE{)1Ch3}(ClidUu8s5#2Ya4<0_6Qz_)II;6L;KdFw6#C`@jQv$cV!!%H?$w=Ebtd z070>?LmL*aF=%G`gV|J8ay~kbh&yhqx9sBx)J(hWoyZn;HeE6Av#-_5g2&O_V4iwS}U%D<08&}Mr-W_qRSsu!57X^mkOZ1HeFa;p8 zr!_!2TJhqeH8S*d=E;b9&aSN@16eS=oO-xH&tZ`iPY;tV!w$TaLu8&CfiQD-krHKY zddKj3Q38&NQ5Bh1uAxJe=RoGv4boT{RhL@fy?wKY{v=EU<6+&^TBDE_y_lHz^qcn9 zuMRpHLsisiT|GO#c!hZefpd6pGFfh9KwbePM7#Ke?n7`kl22Sb@PRK6qQ>0?FbY>| zhQqMLEWV{k6h-nrP-WWo?W0pgpO5|}!_k9F^w;Wk9OxGB3+WUr%khA1NR;a9#G}? zgH92~=)kK6xv?|BwwLeTk4Sv!+^^0%g>(#Y{un=w)UhG*3~+Y=V-HiZ`(!(@C$xP6 zVkEqq!$pl^f#}LSoNv=5hBvvz%=_NnDc}*QoASS9d2zd%)yd}s0EFL3L}6>CE{?1z zmPH`Hu$`qTM^nPYu!0Parr`t@l*}kNCUZu4^~ZTchz&#LPGf1x;iE|^izF6*d{t=T z1#3zH0O(rXScn`8ZS?-w^O9Xl3auaAnmf7qwjG zJ|s!YGqP_fctRnHb0vqhic=b{LdNFWEcK-eLZeyu`xC4 zC_LHJ!=@u=vwyWu61_*=@g7UBx|9VMfmdt@$a&K zkMiPCNzKxT_nX_VtMY$eq0^hzK55M}mX{V89;_^*MejR{QE&;SQth1@;=;DeDs*R} z%0`W<%rb2EyS=HdDM-KtO&$}Hf}Qr_*e!eIAD@;jXmTi)#f^IoeKZ6IIxw+OEP>;U zD6oCH?N%*8C4wJs3(lE1NkP%nllH??kN2op(jCAFU$iBWCaX5eWgd;Nk3;=blk?fr z^~^O?J46UF9M9MFs|9a^?!4EZ^@0-KoI-u|fE#lwCkK{XAy>9A#btz<=zW_5WS5ezvnpAb+zK_UG-NSj07ui zXS~?$kODyM;{tVTvAZP+oPb#ndr5|6=`JxJfFld~@PbX2$oOosag8;oSv@)#xZZZc ze+7J7A$@ft(wqaG#=6iG45fc2$Ch+jf9)f(RZ@_u>k^8j{w7RKl)QqHz72jofFmnV z5Ne@|izxwTNXmPU(3^sw9N*jeIb4?pCQ5hOS8()^jg8WZt2(7AJORNeSr5YHngNHkJ?Xj-5@a|g3F+^Zl0gU zEmI-Ox3|;yL%IoR(Dwy-^zNB-E*2+~IJJtZC!SHjJY;;ZH+a$N&EQsMXT!_7#@n#| z8vF383sY?rnN_JPsGh%)WRLbU|0^K3)5{UyeEkPx1&Rbm{r3@U1qxKA13=WS`y`(R z)+KzthEE;8r%O158>inH(ys<%+uh|?i~h^H=T5#uPeF_O<_=){65p^t9`*nyQrjr- z5IP#dy!m%K8Pco-xfOk~PSLU5xJZ;yq)j&N0u4;}d@bl|ZDftv6JEGe&{10YM%x%6 z^t5%d^1v~4b^a+y-6i^w1HLMP6HIV}W00%^^sQ(Ei(O9G=ZnL%)o;mn$>~_()x*b<&^HdjAnZL3~`iUBp)$j3qXT~7L#k;bhil-B+~EFbL1$k zokl~?Wc(a-SQvdZT!GtBDf1v#H#>gc2~(zsXoGjFSb!PUif1|7^x59;(eXAN8~1f# zFWc|}{_&+(G4iNpykZXvoRr^p3%)B*>6NYB1KGBD|;| zvUSq*jkN^Q`J%A5+H)^_!CUQqxXa9V&P~-A4P<`hHKsiz{zBoGrF^{qGwNA+pr|oBI4QF$D2C zYi)~;31F`Zt>g0A^t)4z7-*9)NhII=InNs-R7{e!!5boQ+zY^H+zOW5bR@9T-7FW( zUBu_`j$HKx#wEejTIM3dp>f80Gc)!w7awuxi#Oeh95?T}#jXD%R#Gvctw>p#qiw)jtE81Ib%FtS6s8=b3B=p}Yfm!VsE9JdUv ze9N4RKd&Lio7hx@kjQW^v(h2cRdQ20PT3K}^a$C_pSWf)hOBmZ)cZyh3Ncxx7>Djg zx-8yjrK%0Rho^YiC^aeZ!Cdn8A~$Ad%%xwXwBVHZoUW++3(27y?&(Vl4=(k;L7|w^ z^8=B2=0HsQW|On84|R$ArtBCvY79Rb;$bQLTH@6pbH-v!UP)S)OPsPbXv=`xg39uY zDG1VijXA?{Jb2cm6zA!Q`LjHK#oh`U53h!;oLq_@z41AwMyP?Q+>tk32+NHex8aFU zRL}8qUHwo9Q~OU$mg2<~_t;@$x`g(JcBGW&xdS|2iJRPoq+R=YyPr$4b|zZ=7fk$e z7itvbL71B+WvS&TRrTMXE!Yiiac9xZBDeCLoMxSzQ8pR9m9Ln4FUixaoUEqD99WefxyP`IGxn>#&erdV@V2d!%j&1R_PbQSN zok5z+5k~;Jb;Aoj*>?b$QB}+C=c`+{i%8}%rtf7-b0}+z z%wd-aDD5S$o?jM5K(e-?5sI_$yx5Dog2rFOJxW~C7pcR46=l8~sxSn!^u4P<11BM) z1!WeB$3jYdVH7hL8CLAE+U$x5$>ZxKK@h~G9G)uFT|HJf1Yr@%>Bsmedwkr@0;5on zi4IBRKp=EeSRFuj83ohVwa4TB066ZyBNmyJlb^mD%B=TFtn`EvLY|XvkQgrKLJd$! zWG!pR<^^Ehti5QE zrePu3a~p0lg5>j6vXzF=FT6)k%Jz6rH=X*}UCN|f z#^|^&@%J;C#?cr&TI#CDc0qLIMf?Hlp254qu5ZFTnDIp*lSBzw5DWgiS-@llZq&+U z&+%!hn;CDco?m*&s7M?h*ifK@s>;b@Pe;H%MHDpGB{^L+KwQ|N1Hu0kV`gD1atYJW zKN3i9wt)xTXTs6&rUh@Am1qY-=NiVyJ3s)uDRjYM+nrLF1Y&NOO~1KxC=1_?Sdp11 zpoyZl2^`9hV@+(7?`9GXy{%nLx}7hz-@SttoTLa`B<-1Ku;qmPGLV0oZbUh% zv~u2=jU{;Fs_{Ji*aGwZz^gt2?a`WG8AAa^GoJmxYC3xBWUw3)>kG5|$`P(&FyNUlZF`$GF%X9W zpwJF!1p_JkR2fOmywK0{uG7lu!7>m$FSb9E%U9$ngE*Z^zA_JZ+np3qoe=$4T?2XU zDDCKD&5rh*p^qkJLL^V%WqRh_;6pA0APj<&yiM%#>5l z{@a$*g8YmB5ie;O=osmk=}au0U1%-s% z&wk*3Zv`9Q;7g^V^cQb@v%qjhl6WKF780ye&S> zxYj@l!_B$}$K?P4{m`-5JMRHtn<<7Jm|i51zxebr#}6F9Wg&W9VIKrfHy`KY8X}9^ z+=F57FubARpK4mS+t)swL&n26I?_0h9E`cMxgYL!@V4f+Q6c!s z;9(?Z0g)4CLE?Iyy+1ZMSI)b!bxIk=7)dl&pa{3Op@a^V9k#X}=Oo(*oW!EfrVuS&~Il zr2|0R((0&zsZ-5? zqM-w4-iXlIDPdAibkHD8y7%gg#3)VXWB%IR7&IZN5b?(KElH><#i*o;9WoE3Z1@PU zH=;p0&{}JaDOC{}Hp_{WXMi}Cv^UIE?&n4CLG-M)yb*-R+~IUs0Vxl^?wh~-AFj#& z)URgu+A(?bUyUfm2F3EVgSIwBWTtM$J*zHPkUPdSTldSP>h*dFHb;2#kzV-+<;zQSDst@Cl$Wb8b>qrwb z(ac2DgW4@Wb&refm}NFdlLBQfM3$H!TC994%ccG!BsXB4b_%_?yPb34qb3u?kaLS< zihJpEj<0{kX1e|y7wwZ~mc~hsIErpJr|w*F*!!k4tVK2;Waduw`G*o6Ov(Fp*&jNG z@=%|?0a*Ow{#nmq>Bo@%?4?`TbWn!gDSMe;GFgYTgI|Y^z^E4QQ5WZfP6{9Cwgyhs zD8;|lc&ev|gR=`Od%0M?W&G2KaEP08dECYKz_3)Mn=idglf(Nv6Q~%?IG3J(D_5O( zkybc1xS?3p8ctM=`jnJ=dAL#8v9?7A-n5h>aYUD)$ADW6d5o*ttarmfZ7zGZWw34Z zlEs-Rr376CGVucpzXCjWm7JYwX_5k@H_+wuOoJ=YLxxttrBS}D+qhHV13(?xVJZA4 z8^+ho{^0}14N3-ebBb_AefnZ`VE5%e4<;Xw{3~Epp8O8GrvsGY=UPN++OevLI+0&EI z))w=J57JPc9It+i*y)EOOAx4R?m0tvsf|RmceEM=&VV6pq0qc=F&b5*!6fMuOoHZ^ zP7^`Bdawb;q|B~WE7)+fv83H3ye8s;zf)GlltkAhPY(9X{$44e>uvJc&1X*SF1LZ4 z`UPON2}3Yjo|sfSvdv#I+)<-Kk%j=Ye}-BS zEL?ujyaoXLRM6NVcgMFlULM^2o+#9r!$eiyTdL646XuL$ordxIyfkUk&h&(gK`jdW2br|)z0wSH1mH){#JS+1D|v1IU21o7M=n9Mx$)YeF9ST=rn(=f zyhM{nL(;=2tYcn*B*h@1FGAE;J7Ih*Pdpp};lZ>=(J*1LdauF#b*CrDz>tG96=Cm)Q&c%@Xe*3N+A`b!%am8{Wr!YUa z$3X`-B=R9-M&%@7#0U^#4dt4sMgTs*1QEKjzfmA5cjT6FS4rh}Gac~VC&ae28WQI~ ziYr=mDAoe#0GLuQR|}Al4LUSqS9D+bBv{QQ1i*=g7-FENf*?eg0{YdrDHw3zwt%%t z>6{GcT@cVpiE@0@0XMww7)&mq14gwFFR?e!AZtV8tdRK29Pu_)_N+j}RJ2FpM1 z@|BO#Q1H09(VdxCpO$ym8`xYwu^|-E7>vwE6N@@g^ta`- z`(^LTx|)z9tdJsOasG(}feey}k0U`EiQpT>uOy49^Yw^b*px`DW8-M`|=mL1{9p-ig|h>SIAd?6dbj+3F6z zfO;MHxn_#z1n5#tEy4P_PeKIim)F)wODHl%G8<#rt1+uVp4x9c5AQJ)*dEcm6lt~2 z?K6>oz#XxTTj1xB_dg{snpk?-Dr5^p;%mG-zi+C4xl8Wb=1?zHg2&DE|VNa3BSMe`8(+z_&OsUf1WeWDi5=>+nPW{_N zZZTo-Q&HJr2H|a#oghCW z5dsSgbV)of3l>b^fGQ~+ONu#LLgR}7A~+=$ikPnpz)UA5q;MW^_%f_?>C6^zPT3B% zJ{WN8{nfj4Cx(u)lRX5;HE;Rt}e;`ChYn^cu^l ztE-9+Ry_!nKu6o!ECKtgE`eq&O`!Bzo3I_JpwCuE*uB%;*V9AXoHiNFf z;wrw$W)#l$I$ROMmYO7*JA=53cD|-&ZxiYf2Zi{J6B7z~stF4Ox&4P_druRZ=++4- zwEQ^J-%oh`U49lXfEjk-fCg5S*jKv>J!6FI$v`68!G!K^qLWt;YU;W`L3f1dXro4v z_Ps+aQUF&txG@^UWvoM9M!jq^NIODSR1K64g{T%hlI7I4`9R^=9>xV(>Z0|2BzZWP zMMF^4sXD=|QjJ?W;iQ}=ldA-KG%!7cx#q^!P^6Vwo{bIa#(tT#vIGyGV}gh1OmzLK(ERq@}eqDwbeqmDbXh5Mv#Kl#FW8 z#ZqnTOH#`u!dR+BrKGmnYF`>uq8K_2#`4C@=}VvY%rrUA$*26z^Wn+;pL=ufmEhY( zGnRgbq29C$snN1IdTb0Wu=f>C;J0m#!+eL>q zL>=!;|aKlE(Q*iiSvJx)vuu~cYu?t?ZJSrJ~Bx(6;6(tKvyi*_kgueBj zHr8D~3t@h|i3*CkCfkvn3 zP=(YYw<2;yn3}*^ebVZfvB(YM$HK(<%Tfh80lfz2e~Ymf=R6&d;9q)&4+cMsl+k#L zZT7mSBI<>*JstlfS7hJq+v)YsifxEj4(`{u-!aMsI!aoGf!w`2_2S~hG-PS|aTX6% zrmNyY?c9^!n-cjhX5wjmR}*}h%mCMo+v1NFp&xB!k2tj->MJQw=E2!3+Zg&!Pe`G< z18RiM=3m@X z-6f=oRz8Q&_ovL^CHWK04uzt?EtJkHN(F1)1_b!%wql$9MO#m-ovii2q=bJYHPz z9<9%nLy?&HUS7kWVM&8=>-p>vSA$)*h(} zpJ*ny1^HN6o4F!ipUF`zDZLxN`jpqnU)@u1<)%wp#>y=dC|x!xhQ}|kJ}rATWqi>! zUC&wkJ`n?zHl(aiT`F|32A!N1iy6hAWUPxMZg=#CYzgabQ&Y9-0<)w%uLPJ|Xm^Ot zHL^S8v?5%4r14SaEak1RP&aZ=#OFw7y)tfYM$x3zcmZdeUv7{g-iZ|IZ884`=dGza z%G{#!wV#t1xfzdZIFO2R$N`pv-H8`)ilg>Z@#U5f<+9q7DswfITR{mwD^QNiG9@eG zJl=2fUr-C?>+>q5!iKG?d`i@HXc!9gIZ3*kHlwo`gspmR3EBK{X#uxzD!=|(0XU3y zz@wY^9)SyV>?6naiG9#r^rFPr{1*zf#&GdiY2Pe2;F$5zkLKK4Q z7Iee!d`v_B==HLWrB7Hi#iiT`e0?TgNqmYek~Ch`|X)bu#xw*IrS23>;Uh_LIz?@z6k7+pd^F=N@8 zqx|~t?=ijNZcu}nN}jSq`MFl-u9YDWjE8}G1eWpF2Oh;vJDu5={I}9iC5r=J>q z(he-d-->l)hNs_9E)36qrA8l|);Z+24H1&XSNg}@UD3GCR~#y~xu-8}CpESZ?Q#)j z-;*Hw+F7IXucysX_R?Q);h!&RD zl_N)NQ8}8$i2LEU-EAcm1ZLZ_8fW7j)wbW!^I9Hwq>d&p#L&)JmR~9v=9eubUFfN` zefvpjJ2X8Vj=TJPN%he^1vQ^Ux9aS>|C#GmIW4Z-VPdS9_p_6>9u*c^1(LKK4P@3w>{UUI#9A@y;y0KkmKvAy6i=0CU)eCQ81k-xKe(-2(|3%K>k7Aj zHM$`w7eYN%bCW(od);^UTn$wteqM>eM~81{tMF{)T1kDd_%!uv9!4@~h8h~_1B&VY z5b96`d1DPbQ8J0zK;d%(1m;V7dpnPfO&m$fsh;CRUFJTr+H37K3dM268xPKyHDZnl$NqU#`s|>H?;jI3YP~Bjt_WcU zmN_Uay{i&Y?l0Wy1Ajw2E@dWc38dEWTJ~BWJ8QlypvM(~OuA4B$@J55xL)2Kt(4U2 zfPxW2HDSZyX-;4fRf+aUAFivMULM2*!$e?T^$8f zfP10-Hb5Y9^v&7bGe|%HI0w87EVJueU){!+sbC-h9D&*; z^dSHOFc1ZlfWstvq=NtnI9dXvfNiq76hZEvsQ)$1?!|V35U^~$i+SsOjs0H{5Lm

<*oAE98 a??7*3$;G|%B?0z3oqay2vYTB%p#K7ISnm-4 diff --git a/stable_codec.egg-info/PKG-INFO b/stable_codec.egg-info/PKG-INFO deleted file mode 100644 index 3f60642..0000000 --- a/stable_codec.egg-info/PKG-INFO +++ /dev/null @@ -1,223 +0,0 @@ -Metadata-Version: 2.4 -Name: stable-codec -Version: 0.1.3 -Summary: Stable Codec: A series of codec models for speech and audio -Home-page: https://github.com/Stability-AI/stable-codec/ -Author: Stability AI -Author-email: julian.parker@stability.ai -Requires-Python: >=3.9,<3.12 -Description-Content-Type: text/markdown -License-File: LICENSE -Requires-Dist: packaging -Requires-Dist: wheel -Requires-Dist: torch==2.4 -Requires-Dist: torchaudio==2.4 -Requires-Dist: stable-audio-tools==0.0.19 -Requires-Dist: pytorch-lightning==2.1 -Requires-Dist: prefigure==0.0.9 -Dynamic: author -Dynamic: author-email -Dynamic: description -Dynamic: description-content-type -Dynamic: home-page -Dynamic: license-file -Dynamic: requires-dist -Dynamic: requires-python -Dynamic: summary - -# Stable Codec - -This repository contains training and inference scripts for models in the Stable Codec series, starting with `stable-codec-speech-16k` - introduced in the paper titled Scaling Transformers for Low-bitrate High-Quality Speech Coding. - -Paper: https://arxiv.org/abs/2411.19842 - -Sound demos: https://stability-ai.github.io/stable-codec-demo/ - -Model weights: https://huggingface.co/stabilityai/stable-codec-speech-16k - -## Changelog - -### [v0.1.3] TBD -- __Fix__ restricted Python version to <3.12 due to dependency incompatibilities -- __Fix__ clarified installation instructions regarding Python version requirements -### [v0.1.2] 14-01-25 -- __New__ added hooks for `stable-codec-speech-16k-base`. -- __Fix__ fixed major issue with precision in FSQ token calculation, which was degrading results. Fix is currently local, will be upstreamed to `stable-audio-tools` later. -### [v0.1.1] 10-01-25 -- Release - - -## - -Note that whilst this code is MIT licensed, the model weights are covered by the [Stability AI Community License](https://huggingface.co/stabilityai/stable-codec-speech-16k/blob/main/LICENSE.md) - -## Variants -The model is currently available in two variants: -- `stable-codec-speech-16k` is an improved finetune, with boosted latent semantics. __It should be used in 99% of use-cases.__ -- `stable-codec-speech-16k-base` is the weights corresponding to the results in our [publication](https://arxiv.org/abs/2411.19842), provided for reproducibility. - -### Additional Training - -In addition to the training described in the paper, the weights for `stable-codec-speech-16k` have undergone 500k steps of finetuning with force-aligned data from LibriLight and the English portion Multilingual LibriSpeech. This was performed by using a CTC head to regress the force-aligned phoneme tags from pre-bottleneck latents. We found that this additional training significantly boosted the applicability of the codec tokens to downstream tasks like TTS, at a small cost to objective reconstruction metrics. - -## Install - -The model itself is defined in [stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools) package. - -### Python Version Compatibility - -**Important:** This package currently requires **Python 3.9, 3.10, or 3.11**. Python 3.12 and later are not supported due to incompatibilities in the `stable-audio-tools` dependency chain (specifically `PyWavelets==1.4.1` and `pandas==2.0.2`). - -If you attempt to install on Python 3.12+, you will encounter build errors. Please use Python 3.11 or earlier. - -To install `stable-codec`: - -```bash -pip install stable-codec -pip install -U flash-attn --no-build-isolation -``` - -**IMPORTANT NOTE:** This model currently has a hard requirement for FlashAttention due to its use of sliding window attention. Inference without FlashAttention will likely be greatly degraded. This also means that the model currently does not support CPU inference. We will relax the dependency on FlashAttention in the future. - -## Encoding and decoding - -To encode audio or decode tokens, the `StableCodec` class provides a convenient wrapper for the model. It can be used with a local checkpoint and config as follows: - -```python -import torch -import torchaudio -from stable_codec import StableCodec - -model = StableCodec( - model_config_path="", - ckpt_path="", # optional, can be `None`, - device = torch.device("cuda") -) - -audiopath = "audio.wav" - -latents, tokens = model.encode(audiopath) -decoded_audio = model.decode(tokens) - -torchaudio.save("decoded.wav", decoded_audio, model.sample_rate) -``` - -To download the model weights automatically from HuggingFace, simply provide the model name: - -```python -model = StableCodec( - pretrained_model = 'stabilityai/stable-codec-speech-16k' -) -``` -### Posthoc bottleneck configuration - -Most usecases will benefit from replacing the training-time FSQ bottleneck with a post-hoc FSQ bottleneck, as described in the paper. This allows token dictionary size to be reduced to a reasonable level for modern language models. This is achieved by calling the `set_posthoc_bottleneck` function, and setting a flag to the encode/decode calls: - -```python -model.set_posthoc_bottleneck("2x15625_700bps") -latents, tokens = model.encode(audiopath, posthoc_bottleneck = True) -decoded_audio = model.decode(tokens, posthoc_bottleneck = True) -``` -`set_posthoc_bottleneck` can take a string as argument, which allows selection a number of recommended preset settings for the bottleneck: - -| Bottleneck Preset | Number of Tokens per step | Dictionary Size | Bits Per Second (bps) | -|-------------------|------------------|-----------------|-----------------------| -| `1x46656_400bps` | 1 | 46656 | 400 | -| `2x15625_700bps` | 2 | 15625 | 700 | -| `4x729_1000bps` | 4 | 729 | 1000 | - -Alternatively, the bottleneck stages can be specified directly. The format for specifying this can be seen in the definition of the `StableCodec` class in `model.py`. - -### Normalization - -The model is trained with utterances normalized to -20 +-5 LUFS. The `encode` function normalizes to -20 LUFS by default, but it can be disabled by setting `normalize = False` when calling the function. - -## Finetune - -To finetune a model given its config and checkpoint, execute `train.py` file: - -```bash -python train.py \ - --project "stable-codec" \ - --name "finetune" \ - --config-file "defaults.ini" \ - --save-dir "" \ - --model-config "" \ - --dataset-config "" \ - --val-dataset-config "" \ - --pretrained-ckpt-path "" \ - --ckpt-path "$CKPT_PATH" \ - --num-nodes $SLURM_JOB_NUM_NODES \ - --num-workers 16 --batch-size 10 --precision "16-mixed" \ - --checkpoint-every 10000 \ - --logger "wandb" -``` - -For dataset configuration, refer to `stable-audio-tools` [dataset docs](https://github.com/Stability-AI/stable-audio-tools/blob/main/docs/datasets.md). - - -### Using CTC loss - -To use [CTC loss](https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html) -during training you have to enable it in the training configuration file -and in the training dataset configuration. - -1. Modifying training configuration: - - Enable CTC projection head and set its hidden dimension: - ```python - config["model"]["use_proj_head"] = True - config["model"]["proj_head_dim"] = 81 - ``` - - Enable CTC in the training part of the config: - ```python - config["training"]["use_ctc"] = True - ``` - - And set its loss config: - ```python - config["training"]["loss_configs"]["ctc"] = { - "blank_idx": 80, - "decay": 1.0, - "weights": {"ctc": 1.0} - } - ``` - - Optionally, you can enable computation of the Phone-Error-Rate (PER) during validation: - ```python - config["training"]["eval_loss_configs"]["per"] = {} - ``` - -2. Configuring dataset (only WebDataset format is supported for CTC): - - The dataset configuration should have one additional field set to it (see [dataset docs](https://github.com/Stability-AI/stable-audio-tools/blob/main/docs/datasets.md) for other options): - ```python - config["force_align_text"] = True - ``` - - And the JSON metadata file for each sample should contain force aligned transcript under `force_aligned_text` entry in the format specified below (besides other metadata). - Where `transcript` is a list of word-level alignments with `start` and `end` fields specifying range **in seconds** of each word. - ```json - "normalized_text":"and i feel" - "force_aligned_text":{ - "transcript":[ - { - "word":"and", - "start":0.2202, - "end":0.3403 - }, - { - "word":"i", - "start":0.4604, - "end":0.4804 - }, - { - "word":"feel", - "start":0.5204, - "end":0.7006 - } - ] - } - ``` -## Objective Metrics - -| Model | SI-SDR | Mel Dis | STFT Dis | PESQ | STOI | -|---------------------------|-------:|--------:|---------:|-----:|-----:| -| `stable-codec-speech-16k-base` | 4.73 | 0.86 | 1.26 | 3.09 | 0.92 | -| `stable-codec-speech-16k` | 3.58 | 0.90 | 1.30 | 3.01 | 0.90 | - diff --git a/stable_codec.egg-info/SOURCES.txt b/stable_codec.egg-info/SOURCES.txt deleted file mode 100644 index 974d85b..0000000 --- a/stable_codec.egg-info/SOURCES.txt +++ /dev/null @@ -1,16 +0,0 @@ -LICENSE -README.md -pyproject.toml -setup.py -stable_codec/__init__.py -stable_codec/ctc_loss.py -stable_codec/fsq.py -stable_codec/model.py -stable_codec/residual_fsq.py -stable_codec/training_demo.py -stable_codec/training_module.py -stable_codec.egg-info/PKG-INFO -stable_codec.egg-info/SOURCES.txt -stable_codec.egg-info/dependency_links.txt -stable_codec.egg-info/requires.txt -stable_codec.egg-info/top_level.txt \ No newline at end of file diff --git a/stable_codec.egg-info/dependency_links.txt b/stable_codec.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/stable_codec.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/stable_codec.egg-info/requires.txt b/stable_codec.egg-info/requires.txt deleted file mode 100644 index 959f1d4..0000000 --- a/stable_codec.egg-info/requires.txt +++ /dev/null @@ -1,7 +0,0 @@ -packaging -wheel -torch==2.4 -torchaudio==2.4 -stable-audio-tools==0.0.19 -pytorch-lightning==2.1 -prefigure==0.0.9 diff --git a/stable_codec.egg-info/top_level.txt b/stable_codec.egg-info/top_level.txt deleted file mode 100644 index e12fb46..0000000 --- a/stable_codec.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -stable_codec