vocab size as a part of a model metadata

This commit is contained in:
Michael Podvitskiy 2024-02-27 14:34:29 +01:00
parent e700b44217
commit cc3fe18b43
4 changed files with 9 additions and 1 deletions

View file

@ -977,6 +977,7 @@ class OutputFile:
name = str(params.path_model.parent).split('/')[-1] name = str(params.path_model.parent).split('/')[-1]
self.gguf.add_name (name) self.gguf.add_name (name)
self.gguf.add_vocab_size (params.n_vocab)
self.gguf.add_context_length (params.n_ctx) self.gguf.add_context_length (params.n_ctx)
self.gguf.add_embedding_length (params.n_embd) self.gguf.add_embedding_length (params.n_embd)
self.gguf.add_block_count (params.n_layer) self.gguf.add_block_count (params.n_layer)

View file

@ -32,6 +32,7 @@ class Keys:
FILE_TYPE = "general.file_type" FILE_TYPE = "general.file_type"
class LLM: class LLM:
VOCAB_SIZE = "{arch}.vocab_size"
CONTEXT_LENGTH = "{arch}.context_length" CONTEXT_LENGTH = "{arch}.context_length"
EMBEDDING_LENGTH = "{arch}.embedding_length" EMBEDDING_LENGTH = "{arch}.embedding_length"
BLOCK_COUNT = "{arch}.block_count" BLOCK_COUNT = "{arch}.block_count"
@ -711,6 +712,7 @@ KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO
KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
# LLM # LLM
KEY_VOCAB_SIZE = Keys.LLM.VOCAB_SIZE
KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH
KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH
KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT

View file

@ -313,6 +313,9 @@ class GGUFWriter:
self.data_alignment = alignment self.data_alignment = alignment
self.add_uint32(Keys.General.ALIGNMENT, alignment) self.add_uint32(Keys.General.ALIGNMENT, alignment)
def add_vocab_size(self, size: int) -> None:
self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
def add_context_length(self, length: int) -> None: def add_context_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length) self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)

View file

@ -256,6 +256,7 @@ enum llm_kv {
LLM_KV_GENERAL_SOURCE_URL, LLM_KV_GENERAL_SOURCE_URL,
LLM_KV_GENERAL_SOURCE_HF_REPO, LLM_KV_GENERAL_SOURCE_HF_REPO,
LLM_KV_VOCAB_SIZE,
LLM_KV_CONTEXT_LENGTH, LLM_KV_CONTEXT_LENGTH,
LLM_KV_EMBEDDING_LENGTH, LLM_KV_EMBEDDING_LENGTH,
LLM_KV_BLOCK_COUNT, LLM_KV_BLOCK_COUNT,
@ -314,6 +315,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
{ LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
{ LLM_KV_BLOCK_COUNT, "%s.block_count" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" },
@ -12485,7 +12487,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
} }
int32_t llama_n_vocab(const struct llama_model * model) { int32_t llama_n_vocab(const struct llama_model * model) {
return model->vocab.id_to_token.size(); return model->hparams.n_vocab;
} }
int32_t llama_n_ctx_train(const struct llama_model * model) { int32_t llama_n_ctx_train(const struct llama_model * model) {