llama : add Mixtral support (#4406)
* convert : support Mixtral as LLAMA arch
* convert : fix n_ff typo
* llama : model loading
* ggml : sync latest ggml_mul_mat_id
* llama : update graph to support MoE
* llama : fix cur -> cur_expert
* llama : first working version
* llama : fix expert weighting in the FFN
* ggml : ggml_get_rows support 2D indexing [n_tokens, n_experts] (cpu only)
* ggml : add n_as argument to ggml_mul_mat_id
* ggml : fix ggml_get_rows to take into account ne02 / ne11
* metal : add more general support for ggml_get_rows + tests
* llama : add basic support for offloading moe with CUDA
* metal : add/mul/div use general kernel when src1 not cont
* metal : reduce the kernel launches for ggml_mul_mat_id
* ggml : get_rows : support non-contiguos tensors with gaps, generalize up to 3D
* ggml : update get_rows f16 and q
* cuda : support non-contiguous src1 in get_rows
* llama : offload missing ffn_moe_silu
* metal : fix ggml_get_rows to work with non-cont src1
* metal : add indirect mat-vec kernels for all quantization types
* llama : do not quantize expert gating tensors
* llama : add n_expert and n_expert_used to hparams + change quants
* test-backend-ops : add moe test
* cuda : fix get_rows when ncols is odd
* convert : determine n_ctx correctly
* metal : fix ggml_mul_mat_id for F32
* test-backend-ops : make experts more evenly probable (test_moe)
* test-backend-ops : cleanup, add moe test for batches
* test-backend-ops : add cpy from f32 -> all types test
* test-backend-ops : fix dequantize block offset
* llama : fix hard-coded number of experts
* test-backend-ops : simplify and disable slow tests to avoid CI timeout
* test-backend-ops : disable MOE test with thread sanitizer
* cuda : fix mul_mat_id with multi gpu
* convert : use 1e6 rope_freq_base for mixtral
* convert : fix style
* convert : support safetensors format
* gguf-py : bump version
* metal : add cpy f16 -> f32 kernel
* metal : fix binary ops for ne10 % 4 != 0
* test-backend-ops : add one more sum_rows test
* ggml : do not use BLAS with ggml_mul_mat_id
* convert-hf : support for mixtral-instruct (#4428)
* convert : typo fix, add additional hyperparameters, use LLaMA arch for Mixtral-instruct
* convert : use sentencepiece tokenizer for Mixtral-instruct
* convert : make flake8 happy
* metal : fix soft_max kernels
ref: 1914017863
* metal : limit kernels to not use more than the allowed threads
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Radek Pilar <github@mrkva.eu>
This commit is contained in:
parent
fecac45658
commit
799a1cb13b
14 changed files with 2370 additions and 395 deletions
202
llama.cpp
202
llama.cpp
|
@ -91,7 +91,8 @@
|
|||
#define LLAMA_ATTRIBUTE_FORMAT(...)
|
||||
#endif
|
||||
|
||||
#define LLAMA_MAX_NODES 8192
|
||||
#define LLAMA_MAX_NODES 8192
|
||||
#define LLAMA_MAX_EXPERTS 8
|
||||
|
||||
//
|
||||
// logging
|
||||
|
@ -231,6 +232,8 @@ enum llm_kv {
|
|||
LLM_KV_FEED_FORWARD_LENGTH,
|
||||
LLM_KV_USE_PARALLEL_RESIDUAL,
|
||||
LLM_KV_TENSOR_DATA_LAYOUT,
|
||||
LLM_KV_EXPERT_COUNT,
|
||||
LLM_KV_EXPERT_USED_COUNT,
|
||||
|
||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||
|
@ -281,6 +284,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
|
|||
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
|
||||
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
|
||||
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
|
||||
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
|
||||
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
|
@ -338,10 +343,14 @@ enum llm_tensor {
|
|||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_NORM_2,
|
||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_GATE,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_DOWN_EXP,
|
||||
LLM_TENSOR_FFN_GATE_EXP,
|
||||
LLM_TENSOR_FFN_UP_EXP,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
};
|
||||
|
@ -360,10 +369,14 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
|||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
||||
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -585,6 +598,10 @@ struct LLM_TN {
|
|||
std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const {
|
||||
return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix;
|
||||
}
|
||||
|
||||
std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const {
|
||||
return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid, xid) + "." + suffix;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
|
@ -1164,6 +1181,8 @@ struct llama_hparams {
|
|||
uint32_t n_layer;
|
||||
uint32_t n_rot;
|
||||
uint32_t n_ff;
|
||||
uint32_t n_expert = 0;
|
||||
uint32_t n_expert_used = 0;
|
||||
|
||||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
|
@ -1178,15 +1197,18 @@ struct llama_hparams {
|
|||
float f_max_alibi_bias;
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
if (this->vocab_only != other.vocab_only) return true;
|
||||
if (this->n_vocab != other.n_vocab) return true;
|
||||
if (this->n_ctx_train != other.n_ctx_train) return true;
|
||||
if (this->n_embd != other.n_embd) return true;
|
||||
if (this->n_head != other.n_head) return true;
|
||||
if (this->n_head_kv != other.n_head_kv) return true;
|
||||
if (this->n_layer != other.n_layer) return true;
|
||||
if (this->n_rot != other.n_rot) return true;
|
||||
if (this->n_ff != other.n_ff) return true;
|
||||
if (this->vocab_only != other.vocab_only) return true;
|
||||
if (this->n_vocab != other.n_vocab) return true;
|
||||
if (this->n_ctx_train != other.n_ctx_train) return true;
|
||||
if (this->n_embd != other.n_embd) return true;
|
||||
if (this->n_head != other.n_head) return true;
|
||||
if (this->n_head_kv != other.n_head_kv) return true;
|
||||
if (this->n_layer != other.n_layer) return true;
|
||||
if (this->n_rot != other.n_rot) return true;
|
||||
if (this->n_ff != other.n_ff) return true;
|
||||
if (this->n_expert != other.n_expert) return true;
|
||||
if (this->n_expert_used != other.n_expert_used) return true;
|
||||
|
||||
if (this->rope_finetuned != other.rope_finetuned) return true;
|
||||
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
|
||||
|
||||
|
@ -1268,6 +1290,12 @@ struct llama_layer {
|
|||
struct ggml_tensor * ffn_down; // w2
|
||||
struct ggml_tensor * ffn_up; // w3
|
||||
|
||||
// ff MoE
|
||||
struct ggml_tensor * ffn_gate_inp;
|
||||
struct ggml_tensor * ffn_gate_exp[LLAMA_MAX_EXPERTS];
|
||||
struct ggml_tensor * ffn_down_exp[LLAMA_MAX_EXPERTS];
|
||||
struct ggml_tensor * ffn_up_exp [LLAMA_MAX_EXPERTS];
|
||||
|
||||
// ff bias
|
||||
struct ggml_tensor * ffn_down_b; // b2
|
||||
struct ggml_tensor * ffn_up_b; // b3
|
||||
|
@ -2440,6 +2468,16 @@ static void llm_load_hparams(
|
|||
ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
|
||||
ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
|
||||
ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer);
|
||||
ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
|
||||
ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
|
||||
|
||||
GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
|
||||
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
|
||||
if (hparams.n_expert > 0) {
|
||||
GGML_ASSERT(hparams.n_expert_used > 0);
|
||||
} else {
|
||||
GGML_ASSERT(hparams.n_expert_used == 0);
|
||||
}
|
||||
|
||||
// n_head_kv is optional, default to n_head
|
||||
hparams.n_head_kv = hparams.n_head;
|
||||
|
@ -2871,6 +2909,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
|
||||
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
|
||||
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
|
||||
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
|
||||
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
|
||||
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
|
||||
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
|
||||
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
||||
|
@ -3025,9 +3065,26 @@ static void llm_load_tensors(
|
|||
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||
|
||||
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
layer.ffn_gate_inp = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, backend, false);
|
||||
|
||||
if (layer.ffn_gate_inp == nullptr) {
|
||||
GGML_ASSERT(hparams.n_expert == 0);
|
||||
GGML_ASSERT(hparams.n_expert_used == 0);
|
||||
|
||||
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
} else {
|
||||
GGML_ASSERT(hparams.n_expert > 0);
|
||||
GGML_ASSERT(hparams.n_expert_used > 0);
|
||||
|
||||
// MoE branch
|
||||
for (uint32_t x = 0; x < hparams.n_expert; ++x) {
|
||||
layer.ffn_gate_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}, backend_split);
|
||||
layer.ffn_down_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd}, backend_split);
|
||||
layer.ffn_up_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff}, backend_split);
|
||||
}
|
||||
}
|
||||
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights +=
|
||||
|
@ -3037,8 +3094,18 @@ static void llm_load_tensors(
|
|||
(layer.bk ? ggml_nbytes(layer.bk) : 0) +
|
||||
(layer.bv ? ggml_nbytes(layer.bv) : 0) +
|
||||
(layer.bo ? ggml_nbytes(layer.bo) : 0) +
|
||||
ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_gate) +
|
||||
ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
|
||||
ggml_nbytes(layer.ffn_norm);
|
||||
|
||||
if (layer.ffn_gate_inp == nullptr) {
|
||||
vram_weights +=
|
||||
ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
|
||||
} else {
|
||||
vram_weights += ggml_nbytes(layer.ffn_gate_inp);
|
||||
for (uint32_t x = 0; x < hparams.n_expert; ++x) {
|
||||
vram_weights +=
|
||||
ggml_nbytes(layer.ffn_gate_exp[x]) + ggml_nbytes(layer.ffn_down_exp[x]) + ggml_nbytes(layer.ffn_up_exp[x]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} break;
|
||||
|
@ -4019,6 +4086,8 @@ struct llm_build_context {
|
|||
const int64_t n_head_kv;
|
||||
const int64_t n_embd_head;
|
||||
const int64_t n_embd_gqa;
|
||||
const int64_t n_expert;
|
||||
const int64_t n_expert_used;
|
||||
|
||||
const float freq_base;
|
||||
const float freq_scale;
|
||||
|
@ -4060,6 +4129,8 @@ struct llm_build_context {
|
|||
n_head_kv (hparams.n_head_kv),
|
||||
n_embd_head (hparams.n_embd_head()),
|
||||
n_embd_gqa (hparams.n_embd_gqa()),
|
||||
n_expert (hparams.n_expert),
|
||||
n_expert_used (hparams.n_expert_used),
|
||||
freq_base (cparams.rope_freq_base),
|
||||
freq_scale (cparams.rope_freq_scale),
|
||||
ext_factor (cparams.yarn_ext_factor),
|
||||
|
@ -4184,7 +4255,7 @@ struct llm_build_context {
|
|||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
|
@ -4196,6 +4267,69 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_down, NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// MoE branch
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
||||
cb(logits, "ffn_moe_logits", il);
|
||||
|
||||
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
|
||||
cb(probs, "ffn_moe_probs", il);
|
||||
|
||||
// select experts
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
|
||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||
|
||||
ggml_tensor * weights = ggml_get_rows(ctx0,
|
||||
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
|
||||
cb(weights, "ffn_moe_weights", il);
|
||||
|
||||
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
|
||||
|
||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
|
||||
cb(weights_sum, "ffn_moe_weights_sum", il);
|
||||
|
||||
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
|
||||
cb(weights, "ffn_moe_weights_norm", il);
|
||||
|
||||
// compute expert outputs
|
||||
ggml_tensor * moe_out = nullptr;
|
||||
|
||||
for (int i = 0; i < n_expert_used; ++i) {
|
||||
ggml_tensor * cur_expert;
|
||||
|
||||
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur);
|
||||
cb(cur_up, "ffn_moe_up", il);
|
||||
|
||||
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur);
|
||||
cb(cur_gate, "ffn_moe_gate", il);
|
||||
|
||||
cur_gate = ggml_silu(ctx0, cur_gate);
|
||||
cb(cur_gate, "ffn_moe_silu", il);
|
||||
|
||||
cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
|
||||
cb(cur_expert, "ffn_moe_gate_par", il);
|
||||
|
||||
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd]
|
||||
cb(cur_expert, "ffn_moe_down", il);
|
||||
|
||||
cur_expert = ggml_mul(ctx0, cur_expert,
|
||||
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
|
||||
cb(cur_expert, "ffn_moe_weighted", il);
|
||||
|
||||
if (i == 0) {
|
||||
moe_out = cur_expert;
|
||||
} else {
|
||||
moe_out = ggml_add(ctx0, moe_out, cur_expert);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
}
|
||||
}
|
||||
|
||||
cur = moe_out;
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
@ -5450,6 +5584,20 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
|
|||
{ "ffn_relu", OFFLOAD_FUNC },
|
||||
{ "ffn_sqr(relu)", OFFLOAD_FUNC },
|
||||
|
||||
{ "ffn_moe_logits", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_probs", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_argsort", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_weights", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_weights_sum", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_weights_norm", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_weighted", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_up", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_gate", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_silu", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_gate_par", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_down", OFFLOAD_FUNC },
|
||||
{ "ffn_moe_out", OFFLOAD_FUNC },
|
||||
|
||||
{ "l_out", OFFLOAD_FUNC },
|
||||
|
||||
{ "result_norm", OFFLOAD_FUNC_EMB },
|
||||
|
@ -8067,11 +8215,9 @@ static void llama_convert_tensor_internal(
|
|||
workers.clear();
|
||||
}
|
||||
|
||||
static ggml_type get_k_quant_type(
|
||||
quantize_state_internal & qs,
|
||||
ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype
|
||||
) {
|
||||
static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
|
||||
// TODO: avoid hardcoded tensor names - use the TN_* constants
|
||||
const llm_arch arch = qs.model.arch;
|
||||
const auto tn = LLM_TN(arch);
|
||||
|
@ -8105,7 +8251,18 @@ static ggml_type get_k_quant_type(
|
|||
// nearly negligible increase in model size by quantizing this tensor with more bits:
|
||||
if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
|
||||
}
|
||||
if (qs.model.hparams.n_expert == 8) {
|
||||
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||
// TODO: explore better strategies
|
||||
new_type = GGML_TYPE_Q8_0;
|
||||
}
|
||||
++qs.i_attention_wv;
|
||||
} else if (name.find("attn_k.weight") != std::string::npos) {
|
||||
if (qs.model.hparams.n_expert == 8) {
|
||||
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||
// TODO: explore better strategies
|
||||
new_type = GGML_TYPE_Q8_0;
|
||||
}
|
||||
} else if (name.find("ffn_down.weight") != std::string::npos) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
|
||||
|
@ -8318,6 +8475,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
quantize &= params->quantize_output_tensor || name != "output.weight";
|
||||
quantize &= !params->only_copy;
|
||||
|
||||
// do not quantize expert gating tensors
|
||||
quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
|
||||
|
||||
enum ggml_type new_type;
|
||||
void * new_data;
|
||||
size_t new_size;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue