From 0a81051ae2c7c881fc1ce74f95f28cba977fbe9b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 18 Jan 2025 20:56:35 +0100 Subject: [PATCH] llama : second attempt to refactor vision API --- common/arg.cpp | 2 +- common/common.h | 1 + examples/CMakeLists.txt | 1 + examples/vision/CMakeLists.txt | 5 + examples/vision/README.md | 3 + examples/vision/vision.cpp | 211 +++++++++++++++++++++++++++++++++ include/llama.h | 26 +++- src/CMakeLists.txt | 1 + src/llama-arch.cpp | 77 +++++++++++- src/llama-arch.h | 71 ++++++++++- src/llama-batch.cpp | 42 ++++++- src/llama-batch.h | 2 + src/llama-context.cpp | 2 +- src/llama-context.h | 4 + src/llama-model-loader.cpp | 2 + src/llama-model.cpp | 114 ++++++++++++++++++ src/llama-model.h | 5 + src/llama-vision.cpp | 164 ++++++++++++------------- src/llama-vision.h | 87 +++++++------- src/llama.cpp | 20 +++- 20 files changed, 695 insertions(+), 145 deletions(-) create mode 100644 examples/vision/CMakeLists.txt create mode 100644 examples/vision/README.md create mode 100644 examples/vision/vision.cpp diff --git a/common/arg.cpp b/common/arg.cpp index 9069950eb..710b61c6d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1403,7 +1403,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.image.emplace_back(value); } - ).set_examples({LLAMA_EXAMPLE_LLAVA})); + ).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_VISION})); if (llama_supports_rpc()) { add_opt(common_arg( {"--rpc"}, "SERVERS", diff --git a/common/common.h b/common/common.h index 691141d6b..8fc982cf5 100644 --- a/common/common.h +++ b/common/common.h @@ -79,6 +79,7 @@ enum llama_example { LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_PARALLEL, LLAMA_EXAMPLE_TTS, + LLAMA_EXAMPLE_VISION, LLAMA_EXAMPLE_COUNT, }; diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 66cfab2c3..41d968ed6 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -53,6 +53,7 @@ else() add_subdirectory(tokenize) add_subdirectory(tts) add_subdirectory(gen-docs) + add_subdirectory(vision) if (NOT GGML_BACKEND_DL) # these examples use the backends directly and cannot be built with dynamic loading add_subdirectory(convert-llama2c-to-ggml) diff --git a/examples/vision/CMakeLists.txt b/examples/vision/CMakeLists.txt new file mode 100644 index 000000000..ab009157a --- /dev/null +++ b/examples/vision/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-vision) +add_executable(${TARGET} vision.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/vision/README.md b/examples/vision/README.md new file mode 100644 index 000000000..c2468444c --- /dev/null +++ b/examples/vision/README.md @@ -0,0 +1,3 @@ +# llama.cpp/example/simple-vision + +Minimal demo for vision API diff --git a/examples/vision/vision.cpp b/examples/vision/vision.cpp new file mode 100644 index 000000000..73f8ef1b6 --- /dev/null +++ b/examples/vision/vision.cpp @@ -0,0 +1,211 @@ +#include "llama.h" +#include "common.h" +#include "arg.h" +#include "log.h" +#include "sampling.h" +#include +#include +#include +#include +#include + +#define STB_IMAGE_IMPLEMENTATION +#include "stb_image.h" + +static void print_usage(int, char ** argv) { + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers] [--image img_path] [-p prompt]\n", argv[0]); + printf("\n"); +} + +static llama_vision_bitmap * load_image_from_file(const char * fname) { + std::ifstream file(fname, std::ios::binary); + if (!file) { + throw std::runtime_error("Unable to open file"); + } + std::vector image_bytes = std::vector( + std::istreambuf_iterator(file), + std::istreambuf_iterator()); + // decode image to byte array + int nx, ny, nc; + auto * bytes = (unsigned char *) image_bytes.data(); + auto * img = stbi_load_from_memory(bytes, image_bytes.size(), &nx, &ny, &nc, 3); + if (!img) { + throw std::runtime_error("failed to decode image bytes"); + } + // printf("nx=%d ny=%d nc=%d\n", nx, ny, nc); + // GGML_ASSERT(nc == 3); + // for (int y = 0; y < ny; y++) { + // for (int x = 0; x < nx; x++) { + // unsigned char * pix = img + x*nc + y*nc*nx; + // printf("%02x%02x%02x ", pix[0], pix[1], pix[2]); + // } + // printf("\n"); + // } + // printf("\n"); + llama_vision_bitmap * result = llama_vision_bitmap_init(nx, ny); + memcpy(result->data, img, nx*ny*3); + stbi_image_free(img); + return result; +} + +// split string by a `std::string delim` instead of `char delim` +static std::vector string_split(std::string s, const std::string & delimiter) { + std::vector tokens; + size_t pos = 0; + std::string token; + while ((pos = s.find(delimiter)) != std::string::npos) { + token = s.substr(0, pos); + tokens.push_back(token); + s.erase(0, pos + delimiter.length()); + } + tokens.push_back(s); + return tokens; +} + +struct tokenized_part { + llama_tokens tokens; + bool is_image; +}; + +// TODO: this function is hacky, need to be improved +// static const llama_token TOKEN_IMG_PLACEMENT = -1000; +static const std::string IMG_PLACEMENT = ""; +static std::vector tokenize_with_img_placement( + const llama_vocab * vocab, + const std::string & text, + bool add_special, + bool parse_special) { + std::vector parts = string_split(text, IMG_PLACEMENT); + std::vector output; + for (const auto & part : parts) { + //printf("tokenizing part: %s\n", part.c_str()); + bool add_bos = &parts.front() == ∂ + auto tokens = common_tokenize(vocab, part, add_special && add_bos, parse_special); + if (tokens.empty()) { + continue; + } + output.push_back({std::move(tokens), false}); + if (&parts.back() != &part) { + // add image token to middle of 2 parts + output.push_back({{}, true}); + } + } + return output; +} + +int main(int argc, char ** argv) { + common_params params; + + // default prompt for llava 1.5 + params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n" + "USER:\nwhat did you see?\nASSISTANT:"; + params.n_predict = 64; + params.n_batch = 2048; + params.n_ubatch = 1024; + params.n_gpu_layers = 99; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_VISION, print_usage)) { + return 1; + } + + common_init(); + common_init_result llama_init = common_init_from_params(params); + llama_context * ctx = llama_init.context.get(); + const llama_model * model = llama_init.model.get(); + const llama_vocab * vocab = llama_model_get_vocab(model); + + struct common_sampler * smpl = common_sampler_init(model, params.sampling); + + llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); + int n_past = 0; + int n_prompt = 0; + + // process image + llama_vision_patches * img_patches = nullptr; + { + const char * img_path = params.image[0].c_str(); + if (params.image[0].empty()) { + LOG_ERR("no image path provided\n"); + return 1; + } + llama_vision_bitmap * img = load_image_from_file(img_path); + LOG_INF("loaded image %s, size = %d x %d\n", img_path, img->nx, img->ny); + img_patches = llama_vision_patches_init(ctx, img); + if (!img_patches) { + LOG_ERR("failed to create image patches\n"); + return 1; + } + if (llama_vision_encode(ctx, img_patches)) { + LOG_ERR("failed to encode image\n"); + return 1; + } + LOG_INF("encoded image\n"); + } + + // process prompt + { + std::vector parts = tokenize_with_img_placement(vocab, params.prompt, true, true); + for (const tokenized_part & part : parts) { + if (!part.is_image) { + for (const llama_token & token : part.tokens) { + //LOG_INF("%d -> %s\n", token, common_token_to_piece(ctx, token).c_str()); + common_batch_add(batch, token, n_past++, {0}, &part == &parts.back()); + } + LOG_INF("eval text batch (%d tokens)\n", batch.n_tokens); + if (llama_decode(ctx, batch)) { + LOG_ERR("failed to decode text prompt\n"); + return 1; + } + } else { + auto * img_embd = llama_vision_get_output_tensor(ctx); + // std::vector output_debug(ggml_nelements(img_embd)); + // ggml_backend_tensor_get(img_embd, output_debug.data(), 0, ggml_nbytes(img_embd)); + // for (int row = 0; row < 10; row++) { + // int off = row * img_embd->ne[0]; + // printf("... %f %f %f\n", output_debug[off], output_debug[off+1], output_debug[off+2]); + // } + // exit(1); + llama_batch batch_img = llama_batch_get_one_from_tensor(img_embd, n_past, 0); + n_past += batch_img.n_tokens; + LOG_INF("eval image batch (%d embeddings)\n", batch_img.n_tokens); + if (llama_decode(ctx, batch_img)) { + LOG_ERR("failed to decode image prompt\n"); + return 1; + } + llama_batch_free(batch_img); + } + } + n_prompt = n_past; + LOG_INF("prompt processed, %d tokens\n", n_prompt); + } + + // generate response + while (true){ + int n_generated = n_past - n_prompt; + if (n_generated > params.n_predict) { + printf("\n"); + break; + } + + llama_token token_id = common_sampler_sample(smpl, ctx, -1); + common_sampler_accept(smpl, token_id, true); + printf("%s", common_token_to_piece(ctx, token_id).c_str()); + fflush(stdout); + + if (llama_vocab_is_eog(vocab, token_id)) { + printf("\n"); + break; + } + + // eval the token + common_batch_clear(batch); + common_batch_add(batch, token_id, n_past++, {0}, true); + if (llama_decode(ctx, batch)) { + LOG_ERR("failed to decode token\n"); + break; + } + } + + return 0; +} diff --git a/include/llama.h b/include/llama.h index 6049d2382..5013e96e7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -229,6 +229,8 @@ extern "C" { bool sorted; } llama_token_data_array; + struct llama_vision_patches; + // represent an RGB image // size of data must be equal to 3*nx*ny typedef struct llama_vision_bitmap { @@ -237,8 +239,6 @@ extern "C" { unsigned char * data; } llama_vision_bitmap; - struct llama_vision_patches; - typedef bool (*llama_progress_callback)(float progress, void * user_data); // Input data for llama_decode @@ -263,6 +263,8 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" + + struct ggml_tensor * embd_tensor; } llama_batch; enum llama_model_kv_override_type { @@ -854,6 +856,10 @@ extern "C" { int32_t embd, int32_t n_seq_max); + // Allocates a batch based on a tensor, only used by vision API for now + // Unlike llama_batch_get_one, this will need to be freed after use + LLAMA_API struct llama_batch llama_batch_get_one_from_tensor(struct ggml_tensor * tensor, int32_t p0, int32_t seq_id); + // Frees a batch of tokens allocated with llama_batch_init() LLAMA_API void llama_batch_free(struct llama_batch batch); @@ -1272,6 +1278,22 @@ extern "C" { // TODO: extend in the future //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...); + // + // Vision API + // + + // Container for RGB bitmap + LLAMA_API struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny); + LLAMA_API void llama_vision_bitmap_free(struct llama_vision_bitmap * bmp); + + // Create patches from the RGB bitmap + LLAMA_API struct llama_vision_patches * llama_vision_patches_init(struct llama_context * ctx, llama_vision_bitmap * bmp); + LLAMA_API void llama_vision_patches_free(struct llama_vision_patches * p); + + // Encode patches into embeddings + LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, struct llama_vision_patches * p); + LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx); + // // Model split // diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index aeb75bf3e..1f3b454fa 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,6 +24,7 @@ add_library(llama llama-quant.cpp llama-sampling.cpp llama-vocab.cpp + llama-vision.cpp unicode.h unicode.cpp unicode-data.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d7d277e72..dcfbdab3e 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -65,6 +65,11 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_UNKNOWN, "(unknown)" }, }; +static const std::map VISION_ARCH_NAMES = { + { VISION_ARCH_LLAVA, "llava" }, + { VISION_ARCH_UNKNOWN, "(unknown)" }, +}; + static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_TYPE, "general.type" }, { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, @@ -189,6 +194,27 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ADAPTER_TYPE, "adapter.type" }, { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_VISION_TYPE, "vision.type" }, + { LLM_KV_VISION_IMAGE_SIZE, "vision.image_size" }, + { LLM_KV_VISION_PATCH_SIZE, "vision.patch_size" }, + { LLM_KV_VISION_IMAGE_MEAN, "vision.image_mean" }, + { LLM_KV_VISION_IMAGE_STD, "vision.image_std" }, + { LLM_KV_VISION_CLIP_ARCHITECTURE, "vision.clip.architecture" }, + { LLM_KV_VISION_CLIP_CONTEXT_LENGTH, "vision.clip.context_length" }, + { LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, "vision.clip.embedding_length" }, + { LLM_KV_VISION_CLIP_BLOCK_COUNT, "vision.clip.block_count" }, + { LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, "vision.clip.feed_forward_length" }, + { LLM_KV_VISION_CLIP_PROJECTION_TYPE, "vision.clip.projection_type" }, + { LLM_KV_VISION_CLIP_PROJECTION_DIM, "vision.clip.projection_dim" }, + { LLM_KV_VISION_CLIP_USE_GELU, "vision.clip.use_gelu" }, + { LLM_KV_VISION_CLIP_MAX_POS_EMBD, "vision.clip.max_position_embeddings" }, + { LLM_KV_VISION_CLIP_MAX_SLICES, "vision.clip.max_slices" }, + { LLM_KV_VISION_CLIP_PROJECTOR_TYPE, "vision.clip.projector_type" }, + { LLM_KV_VISION_CLIP_SELECT_LAYER, "vision.clip.select_layer" }, + { LLM_KV_VISION_CLIP_PATCH_MERGE_TYPE, "vision.clip.patch_merge_type" }, + { LLM_KV_VISION_CLIP_HEAD_COUNT, "vision.clip.attention.head_count" }, + { LLM_KV_VISION_CLIP_LAYERNORM_EPS, "vision.clip.attention.layer_norm_epsilon" }, + // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, @@ -1300,6 +1326,28 @@ static const std::map> LLM_TENSOR_N }, }; +static const std::map> VISION_TENSOR_NAMES = { + { + VISION_ARCH_LLAVA, + { + { VISION_TENSOR_MMPROJ, "v.mmproj_%d" }, + { VISION_TENSOR_ENC_EMBD_CLS, "v.enc.embd.cls" }, + { VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" }, + { VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { VISION_TENSOR_PRE_NORM, "v.pre_norm" }, + { VISION_TENSOR_POST_NORM, "v.post_norm" }, + } + } +}; + static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, @@ -1449,7 +1497,8 @@ std::string LLM_KV::operator()(llm_kv kv) const { return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); } -std::string LLM_TN_IMPL::str() const { +template<> +std::string BASE_TN_IMPL::str() const { if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } @@ -1464,6 +1513,22 @@ std::string LLM_TN_IMPL::str() const { return name; } +template<> +std::string BASE_TN_IMPL::str() const { + if (VISION_TENSOR_NAMES.at(arch).find(tensor) == VISION_TENSOR_NAMES.at(arch).end()) { + return "__missing__"; + } + + std::string name = ::format(VISION_TENSOR_NAMES.at(arch).at(tensor), bid, xid); + + if (suffix != nullptr) { + name += "."; + name += suffix; + } + + return name; +} + const char * llm_arch_name(llm_arch arch) { auto it = LLM_ARCH_NAMES.find(arch); if (it == LLM_ARCH_NAMES.end()) { @@ -1482,6 +1547,16 @@ llm_arch llm_arch_from_string(const std::string & name) { return LLM_ARCH_UNKNOWN; } +vision_arch vision_arch_from_string(const std::string & name) { + for (const auto & kv : VISION_ARCH_NAMES) { // NOLINT + if (kv.second == name) { + return kv.first; + } + } + + return VISION_ARCH_UNKNOWN; +} + const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { return LLM_TENSOR_INFOS.at(tensor); } diff --git a/src/llama-arch.h b/src/llama-arch.h index 349844790..ce89b15f5 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -69,6 +69,11 @@ enum llm_arch { LLM_ARCH_UNKNOWN, }; +enum vision_arch { + VISION_ARCH_UNKNOWN, + VISION_ARCH_LLAVA, +}; + enum llm_kv { LLM_KV_GENERAL_TYPE, LLM_KV_GENERAL_ARCHITECTURE, @@ -193,6 +198,27 @@ enum llm_kv { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, LLM_KV_CONVNEXT_BLOCK_COUNT, + LLM_KV_VISION_TYPE, + LLM_KV_VISION_IMAGE_SIZE, + LLM_KV_VISION_PATCH_SIZE, + LLM_KV_VISION_IMAGE_MEAN, + LLM_KV_VISION_IMAGE_STD, + LLM_KV_VISION_CLIP_ARCHITECTURE, + LLM_KV_VISION_CLIP_CONTEXT_LENGTH, + LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, + LLM_KV_VISION_CLIP_BLOCK_COUNT, + LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, + LLM_KV_VISION_CLIP_PROJECTION_TYPE, + LLM_KV_VISION_CLIP_PROJECTION_DIM, + LLM_KV_VISION_CLIP_USE_GELU, + LLM_KV_VISION_CLIP_MAX_POS_EMBD, + LLM_KV_VISION_CLIP_MAX_SLICES, + LLM_KV_VISION_CLIP_PROJECTOR_TYPE, + LLM_KV_VISION_CLIP_SELECT_LAYER, + LLM_KV_VISION_CLIP_PATCH_MERGE_TYPE, + LLM_KV_VISION_CLIP_HEAD_COUNT, + LLM_KV_VISION_CLIP_LAYERNORM_EPS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, @@ -328,6 +354,23 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_OUT, }; +enum vision_tensor { + VISION_TENSOR_MMPROJ, + VISION_TENSOR_ENC_EMBD_CLS, + VISION_TENSOR_ENC_EMBD_PATCH, + VISION_TENSOR_ENC_EMBD_POS, + VISION_TENSOR_ENC_ATTN_Q, + VISION_TENSOR_ENC_ATTN_K, + VISION_TENSOR_ENC_ATTN_V, + VISION_TENSOR_ENC_INPUT_NORM, + VISION_TENSOR_ENC_OUTPUT, + VISION_TENSOR_ENC_OUTPUT_NORM, + VISION_TENSOR_ENC_FFN_UP, + VISION_TENSOR_ENC_FFN_DOWN, + VISION_TENSOR_PRE_NORM, + VISION_TENSOR_POST_NORM, +}; + enum llm_tensor_layer { LLM_TENSOR_LAYER_INPUT, LLM_TENSOR_LAYER_REPEATING, @@ -351,9 +394,10 @@ struct LLM_KV { // std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" // std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" // -struct LLM_TN_IMPL { - const llm_arch arch; - const llm_tensor tensor; +template +struct BASE_TN_IMPL { + const Tname arch; + const Ttensor tensor; const char * const suffix; const int bid; const int xid; @@ -364,15 +408,16 @@ struct LLM_TN_IMPL { return str(); } - friend bool operator==(const std::string & str, const LLM_TN_IMPL & tn) { + friend bool operator==(const std::string & str, const BASE_TN_IMPL & tn) { return str == tn.str(); } - friend bool operator!=(const std::string & str, const LLM_TN_IMPL & tn) { + friend bool operator!=(const std::string & str, const BASE_TN_IMPL & tn) { return str != tn.str(); } }; +using LLM_TN_IMPL = BASE_TN_IMPL; struct LLM_TN { LLM_TN(llm_arch arch) : arch(arch) {} @@ -387,6 +432,20 @@ struct LLM_TN { } }; +struct VISION_TN { + VISION_TN(vision_arch arch) : arch(arch) {} + + vision_arch arch; + + BASE_TN_IMPL operator()(vision_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const { + return { arch, tensor, suffix, bid, xid }; + } + + BASE_TN_IMPL operator()(vision_tensor tensor, int bid = -1, int xid = -1) const { + return { arch, tensor, nullptr, bid, xid }; + } +}; + struct llm_tensor_info { llm_tensor_layer layer; @@ -397,4 +456,6 @@ const char * llm_arch_name(llm_arch arch); llm_arch llm_arch_from_string(const std::string & name); +vision_arch vision_arch_from_string(const std::string & name); + const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57f..5ed32d859 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -31,6 +31,7 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { /*n_seq_id =*/ ubatch_n_seq_id.data(), /*seq_id =*/ ubatch_seq_id.data(), /*output =*/ ubatch_output.data(), + /*embd_tensor =*/ nullptr, }; return ubatch; } @@ -55,7 +56,9 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s } else { ubatch.token = nullptr; } - if (batch->embd) { + if (batch->embd_tensor) { + ubatch.embd_tensor = batch->embd_tensor; + } else if (batch->embd) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { memcpy( @@ -139,7 +142,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr); ubatch.equal_seqs = false; if (!seq.empty()) { llama_sbatch_seq & s = seq[0]; @@ -152,7 +155,7 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr); if (!seq.empty()) { size_t length = 0; size_t n_tokens_in_ubatch = 0; @@ -179,7 +182,7 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr); if (!seq.empty()) { llama_sbatch_seq & s = seq[seq.size() - 1]; size_t length = s.length < n_ubatch ? s.length : n_ubatch; @@ -320,6 +323,7 @@ struct llama_batch llama_batch_get_one( /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, + /*embd_tensor =*/ nullptr, }; } @@ -332,6 +336,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, + /*embd_tensor =*/ nullptr, }; if (embd) { @@ -353,6 +358,35 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ return batch; } +struct llama_batch llama_batch_get_one_from_tensor(struct ggml_tensor * tensor, int32_t p0, int32_t seq_id) { + GGML_ASSERT(tensor->ne[2] == 1 && tensor->ne[3] == 1); + int32_t n_tokens = tensor->ne[1]; + llama_batch batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*embd_tensor =*/ tensor, + }; + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.pos [i] = p0 + i; + batch.seq_id [i] = (llama_seq_id *) malloc(sizeof(llama_seq_id)); + batch.seq_id [i][0] = seq_id; + batch.n_seq_id[i] = 1; + } + batch.seq_id[n_tokens] = nullptr; + + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + void llama_batch_free(struct llama_batch batch) { if (batch.token) free(batch.token); if (batch.embd) free(batch.embd); diff --git a/src/llama-batch.h b/src/llama-batch.h index 773c3808b..a5e6f1d49 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -21,6 +21,8 @@ struct llama_ubatch { int32_t * n_seq_id; // [n_seqs] llama_seq_id ** seq_id; // [n_seqs] int8_t * output; // [n_tokens] + + struct ggml_tensor * embd_tensor; }; struct llama_sbatch_seq { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 671d2a81a..47cb701a3 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -73,7 +73,7 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens)); } - if (ubatch.embd) { + if (ubatch.embd && !ubatch.embd_tensor) { const int64_t n_embd = hparams.n_embd; const int64_t n_tokens = ubatch.n_tokens; diff --git a/src/llama-context.h b/src/llama-context.h index a9268b292..10c839f55 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -6,6 +6,7 @@ #include "llama-model.h" #include "llama-kv-cache.h" #include "llama-adapter.h" +#include "llama-vision.h" #include "ggml-cpp.h" @@ -107,6 +108,9 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + + // vision + clip_context vctx; }; // TODO: make these methods of llama_context diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 75073bf61..2045fcfa5 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -375,6 +375,7 @@ namespace GGUFMeta { template bool llama_model_loader::get_key (enum llm_kv kid, bool & result, bool required); template bool llama_model_loader::get_key (enum llm_kv kid, float & result, bool required); + template bool llama_model_loader::get_key (enum llm_kv kid, int32_t & result, bool required); template bool llama_model_loader::get_key (enum llm_kv kid, uint32_t & result, bool required); template bool llama_model_loader::get_key(enum llm_kv kid, std::string & result, bool required); @@ -439,6 +440,7 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); llama_model_loader::llama_model_loader( const std::string & fname, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c2d23a8d3..42cc230ce 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1245,6 +1245,54 @@ void llama_model::load_hparams(llama_model_loader & ml) { } hparams.rope_type = llama_model_rope_type(this); + + // vision model + auto & vparams = clip.hparams; + std::string vision_type; + ml.get_key(LLM_KV_VISION_TYPE, vision_type, false); + if (vision_type == "clip-vit") { + LLAMA_LOG_INFO("%s: loading clip-vit vision model\n", __func__); + has_vision = true; + ml.get_key(LLM_KV_VISION_IMAGE_SIZE, vparams.image_size, true); + ml.get_key(LLM_KV_VISION_PATCH_SIZE, vparams.patch_size, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_MEAN, vparams.image_mean, 3, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_STD, vparams.image_std, 3, true); + ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, vparams.hidden_size, true); + ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, vparams.n_layer, true); + ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, vparams.n_intermediate, true); + ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, vparams.n_head, true); + ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, vparams.eps, true); + ml.get_key(LLM_KV_VISION_CLIP_SELECT_LAYER, vparams.select_layer, true); + { + std::string name; + ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, name, true); + vparams.proj_type = clip_projector_type_from_name(name); + if (vparams.proj_type == CLIP_PROJECTOR_TYPE_UNKNOWN) { + throw std::runtime_error(format("unsupported clip projector type: %s", name.c_str())); + } + } + { + std::string name; + ml.get_key(LLM_KV_VISION_CLIP_PATCH_MERGE_TYPE, name, false); + vparams.mm_patch_merge_type = mm_patch_merge_from_name(name); + } + { + std::string arch; + ml.get_key(LLM_KV_VISION_CLIP_ARCHITECTURE, arch, true); + vparams.arch = vision_arch_from_string(arch); + } + } else if (!vision_type.empty()) { + throw std::runtime_error(format("unsupported vision type: %s", vision_type.c_str())); + } + + // arch-specific CLIP hparams + switch (vparams.arch) { + case VISION_ARCH_LLAVA: + { + ml.get_key(LLM_KV_VISION_CLIP_MAX_POS_EMBD, vparams.max_pos_embd, true); + } break; + default: (void)0; + } } void llama_model::load_vocab(llama_model_loader & ml) { @@ -3359,6 +3407,72 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } + // load tensors for vision model + auto & vparams = clip.hparams; + if (has_vision) { + const int64_t n_layer = vparams.n_layer; + const int64_t n_embd = vparams.hidden_size; + const int64_t n_ff = vparams.n_intermediate; + const int64_t max_pos_embd = vparams.max_pos_embd; + const int64_t n_channel = 3; // always RGB + const int64_t patch_size = vparams.patch_size; + const auto tn = VISION_TN(vparams.arch); + + // clip is CPU-only for now + clip.buft = ggml_backend_cpu_buffer_type(); + ggml_context * ctx_vision = ctx_map.at(clip.buft); + clip.layers.resize(n_layer); + + switch (vparams.arch) { + case VISION_ARCH_LLAVA: + { + clip.mm_1_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "weight", 1), {n_embd, n_ff}); + clip.mm_1_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "bias" , 1), {n_ff}); + clip.mm_2_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "weight", 2), {n_ff, n_ff}); + clip.mm_2_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "bias" , 2), {n_ff}); + + clip.class_embedding = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_CLS ), {n_embd}); + clip.patch_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_embd}); + clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_embd, max_pos_embd}); + + clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "weight"), {n_embd}); + clip.pre_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd}); + clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + clip.post_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = clip.layers[i]; + + layer.k_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd}); + layer.k_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_K, "bias" , i), {n_embd}); + layer.v_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd}); + layer.v_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_V, "bias" , i), {n_embd}); + layer.q_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.q_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_Q, "bias" , i), {n_embd}); + + layer.ffn_up_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_UP, "bias" , i), {n_ff}); + layer.ffn_down_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_DOWN, "bias" , i), {n_embd}); + + layer.norm_in_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_INPUT_NORM, "weight", i), {n_embd}); + layer.norm_in_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_INPUT_NORM, "bias" , i), {n_embd}); + layer.norm_out_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "weight", i), {n_embd}); + layer.norm_out_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "bias" , i), {n_embd}); + + layer.output_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT, "weight", i), {n_embd, n_embd}); + layer.output_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT, "bias" , i), {n_embd}); + } + } break; + default: + throw std::runtime_error("unknown vision architecture"); + } + + if (clip_n_mmproj_embd(clip) != hparams.n_embd) { + std::runtime_error("model has vision, but n_mmproj_embd != n_embd"); + } + } + ml.done_getting_tensors(); ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); diff --git a/src/llama-model.h b/src/llama-model.h index a7c304447..fd3820f1e 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -4,6 +4,7 @@ #include "llama-arch.h" #include "llama-hparams.h" #include "llama-vocab.h" +#include "llama-vision.h" #include #include @@ -362,6 +363,10 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; + // vision + bool has_vision = false; + clip_vision_model clip; + private: struct impl; std::unique_ptr pimpl; diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp index 87a33c181..b419627e6 100644 --- a/src/llama-vision.cpp +++ b/src/llama-vision.cpp @@ -1,6 +1,7 @@ #include "llama.h" #include "llama-vision.h" #include "llama-impl.h" +#include "llama-context.h" #include // memcpy #include @@ -43,15 +44,22 @@ struct clip_image_u8_batch { size_t size; }; -static int clip_n_patches(const clip_context & ctx) { +static int clip_n_patches_x(const clip_context & ctx) { auto & hparams = ctx.model->hparams; - int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size); - return n_patches; + return hparams.image_size / hparams.patch_size; } -int clip_n_mmproj_embd(const clip_context & ctx) { - if (ctx.model->hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) { - return ctx.model->mm_2_b->ne[0]; +static int clip_n_patches_y(const clip_context & ctx) { + return clip_n_patches_x(ctx); +} + +static int clip_n_patches(const clip_context & ctx) { + return clip_n_patches_x(ctx) * clip_n_patches_y(ctx); +} + +uint32_t clip_n_mmproj_embd(const clip_vision_model & clip_model) { + if (clip_model.hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) { + return clip_model.mm_2_b->ne[0]; } else { GGML_ASSERT(false && "invalid proj type"); } @@ -242,11 +250,11 @@ static llama_vision_patches clip_image_preprocess(const clip_context & ctx, cons pad_to_square = false; } - llama_vision_patches output_imgs; - output_imgs.px = clip_n_patches(ctx); - output_imgs.py = clip_n_patches(ctx); - output_imgs.n_px = params.image_size / output_imgs.px; - output_imgs.n_py = params.image_size / output_imgs.py; + llama_vision_patches output_patches; + output_patches.n_px = clip_n_patches_x(ctx); + output_patches.n_py = clip_n_patches_y(ctx); + output_patches.px = params.patch_size; + output_patches.py = params.patch_size; // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 @@ -296,13 +304,13 @@ static llama_vision_patches clip_image_preprocess(const clip_context & ctx, cons bicubic_resize(img, image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square patches.insert(patches.begin(), image_original_resize); // clip_image_f32_batch_init(patches.size()); - output_imgs.buf.resize(patches.size()); + output_patches.buf.resize(patches.size()); int num = 0; for (auto & patch : patches) { - normalize_image_u8_to_f32(patch, output_imgs.buf[num], params.image_mean, params.image_std); + normalize_image_u8_to_f32(patch, output_patches.buf[num], params.image_mean, params.image_std); num++; } - return output_imgs; + return output_patches; } else { temp.nx = img.nx; temp.ny = img.ny; @@ -367,10 +375,10 @@ static llama_vision_patches clip_image_preprocess(const clip_context & ctx, cons } } - output_imgs.buf.resize(1); - output_imgs.buf[0] = std::move(res); + output_patches.buf.resize(1); + output_patches.buf[0] = std::move(res); - return output_imgs; + return output_patches; } static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, clip_image_size & image_size) { @@ -556,14 +564,16 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, } } + embeddings = ggml_cont(ctx0, embeddings); + // build the graph ggml_build_forward_expand(gf, embeddings); ggml_free(ctx0); return gf; } -static int32_t clip_image_batch_encode(clip_context & ctx, const clip_image_f32_batch & imgs, std::vector & output) { - int batch_size = imgs.size(); +static int32_t clip_image_encode(clip_context & ctx, const llama_vision_patches & patches) { + int batch_size = patches.buf.size(); auto & model = *ctx.model; auto & hparams = ctx.model->hparams; @@ -595,15 +605,15 @@ static int32_t clip_image_batch_encode(clip_context & ctx, const clip_image_f32_ float * data = (float *)malloc(ggml_nbytes(inp_raw)); for (int i = 0; i < batch_size; i++) { - const int nx = imgs[i].nx; - const int ny = imgs[i].ny; + const int nx = patches.px * patches.n_px; + const int ny = patches.py * patches.n_py; const int n = nx * ny; for (int b = 0; b < batch_size; b++) { for (int k = 0; k < 3; k++) { for (int y = 0; y < ny; y++) { for (int x = 0; x < nx; x++) { - data[(b * 3 * n) + k * n + y * nx + x] = imgs[b].buf[3 * (y * nx + x) + k]; + data[(b * 3 * n) + k * n + y * nx + x] = patches.buf[b][3 * (y * nx + x) + k]; } } } @@ -644,45 +654,71 @@ static int32_t clip_image_batch_encode(clip_context & ctx, const clip_image_f32_ } // compute - ggml_backend_sched_graph_compute_async(ctx.sched, gf); + ggml_backend_sched_graph_compute(ctx.sched, gf); // the last node is the embedding tensor - struct ggml_tensor * embeddings = ggml_graph_node(gf, -1); - ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(ctx.sched, embeddings); + struct ggml_tensor * output_node = ggml_graph_node(gf, -1); + //LLAMA_LOG_INFO("%s: output tensor shape = %lld %lld %lld %lld\n", __func__, output->ne[0], output->ne[1], output->ne[2], output->ne[3]); - // copy the embeddings to the location passed by the user - size_t out_nbytes = clip_n_patches(ctx)*clip_n_mmproj_embd(ctx)*sizeof(float); - GGML_ASSERT(out_nbytes == ggml_nbytes(embeddings)); - output.resize(out_nbytes); - ggml_backend_tensor_get_async(backend_embd, embeddings, output.data(), 0, ggml_nbytes(embeddings)); - - ggml_backend_sched_synchronize(ctx.sched); + // copy output node to context + if (ctx.ctx_ggml) { + ggml_free(ctx.ctx_ggml); + } + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ctx.ctx_ggml = ggml_init(params); + ctx.output = ggml_dup_tensor(ctx.ctx_ggml, output_node); + ggml_backend_alloc_ctx_tensors_from_buft(ctx.ctx_ggml, ctx.model->buft); + ggml_backend_tensor_copy(output_node, ctx.output); return 0; } -static int32_t clip_image_encode(clip_context & ctx, const clip_image_f32 & img, std::vector & output) { - clip_image_f32_batch imgs{img}; - return clip_image_batch_encode(ctx, imgs, output); +//////////////////////////////////////////////////////////////////////////////////////// +// public API + +struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny) { + llama_vision_bitmap * bmp = new llama_vision_bitmap; + bmp->nx = nx; + bmp->ny = ny; + bmp->data = (unsigned char *)malloc(3 * nx * ny); + return bmp; } -static int32_t encode_image_with_clip(clip_context & ctx, const llama_img img, std::vector & output_embd) { - clip_image_u8 img_u8(img); - clip_image_f32_batch img_res_v; - auto & hparams = ctx.model->hparams; - // bmp_export(img_u8, "test_inp.bmp"); +void llama_vision_bitmap_free(llama_vision_bitmap * bmp) { + free(bmp->data); + delete bmp; +} - if (!clip_image_preprocess(ctx, img_u8, img_res_v)) { - LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__); - return -2; +struct llama_vision_patches * llama_vision_patches_init( + struct llama_context * ctx, + llama_vision_bitmap * bmp) { + clip_context & vctx = ctx->vctx; + llama_vision_patches p = clip_image_preprocess(vctx, *bmp); + return new llama_vision_patches(p); +} + +void llama_vision_patches_free(llama_vision_patches * p) { + delete p; +} + +int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_patches * p) { + if (p->buf.empty()) { + LLAMA_LOG_ERROR("%s: nothing to encode\n", __func__); + return -1; } + clip_context & vctx = ctx->vctx; + auto & hparams = vctx.model->hparams; switch (hparams.mm_patch_merge_type) { case MM_PATCH_MERGE_FLAT: { // flat / default llava-1.5 type embedding // n_output = clip_n_patches(ctx); - int32_t encoded = clip_image_encode(ctx, img_res_v[0], output_embd); + int32_t encoded = clip_image_encode(vctx, *p); if (encoded != 0) { LLAMA_LOG_ERROR("Unable to encode image\n"); return encoded; @@ -700,44 +736,8 @@ static int32_t encode_image_with_clip(clip_context & ctx, const llama_img img, s return 0; } -//////////////////////////////////////////////////////////////////////////////////////// -// public API - -int32_t llama_encode_vision_internal(clip_context & ctx, llama_batch_img * batch) { - if (batch->n_imgs == 0) { - return 0; - } - - // TODO: batching is not working atm, should be fixed later - const int n_embd = clip_n_mmproj_embd(ctx); - const int n_tokens_per_img = clip_n_patches(ctx); - const int n_pos = n_tokens_per_img*batch->n_imgs; - - ctx.out_embd.resize(n_embd*n_pos); - ctx.out_pos.resize(n_pos); - - for (int i = 0; i < batch->n_imgs; i++) { - std::vector output_single; - int32_t status = encode_image_with_clip(ctx, *batch->imgs[i], output_single); - if (status != 0) { - return status; - } - // copy output embeddings to result - for (int k = 0; k < n_embd*n_tokens_per_img; k++) { - ctx.out_embd[n_embd*n_tokens_per_img*i + k] = output_single[k]; - } - // fill position for all output tokens - for (int p = 0; p < n_tokens_per_img; p++) { - ctx.out_pos[n_tokens_per_img*i + p] = batch->pos[i] + p; - } - } - - return 0; -} - -void llama_vision_clear_output(clip_context & ctx) { - ctx.out_embd.clear(); - ctx.out_pos.clear(); +struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx) { + return ctx->vctx.output; } //////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/llama-vision.h b/src/llama-vision.h index d7c922d99..56c6b49c9 100644 --- a/src/llama-vision.h +++ b/src/llama-vision.h @@ -2,15 +2,11 @@ #include "ggml.h" #include "llama.h" +#include "llama-arch.h" #include #include -enum vision_arch { - VISION_ARCH_UNKNOWN, - VISION_ARCH_LLAVA, -}; - enum clip_projector_type { CLIP_PROJECTOR_TYPE_UNKNOWN, CLIP_PROJECTOR_TYPE_MLP, @@ -50,72 +46,76 @@ struct clip_hparams { struct clip_layer { // attention - struct ggml_tensor * k_w = NULL; - struct ggml_tensor * k_b = NULL; - struct ggml_tensor * q_w = NULL; - struct ggml_tensor * q_b = NULL; - struct ggml_tensor * v_w = NULL; - struct ggml_tensor * v_b = NULL; + struct ggml_tensor * k_w = nullptr; + struct ggml_tensor * k_b = nullptr; + struct ggml_tensor * q_w = nullptr; + struct ggml_tensor * q_b = nullptr; + struct ggml_tensor * v_w = nullptr; + struct ggml_tensor * v_b = nullptr; - struct ggml_tensor * output_w = NULL; - struct ggml_tensor * output_b = NULL; + struct ggml_tensor * output_w = nullptr; + struct ggml_tensor * output_b = nullptr; // layernorm 1 - struct ggml_tensor * norm_in_w = NULL; - struct ggml_tensor * norm_in_b = NULL; + struct ggml_tensor * norm_in_w = nullptr; + struct ggml_tensor * norm_in_b = nullptr; // ff - struct ggml_tensor * ffn_up_w = NULL; - struct ggml_tensor * ffn_up_b = NULL; + struct ggml_tensor * ffn_up_w = nullptr; + struct ggml_tensor * ffn_up_b = nullptr; - struct ggml_tensor * ffn_down_w = NULL; - struct ggml_tensor * ffn_down_b = NULL; + struct ggml_tensor * ffn_down_w = nullptr; + struct ggml_tensor * ffn_down_b = nullptr; // layernorm 2 - struct ggml_tensor * norm_out_w = NULL; - struct ggml_tensor * norm_out_b = NULL; + struct ggml_tensor * norm_out_w = nullptr; + struct ggml_tensor * norm_out_b = nullptr; }; struct clip_vision_model { struct clip_hparams hparams; + ggml_backend_buffer_type_t buft; // embeddings - struct ggml_tensor * class_embedding = NULL; - struct ggml_tensor * patch_embeddings = NULL; - struct ggml_tensor * patch_bias = NULL; - struct ggml_tensor * position_embeddings = NULL; + struct ggml_tensor * class_embedding = nullptr; + struct ggml_tensor * patch_embeddings = nullptr; + struct ggml_tensor * patch_bias = nullptr; + struct ggml_tensor * position_embeddings = nullptr; - struct ggml_tensor * pre_norm_w = NULL; - struct ggml_tensor * pre_norm_b = NULL; + struct ggml_tensor * pre_norm_w = nullptr; + struct ggml_tensor * pre_norm_b = nullptr; std::vector layers; - struct ggml_tensor * post_norm_w = NULL; - struct ggml_tensor * post_norm_b = NULL; + struct ggml_tensor * post_norm_w = nullptr; + struct ggml_tensor * post_norm_b = nullptr; - struct ggml_tensor * projection = NULL; + struct ggml_tensor * projection = nullptr; // LLaVA projection - struct ggml_tensor * mm_1_w = NULL; - struct ggml_tensor * mm_1_b = NULL; - struct ggml_tensor * mm_2_w = NULL; - struct ggml_tensor * mm_2_b = NULL; + struct ggml_tensor * mm_1_w = nullptr; + struct ggml_tensor * mm_1_b = nullptr; + struct ggml_tensor * mm_2_w = nullptr; + struct ggml_tensor * mm_2_b = nullptr; - struct ggml_tensor * image_newline = NULL; + struct ggml_tensor * image_newline = nullptr; }; struct clip_context { // memory buffers used to evaluate the model std::vector buf_compute_meta; ggml_backend_sched_t sched = nullptr; + struct ggml_context * ctx_ggml = nullptr; const clip_vision_model * model; // temporary output data, to be picked up by llama_decode() - std::vector out_embd; // size == n_tokens * n_embd - std::vector out_pos; // position of each token + struct ggml_tensor * output; }; +// for now, this only contains: +// - the instruction for ggml_conv_2d to break the image into patches +// - the pre-processed image data in f32 struct llama_vision_patches { uint32_t px; // size of patch uint32_t py; // size of patch @@ -126,7 +126,7 @@ struct llama_vision_patches { std::vector> buf; // preprocessed image data }; -mm_patch_merge mm_patch_merge_from_name(std::string & name) { +inline mm_patch_merge mm_patch_merge_from_name(std::string & name) { if (name == "flat") { return MM_PATCH_MERGE_FLAT; } else if (name == "spatial_unpad") { @@ -135,17 +135,14 @@ mm_patch_merge mm_patch_merge_from_name(std::string & name) { return MM_PATCH_MERGE_UNKNOWN; } -clip_projector_type clip_projector_type_from_name(std::string & name) { +inline clip_projector_type clip_projector_type_from_name(std::string & name) { if (name == "mlp") { return CLIP_PROJECTOR_TYPE_MLP; } return CLIP_PROJECTOR_TYPE_UNKNOWN; } -llama_vision_patches * llama_vision_patches_init(llama_vision_bitmap * bmp); -void llama_vision_patches_free(llama_vision_patches * p); +// only for sanity check: must be equal to n_embd of language model +uint32_t clip_n_mmproj_embd(const clip_vision_model & clip_model); -int32_t llama_vision_encode_impl(clip_context & ctx, llama_vision_patches * p); - -// dimension of the output embeddings, must be equal to n_embd of language model -int clip_n_mmproj_embd(const clip_context & ctx); +struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx); diff --git a/src/llama.cpp b/src/llama.cpp index e8cfe5012..6170a655a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -138,6 +138,9 @@ static struct ggml_tensor * llm_build_inp_embd( ), scale); inpL = ggml_add(ctx, inpL, inpL_delta); } + } else if (ubatch.embd_tensor) { + inpL = ubatch.embd_tensor; + ggml_set_input(ubatch.embd_tensor); } else { lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens); inpL = lctx.inp_embd; @@ -8466,7 +8469,9 @@ static int llama_decode_impl( const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + GGML_ASSERT((batch.token && !batch.embd && !batch.embd_tensor) + || (!batch.token && batch.embd && !batch.embd_tensor) + || (!batch.token && !batch.embd && batch.embd_tensor)); if (batch.token) { for (uint32_t i = 0; i < n_tokens_all; ++i) { @@ -9232,7 +9237,7 @@ static void llama_kv_cache_update_impl(struct llama_context & lctx) { uint32_t n_seqs = 1; // TODO: worst-case number of sequences uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); // initialize scheduler with the worst-case graph @@ -9785,7 +9790,7 @@ struct llama_context * llama_init_from_model( uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); llama_token token = ctx->model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true); // reserve pp graph first so that buffers are only allocated once @@ -9794,7 +9799,7 @@ struct llama_context * llama_init_from_model( int n_nodes_pp = ggml_graph_n_nodes(gf_pp); // reserve with tg graph to get the number of splits and nodes - llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; ggml_cgraph * gf_tg = llama_build_graph(*ctx, ubatch_tg, true); ggml_backend_sched_reserve(ctx->sched.get(), gf_tg); int n_splits_tg = ggml_backend_sched_get_n_splits(ctx->sched.get()); @@ -9832,6 +9837,13 @@ struct llama_context * llama_init_from_model( } } + if (model->has_vision) { + ctx->vctx.model = &model->clip; + ctx->vctx.sched = ctx->sched.get(); + const size_t max_nodes = 1024; + ctx->vctx.buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + } + return ctx; }