diff --git a/Makefile b/Makefile index a49ec0154..5e181cda0 100644 --- a/Makefile +++ b/Makefile @@ -926,6 +926,7 @@ OBJ_LLAMA = \ src/llama-vocab.o \ src/llama-grammar.o \ src/llama-sampling.o \ + src/llama-vision.o \ src/unicode.o \ src/unicode-data.o @@ -937,6 +938,7 @@ OBJ_COMMON = \ common/ngram-cache.o \ common/sampling.o \ common/train.o \ + common/vision.o \ common/build-info.o \ common/json-schema-to-grammar.o @@ -1221,6 +1223,12 @@ common/ngram-cache.o: \ common/ngram-cache.h $(CXX) $(CXXFLAGS) -c $< -o $@ +common/vision.o: \ + common/vision.cpp \ + common/vision.h \ + common/stb_image.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + $(LIB_COMMON): \ $(OBJ_COMMON) \ $(LIB_LLAMA) \ diff --git a/common/vision.cpp b/common/vision.cpp new file mode 100644 index 000000000..7b5c1d995 --- /dev/null +++ b/common/vision.cpp @@ -0,0 +1,37 @@ +#include "vision.h" + +#define STB_IMAGE_IMPLEMENTATION +#include "stb_image.h" + +#include +#include + +llama_img * load_image_from_file(const char * fname) { + std::ifstream file(fname, std::ios::binary); + if (!file) { + throw std::runtime_error("Unable to open file"); + } + std::vector image_bytes = std::vector( + std::istreambuf_iterator(file), + std::istreambuf_iterator()); + // decode image to byte array + int nx, ny, nc; + auto * bytes = (unsigned char *) image_bytes.data(); + auto * img = stbi_load_from_memory(bytes, image_bytes.size(), &nx, &ny, &nc, 3); + if (!img) { + throw std::runtime_error("failed to decode image bytes"); + } + // printf("nx=%d ny=%d nc=%d\n", nx, ny, nc); + // GGML_ASSERT(nc == 3); + // for (int y = 0; y < ny; y++) { + // for (int x = 0; x < nx; x++) { + // unsigned char * pix = img + x*nc + y*nc*nx; + // printf("%02x%02x%02x ", pix[0], pix[1], pix[2]); + // } + // printf("\n"); + // } + // printf("\n"); + llama_img * result = llama_img_alloc(nx, ny); + memcpy(result->data, bytes, nx*ny*nc); + return result; +} diff --git a/common/vision.h b/common/vision.h new file mode 100644 index 000000000..16c6325fd --- /dev/null +++ b/common/vision.h @@ -0,0 +1,8 @@ +#pragma once + +#include "llama.h" + +#include +#include + +llama_img * load_image_from_file(const char * fname); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index c2b7267c8..303d31e05 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "vision.h" #include @@ -61,6 +62,19 @@ int main(int argc, char ** argv) { llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); + + + + // TODO: this is for testing; DELETE ME + llama_img_batch ibatch; + ibatch.n_imgs = 1; + ibatch.imgs = (llama_img **) malloc(1024); + ibatch.imgs[0] = load_image_from_file("media/llama0-logo.png"); + llama_vision_encode(ctx, &ibatch); + return 0; + + + // tokenize the prompt std::vector tokens_list; diff --git a/include/llama.h b/include/llama.h index 9aa17ffd1..49a32d66b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -234,8 +234,8 @@ extern "C" { // Input data for llama_vision_decode typedef struct llama_img_batch { - int32_t n_imgs; - llama_img * imgs; + int32_t n_imgs; + llama_img ** imgs; } llama_img_batch; // Input data for llama_decode @@ -893,6 +893,10 @@ extern "C" { // Vision // + // create new RGB image for input + LLAMA_API llama_img * llama_img_alloc(int width, int height); + LLAMA_API void llama_img_free(llama_img * img); + // encode image into embeddings LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, llama_img_batch * batch); diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp index 75fdfc703..ff6dea4f4 100644 --- a/src/llama-vision.cpp +++ b/src/llama-vision.cpp @@ -1,8 +1,22 @@ #include "llama.h" - #include "llama-vision.h" #include "llama-impl.h" +#include // memcpy +#include +#include + +#ifndef NDEBUG +// for debugging +#include +#include +#include + +// export clip_image_u8 to bmp file for debugging +// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c +static int bmp_export(const clip_image_u8 &img, const std::string &location); +#endif + struct clip_image_size { int width; int height; @@ -39,8 +53,23 @@ struct clip_image_f32 { using clip_image_f32_batch = std::vector; using clip_image_f8_batch = std::vector; -int32_t clip_image_encode (const clip_context & ctx, const clip_image_f32 & img, std::vector & output); -int32_t clip_image_batch_encode(const clip_context & ctx, const clip_image_f32_batch & imgs, std::vector & output); +static int clip_n_patches(const clip_context & ctx) { + auto & hparams = ctx.model->hparams; + int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size); + return n_patches; +} + +static int clip_n_mmproj_embd(const clip_context & ctx) { + if (ctx.model->hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) { + return ctx.model->mm_b_b->ne[0]; + } else { + GGML_ASSERT(false && "invalid proj type"); + } +} + +static int clip_n_embd(const clip_context & ctx) { + return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx); +} /** * Selects the best resolution from a list of possible resolutions based on the original size. @@ -221,9 +250,9 @@ static void normalize_image_u8_to_f32(const clip_image_u8 src, clip_image_f32 ds // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found -bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img, clip_image_f32_batch & output_imgs) { +static bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img, clip_image_f32_batch & output_imgs) { bool pad_to_square = true; - auto & params = ctx.model.hparams; + auto & params = ctx.model->hparams; // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing if (params.mm_patch_merge_type == MM_PATCH_MERGE_SPATIAL_UNPAD) { pad_to_square = false; @@ -357,58 +386,356 @@ bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img, return true; } -int clip_n_patches(const clip_context & ctx) { - auto & hparams = ctx.model.hparams; - int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size); - return n_patches; -} +static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, clip_image_size & image_size) { + auto & model = *ctx.model; + auto & hparams = ctx.model->hparams; -static bool encode_image_with_clip(clip_context & ctx_clip, const llama_img img) { - clip_image_u8 img_u8(img); - clip_image_f32_batch img_res_v; - std::vector image_embd; // output vectors - auto & hparams = ctx_clip.model.hparams; - int n_output; + const int hidden_size = hparams.hidden_size; + const int n_head = hparams.n_head; + const int d_head = hidden_size / n_head; + const int patch_size = hparams.patch_size; + const float eps = hparams.eps; + const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size)); + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); - if (!clip_image_preprocess(ctx_clip, img_u8, img_res_v)) { - LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__); - return false; + LLAMA_LOG_DEBUG("%s: num_patches = %d\n", __func__, num_patches); + + struct ggml_init_params params = { + /*.mem_size =*/ ctx.buf_compute_meta.size(), + /*.mem_buffer =*/ ctx.buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + // input + struct ggml_tensor * embeddings; + { + struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size.width, image_size.height, 3, batch_size); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + } + // auto * ne = inp->ne; printf("%d %d %d %d\n", ne[0], ne[1], ne[2], ne[3]); + + embeddings = inp; + if (model.class_embedding) { + embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + ggml_set_name(embeddings, "embeddings"); + ggml_set_input(embeddings); + embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); + embeddings = ggml_acc(ctx0, embeddings, inp, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); + } + + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + embeddings = ggml_add(ctx0, + embeddings, + ggml_get_rows(ctx0, model.position_embeddings, positions)); } - if (hparams.mm_patch_merge_type != MM_PATCH_MERGE_SPATIAL_UNPAD) { - // flat / default llava-1.5 type embedding - n_output = clip_n_patches(ctx_clip); - bool encoded = clip_image_encode(ctx_clip, img_res_v[0], image_embd); - if (!encoded) { - LLAMA_LOG_ERROR("Unable to encode image\n"); - return false; + // pre-layernorm + if (model.pre_norm_w) { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "pre_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_w); + } + + // loop over layers + for (int il = 0; il < (int)hparams.n_layer - 1; il++) { + struct ggml_tensor * cur = embeddings; + + // layernorm1 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].norm_in_w), + model.layers[il].norm_in_b); + } + + // self-attention + { + + struct ggml_tensor * Q = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].q_w, cur), + model.layers[il].q_b); + + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * K = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].k_w, cur), + model.layers[il].k_b); + + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * V = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].v_w, cur), + model.layers[il].v_b); + + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_inplace(ctx0, KQ); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); + } + + // attention output + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].output_w, cur), model.layers[il].output_b); + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, embeddings); + + embeddings = cur; // embeddings = residual, cur = hidden_states + + // layernorm2 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].norm_out_w), + model.layers[il].norm_out_b); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_up_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ffn_up_b); + + if (hparams.use_gelu) { + cur = ggml_gelu_inplace(ctx0, cur); + } else { + cur = ggml_gelu_quick_inplace(ctx0, cur); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ffn_down_b); + + // residual 2 + cur = ggml_add(ctx0, embeddings, cur); + + embeddings = cur; + } + + // llava projector + { + embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); + + struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); + ggml_set_name(patches, "patches"); + ggml_set_input(patches); + + // shape [1, 576, 1024] + // ne is whcn, ne = [1024, 576, 1, 1] + embeddings = ggml_get_rows(ctx0, embeddings, patches); + + if (hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) { + embeddings = ggml_mul_mat(ctx0, model.mm_a_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_a_b); + + embeddings = ggml_gelu(ctx0, embeddings); + embeddings = ggml_mul_mat(ctx0, model.mm_b_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_b_b); + } else { + GGML_ASSERT(false && "unsupported proj type"); } } + + // build the graph + ggml_build_forward_expand(gf, embeddings); + ggml_free(ctx0); + return gf; } -int32_t clip_image_encode(const clip_context & ctx, const clip_image_f32 & img, std::vector & output) { +static int32_t clip_image_batch_encode(clip_context & ctx, const clip_image_f32_batch & imgs, std::vector & output) { + int batch_size = imgs.size(); + auto & model = *ctx.model; + auto & hparams = ctx.model->hparams; + + if (hparams.arch == VISION_ARCH_LLAVA) { + GGML_ASSERT(batch_size == 1); // TODO: support multiple images + } + + clip_image_size image_size{(int)hparams.image_size, (int)hparams.image_size}; + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size)); + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); + + LLAMA_LOG_DEBUG("%s: image_size = %d\n", __func__, hparams.image_size); + LLAMA_LOG_DEBUG("%s: num_positions = %d\n", __func__, num_positions); + + // build the inference graph + ggml_cgraph * gf = clip_image_build_graph(ctx, batch_size, image_size); + + // alloc memory for graph + bool ok = ggml_backend_sched_alloc_graph(ctx.sched, gf); + if (!ok) { + LLAMA_LOG_ERROR("failed to alloc memory for graph\n"); + return -1; + } + + // set raw input + { + struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); + float * data = (float *)malloc(ggml_nbytes(inp_raw)); + + for (int i = 0; i < batch_size; i++) { + const int nx = imgs[i].nx; + const int ny = imgs[i].ny; + const int n = nx * ny; + + for (int b = 0; b < batch_size; b++) { + for (int k = 0; k < 3; k++) { + for (int y = 0; y < ny; y++) { + for (int x = 0; x < nx; x++) { + data[(b * 3 * n) + k * n + y * nx + x] = imgs[b].buf[3 * (y * nx + x) + k]; + } + } + } + } + } + ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); + free(data); + } + + if (model.class_embedding) { + struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); + + void* zero_mem = malloc(ggml_nbytes(embeddings)); + memset(zero_mem, 0, ggml_nbytes(embeddings)); + ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); + free(zero_mem); + } + + { + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + + int* positions_data = (int*)malloc(ggml_nbytes(positions)); + for (int i = 0; i < num_positions; i++) { + positions_data[i] = i; + } + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + free(positions_data); + } + + { + struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); + int* patches_data = (int*)malloc(ggml_nbytes(patches)); + for (int i = 0; i < num_patches; i++) { + patches_data[i] = i + 1; + } + ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); + free(patches_data); + } + + // compute + ggml_backend_sched_graph_compute_async(ctx.sched, gf); + + // the last node is the embedding tensor + struct ggml_tensor * embeddings = ggml_graph_node(gf, -1); + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(ctx.sched, embeddings); + + // copy the embeddings to the location passed by the user + output.resize(clip_n_embd(ctx)); + ggml_backend_tensor_get_async(backend_embd, embeddings, output.data(), 0, ggml_nbytes(embeddings)); + + ggml_backend_sched_synchronize(ctx.sched); + + return 0; +} + +static int32_t clip_image_encode(clip_context & ctx, const clip_image_f32 & img, std::vector & output) { clip_image_f32_batch imgs{img}; return clip_image_batch_encode(ctx, imgs, output); } -int32_t clip_image_batch_encode(const clip_context & ctx, const clip_image_f32_batch & imgs, std::vector & output) { - int batch_size = imgs.size(); +static int32_t encode_image_with_clip(clip_context & ctx, const llama_img img, std::vector & output_embd) { + clip_image_u8 img_u8(img); + clip_image_f32_batch img_res_v; + auto & hparams = ctx.model->hparams; + + if (!clip_image_preprocess(ctx, img_u8, img_res_v)) { + LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__); + return -2; + } + + switch (hparams.mm_patch_merge_type) { + case MM_PATCH_MERGE_FLAT: + { + // flat / default llava-1.5 type embedding + // n_output = clip_n_patches(ctx); + int32_t encoded = clip_image_encode(ctx, img_res_v[0], output_embd); + if (encoded != 0) { + LLAMA_LOG_ERROR("Unable to encode image\n"); + return encoded; + } + } break; + case MM_PATCH_MERGE_SPATIAL_UNPAD: + { + // TODO: support llava-1.6 + (void)0; + } break; + default: + GGML_ASSERT(false && "unsupported mm_patch_merge_type"); + } + + return 0; } - //////////////////////////////////////////////////////////////////////////////////////// +// public API + +int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch) { + if (batch->n_imgs == 0) { + return 0; + } + + // TODO: batching is not working atm, should be fixed later + const int n_embd = clip_n_embd(ctx); + ctx.output.resize(n_embd * batch->n_imgs); + ctx.n_output = batch->n_imgs; + + for (int i = 0; i < batch->n_imgs; i++) { + std::vector output_single; + int32_t status = encode_image_with_clip(ctx, *batch->imgs[i], output_single); + if (status != 0) { + return status; + } + // copy output embeddings to result + for (int k = 0; k < n_embd; k++) { + ctx.output[n_embd*i + k] = output_single[k]; + // if (k<10) printf("%f\n", output_single[k]); + } + } + + return 0; +} + //////////////////////////////////////////////////////////////////////////////////////// // for debugging #ifndef NDEBUG -#include -#include -#include -#include - -// export clip_image_u8 to bmp file for debugging -// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c -inline int bmp_export(const clip_image_u8 &img, const std::string &location) { +static int bmp_export(const clip_image_u8 &img, const std::string &location) { const uint32_t width = img.nx; const uint32_t height = img.ny; const std::vector &buffer = img.buf; @@ -445,7 +772,7 @@ inline int bmp_export(const clip_image_u8 &img, const std::string &location) { const uint32_t blueBitmask = (hasAlphaChannel) ? 0xFF000000 : 0; const uint32_t alphaBitmask = (hasAlphaChannel) ? 0x000000FF : 0; - //Writing the file header and information header to the file + //Writing the file header and information header to the file std::vector header(offset, 0); header[0] = signature[0]; header[1] = signature[1]; diff --git a/src/llama-vision.h b/src/llama-vision.h index e7404ea18..c14c880c4 100644 --- a/src/llama-vision.h +++ b/src/llama-vision.h @@ -3,6 +3,7 @@ #include "ggml.h" #include +#include enum vision_arch { VISION_ARCH_LLAVA, @@ -29,6 +30,7 @@ struct clip_hparams { uint32_t n_head; uint32_t n_layer; uint32_t max_pos_embd; + bool use_gelu = false; float eps; @@ -44,50 +46,50 @@ struct clip_hparams { struct clip_layer { // attention - struct ggml_tensor * k_w; - struct ggml_tensor * k_b; - struct ggml_tensor * q_w; - struct ggml_tensor * q_b; - struct ggml_tensor * v_w; - struct ggml_tensor * v_b; + struct ggml_tensor * k_w = NULL; + struct ggml_tensor * k_b = NULL; + struct ggml_tensor * q_w = NULL; + struct ggml_tensor * q_b = NULL; + struct ggml_tensor * v_w = NULL; + struct ggml_tensor * v_b = NULL; - struct ggml_tensor * output_w; - struct ggml_tensor * output_b; + struct ggml_tensor * output_w = NULL; + struct ggml_tensor * output_b = NULL; // layernorm 1 - struct ggml_tensor * norm_in_w; - struct ggml_tensor * norm_in_b; + struct ggml_tensor * norm_in_w = NULL; + struct ggml_tensor * norm_in_b = NULL; // ff - struct ggml_tensor * ffn_up_w; - struct ggml_tensor * ffn_up_b; + struct ggml_tensor * ffn_up_w = NULL; + struct ggml_tensor * ffn_up_b = NULL; - struct ggml_tensor * ffn_down_w; - struct ggml_tensor * ffn_down_b; + struct ggml_tensor * ffn_down_w = NULL; + struct ggml_tensor * ffn_down_b = NULL; // layernorm 2 - struct ggml_tensor * norm_out_w; - struct ggml_tensor * norm_out_b; + struct ggml_tensor * norm_out_w = NULL; + struct ggml_tensor * norm_out_b = NULL; }; struct clip_vision_model { struct clip_hparams hparams; // embeddings - struct ggml_tensor * class_embedding; - struct ggml_tensor * patch_embeddings; - struct ggml_tensor * patch_bias; - struct ggml_tensor * position_embeddings; + struct ggml_tensor * class_embedding = NULL; + struct ggml_tensor * patch_embeddings = NULL; + struct ggml_tensor * patch_bias = NULL; + struct ggml_tensor * position_embeddings = NULL; - struct ggml_tensor * pre_norm_w; - struct ggml_tensor * pre_norm_b; + struct ggml_tensor * pre_norm_w = NULL; + struct ggml_tensor * pre_norm_b = NULL; std::vector layers; - struct ggml_tensor * post_norm_w; - struct ggml_tensor * post_norm_b; + struct ggml_tensor * post_norm_w = NULL; + struct ggml_tensor * post_norm_b = NULL; - struct ggml_tensor * projection; + struct ggml_tensor * projection = NULL; // LLaVA projection struct ggml_tensor * mm_a_w = NULL; @@ -99,9 +101,15 @@ struct clip_vision_model { }; struct clip_context { - struct ggml_context * ctx_ggml; - clip_vision_model model; + // memory buffers used to evaluate the model + std::vector buf_compute_meta; + ggml_backend_sched_t sched = nullptr; - int32_t n_output; - float * output; + const clip_vision_model * model; + + // temporary output data + int n_output; + std::vector output; // size == n_output * n_embd }; + +int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch); diff --git a/src/llama.cpp b/src/llama.cpp index 2860c7094..b9d64764f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2552,9 +2552,6 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; - bool has_vision = false; - clip_hparams clip; - bool operator!=(const llama_hparams & other) const { if (this->vocab_only != other.vocab_only) return true; if (this->n_vocab != other.n_vocab) return true; @@ -3005,6 +3002,7 @@ struct llama_model { std::vector layers; + bool has_vision = false; clip_vision_model clip; llama_split_mode split_mode; @@ -3502,6 +3500,9 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + + // vision + clip_context clip; }; struct llama_lora_weight { @@ -6223,23 +6224,24 @@ static void llm_load_hparams( } // vision model + auto & vparams = model.clip.hparams; std::string vision_type; ml.get_key(LLM_KV_VISION_TYPE, vision_type, false); if (vision_type == "clip") { - hparams.has_vision = true; + model.has_vision = true; std::string proj_type; - ml.get_key(LLM_KV_VISION_IMAGE_SIZE, hparams.clip.image_size, true); - ml.get_key(LLM_KV_VISION_PATCH_SIZE, hparams.clip.patch_size, true); - ml.get_key_or_arr(LLM_KV_VISION_IMAGE_MEAN, hparams.clip.image_mean, 3, true); - ml.get_key_or_arr(LLM_KV_VISION_IMAGE_STD, hparams.clip.image_std, 3, true); - ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, hparams.clip.hidden_size, true); - ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, hparams.clip.n_layer, true); - ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, hparams.clip.n_intermediate, true); - ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, hparams.clip.n_head, true); - ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, hparams.clip.eps, true); + ml.get_key(LLM_KV_VISION_IMAGE_SIZE, vparams.image_size, true); + ml.get_key(LLM_KV_VISION_PATCH_SIZE, vparams.patch_size, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_MEAN, vparams.image_mean, 3, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_STD, vparams.image_std, 3, true); + ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, vparams.hidden_size, true); + ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, vparams.n_layer, true); + ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, vparams.n_intermediate, true); + ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, vparams.n_head, true); + ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, vparams.eps, true); ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, proj_type, true); if (proj_type == "mlp") { - hparams.clip.proj_type = CLIP_PROJECTOR_TYPE_MLP; + vparams.proj_type = CLIP_PROJECTOR_TYPE_MLP; } else { throw std::runtime_error(format("unsupported clip projector type: %s", proj_type.c_str())); } @@ -6247,7 +6249,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_VISION_CLIP_ARCHITECTURE, arch, true); for (auto & it : VISION_ARCH_NAMES) { if (arch == it.second) { - hparams.clip.arch = it.first; + vparams.arch = it.first; break; } } @@ -6256,10 +6258,10 @@ static void llm_load_hparams( } // arch-specific CLIP hparams - switch (hparams.clip.arch) { + switch (vparams.arch) { case VISION_ARCH_LLAVA: { - ml.get_key(LLM_KV_VISION_CLIP_MAX_POS_EMBD, hparams.clip.max_pos_embd, true); + ml.get_key(LLM_KV_VISION_CLIP_MAX_POS_EMBD, vparams.max_pos_embd, true); } break; default: (void)0; } @@ -8957,21 +8959,22 @@ static bool llm_load_tensors( } // load tensors for vision model - if (hparams.has_vision) { - const int64_t n_layer = hparams.clip.n_layer; - const int64_t n_embd = hparams.clip.hidden_size; - const int64_t n_ff = hparams.clip.n_intermediate; - const int64_t max_pos_embd = hparams.clip.max_pos_embd; + auto & vparams = model.clip.hparams; + if (model.has_vision) { + const int64_t n_layer = vparams.n_layer; + const int64_t n_embd = vparams.hidden_size; + const int64_t n_ff = vparams.n_intermediate; + const int64_t max_pos_embd = vparams.max_pos_embd; const int64_t n_channel = 3; // always RGB - const int64_t patch_size = hparams.clip.patch_size; - const auto tn = VISION_TN(hparams.clip.arch); + const int64_t patch_size = vparams.patch_size; + const auto tn = VISION_TN(vparams.arch); ggml_context * ctx_vision = ctx_map.at(model.buft_input.buft); // TODO: make dedicated buft for vision auto ctx_for_layer = [&](int i) { return ctx_map.at(model.buft_layer[i].buft); }; model.clip.layers.resize(n_layer); - switch (hparams.clip.arch) { + switch (vparams.arch) { case VISION_ARCH_LLAVA: { model.clip.mm_a_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_A, "weight"), {n_embd, n_ff}); @@ -19637,6 +19640,14 @@ struct llama_context * llama_new_context_with_model( } } + // initialize vision context + if (model->has_vision) { + ctx->clip.model = &model->clip; + ctx->clip.sched = ctx->sched; + const size_t max_nodes = llama_model_max_nodes(*model); + ctx->clip.buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + } + return ctx; } @@ -21780,6 +21791,30 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); } +// +// vision +// + +llama_img * llama_img_alloc(int width, int height) { + llama_img * img = new llama_img(); + img->nx = width; + img->ny = height; + img->data = (unsigned char *)malloc(width*height*3); + return img; +} +void llama_img_free(llama_img * img) { + free(img->data); + delete img; +} + +int32_t llama_vision_encode(struct llama_context * ctx, llama_img_batch * batch) { + return llama_vision_encode_internal(ctx->clip, batch); +} + +float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx) { + return ctx->clip.output.data(); +} + // // model split //