tts : add OuteTTS support (#10784)

* server : add "tokens" output

ggml-ci

* server : output embeddings for all tokens when pooling = none

ggml-ci

* server : be explicit about the pooling type in the tests

ggml-ci

* server : do not normalize embeddings when there is no pooling

ggml-ci

* llama : add OuteTTS support (wip)

* wip

* extract features

* first conv

* group norm

* resnet conv

* resnet

* attn

* pos net

* layer norm

* convnext

* head

* hann window

* fix n_embd + remove llama.cpp hacks

* compute hann window

* fft

* spectrum processing

* clean-up

* tts : receive input text and generate codes

* clip : fix new conv name

* tts : minor fix

* tts : add header + minor fixes

ggml-ci

* tts : add matchematical constant

ggml-ci

* tts : fix sampling + cut initial noise

* tts : fixes

* tts : update default samplers

ggml-ci

* tts : text pre-processing

* tts : outetts-voc -> wavtokenizer-dec

* tts : remove hardcoded constants

ggml-ci

* tts : fix tensor shapes

* llama : refactor wavtokenizer tensors

ggml-ci

* cont

ggml-ci

* cont [no ci]

* llama : update WavTokenizer to non-causal attn

* llama : handle no-vocab detokenization

* tts : add Python example for OuteTTS (wip)

* tts : extend python example to generate spectrogram

ggml-ci

* server : fix rebase artifacts

* tts : enable "return_tokens" in Python example

ggml-ci

* tts : minor fixes

* common : support HF download for vocoder
This commit is contained in:
Georgi Gerganov 2024-12-18 19:27:21 +02:00 committed by GitHub
parent 7bbb5acf12
commit 0bf2d10c55
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 2509 additions and 532 deletions

View file

@ -90,6 +90,7 @@ class Keys:
VOCAB_SIZE = "{arch}.vocab_size"
CONTEXT_LENGTH = "{arch}.context_length"
EMBEDDING_LENGTH = "{arch}.embedding_length"
FEATURES_LENGTH = "{arch}.features_length"
BLOCK_COUNT = "{arch}.block_count"
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
@ -122,6 +123,8 @@ class Keys:
VALUE_LENGTH = "{arch}.attention.value_length"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
CAUSAL = "{arch}.attention.causal"
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
@ -155,6 +158,14 @@ class Keys:
class WKV:
HEAD_SIZE = "{arch}.wkv.head_size"
class PosNet:
EMBEDDING_LENGTH = "{arch}.posnet.embedding_length"
BLOCK_COUNT = "{arch}.posnet.block_count"
class ConvNext:
EMBEDDING_LENGTH = "{arch}.convnext.embedding_length"
BLOCK_COUNT = "{arch}.convnext.block_count"
class Tokenizer:
MODEL = "tokenizer.ggml.model"
PRE = "tokenizer.ggml.pre"
@ -209,58 +220,59 @@ class GGUFType:
class MODEL_ARCH(IntEnum):
LLAMA = auto()
FALCON = auto()
BAICHUAN = auto()
GROK = auto()
GPT2 = auto()
GPTJ = auto()
GPTNEOX = auto()
MPT = auto()
STARCODER = auto()
REFACT = auto()
BERT = auto()
NOMIC_BERT = auto()
JINA_BERT_V2 = auto()
BLOOM = auto()
STABLELM = auto()
QWEN = auto()
QWEN2 = auto()
QWEN2MOE = auto()
QWEN2VL = auto()
PHI2 = auto()
PHI3 = auto()
PLAMO = auto()
CODESHELL = auto()
ORION = auto()
INTERNLM2 = auto()
MINICPM = auto()
MINICPM3 = auto()
GEMMA = auto()
GEMMA2 = auto()
STARCODER2 = auto()
RWKV6 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
DBRX = auto()
OLMO = auto()
OLMO2 = auto()
OLMOE = auto()
OPENELM = auto()
ARCTIC = auto()
DEEPSEEK = auto()
DEEPSEEK2 = auto()
CHATGLM = auto()
BITNET = auto()
T5 = auto()
T5ENCODER = auto()
JAIS = auto()
NEMOTRON = auto()
EXAONE = auto()
GRANITE = auto()
GRANITE_MOE = auto()
CHAMELEON = auto()
LLAMA = auto()
FALCON = auto()
BAICHUAN = auto()
GROK = auto()
GPT2 = auto()
GPTJ = auto()
GPTNEOX = auto()
MPT = auto()
STARCODER = auto()
REFACT = auto()
BERT = auto()
NOMIC_BERT = auto()
JINA_BERT_V2 = auto()
BLOOM = auto()
STABLELM = auto()
QWEN = auto()
QWEN2 = auto()
QWEN2MOE = auto()
QWEN2VL = auto()
PHI2 = auto()
PHI3 = auto()
PLAMO = auto()
CODESHELL = auto()
ORION = auto()
INTERNLM2 = auto()
MINICPM = auto()
MINICPM3 = auto()
GEMMA = auto()
GEMMA2 = auto()
STARCODER2 = auto()
RWKV6 = auto()
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
DBRX = auto()
OLMO = auto()
OLMO2 = auto()
OLMOE = auto()
OPENELM = auto()
ARCTIC = auto()
DEEPSEEK = auto()
DEEPSEEK2 = auto()
CHATGLM = auto()
BITNET = auto()
T5 = auto()
T5ENCODER = auto()
JAIS = auto()
NEMOTRON = auto()
EXAONE = auto()
GRANITE = auto()
GRANITE_MOE = auto()
CHAMELEON = auto()
WAVTOKENIZER_DEC = auto()
class MODEL_TENSOR(IntEnum):
@ -370,61 +382,78 @@ class MODEL_TENSOR(IntEnum):
ENC_OUTPUT_NORM = auto()
CLS = auto() # classifier
CLS_OUT = auto() # classifier output projection
CONV1D = auto()
CONVNEXT_DW = auto()
CONVNEXT_NORM = auto()
CONVNEXT_PW1 = auto()
CONVNEXT_PW2 = auto()
CONVNEXT_GAMMA = auto()
POSNET_CONV1 = auto()
POSNET_CONV2 = auto()
POSNET_NORM = auto()
POSNET_NORM1 = auto()
POSNET_NORM2 = auto()
POSNET_ATTN_NORM = auto()
POSNET_ATTN_Q = auto()
POSNET_ATTN_K = auto()
POSNET_ATTN_V = auto()
POSNET_ATTN_OUT = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.BAICHUAN: "baichuan",
MODEL_ARCH.GROK: "grok",
MODEL_ARCH.GPT2: "gpt2",
MODEL_ARCH.GPTJ: "gptj",
MODEL_ARCH.GPTNEOX: "gptneox",
MODEL_ARCH.MPT: "mpt",
MODEL_ARCH.STARCODER: "starcoder",
MODEL_ARCH.REFACT: "refact",
MODEL_ARCH.BERT: "bert",
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen",
MODEL_ARCH.QWEN2: "qwen2",
MODEL_ARCH.QWEN2MOE: "qwen2moe",
MODEL_ARCH.QWEN2VL: "qwen2vl",
MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
MODEL_ARCH.MINICPM: "minicpm",
MODEL_ARCH.MINICPM3: "minicpm3",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.OLMO2: "olmo2",
MODEL_ARCH.OLMOE: "olmoe",
MODEL_ARCH.OPENELM: "openelm",
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK: "deepseek",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.CHAMELEON: "chameleon",
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.BAICHUAN: "baichuan",
MODEL_ARCH.GROK: "grok",
MODEL_ARCH.GPT2: "gpt2",
MODEL_ARCH.GPTJ: "gptj",
MODEL_ARCH.GPTNEOX: "gptneox",
MODEL_ARCH.MPT: "mpt",
MODEL_ARCH.STARCODER: "starcoder",
MODEL_ARCH.REFACT: "refact",
MODEL_ARCH.BERT: "bert",
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen",
MODEL_ARCH.QWEN2: "qwen2",
MODEL_ARCH.QWEN2MOE: "qwen2moe",
MODEL_ARCH.QWEN2VL: "qwen2vl",
MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
MODEL_ARCH.MINICPM: "minicpm",
MODEL_ARCH.MINICPM3: "minicpm3",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.OLMO2: "olmo2",
MODEL_ARCH.OLMOE: "olmoe",
MODEL_ARCH.OPENELM: "openelm",
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK: "deepseek",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.CHAMELEON: "chameleon",
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -534,6 +563,22 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
MODEL_TENSOR.CLS: "cls",
MODEL_TENSOR.CLS_OUT: "cls.output",
MODEL_TENSOR.CONV1D: "conv1d",
MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw",
MODEL_TENSOR.CONVNEXT_NORM: "convnext.{bid}.norm",
MODEL_TENSOR.CONVNEXT_PW1: "convnext.{bid}.pw1",
MODEL_TENSOR.CONVNEXT_PW2: "convnext.{bid}.pw2",
MODEL_TENSOR.CONVNEXT_GAMMA: "convnext.{bid}.gamma",
MODEL_TENSOR.POSNET_CONV1: "posnet.{bid}.conv1",
MODEL_TENSOR.POSNET_CONV2: "posnet.{bid}.conv2",
MODEL_TENSOR.POSNET_NORM: "posnet.{bid}.norm",
MODEL_TENSOR.POSNET_NORM1: "posnet.{bid}.norm1",
MODEL_TENSOR.POSNET_NORM2: "posnet.{bid}.norm2",
MODEL_TENSOR.POSNET_ATTN_NORM: "posnet.{bid}.attn_norm",
MODEL_TENSOR.POSNET_ATTN_Q: "posnet.{bid}.attn_q",
MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k",
MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v",
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -1372,6 +1417,28 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.WAVTOKENIZER_DEC: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.CONV1D,
MODEL_TENSOR.CONVNEXT_DW,
MODEL_TENSOR.CONVNEXT_NORM,
MODEL_TENSOR.CONVNEXT_PW1,
MODEL_TENSOR.CONVNEXT_PW2,
MODEL_TENSOR.CONVNEXT_GAMMA,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.POSNET_CONV1,
MODEL_TENSOR.POSNET_CONV2,
MODEL_TENSOR.POSNET_NORM,
MODEL_TENSOR.POSNET_NORM1,
MODEL_TENSOR.POSNET_NORM2,
MODEL_TENSOR.POSNET_ATTN_NORM,
MODEL_TENSOR.POSNET_ATTN_Q,
MODEL_TENSOR.POSNET_ATTN_K,
MODEL_TENSOR.POSNET_ATTN_V,
MODEL_TENSOR.POSNET_ATTN_OUT,
],
# TODO
}