add mobilevlm

This commit is contained in:
Xuan Son Nguyen 2025-01-19 16:29:20 +01:00
parent 6cabdda0df
commit d0068ef0ed
9 changed files with 210 additions and 61 deletions

View file

@ -67,6 +67,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
static const std::map<vision_arch, const char *> VISION_ARCH_NAMES = {
{ VISION_ARCH_LLAVA, "llava" },
{ VISION_ARCH_MOBILEVLM, "mobilevlm" },
{ VISION_ARCH_UNKNOWN, "(unknown)" },
};
@ -1345,7 +1346,27 @@ static const std::map<vision_arch, std::map<vision_tensor, const char *>> VISION
{ VISION_TENSOR_PRE_NORM, "v.pre_norm" },
{ VISION_TENSOR_POST_NORM, "v.post_norm" },
}
}
},
{
VISION_ARCH_MOBILEVLM,
{
{ VISION_TENSOR_MMPROJ_MLP, "v.mmproj.mlp.%d" },
{ VISION_TENSOR_MMPROJ_PEG, "v.mmproj.peg.%d" },
{ VISION_TENSOR_ENC_EMBD_CLS, "v.enc.embd.cls" },
{ VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" },
{ VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" },
{ VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
{ VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
{ VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
{ VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
{ VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" },
{ VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
{ VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
{ VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
{ VISION_TENSOR_PRE_NORM, "v.pre_norm" },
{ VISION_TENSOR_POST_NORM, "v.post_norm" },
}
},
};
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
@ -1499,6 +1520,10 @@ std::string LLM_KV::operator()(llm_kv kv) const {
template<>
std::string BASE_TN_IMPL<llm_arch, llm_tensor>::str() const {
if (LLM_TENSOR_NAMES.find(arch) == LLM_TENSOR_NAMES.end()) {
throw std::runtime_error(format("Cannot find tensor name mapping for arch %d", arch));
}
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
return "__missing__";
}
@ -1515,6 +1540,10 @@ std::string BASE_TN_IMPL<llm_arch, llm_tensor>::str() const {
template<>
std::string BASE_TN_IMPL<vision_arch, vision_tensor>::str() const {
if (VISION_TENSOR_NAMES.find(arch) == VISION_TENSOR_NAMES.end()) {
throw std::runtime_error(format("Cannot find tensor name mapping for arch %d", arch));
}
if (VISION_TENSOR_NAMES.at(arch).find(tensor) == VISION_TENSOR_NAMES.at(arch).end()) {
return "__missing__";
}