support glm-4-9b-chat

Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>
This commit is contained in:
XingXing Qiao 2024-06-17 10:08:52 +08:00
parent f3bc337f43
commit 1fc5bf5bcb
5 changed files with 116 additions and 7 deletions

View file

@ -4508,6 +4508,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 28: model.type = e_model::MODEL_7B; break;
case 40: model.type = e_model::MODEL_8B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
@ -4636,9 +4637,9 @@ static void llm_load_vocab(
if (merges_keyidx == -1) {
throw std::runtime_error("cannot find tokenizer merges in model file\n");
}
printf("merges_keyidx: %d\n", merges_keyidx);
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
printf("n_merges: %d\n", n_merges);
for (int i = 0; i < n_merges; i++) {
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
@ -4728,6 +4729,9 @@ static void llm_load_vocab(
} else if (
tokenizer_pre == "smaug-bpe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
} else if (
tokenizer_pre == "chatglm-bpe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}
@ -11449,7 +11453,7 @@ struct llm_build_context {
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
//printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
@ -13032,6 +13036,7 @@ struct llm_tokenizer_bpe {
break;
case LLAMA_VOCAB_PRE_TYPE_DBRX:
case LLAMA_VOCAB_PRE_TYPE_SMAUG:
case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
word_collection = unicode_regex_split(text, {
// same as llama3
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
@ -18741,6 +18746,15 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == "ChatGLM4") {
ss << "[gMASK]" << "<sop>";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>";
}
} else {
// template not supported
return -1;