Merge branch 'master' into compilade/refactor-kv-cache
This commit is contained in:
commit
4e4c41e553
66 changed files with 2719 additions and 454 deletions
|
@ -33,17 +33,21 @@ class Keys:
|
|||
FILE_TYPE = "general.file_type"
|
||||
|
||||
class LLM:
|
||||
VOCAB_SIZE = "{arch}.vocab_size"
|
||||
CONTEXT_LENGTH = "{arch}.context_length"
|
||||
EMBEDDING_LENGTH = "{arch}.embedding_length"
|
||||
BLOCK_COUNT = "{arch}.block_count"
|
||||
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
||||
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
|
||||
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
||||
EXPERT_COUNT = "{arch}.expert_count"
|
||||
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||
POOLING_TYPE = "{arch}.pooling_type"
|
||||
LOGIT_SCALE = "{arch}.logit_scale"
|
||||
VOCAB_SIZE = "{arch}.vocab_size"
|
||||
CONTEXT_LENGTH = "{arch}.context_length"
|
||||
EMBEDDING_LENGTH = "{arch}.embedding_length"
|
||||
BLOCK_COUNT = "{arch}.block_count"
|
||||
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
|
||||
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
||||
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
|
||||
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
|
||||
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
||||
EXPERT_COUNT = "{arch}.expert_count"
|
||||
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
|
||||
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
||||
POOLING_TYPE = "{arch}.pooling_type"
|
||||
LOGIT_SCALE = "{arch}.logit_scale"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
|
@ -55,6 +59,8 @@ class Keys:
|
|||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
|
@ -64,6 +70,7 @@ class Keys:
|
|||
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
|
||||
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
|
||||
|
||||
class SSM:
|
||||
CONV_KERNEL = "{arch}.ssm.conv_kernel"
|
||||
|
@ -141,6 +148,7 @@ class MODEL_ARCH(IntEnum):
|
|||
DBRX = auto()
|
||||
OLMO = auto()
|
||||
ARCTIC = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
|
@ -189,6 +197,12 @@ class MODEL_TENSOR(IntEnum):
|
|||
SSM_C_NORM = auto()
|
||||
SSM_D = auto()
|
||||
SSM_OUT = auto()
|
||||
ATTN_Q_A = auto()
|
||||
ATTN_Q_B = auto()
|
||||
ATTN_KV_A_MQA = auto()
|
||||
ATTN_KV_B = auto()
|
||||
ATTN_Q_A_NORM = auto()
|
||||
ATTN_KV_A_NORM = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
|
@ -226,6 +240,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.DBRX: "dbrx",
|
||||
MODEL_ARCH.OLMO: "olmo",
|
||||
MODEL_ARCH.ARCTIC: "arctic",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
|
@ -274,6 +289,12 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
|
||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
|
||||
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
|
||||
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
|
@ -793,6 +814,33 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.DEEPSEEK2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_A,
|
||||
MODEL_TENSOR.ATTN_Q_B,
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA,
|
||||
MODEL_TENSOR.ATTN_KV_B,
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM,
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
@ -826,6 +874,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.DEEPSEEK2: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
}
|
||||
|
||||
#
|
||||
|
|
|
@ -12,6 +12,8 @@ from typing import Any, Literal, NamedTuple, TypeVar, Union
|
|||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from .quants import quant_shape_to_byte_shape
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
@ -251,6 +253,7 @@ class GGUFReader:
|
|||
tensor_names.add(tensor_name)
|
||||
ggml_type = GGMLQuantizationType(raw_dtype[0])
|
||||
n_elems = int(np.prod(dims))
|
||||
np_dims = tuple(reversed(dims.tolist()))
|
||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||
n_bytes = n_elems * type_size // block_size
|
||||
data_offs = int(start_offs + offset_tensor[0])
|
||||
|
@ -279,6 +282,7 @@ class GGUFReader:
|
|||
else:
|
||||
item_count = n_bytes
|
||||
item_type = np.uint8
|
||||
np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
|
||||
tensors.append(ReaderTensor(
|
||||
name = tensor_name,
|
||||
tensor_type = ggml_type,
|
||||
|
@ -286,7 +290,7 @@ class GGUFReader:
|
|||
n_elements = n_elems,
|
||||
n_bytes = n_bytes,
|
||||
data_offset = data_offs,
|
||||
data = self._get(data_offs, item_type, item_count),
|
||||
data = self._get(data_offs, item_type, item_count).reshape(np_dims),
|
||||
field = field,
|
||||
))
|
||||
self.tensors = tensors
|
||||
|
|
|
@ -13,7 +13,6 @@ from string import ascii_letters, digits
|
|||
import numpy as np
|
||||
|
||||
from .constants import (
|
||||
GGML_QUANT_SIZES,
|
||||
GGUF_DEFAULT_ALIGNMENT,
|
||||
GGUF_MAGIC,
|
||||
GGUF_VERSION,
|
||||
|
@ -26,6 +25,8 @@ from .constants import (
|
|||
TokenType,
|
||||
)
|
||||
|
||||
from .quants import quant_shape_from_byte_shape
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -229,10 +230,7 @@ class GGUFWriter:
|
|||
else:
|
||||
dtype = raw_dtype
|
||||
if tensor_dtype == np.uint8:
|
||||
block_size, type_size = GGML_QUANT_SIZES[raw_dtype]
|
||||
if tensor_shape[-1] % type_size != 0:
|
||||
raise ValueError(f"Quantized tensor row size ({tensor_shape[-1]}) is not a multiple of {dtype.name} type size ({type_size})")
|
||||
tensor_shape = tuple(tensor_shape[:-1]) + (tensor_shape[-1] // type_size * block_size,)
|
||||
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
|
||||
n_dims = len(tensor_shape)
|
||||
self.ti_data += self._pack("I", n_dims)
|
||||
for i in range(n_dims):
|
||||
|
@ -376,9 +374,15 @@ class GGUFWriter:
|
|||
def add_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_leading_dense_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_feed_forward_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_expert_feed_forward_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_parallel_residual(self, use: bool) -> None:
|
||||
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
|
||||
|
||||
|
@ -412,6 +416,12 @@ class GGUFWriter:
|
|||
def add_expert_used_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_expert_shared_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_expert_weights_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
|
||||
|
||||
def add_layer_norm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
|
||||
|
||||
|
@ -421,6 +431,12 @@ class GGUFWriter:
|
|||
def add_causal_attention(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
|
||||
|
||||
def add_q_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_kv_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_pooling_type(self, value: PoolingType) -> None:
|
||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||
|
||||
|
@ -445,6 +461,9 @@ class GGUFWriter:
|
|||
def add_rope_scaling_finetuned(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_conv_kernel(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
from typing import Callable
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
|
@ -9,6 +9,20 @@ from .lazy import LazyNumpyTensor
|
|||
import numpy as np
|
||||
|
||||
|
||||
def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):
|
||||
block_size, type_size = GGML_QUANT_SIZES[quant_type]
|
||||
if shape[-1] % block_size != 0:
|
||||
raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
|
||||
return (*shape[:-1], shape[-1] // block_size * type_size)
|
||||
|
||||
|
||||
def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):
|
||||
block_size, type_size = GGML_QUANT_SIZES[quant_type]
|
||||
if shape[-1] % type_size != 0:
|
||||
raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})")
|
||||
return (*shape[:-1], shape[-1] // type_size * block_size)
|
||||
|
||||
|
||||
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
|
||||
def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
|
||||
n = n.astype(np.float32, copy=False).view(np.int32)
|
||||
|
|
|
@ -260,6 +260,7 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2
|
||||
),
|
||||
|
||||
# AWQ-activation gate
|
||||
|
@ -290,6 +291,7 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2
|
||||
),
|
||||
|
||||
# Feed-forward down
|
||||
|
@ -326,6 +328,7 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
|
@ -410,6 +413,30 @@ class TensorNameMap:
|
|||
"backbone.layers.{bid}.mixer.out_proj", # mamba
|
||||
"model.layers.{bid}.mamba.out_proj", # jamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A: (
|
||||
"model.layers.{bid}.self_attn.q_a_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_B: (
|
||||
"model.layers.{bid}.self_attn.q_b_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: (
|
||||
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_B: (
|
||||
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
|
||||
),
|
||||
}
|
||||
|
||||
# architecture-specific block mappings
|
||||
|
@ -442,7 +469,7 @@ class TensorNameMap:
|
|||
if tensor not in MODEL_TENSORS[arch]:
|
||||
continue
|
||||
# TODO: make this configurable
|
||||
n_experts = 128
|
||||
n_experts = 160
|
||||
for xid in range(n_experts):
|
||||
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
|
||||
self.mapping[tensor_name] = (tensor, tensor_name)
|
||||
|
|
|
@ -118,9 +118,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
|||
|
||||
for tensor in reader.tensors:
|
||||
total_bytes += tensor.n_bytes
|
||||
# Dimensions are written in reverse order, so flip them first
|
||||
shape = np.flipud(tensor.shape).tolist()
|
||||
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
|
||||
writer.add_tensor_info(tensor.name, tensor.data.shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
|
||||
|
||||
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue