fix NamedTuple and Enum usage

This commit is contained in:
Jared Van Bortel 2023-11-07 21:12:26 -05:00
parent f364636b2e
commit f2292fcc19
2 changed files with 89 additions and 91 deletions

View file

@ -1,8 +1,8 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
from enum import Enum, IntEnum, StrEnum, auto from enum import IntEnum, StrEnum, auto
from typing import Any, NamedTuple, Type from typing import Any, Type
# #
# constants # constants
@ -73,16 +73,14 @@ class LLMKeys(StrEnum):
TENSOR_DATA_LAYOUT: str = "{arch}.tensor_data_layout" TENSOR_DATA_LAYOUT: str = "{arch}.tensor_data_layout"
class Keys(NamedTuple): class Keys:
GENERAL: Type[GeneralKeys] = GeneralKeys GENERAL = GeneralKeys
LLM: Type[LLMKeys] = LLMKeys LLM = LLMKeys
ATTENTION: Type[AttentionKeys] = AttentionKeys ATTENTION = AttentionKeys
ROPE: Type[RopeKeys] = RopeKeys ROPE = RopeKeys
TOKENIZER: Type[TokenizerKeys] = TokenizerKeys TOKENIZER = TokenizerKeys
KEY = Keys()
# #
# recommended mapping of model tensor names for storage in gguf # recommended mapping of model tensor names for storage in gguf
# #
@ -345,7 +343,7 @@ class TokenType(IntEnum):
BYTE = 6 BYTE = 6
class RopeScalingType(Enum): class RopeScalingType(StrEnum):
NONE = 'none' NONE = 'none'
LINEAR = 'linear' LINEAR = 'linear'
YARN = 'yarn' YARN = 'yarn'
@ -430,52 +428,52 @@ GGML_QUANT_SIZES = {
# Aliases for backward compatibility. # Aliases for backward compatibility.
# general # general
KEY_GENERAL_ARCHITECTURE: str = KEY.GENERAL.ARCHITECTURE KEY_GENERAL_ARCHITECTURE: str = Keys.GENERAL.ARCHITECTURE
KEY_GENERAL_QUANTIZATION_VERSION: str = KEY.GENERAL.QUANTIZATION_VERSION KEY_GENERAL_QUANTIZATION_VERSION: str = Keys.GENERAL.QUANTIZATION_VERSION
KEY_GENERAL_ALIGNMENT: str = KEY.GENERAL.ALIGNMENT KEY_GENERAL_ALIGNMENT: str = Keys.GENERAL.ALIGNMENT
KEY_GENERAL_NAME: str = KEY.GENERAL.NAME KEY_GENERAL_NAME: str = Keys.GENERAL.NAME
KEY_GENERAL_AUTHOR: str = KEY.GENERAL.AUTHOR KEY_GENERAL_AUTHOR: str = Keys.GENERAL.AUTHOR
KEY_GENERAL_URL: str = KEY.GENERAL.URL KEY_GENERAL_URL: str = Keys.GENERAL.URL
KEY_GENERAL_DESCRIPTION: str = KEY.GENERAL.DESCRIPTION KEY_GENERAL_DESCRIPTION: str = Keys.GENERAL.DESCRIPTION
KEY_GENERAL_LICENSE: str = KEY.GENERAL.LICENSE KEY_GENERAL_LICENSE: str = Keys.GENERAL.LICENSE
KEY_GENERAL_SOURCE_URL: str = KEY.GENERAL.SOURCE_URL KEY_GENERAL_SOURCE_URL: str = Keys.GENERAL.SOURCE_URL
KEY_GENERAL_SOURCE_HF_REPO: str = KEY.GENERAL.SOURCE_HF_REPO KEY_GENERAL_SOURCE_HF_REPO: str = Keys.GENERAL.SOURCE_HF_REPO
KEY_GENERAL_FILE_TYPE: str = KEY.GENERAL.FILE_TYPE KEY_GENERAL_FILE_TYPE: str = Keys.GENERAL.FILE_TYPE
# LLM # LLM
KEY_CONTEXT_LENGTH: str = KEY.LLM.CONTEXT_LENGTH KEY_CONTEXT_LENGTH: str = Keys.LLM.CONTEXT_LENGTH
KEY_EMBEDDING_LENGTH: str = KEY.LLM.EMBEDDING_LENGTH KEY_EMBEDDING_LENGTH: str = Keys.LLM.EMBEDDING_LENGTH
KEY_BLOCK_COUNT: str = KEY.LLM.BLOCK_COUNT KEY_BLOCK_COUNT: str = Keys.LLM.BLOCK_COUNT
KEY_FEED_FORWARD_LENGTH: str = KEY.LLM.FEED_FORWARD_LENGTH KEY_FEED_FORWARD_LENGTH: str = Keys.LLM.FEED_FORWARD_LENGTH
KEY_USE_PARALLEL_RESIDUAL: str = KEY.LLM.USE_PARALLEL_RESIDUAL KEY_USE_PARALLEL_RESIDUAL: str = Keys.LLM.USE_PARALLEL_RESIDUAL
KEY_TENSOR_DATA_LAYOUT: str = KEY.LLM.TENSOR_DATA_LAYOUT KEY_TENSOR_DATA_LAYOUT: str = Keys.LLM.TENSOR_DATA_LAYOUT
# attention # attention
KEY_ATTENTION_HEAD_COUNT: str = KEY.ATTENTION.HEAD_COUNT KEY_ATTENTION_HEAD_COUNT: str = Keys.ATTENTION.HEAD_COUNT
KEY_ATTENTION_HEAD_COUNT_KV: str = KEY.ATTENTION.HEAD_COUNT_KV KEY_ATTENTION_HEAD_COUNT_KV: str = Keys.ATTENTION.HEAD_COUNT_KV
KEY_ATTENTION_MAX_ALIBI_BIAS: str = KEY.ATTENTION.MAX_ALIBI_BIAS KEY_ATTENTION_MAX_ALIBI_BIAS: str = Keys.ATTENTION.MAX_ALIBI_BIAS
KEY_ATTENTION_CLAMP_KQV: str = KEY.ATTENTION.CLAMP_KQV KEY_ATTENTION_CLAMP_KQV: str = Keys.ATTENTION.CLAMP_KQV
KEY_ATTENTION_LAYERNORM_EPS: str = KEY.ATTENTION.LAYERNORM_EPS KEY_ATTENTION_LAYERNORM_EPS: str = Keys.ATTENTION.LAYERNORM_EPS
KEY_ATTENTION_LAYERNORM_RMS_EPS: str = KEY.ATTENTION.LAYERNORM_RMS_EPS KEY_ATTENTION_LAYERNORM_RMS_EPS: str = Keys.ATTENTION.LAYERNORM_RMS_EPS
# RoPE # RoPE
KEY_ROPE_DIMENSION_COUNT: str = KEY.ROPE.DIMENSION_COUNT KEY_ROPE_DIMENSION_COUNT: str = Keys.ROPE.DIMENSION_COUNT
KEY_ROPE_FREQ_BASE: str = KEY.ROPE.FREQ_BASE KEY_ROPE_FREQ_BASE: str = Keys.ROPE.FREQ_BASE
KEY_ROPE_SCALING_TYPE: str = KEY.ROPE.SCALING_TYPE KEY_ROPE_SCALING_TYPE: str = Keys.ROPE.SCALING_TYPE
KEY_ROPE_SCALING_FACTOR: str = KEY.ROPE.SCALING_FACTOR KEY_ROPE_SCALING_FACTOR: str = Keys.ROPE.SCALING_FACTOR
KEY_ROPE_SCALING_ORIG_CTX_LEN: str = KEY.ROPE.SCALING_ORIG_CTX_LEN KEY_ROPE_SCALING_ORIG_CTX_LEN: str = Keys.ROPE.SCALING_ORIG_CTX_LEN
KEY_ROPE_SCALING_FINETUNED: str = KEY.ROPE.SCALING_FINETUNED KEY_ROPE_SCALING_FINETUNED: str = Keys.ROPE.SCALING_FINETUNED
# tokenization # tokenization
KEY_TOKENIZER_MODEL: str = KEY.TOKENIZER.MODEL KEY_TOKENIZER_MODEL: str = Keys.TOKENIZER.MODEL
KEY_TOKENIZER_LIST: str = KEY.TOKENIZER.LIST KEY_TOKENIZER_LIST: str = Keys.TOKENIZER.LIST
KEY_TOKENIZER_TOKEN_TYPE: str = KEY.TOKENIZER.TOKEN_TYPE KEY_TOKENIZER_TOKEN_TYPE: str = Keys.TOKENIZER.TOKEN_TYPE
KEY_TOKENIZER_SCORES: str = KEY.TOKENIZER.SCORES KEY_TOKENIZER_SCORES: str = Keys.TOKENIZER.SCORES
KEY_TOKENIZER_MERGES: str = KEY.TOKENIZER.MERGES KEY_TOKENIZER_MERGES: str = Keys.TOKENIZER.MERGES
KEY_TOKENIZER_BOS_ID: str = KEY.TOKENIZER.BOS_ID KEY_TOKENIZER_BOS_ID: str = Keys.TOKENIZER.BOS_ID
KEY_TOKENIZER_EOS_ID: str = KEY.TOKENIZER.EOS_ID KEY_TOKENIZER_EOS_ID: str = Keys.TOKENIZER.EOS_ID
KEY_TOKENIZER_UNK_ID: str = KEY.TOKENIZER.UNK_ID KEY_TOKENIZER_UNK_ID: str = Keys.TOKENIZER.UNK_ID
KEY_TOKENIZER_SEP_ID: str = KEY.TOKENIZER.SEP_ID KEY_TOKENIZER_SEP_ID: str = Keys.TOKENIZER.SEP_ID
KEY_TOKENIZER_PAD_ID: str = KEY.TOKENIZER.PAD_ID KEY_TOKENIZER_PAD_ID: str = Keys.TOKENIZER.PAD_ID
KEY_TOKENIZER_HF_JSON: str = KEY.TOKENIZER.HF_JSON KEY_TOKENIZER_HF_JSON: str = Keys.TOKENIZER.HF_JSON
KEY_TOKENIZER_RWKV: str = KEY.TOKENIZER.RWKV KEY_TOKENIZER_RWKV: str = Keys.TOKENIZER.RWKV

View file

@ -14,10 +14,10 @@ from .constants import (
GGUF_DEFAULT_ALIGNMENT, GGUF_DEFAULT_ALIGNMENT,
GGUF_MAGIC, GGUF_MAGIC,
GGUF_VERSION, GGUF_VERSION,
KEY,
GGMLQuantizationType, GGMLQuantizationType,
GGUFEndian, GGUFEndian,
GGUFValueType, GGUFValueType,
Keys,
RopeScalingType, RopeScalingType,
TokenType, TokenType,
) )
@ -278,132 +278,132 @@ class GGUFWriter:
self.fout.close() self.fout.close()
def add_architecture(self) -> None: def add_architecture(self) -> None:
self.add_string(KEY.GENERAL.ARCHITECTURE, self.arch) self.add_string(Keys.GENERAL.ARCHITECTURE, self.arch)
def add_author(self, author: str) -> None: def add_author(self, author: str) -> None:
self.add_string(KEY.GENERAL.AUTHOR, author) self.add_string(Keys.GENERAL.AUTHOR, author)
def add_tensor_data_layout(self, layout: str) -> None: def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(KEY.LLM.TENSOR_DATA_LAYOUT.value.format(arch=self.arch), layout) self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.value.format(arch=self.arch), layout)
def add_url(self, url: str) -> None: def add_url(self, url: str) -> None:
self.add_string(KEY.GENERAL.URL, url) self.add_string(Keys.GENERAL.URL, url)
def add_description(self, description: str) -> None: def add_description(self, description: str) -> None:
self.add_string(KEY.GENERAL.DESCRIPTION, description) self.add_string(Keys.GENERAL.DESCRIPTION, description)
def add_source_url(self, url: str) -> None: def add_source_url(self, url: str) -> None:
self.add_string(KEY.GENERAL.SOURCE_URL, url) self.add_string(Keys.GENERAL.SOURCE_URL, url)
def add_source_hf_repo(self, repo: str) -> None: def add_source_hf_repo(self, repo: str) -> None:
self.add_string(KEY.GENERAL.SOURCE_HF_REPO, repo) self.add_string(Keys.GENERAL.SOURCE_HF_REPO, repo)
def add_file_type(self, ftype: int) -> None: def add_file_type(self, ftype: int) -> None:
self.add_uint32(KEY.GENERAL.FILE_TYPE, ftype) self.add_uint32(Keys.GENERAL.FILE_TYPE, ftype)
def add_name(self, name: str) -> None: def add_name(self, name: str) -> None:
self.add_string(KEY.GENERAL.NAME, name) self.add_string(Keys.GENERAL.NAME, name)
def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None: def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None:
self.add_uint32( self.add_uint32(
KEY.GENERAL.QUANTIZATION_VERSION, quantization_version) Keys.GENERAL.QUANTIZATION_VERSION, quantization_version)
def add_custom_alignment(self, alignment: int) -> None: def add_custom_alignment(self, alignment: int) -> None:
self.data_alignment = alignment self.data_alignment = alignment
self.add_uint32(KEY.GENERAL.ALIGNMENT, alignment) self.add_uint32(Keys.GENERAL.ALIGNMENT, alignment)
def add_context_length(self, length: int) -> None: def add_context_length(self, length: int) -> None:
self.add_uint32( self.add_uint32(
KEY.LLM.CONTEXT_LENGTH.value.format(arch=self.arch), length) Keys.LLM.CONTEXT_LENGTH.value.format(arch=self.arch), length)
def add_embedding_length(self, length: int) -> None: def add_embedding_length(self, length: int) -> None:
self.add_uint32( self.add_uint32(
KEY.LLM.EMBEDDING_LENGTH.value.format(arch=self.arch), length) Keys.LLM.EMBEDDING_LENGTH.value.format(arch=self.arch), length)
def add_block_count(self, length: int) -> None: def add_block_count(self, length: int) -> None:
self.add_uint32( self.add_uint32(
KEY.LLM.BLOCK_COUNT.value.format(arch=self.arch), length) Keys.LLM.BLOCK_COUNT.value.format(arch=self.arch), length)
def add_feed_forward_length(self, length: int) -> None: def add_feed_forward_length(self, length: int) -> None:
self.add_uint32( self.add_uint32(
KEY.LLM.FEED_FORWARD_LENGTH.value.format(arch=self.arch), length) Keys.LLM.FEED_FORWARD_LENGTH.value.format(arch=self.arch), length)
def add_parallel_residual(self, use: bool) -> None: def add_parallel_residual(self, use: bool) -> None:
self.add_bool( self.add_bool(
KEY.LLM.USE_PARALLEL_RESIDUAL.value.format(arch=self.arch), use) Keys.LLM.USE_PARALLEL_RESIDUAL.value.format(arch=self.arch), use)
def add_head_count(self, count: int) -> None: def add_head_count(self, count: int) -> None:
self.add_uint32( self.add_uint32(
KEY.ATTENTION.HEAD_COUNT.value.format(arch=self.arch), count) Keys.ATTENTION.HEAD_COUNT.value.format(arch=self.arch), count)
def add_head_count_kv(self, count: int) -> None: def add_head_count_kv(self, count: int) -> None:
self.add_uint32( self.add_uint32(
KEY.ATTENTION.HEAD_COUNT_KV.value.format(arch=self.arch), count) Keys.ATTENTION.HEAD_COUNT_KV.value.format(arch=self.arch), count)
def add_max_alibi_bias(self, bias: float) -> None: def add_max_alibi_bias(self, bias: float) -> None:
self.add_float32( self.add_float32(
KEY.ATTENTION.MAX_ALIBI_BIAS.value.format(arch=self.arch), bias) Keys.ATTENTION.MAX_ALIBI_BIAS.value.format(arch=self.arch), bias)
def add_clamp_kqv(self, value: float) -> None: def add_clamp_kqv(self, value: float) -> None:
self.add_float32( self.add_float32(
KEY.ATTENTION.CLAMP_KQV.value.format(arch=self.arch), value) Keys.ATTENTION.CLAMP_KQV.value.format(arch=self.arch), value)
def add_layer_norm_eps(self, value: float) -> None: def add_layer_norm_eps(self, value: float) -> None:
self.add_float32( self.add_float32(
KEY.ATTENTION.LAYERNORM_EPS.value.format(arch=self.arch), value) Keys.ATTENTION.LAYERNORM_EPS.value.format(arch=self.arch), value)
def add_layer_norm_rms_eps(self, value: float) -> None: def add_layer_norm_rms_eps(self, value: float) -> None:
self.add_float32( self.add_float32(
KEY.ATTENTION.LAYERNORM_RMS_EPS.value.format(arch=self.arch), value) Keys.ATTENTION.LAYERNORM_RMS_EPS.value.format(arch=self.arch), value)
def add_rope_dimension_count(self, count: int) -> None: def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32( self.add_uint32(
KEY.ROPE.DIMENSION_COUNT.value.format(arch=self.arch), count) Keys.ROPE.DIMENSION_COUNT.value.format(arch=self.arch), count)
def add_rope_freq_base(self, value: float) -> None: def add_rope_freq_base(self, value: float) -> None:
self.add_float32(KEY.ROPE.FREQ_BASE.value.format(arch=self.arch), value) self.add_float32(Keys.ROPE.FREQ_BASE.value.format(arch=self.arch), value)
def add_rope_scaling_type(self, value: RopeScalingType) -> None: def add_rope_scaling_type(self, value: RopeScalingType) -> None:
self.add_string(KEY.ROPE.SCALING_TYPE.value.format(arch=self.arch), value.value) self.add_string(Keys.ROPE.SCALING_TYPE.value.format(arch=self.arch), value)
def add_rope_scaling_factor(self, value: float) -> None: def add_rope_scaling_factor(self, value: float) -> None:
self.add_float32(KEY.ROPE.SCALING_FACTOR.value.format(arch=self.arch), value) self.add_float32(Keys.ROPE.SCALING_FACTOR.value.format(arch=self.arch), value)
def add_rope_scaling_orig_ctx_len(self, value: int) -> None: def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
self.add_uint32(KEY.ROPE.SCALING_ORIG_CTX_LEN.value.format(arch=self.arch), value) self.add_uint32(Keys.ROPE.SCALING_ORIG_CTX_LEN.value.format(arch=self.arch), value)
def add_rope_scaling_finetuned(self, value: bool) -> None: def add_rope_scaling_finetuned(self, value: bool) -> None:
self.add_bool(KEY.ROPE.SCALING_FINETUNED.value.format(arch=self.arch), value) self.add_bool(Keys.ROPE.SCALING_FINETUNED.value.format(arch=self.arch), value)
def add_tokenizer_model(self, model: str) -> None: def add_tokenizer_model(self, model: str) -> None:
self.add_string(KEY.TOKENIZER.MODEL, model) self.add_string(Keys.TOKENIZER.MODEL, model)
def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
self.add_array(KEY.TOKENIZER.LIST, tokens) self.add_array(Keys.TOKENIZER.LIST, tokens)
def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
self.add_array(KEY.TOKENIZER.MERGES, merges) self.add_array(Keys.TOKENIZER.MERGES, merges)
def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None: def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
self.add_array(KEY.TOKENIZER.TOKEN_TYPE, types) self.add_array(Keys.TOKENIZER.TOKEN_TYPE, types)
def add_token_scores(self, scores: Sequence[float]) -> None: def add_token_scores(self, scores: Sequence[float]) -> None:
self.add_array(KEY.TOKENIZER.SCORES, scores) self.add_array(Keys.TOKENIZER.SCORES, scores)
def add_bos_token_id(self, id: int) -> None: def add_bos_token_id(self, id: int) -> None:
self.add_uint32(KEY.TOKENIZER.BOS_ID, id) self.add_uint32(Keys.TOKENIZER.BOS_ID, id)
def add_eos_token_id(self, id: int) -> None: def add_eos_token_id(self, id: int) -> None:
self.add_uint32(KEY.TOKENIZER.EOS_ID, id) self.add_uint32(Keys.TOKENIZER.EOS_ID, id)
def add_unk_token_id(self, id: int) -> None: def add_unk_token_id(self, id: int) -> None:
self.add_uint32(KEY.TOKENIZER.UNK_ID, id) self.add_uint32(Keys.TOKENIZER.UNK_ID, id)
def add_sep_token_id(self, id: int) -> None: def add_sep_token_id(self, id: int) -> None:
self.add_uint32(KEY.TOKENIZER.SEP_ID, id) self.add_uint32(Keys.TOKENIZER.SEP_ID, id)
def add_pad_token_id(self, id: int) -> None: def add_pad_token_id(self, id: int) -> None:
self.add_uint32(KEY.TOKENIZER.PAD_ID, id) self.add_uint32(Keys.TOKENIZER.PAD_ID, id)
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = '' pack_prefix = ''