feat: initial support for llama.cpp
This commit is contained in:
parent
2339a0be1c
commit
59a08be755
5 changed files with 55 additions and 0 deletions
|
@ -106,6 +106,7 @@ Typically finetunes of the base models below are supported as well.
|
||||||
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
|
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
|
||||||
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
|
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
|
||||||
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
|
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
|
||||||
|
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
|
||||||
|
|
||||||
(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
|
(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
|
||||||
|
|
||||||
|
|
|
@ -2756,6 +2756,7 @@ class MambaModel(Model):
|
||||||
self.gguf_writer.add_ssm_state_size(d_state)
|
self.gguf_writer.add_ssm_state_size(d_state)
|
||||||
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||||
|
self.gguf_writer.add_mamba_b_dt_rms(False) # For classic Mamba we don't apply rms norm on B / DT layers
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
_tok_embd = None
|
_tok_embd = None
|
||||||
|
@ -3854,6 +3855,39 @@ class ExaoneModel(Model):
|
||||||
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
||||||
|
|
||||||
super().prepare_tensors()
|
super().prepare_tensors()
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("FalconMambaForCausalLM")
|
||||||
|
class FalconMambaModel(MambaModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
d_model = self.find_hparam(["hidden_size", "d_model"])
|
||||||
|
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
|
||||||
|
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
|
||||||
|
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
|
||||||
|
# ceiling division
|
||||||
|
# ref: https://stackoverflow.com/a/17511341/22827863
|
||||||
|
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
|
||||||
|
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
|
||||||
|
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||||
|
|
||||||
|
# Fail early for models which don't have a block expansion factor of 2
|
||||||
|
assert d_inner == 2 * d_model
|
||||||
|
|
||||||
|
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
|
||||||
|
self.gguf_writer.add_embedding_length(d_model)
|
||||||
|
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
|
||||||
|
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
|
||||||
|
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
|
||||||
|
self.gguf_writer.add_ssm_conv_kernel(d_conv)
|
||||||
|
self.gguf_writer.add_mamba_b_dt_rms(True) # For FalconMamba we do apply rms norm on B / DT layers
|
||||||
|
self.gguf_writer.add_ssm_inner_size(d_inner)
|
||||||
|
self.gguf_writer.add_ssm_state_size(d_state)
|
||||||
|
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
||||||
|
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||||
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
|
@ -130,6 +130,7 @@ class Keys:
|
||||||
INNER_SIZE = "{arch}.ssm.inner_size"
|
INNER_SIZE = "{arch}.ssm.inner_size"
|
||||||
STATE_SIZE = "{arch}.ssm.state_size"
|
STATE_SIZE = "{arch}.ssm.state_size"
|
||||||
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
||||||
|
B_DT_RMS = "{arch}.ssm.b_dt_rms"
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
MODEL = "tokenizer.ggml.model"
|
MODEL = "tokenizer.ggml.model"
|
||||||
|
@ -1372,6 +1373,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
|
||||||
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
|
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
|
||||||
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
|
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
|
||||||
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
|
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
|
||||||
|
KEY_SSM_B_DT_RMS = Keys.SSM.B_DT_RMS
|
||||||
|
|
||||||
# tokenization
|
# tokenization
|
||||||
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
|
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
|
||||||
|
|
|
@ -714,6 +714,9 @@ class GGUFWriter:
|
||||||
|
|
||||||
def add_rope_scaling_finetuned(self, value: bool) -> None:
|
def add_rope_scaling_finetuned(self, value: bool) -> None:
|
||||||
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
|
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_mamba_b_dt_rms(self, value: bool) -> None:
|
||||||
|
self.add_bool(Keys.SSM.B_DT_RMS.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
|
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
|
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
|
||||||
|
|
|
@ -328,6 +328,7 @@ enum llm_kv {
|
||||||
LLM_KV_SSM_CONV_KERNEL,
|
LLM_KV_SSM_CONV_KERNEL,
|
||||||
LLM_KV_SSM_STATE_SIZE,
|
LLM_KV_SSM_STATE_SIZE,
|
||||||
LLM_KV_SSM_TIME_STEP_RANK,
|
LLM_KV_SSM_TIME_STEP_RANK,
|
||||||
|
LLM_KV_SSM_B_DT_RMS,
|
||||||
|
|
||||||
LLM_KV_TOKENIZER_MODEL,
|
LLM_KV_TOKENIZER_MODEL,
|
||||||
LLM_KV_TOKENIZER_PRE,
|
LLM_KV_TOKENIZER_PRE,
|
||||||
|
@ -426,6 +427,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
|
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
|
||||||
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
|
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
|
||||||
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
|
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
|
||||||
|
{ LLM_KV_SSM_B_DT_RMS, "%s.ssm.b_dt_rms" },
|
||||||
|
|
||||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||||
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
||||||
|
@ -2237,6 +2239,7 @@ struct llama_hparams {
|
||||||
uint32_t ssm_d_inner = 0;
|
uint32_t ssm_d_inner = 0;
|
||||||
uint32_t ssm_d_state = 0;
|
uint32_t ssm_d_state = 0;
|
||||||
uint32_t ssm_dt_rank = 0;
|
uint32_t ssm_dt_rank = 0;
|
||||||
|
bool ssm_b_dt_rms = false;
|
||||||
|
|
||||||
float f_clamp_kqv = 0.0f;
|
float f_clamp_kqv = 0.0f;
|
||||||
float f_max_alibi_bias = 0.0f;
|
float f_max_alibi_bias = 0.0f;
|
||||||
|
@ -5052,6 +5055,7 @@ static void llm_load_hparams(
|
||||||
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||||
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
||||||
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||||
|
ml.get_key(LLM_KV_SSM_B_DT_RMS, hparams.ssm_b_dt_rms);
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
|
@ -12161,6 +12165,10 @@ struct llm_build_context {
|
||||||
GGML_ASSERT(2 * d_model == d_inner);
|
GGML_ASSERT(2 * d_model == d_inner);
|
||||||
const int64_t d_state = hparams.ssm_d_state;
|
const int64_t d_state = hparams.ssm_d_state;
|
||||||
const int64_t dt_rank = hparams.ssm_dt_rank;
|
const int64_t dt_rank = hparams.ssm_dt_rank;
|
||||||
|
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
|
||||||
|
const bool ssm_b_dt_rms = hparams.ssm_b_dt_rms;
|
||||||
|
// Use the same RMS norm as the final layer norm
|
||||||
|
const float norm_rms_eps = hparams.f_norm_rms_eps;
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
struct ggml_tensor * inpL;
|
struct ggml_tensor * inpL;
|
||||||
|
@ -12241,6 +12249,13 @@ struct llm_build_context {
|
||||||
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
|
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
|
||||||
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
|
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
|
||||||
|
|
||||||
|
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
|
||||||
|
if (ssm_b_dt_rms) {
|
||||||
|
dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
|
||||||
|
B = ggml_rms_norm(ctx0, B, norm_rms_eps);
|
||||||
|
C = ggml_rms_norm(ctx0, C, norm_rms_eps);
|
||||||
|
}
|
||||||
|
|
||||||
// {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
|
// {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
|
||||||
dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
|
dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
|
||||||
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
|
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue