model is loadable
This commit is contained in:
parent
cd806a7e88
commit
a75c5c4295
9 changed files with 371 additions and 39 deletions
12
Makefile
12
Makefile
|
@ -1120,6 +1120,7 @@ src/llama.o: \
|
|||
src/llama-vocab.h \
|
||||
src/llama-grammar.h \
|
||||
src/llama-sampling.h \
|
||||
src/llama-vision.h \
|
||||
src/unicode.h \
|
||||
include/llama.h \
|
||||
ggml/include/ggml-cuda.h \
|
||||
|
@ -1152,6 +1153,17 @@ src/llama-sampling.o: \
|
|||
include/llama.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
src/llama-vision.o: \
|
||||
src/llama-vision.cpp \
|
||||
src/llama-vision.h \
|
||||
include/llama.h \
|
||||
ggml/include/ggml-cuda.h \
|
||||
ggml/include/ggml-metal.h \
|
||||
ggml/include/ggml.h \
|
||||
ggml/include/ggml-alloc.h \
|
||||
ggml/include/ggml-backend.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
$(LIB_LLAMA): \
|
||||
$(OBJ_LLAMA) \
|
||||
$(LIB_GGML)
|
||||
|
|
|
@ -1583,6 +1583,9 @@ class LlamaModel(Model):
|
|||
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
|
||||
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
|
||||
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
|
||||
# TODO: should not hardcode these, but they are currently missing from config.json
|
||||
self.gguf_writer.add_vision_clip_max_position_embeddings(577)
|
||||
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)
|
||||
|
||||
@staticmethod
|
||||
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
|
||||
|
|
|
@ -195,6 +195,7 @@ class Keys:
|
|||
PROJECTION_TYPE = "vision.clip.projection_type"
|
||||
PROJECTION_DIM = "vision.clip.projection_dim"
|
||||
USE_GELU = "vision.clip.use_gelu"
|
||||
MAX_POS_EMBEDDING = "vision.clip.max_position_embeddings"
|
||||
HEAD_COUNT = "vision.clip.attention.head_count"
|
||||
LAYERNORM_EPS = "vision.clip.attention.layer_norm_epsilon"
|
||||
|
||||
|
|
|
@ -841,6 +841,9 @@ class GGUFWriter:
|
|||
def add_vision_clip_head_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Vision.Clip.HEAD_COUNT, value)
|
||||
|
||||
def add_vision_clip_max_position_embeddings(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Vision.Clip.MAX_POS_EMBEDDING, value)
|
||||
|
||||
def add_vision_clip_layer_norm_epsilon(self, value: float) -> None:
|
||||
self.add_float32(Keys.Vision.Clip.LAYERNORM_EPS, value)
|
||||
|
||||
|
|
|
@ -224,6 +224,20 @@ extern "C" {
|
|||
|
||||
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
||||
|
||||
// represent an RGB image
|
||||
// size of data must be equal to 3*nx*ny
|
||||
typedef struct llama_img {
|
||||
uint32_t nx;
|
||||
uint32_t ny;
|
||||
unsigned char * data;
|
||||
} llama_img;
|
||||
|
||||
// Input data for llama_vision_decode
|
||||
typedef struct llama_img_batch {
|
||||
int32_t n_imgs;
|
||||
llama_img * imgs;
|
||||
} llama_img_batch;
|
||||
|
||||
// Input data for llama_decode
|
||||
// A llama_batch object can contain input about one or many sequences
|
||||
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
||||
|
@ -875,6 +889,16 @@ extern "C" {
|
|||
// shape: [n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
//
|
||||
// Vision
|
||||
//
|
||||
|
||||
// encode image into embeddings
|
||||
LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, llama_img_batch * batch);
|
||||
|
||||
// get output embeddings, to be put into language batch
|
||||
LLAMA_API float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx);
|
||||
|
||||
//
|
||||
// Vocab
|
||||
//
|
||||
|
|
|
@ -17,6 +17,7 @@ add_library(llama
|
|||
llama-vocab.cpp
|
||||
llama-grammar.cpp
|
||||
llama-sampling.cpp
|
||||
llama-vision.cpp
|
||||
unicode.h
|
||||
unicode.cpp
|
||||
unicode-data.cpp
|
||||
|
|
5
src/llama-vision.cpp
Normal file
5
src/llama-vision.cpp
Normal file
|
@ -0,0 +1,5 @@
|
|||
#include "llama.h"
|
||||
|
||||
#include "llama-vision.h"
|
||||
|
||||
|
91
src/llama-vision.h
Normal file
91
src/llama-vision.h
Normal file
|
@ -0,0 +1,91 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
enum vision_arch {
|
||||
VISION_ARCH_LLAVA,
|
||||
VISION_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
enum mm_patch_merge {
|
||||
MM_PATCH_MERGE_FLAT,
|
||||
MM_PATCH_MERGE_SPATIAL_UNPAD,
|
||||
};
|
||||
|
||||
struct clip_hparams {
|
||||
vision_arch arch = VISION_ARCH_UNKNOWN;
|
||||
|
||||
uint32_t image_size;
|
||||
uint32_t patch_size;
|
||||
uint32_t hidden_size;
|
||||
uint32_t n_intermediate;
|
||||
uint32_t projection_dim;
|
||||
uint32_t n_head;
|
||||
uint32_t n_layer;
|
||||
uint32_t max_pos_embd;
|
||||
|
||||
float eps;
|
||||
|
||||
mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_FLAT;
|
||||
|
||||
int32_t image_grid_pinpoints[32];
|
||||
int32_t image_crop_resolution;
|
||||
};
|
||||
|
||||
struct clip_layer {
|
||||
// attention
|
||||
struct ggml_tensor * k_w;
|
||||
struct ggml_tensor * k_b;
|
||||
struct ggml_tensor * q_w;
|
||||
struct ggml_tensor * q_b;
|
||||
struct ggml_tensor * v_w;
|
||||
struct ggml_tensor * v_b;
|
||||
|
||||
struct ggml_tensor * output_w;
|
||||
struct ggml_tensor * output_b;
|
||||
|
||||
// layernorm 1
|
||||
struct ggml_tensor * norm_in_w;
|
||||
struct ggml_tensor * norm_in_b;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * ffn_up_w;
|
||||
struct ggml_tensor * ffn_up_b;
|
||||
|
||||
struct ggml_tensor * ffn_down_w;
|
||||
struct ggml_tensor * ffn_down_b;
|
||||
|
||||
// layernorm 2
|
||||
struct ggml_tensor * norm_out_w;
|
||||
struct ggml_tensor * norm_out_b;
|
||||
};
|
||||
|
||||
struct clip_vision_model {
|
||||
struct clip_hparams hparams;
|
||||
|
||||
// embeddings
|
||||
struct ggml_tensor * class_embedding;
|
||||
struct ggml_tensor * patch_embeddings;
|
||||
struct ggml_tensor * patch_bias;
|
||||
struct ggml_tensor * position_embeddings;
|
||||
|
||||
struct ggml_tensor * pre_norm_w;
|
||||
struct ggml_tensor * pre_norm_b;
|
||||
|
||||
std::vector<clip_layer> layers;
|
||||
|
||||
struct ggml_tensor * post_norm_w;
|
||||
struct ggml_tensor * post_norm_b;
|
||||
|
||||
struct ggml_tensor * projection;
|
||||
|
||||
// LLaVA projection
|
||||
struct ggml_tensor * mm_a_w = NULL;
|
||||
struct ggml_tensor * mm_a_b = NULL;
|
||||
struct ggml_tensor * mm_b_w = NULL;
|
||||
struct ggml_tensor * mm_b_b = NULL;
|
||||
|
||||
struct ggml_tensor * image_newline = NULL;
|
||||
};
|
270
src/llama.cpp
270
src/llama.cpp
|
@ -1,6 +1,7 @@
|
|||
#include "llama-impl.h"
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-sampling.h"
|
||||
#include "llama-vision.h"
|
||||
|
||||
#include "unicode.h"
|
||||
|
||||
|
@ -273,6 +274,11 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
static const std::map<vision_arch, const char *> VISION_ARCH_NAMES = {
|
||||
{ VISION_ARCH_LLAVA, "llava" },
|
||||
{ VISION_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
enum llm_kv {
|
||||
LLM_KV_GENERAL_TYPE,
|
||||
LLM_KV_GENERAL_ARCHITECTURE,
|
||||
|
@ -379,6 +385,24 @@ enum llm_kv {
|
|||
|
||||
LLM_KV_ADAPTER_TYPE,
|
||||
LLM_KV_ADAPTER_LORA_ALPHA,
|
||||
|
||||
// TODO: these are vision-related KV, probably should be moved to a new enum
|
||||
LLM_KV_VISION_TYPE,
|
||||
LLM_KV_VISION_IMAGE_SIZE,
|
||||
LLM_KV_VISION_PATCH_SIZE,
|
||||
LLM_KV_VISION_IMAGE_MEAN,
|
||||
LLM_KV_VISION_IMAGE_STD,
|
||||
LLM_KV_VISION_CLIP_ARCHITECTURE,
|
||||
LLM_KV_VISION_CLIP_CONTEXT_LENGTH,
|
||||
LLM_KV_VISION_CLIP_EMBEDDING_LENGTH,
|
||||
LLM_KV_VISION_CLIP_BLOCK_COUNT,
|
||||
LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH,
|
||||
LLM_KV_VISION_CLIP_PROJECTION_TYPE,
|
||||
LLM_KV_VISION_CLIP_PROJECTION_DIM,
|
||||
LLM_KV_VISION_CLIP_USE_GELU,
|
||||
LLM_KV_VISION_CLIP_HEAD_COUNT,
|
||||
LLM_KV_VISION_CLIP_MAX_POS_EMBD,
|
||||
LLM_KV_VISION_CLIP_LAYERNORM_EPS,
|
||||
};
|
||||
|
||||
static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
|
@ -487,6 +511,23 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
|
||||
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
|
||||
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
|
||||
|
||||
{ LLM_KV_VISION_TYPE, "vision.type" },
|
||||
{ LLM_KV_VISION_IMAGE_SIZE, "vision.image_size" },
|
||||
{ LLM_KV_VISION_PATCH_SIZE, "vision.patch_size" },
|
||||
{ LLM_KV_VISION_IMAGE_MEAN, "vision.image_mean" },
|
||||
{ LLM_KV_VISION_IMAGE_STD, "vision.image_std" },
|
||||
{ LLM_KV_VISION_CLIP_ARCHITECTURE, "vision.clip.architecture" },
|
||||
{ LLM_KV_VISION_CLIP_CONTEXT_LENGTH, "vision.clip.context_length" },
|
||||
{ LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, "vision.clip.embedding_length" },
|
||||
{ LLM_KV_VISION_CLIP_BLOCK_COUNT, "vision.clip.block_count" },
|
||||
{ LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, "vision.clip.feed_forward_length" },
|
||||
{ LLM_KV_VISION_CLIP_PROJECTION_TYPE, "vision.clip.projection_type" },
|
||||
{ LLM_KV_VISION_CLIP_PROJECTION_DIM, "vision.clip.projection_dim" },
|
||||
{ LLM_KV_VISION_CLIP_USE_GELU, "vision.clip.use_gelu" },
|
||||
{ LLM_KV_VISION_CLIP_MAX_POS_EMBD, "vision.clip.max_position_embeddings" },
|
||||
{ LLM_KV_VISION_CLIP_HEAD_COUNT, "vision.clip.attention.head_count" },
|
||||
{ LLM_KV_VISION_CLIP_LAYERNORM_EPS, "vision.clip.attention.layer_norm_epsilon" },
|
||||
};
|
||||
|
||||
struct LLM_KV {
|
||||
|
@ -608,6 +649,24 @@ enum llm_tensor {
|
|||
LLM_TENSOR_ENC_OUTPUT_NORM,
|
||||
};
|
||||
|
||||
enum vision_tensor {
|
||||
VISION_TENSOR_MMPROJ_A,
|
||||
VISION_TENSOR_MMPROJ_B,
|
||||
VISION_TENSOR_ENC_EMBD_CLS,
|
||||
VISION_TENSOR_ENC_EMBD_PATCH,
|
||||
VISION_TENSOR_ENC_EMBD_POS,
|
||||
VISION_TENSOR_ENC_ATTN_Q,
|
||||
VISION_TENSOR_ENC_ATTN_K,
|
||||
VISION_TENSOR_ENC_ATTN_V,
|
||||
VISION_TENSOR_ENC_INPUT_NORM,
|
||||
VISION_TENSOR_ENC_OUTPUT,
|
||||
VISION_TENSOR_ENC_OUTPUT_NORM,
|
||||
VISION_TENSOR_ENC_FFN_UP,
|
||||
VISION_TENSOR_ENC_FFN_DOWN,
|
||||
VISION_TENSOR_PRE_NORM,
|
||||
VISION_TENSOR_POST_NORM,
|
||||
};
|
||||
|
||||
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
||||
{
|
||||
LLM_ARCH_LLAMA,
|
||||
|
@ -1530,6 +1589,29 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|||
},
|
||||
};
|
||||
|
||||
static const std::map<vision_arch, std::map<vision_tensor, std::string>> VISION_TENSOR_NAMES = {
|
||||
{
|
||||
VISION_ARCH_LLAVA,
|
||||
{
|
||||
{ VISION_TENSOR_MMPROJ_A, "v.mmproj_a" },
|
||||
{ VISION_TENSOR_MMPROJ_B, "v.mmproj_b" },
|
||||
{ VISION_TENSOR_ENC_EMBD_CLS, "v.enc.embd.cls" },
|
||||
{ VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" },
|
||||
{ VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" },
|
||||
{ VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
|
||||
{ VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
|
||||
{ VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
|
||||
{ VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
|
||||
{ VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" },
|
||||
{ VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
|
||||
{ VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
|
||||
{ VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
|
||||
{ VISION_TENSOR_PRE_NORM, "v.pre_norm" },
|
||||
{ VISION_TENSOR_POST_NORM, "v.post_norm" },
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static llm_arch llm_arch_from_string(const std::string & name) {
|
||||
for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT
|
||||
if (kv.second == name) {
|
||||
|
@ -1540,6 +1622,49 @@ static llm_arch llm_arch_from_string(const std::string & name) {
|
|||
return LLM_ARCH_UNKNOWN;
|
||||
}
|
||||
|
||||
template<typename Tname, typename Ttensor>
|
||||
struct BASE_TN {
|
||||
Tname arch;
|
||||
std::map<Tname, std::map<Ttensor, std::string>> name_mapping;
|
||||
|
||||
BASE_TN(Tname arch, std::map<Tname, std::map<Ttensor, std::string>> name_mapping) : arch(arch), name_mapping(name_mapping) {}
|
||||
|
||||
std::string operator()(Ttensor tensor) const {
|
||||
if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) {
|
||||
return "__missing__";
|
||||
}
|
||||
return name_mapping.at(arch).at(tensor);
|
||||
}
|
||||
|
||||
std::string operator()(Ttensor tensor, const std::string & suffix) const {
|
||||
if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) {
|
||||
return "__missing__";
|
||||
}
|
||||
return name_mapping.at(arch).at(tensor) + "." + suffix;
|
||||
}
|
||||
|
||||
std::string operator()(Ttensor tensor, int bid) const {
|
||||
if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) {
|
||||
return "__missing__";
|
||||
}
|
||||
return ::format(name_mapping.at(arch).at(tensor).c_str(), bid);
|
||||
}
|
||||
|
||||
std::string operator()(Ttensor tensor, const std::string & suffix, int bid) const {
|
||||
if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) {
|
||||
return "__missing__";
|
||||
}
|
||||
return ::format(name_mapping.at(arch).at(tensor).c_str(), bid) + "." + suffix;
|
||||
}
|
||||
|
||||
std::string operator()(Ttensor tensor, const std::string & suffix, int bid, int xid) const {
|
||||
if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) {
|
||||
return "__missing__";
|
||||
}
|
||||
return ::format(name_mapping.at(arch).at(tensor).c_str(), bid, xid) + "." + suffix;
|
||||
}
|
||||
};
|
||||
|
||||
// helper to handle gguf constants
|
||||
// usage:
|
||||
//
|
||||
|
@ -1549,45 +1674,12 @@ static llm_arch llm_arch_from_string(const std::string & name) {
|
|||
// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias"
|
||||
// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight"
|
||||
//
|
||||
struct LLM_TN {
|
||||
LLM_TN(llm_arch arch) : arch(arch) {}
|
||||
struct LLM_TN : BASE_TN<llm_arch, llm_tensor> {
|
||||
LLM_TN(llm_arch arch) : BASE_TN(arch, LLM_TENSOR_NAMES) {}
|
||||
};
|
||||
|
||||
llm_arch arch;
|
||||
|
||||
std::string operator()(llm_tensor tensor) const {
|
||||
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
|
||||
return "__missing__";
|
||||
}
|
||||
return LLM_TENSOR_NAMES.at(arch).at(tensor);
|
||||
}
|
||||
|
||||
std::string operator()(llm_tensor tensor, const std::string & suffix) const {
|
||||
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
|
||||
return "__missing__";
|
||||
}
|
||||
return LLM_TENSOR_NAMES.at(arch).at(tensor) + "." + suffix;
|
||||
}
|
||||
|
||||
std::string operator()(llm_tensor tensor, int bid) const {
|
||||
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
|
||||
return "__missing__";
|
||||
}
|
||||
return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid);
|
||||
}
|
||||
|
||||
std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const {
|
||||
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
|
||||
return "__missing__";
|
||||
}
|
||||
return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid) + "." + suffix;
|
||||
}
|
||||
|
||||
std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const {
|
||||
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
|
||||
return "__missing__";
|
||||
}
|
||||
return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid, xid) + "." + suffix;
|
||||
}
|
||||
struct VISION_TN : BASE_TN<vision_arch, vision_tensor> {
|
||||
VISION_TN(vision_arch arch) : BASE_TN(arch, VISION_TENSOR_NAMES) {}
|
||||
};
|
||||
|
||||
//
|
||||
|
@ -2458,6 +2550,9 @@ struct llama_hparams {
|
|||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||
|
||||
bool has_vision = false;
|
||||
clip_hparams clip;
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
if (this->vocab_only != other.vocab_only) return true;
|
||||
if (this->n_vocab != other.n_vocab) return true;
|
||||
|
@ -2908,6 +3003,8 @@ struct llama_model {
|
|||
|
||||
std::vector<llama_layer> layers;
|
||||
|
||||
clip_vision_model clip;
|
||||
|
||||
llama_split_mode split_mode;
|
||||
int main_gpu;
|
||||
int n_gpu_layers;
|
||||
|
@ -5476,6 +5573,30 @@ static void llm_load_hparams(
|
|||
hparams.n_embd_head_v = 0;
|
||||
}
|
||||
|
||||
std::string vision_type;
|
||||
ml.get_key(LLM_KV_VISION_TYPE, vision_type, false);
|
||||
if (vision_type == "clip") {
|
||||
hparams.has_vision = true;
|
||||
ml.get_key(LLM_KV_VISION_IMAGE_SIZE, hparams.clip.image_size, true);
|
||||
ml.get_key(LLM_KV_VISION_PATCH_SIZE, hparams.clip.patch_size, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, hparams.clip.hidden_size, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, hparams.clip.n_layer, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, hparams.clip.n_intermediate, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, hparams.clip.n_head, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, hparams.clip.eps, true);
|
||||
// TODO: add image_std
|
||||
std::string arch;
|
||||
ml.get_key(LLM_KV_VISION_CLIP_ARCHITECTURE, arch, true);
|
||||
for (auto & it : VISION_ARCH_NAMES) {
|
||||
if (arch == it.second) {
|
||||
hparams.clip.arch = it.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (!vision_type.empty()) {
|
||||
throw std::runtime_error(format("unsupported vision type: %s", vision_type.c_str()));
|
||||
}
|
||||
|
||||
// arch-specific KVs
|
||||
switch (model.arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
|
@ -6123,6 +6244,15 @@ static void llm_load_hparams(
|
|||
default: (void)0;
|
||||
}
|
||||
|
||||
// arch-specific CLIP hparams
|
||||
switch (hparams.clip.arch) {
|
||||
case VISION_ARCH_LLAVA:
|
||||
{
|
||||
ml.get_key(LLM_KV_VISION_CLIP_MAX_POS_EMBD, hparams.clip.max_pos_embd, true);
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
model.ftype = ml.ftype;
|
||||
|
||||
if (hparams.f_max_alibi_bias > 0.0f) {
|
||||
|
@ -8811,7 +8941,69 @@ static bool llm_load_tensors(
|
|||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
throw std::runtime_error("unknown llm architecture");
|
||||
}
|
||||
}
|
||||
|
||||
// load tensors for vision model
|
||||
if (hparams.has_vision) {
|
||||
const int64_t n_layer = hparams.clip.n_layer;
|
||||
const int64_t n_embd = hparams.clip.hidden_size;
|
||||
const int64_t n_ff = hparams.clip.n_intermediate;
|
||||
const int64_t max_pos_embd = hparams.clip.max_pos_embd;
|
||||
const int64_t n_channel = 3; // always RGB
|
||||
const int64_t patch_size = hparams.clip.patch_size;
|
||||
const auto tn = VISION_TN(hparams.clip.arch);
|
||||
|
||||
ggml_context * ctx_vision = ctx_map.at(model.buft_input.buft); // TODO: make dedicated buft for vision
|
||||
auto ctx_for_layer = [&](int i) { return ctx_map.at(model.buft_layer[i].buft); };
|
||||
|
||||
model.clip.layers.resize(n_layer);
|
||||
|
||||
switch (hparams.clip.arch) {
|
||||
case VISION_ARCH_LLAVA:
|
||||
{
|
||||
model.clip.mm_a_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_A, "weight"), {n_embd, n_ff});
|
||||
model.clip.mm_a_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_A, "bias" ), {n_ff});
|
||||
model.clip.mm_b_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_B, "weight"), {n_ff, n_ff});
|
||||
model.clip.mm_b_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_B, "bias" ), {n_ff});
|
||||
|
||||
model.clip.class_embedding = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_CLS ), {n_embd});
|
||||
model.clip.patch_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_embd});
|
||||
model.clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_embd, max_pos_embd});
|
||||
|
||||
model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "weight"), {n_embd});
|
||||
model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd});
|
||||
model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd});
|
||||
model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd});
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
auto & layer = model.clip.layers[i];
|
||||
|
||||
layer.k_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd});
|
||||
layer.k_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_K, "bias" , i), {n_embd});
|
||||
layer.v_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd});
|
||||
layer.v_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_V, "bias" , i), {n_embd});
|
||||
layer.q_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
layer.q_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_Q, "bias" , i), {n_embd});
|
||||
|
||||
layer.ffn_up_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_FFN_UP, "bias" , i), {n_ff});
|
||||
layer.ffn_down_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_FFN_DOWN, "bias" , i), {n_embd});
|
||||
|
||||
layer.norm_in_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_INPUT_NORM, "weight", i), {n_embd});
|
||||
layer.norm_in_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_INPUT_NORM, "bias" , i), {n_embd});
|
||||
layer.norm_out_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "weight", i), {n_embd});
|
||||
layer.norm_out_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "bias" , i), {n_embd});
|
||||
|
||||
layer.output_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_OUTPUT, "weight", i), {n_embd, n_embd});
|
||||
layer.output_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_OUTPUT, "bias" , i), {n_embd});
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown vision architecture");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue