for new minicpm

This commit is contained in:
zhangkaihuo 2024-03-30 10:34:40 +08:00
parent cfc4d75df6
commit e913ac9c38
4 changed files with 8 additions and 3 deletions

View file

@ -1559,6 +1559,7 @@ class InternLM2Model(Model):
self.gguf_writer.add_add_space_prefix(add_prefix)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
print(special_vocab)
old_eos = special_vocab.special_token_ids["eos"]
if "chat" in os.path.basename(self.dir_model.absolute()):
# For the chat model, we replace the eos with '<|im_end|>'.

View file

@ -795,7 +795,9 @@ int main(int argc, char ** argv) {
}
// deal with end of text token in interactive mode
if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
auto last_token = llama_sampling_last(ctx_sampling);
if (last_token == llama_token_eos(model) || last_token == 122753)
{
LOG("found EOS token\n");
if (params.interactive) {
@ -920,7 +922,7 @@ int main(int argc, char ** argv) {
}
// end of text token
if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) {
if (!embd.empty() && (embd.back() == llama_token_eos(model) || embd.back() == 122753) && !(params.instruct || params.interactive || params.chatml)) {
LOG_TEE(" [end of text]\n");
break;
}

View file

@ -548,6 +548,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
],
MODEL_ARCH.MINICPM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,

View file

@ -4375,6 +4375,7 @@ static bool llm_load_tensors(
case LLM_ARCH_MINICPM:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
// output
{
@ -8699,7 +8700,7 @@ struct llm_build_context {
cb(cur, "lmhead_scaling", -1);
// lm_head
cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);