llava cgraph ok
This commit is contained in:
parent
6854ad4057
commit
5648e30d3e
8 changed files with 536 additions and 95 deletions
8
Makefile
8
Makefile
|
@ -926,6 +926,7 @@ OBJ_LLAMA = \
|
|||
src/llama-vocab.o \
|
||||
src/llama-grammar.o \
|
||||
src/llama-sampling.o \
|
||||
src/llama-vision.o \
|
||||
src/unicode.o \
|
||||
src/unicode-data.o
|
||||
|
||||
|
@ -937,6 +938,7 @@ OBJ_COMMON = \
|
|||
common/ngram-cache.o \
|
||||
common/sampling.o \
|
||||
common/train.o \
|
||||
common/vision.o \
|
||||
common/build-info.o \
|
||||
common/json-schema-to-grammar.o
|
||||
|
||||
|
@ -1221,6 +1223,12 @@ common/ngram-cache.o: \
|
|||
common/ngram-cache.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
common/vision.o: \
|
||||
common/vision.cpp \
|
||||
common/vision.h \
|
||||
common/stb_image.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
$(LIB_COMMON): \
|
||||
$(OBJ_COMMON) \
|
||||
$(LIB_LLAMA) \
|
||||
|
|
37
common/vision.cpp
Normal file
37
common/vision.cpp
Normal file
|
@ -0,0 +1,37 @@
|
|||
#include "vision.h"
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
||||
llama_img * load_image_from_file(const char * fname) {
|
||||
std::ifstream file(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
throw std::runtime_error("Unable to open file");
|
||||
}
|
||||
std::vector<char> image_bytes = std::vector<char>(
|
||||
std::istreambuf_iterator<char>(file),
|
||||
std::istreambuf_iterator<char>());
|
||||
// decode image to byte array
|
||||
int nx, ny, nc;
|
||||
auto * bytes = (unsigned char *) image_bytes.data();
|
||||
auto * img = stbi_load_from_memory(bytes, image_bytes.size(), &nx, &ny, &nc, 3);
|
||||
if (!img) {
|
||||
throw std::runtime_error("failed to decode image bytes");
|
||||
}
|
||||
// printf("nx=%d ny=%d nc=%d\n", nx, ny, nc);
|
||||
// GGML_ASSERT(nc == 3);
|
||||
// for (int y = 0; y < ny; y++) {
|
||||
// for (int x = 0; x < nx; x++) {
|
||||
// unsigned char * pix = img + x*nc + y*nc*nx;
|
||||
// printf("%02x%02x%02x ", pix[0], pix[1], pix[2]);
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
// printf("\n");
|
||||
llama_img * result = llama_img_alloc(nx, ny);
|
||||
memcpy(result->data, bytes, nx*ny*nc);
|
||||
return result;
|
||||
}
|
8
common/vision.h
Normal file
8
common/vision.h
Normal file
|
@ -0,0 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
llama_img * load_image_from_file(const char * fname);
|
|
@ -2,6 +2,7 @@
|
|||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "vision.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
@ -61,6 +62,19 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
||||
|
||||
|
||||
|
||||
|
||||
// TODO: this is for testing; DELETE ME
|
||||
llama_img_batch ibatch;
|
||||
ibatch.n_imgs = 1;
|
||||
ibatch.imgs = (llama_img **) malloc(1024);
|
||||
ibatch.imgs[0] = load_image_from_file("media/llama0-logo.png");
|
||||
llama_vision_encode(ctx, &ibatch);
|
||||
return 0;
|
||||
|
||||
|
||||
|
||||
// tokenize the prompt
|
||||
|
||||
std::vector<llama_token> tokens_list;
|
||||
|
|
|
@ -234,8 +234,8 @@ extern "C" {
|
|||
|
||||
// Input data for llama_vision_decode
|
||||
typedef struct llama_img_batch {
|
||||
int32_t n_imgs;
|
||||
llama_img * imgs;
|
||||
int32_t n_imgs;
|
||||
llama_img ** imgs;
|
||||
} llama_img_batch;
|
||||
|
||||
// Input data for llama_decode
|
||||
|
@ -893,6 +893,10 @@ extern "C" {
|
|||
// Vision
|
||||
//
|
||||
|
||||
// create new RGB image for input
|
||||
LLAMA_API llama_img * llama_img_alloc(int width, int height);
|
||||
LLAMA_API void llama_img_free(llama_img * img);
|
||||
|
||||
// encode image into embeddings
|
||||
LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, llama_img_batch * batch);
|
||||
|
||||
|
|
|
@ -1,8 +1,22 @@
|
|||
#include "llama.h"
|
||||
|
||||
#include "llama-vision.h"
|
||||
#include "llama-impl.h"
|
||||
|
||||
#include <string.h> // memcpy
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
|
||||
#ifndef NDEBUG
|
||||
// for debugging
|
||||
#include <fstream>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
// export clip_image_u8 to bmp file for debugging
|
||||
// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c
|
||||
static int bmp_export(const clip_image_u8 &img, const std::string &location);
|
||||
#endif
|
||||
|
||||
struct clip_image_size {
|
||||
int width;
|
||||
int height;
|
||||
|
@ -39,8 +53,23 @@ struct clip_image_f32 {
|
|||
using clip_image_f32_batch = std::vector<clip_image_f32>;
|
||||
using clip_image_f8_batch = std::vector<clip_image_u8>;
|
||||
|
||||
int32_t clip_image_encode (const clip_context & ctx, const clip_image_f32 & img, std::vector<float> & output);
|
||||
int32_t clip_image_batch_encode(const clip_context & ctx, const clip_image_f32_batch & imgs, std::vector<float> & output);
|
||||
static int clip_n_patches(const clip_context & ctx) {
|
||||
auto & hparams = ctx.model->hparams;
|
||||
int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size);
|
||||
return n_patches;
|
||||
}
|
||||
|
||||
static int clip_n_mmproj_embd(const clip_context & ctx) {
|
||||
if (ctx.model->hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) {
|
||||
return ctx.model->mm_b_b->ne[0];
|
||||
} else {
|
||||
GGML_ASSERT(false && "invalid proj type");
|
||||
}
|
||||
}
|
||||
|
||||
static int clip_n_embd(const clip_context & ctx) {
|
||||
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx);
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects the best resolution from a list of possible resolutions based on the original size.
|
||||
|
@ -221,9 +250,9 @@ static void normalize_image_u8_to_f32(const clip_image_u8 src, clip_image_f32 ds
|
|||
|
||||
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
|
||||
// res_imgs memory is being allocated here, previous allocations will be freed if found
|
||||
bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img, clip_image_f32_batch & output_imgs) {
|
||||
static bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img, clip_image_f32_batch & output_imgs) {
|
||||
bool pad_to_square = true;
|
||||
auto & params = ctx.model.hparams;
|
||||
auto & params = ctx.model->hparams;
|
||||
// The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
|
||||
if (params.mm_patch_merge_type == MM_PATCH_MERGE_SPATIAL_UNPAD) {
|
||||
pad_to_square = false;
|
||||
|
@ -357,58 +386,356 @@ bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img,
|
|||
return true;
|
||||
}
|
||||
|
||||
int clip_n_patches(const clip_context & ctx) {
|
||||
auto & hparams = ctx.model.hparams;
|
||||
int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size);
|
||||
return n_patches;
|
||||
}
|
||||
static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, clip_image_size & image_size) {
|
||||
auto & model = *ctx.model;
|
||||
auto & hparams = ctx.model->hparams;
|
||||
|
||||
static bool encode_image_with_clip(clip_context & ctx_clip, const llama_img img) {
|
||||
clip_image_u8 img_u8(img);
|
||||
clip_image_f32_batch img_res_v;
|
||||
std::vector<float> image_embd; // output vectors
|
||||
auto & hparams = ctx_clip.model.hparams;
|
||||
int n_output;
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
const int patch_size = hparams.patch_size;
|
||||
const float eps = hparams.eps;
|
||||
const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size));
|
||||
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
||||
|
||||
if (!clip_image_preprocess(ctx_clip, img_u8, img_res_v)) {
|
||||
LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__);
|
||||
return false;
|
||||
LLAMA_LOG_DEBUG("%s: num_patches = %d\n", __func__, num_patches);
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx.buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ ctx.buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
// input
|
||||
struct ggml_tensor * embeddings;
|
||||
{
|
||||
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size.width, image_size.height, 3, batch_size);
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
|
||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
|
||||
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
||||
|
||||
if (model.patch_bias) {
|
||||
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||
}
|
||||
// auto * ne = inp->ne; printf("%d %d %d %d\n", ne[0], ne[1], ne[2], ne[3]);
|
||||
|
||||
embeddings = inp;
|
||||
if (model.class_embedding) {
|
||||
embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
||||
ggml_set_name(embeddings, "embeddings");
|
||||
ggml_set_input(embeddings);
|
||||
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
|
||||
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
|
||||
embeddings = ggml_acc(ctx0, embeddings, inp,
|
||||
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
|
||||
}
|
||||
|
||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
||||
ggml_set_name(positions, "positions");
|
||||
ggml_set_input(positions);
|
||||
|
||||
embeddings = ggml_add(ctx0,
|
||||
embeddings,
|
||||
ggml_get_rows(ctx0, model.position_embeddings, positions));
|
||||
}
|
||||
|
||||
if (hparams.mm_patch_merge_type != MM_PATCH_MERGE_SPATIAL_UNPAD) {
|
||||
// flat / default llava-1.5 type embedding
|
||||
n_output = clip_n_patches(ctx_clip);
|
||||
bool encoded = clip_image_encode(ctx_clip, img_res_v[0], image_embd);
|
||||
if (!encoded) {
|
||||
LLAMA_LOG_ERROR("Unable to encode image\n");
|
||||
return false;
|
||||
// pre-layernorm
|
||||
if (model.pre_norm_w) {
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "pre_ln");
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_w);
|
||||
}
|
||||
|
||||
// loop over layers
|
||||
for (int il = 0; il < (int)hparams.n_layer - 1; il++) {
|
||||
struct ggml_tensor * cur = embeddings;
|
||||
|
||||
// layernorm1
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, eps);
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0, cur, model.layers[il].norm_in_w),
|
||||
model.layers[il].norm_in_b);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
|
||||
struct ggml_tensor * Q = ggml_add(ctx0,
|
||||
ggml_mul_mat(ctx0, model.layers[il].q_w, cur),
|
||||
model.layers[il].q_b);
|
||||
|
||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
|
||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor * K = ggml_add(ctx0,
|
||||
ggml_mul_mat(ctx0, model.layers[il].k_w, cur),
|
||||
model.layers[il].k_b);
|
||||
|
||||
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor * V = ggml_add(ctx0,
|
||||
ggml_mul_mat(ctx0, model.layers[il].v_w, cur),
|
||||
model.layers[il].v_b);
|
||||
|
||||
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
|
||||
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
||||
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
KQ = ggml_soft_max_inplace(ctx0, KQ);
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
|
||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
|
||||
}
|
||||
|
||||
// attention output
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].output_w, cur), model.layers[il].output_b);
|
||||
|
||||
// re-add the layer input, e.g., residual
|
||||
cur = ggml_add(ctx0, cur, embeddings);
|
||||
|
||||
embeddings = cur; // embeddings = residual, cur = hidden_states
|
||||
|
||||
// layernorm2
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, eps);
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0, cur, model.layers[il].norm_out_w),
|
||||
model.layers[il].norm_out_b);
|
||||
}
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_up_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ffn_up_b);
|
||||
|
||||
if (hparams.use_gelu) {
|
||||
cur = ggml_gelu_inplace(ctx0, cur);
|
||||
} else {
|
||||
cur = ggml_gelu_quick_inplace(ctx0, cur);
|
||||
}
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ffn_down_b);
|
||||
|
||||
// residual 2
|
||||
cur = ggml_add(ctx0, embeddings, cur);
|
||||
|
||||
embeddings = cur;
|
||||
}
|
||||
|
||||
// llava projector
|
||||
{
|
||||
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
|
||||
|
||||
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
|
||||
ggml_set_name(patches, "patches");
|
||||
ggml_set_input(patches);
|
||||
|
||||
// shape [1, 576, 1024]
|
||||
// ne is whcn, ne = [1024, 576, 1, 1]
|
||||
embeddings = ggml_get_rows(ctx0, embeddings, patches);
|
||||
|
||||
if (hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) {
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_a_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_a_b);
|
||||
|
||||
embeddings = ggml_gelu(ctx0, embeddings);
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_b_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_b_b);
|
||||
} else {
|
||||
GGML_ASSERT(false && "unsupported proj type");
|
||||
}
|
||||
}
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
ggml_free(ctx0);
|
||||
return gf;
|
||||
}
|
||||
|
||||
int32_t clip_image_encode(const clip_context & ctx, const clip_image_f32 & img, std::vector<float> & output) {
|
||||
static int32_t clip_image_batch_encode(clip_context & ctx, const clip_image_f32_batch & imgs, std::vector<float> & output) {
|
||||
int batch_size = imgs.size();
|
||||
auto & model = *ctx.model;
|
||||
auto & hparams = ctx.model->hparams;
|
||||
|
||||
if (hparams.arch == VISION_ARCH_LLAVA) {
|
||||
GGML_ASSERT(batch_size == 1); // TODO: support multiple images
|
||||
}
|
||||
|
||||
clip_image_size image_size{(int)hparams.image_size, (int)hparams.image_size};
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size));
|
||||
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: image_size = %d\n", __func__, hparams.image_size);
|
||||
LLAMA_LOG_DEBUG("%s: num_positions = %d\n", __func__, num_positions);
|
||||
|
||||
// build the inference graph
|
||||
ggml_cgraph * gf = clip_image_build_graph(ctx, batch_size, image_size);
|
||||
|
||||
// alloc memory for graph
|
||||
bool ok = ggml_backend_sched_alloc_graph(ctx.sched, gf);
|
||||
if (!ok) {
|
||||
LLAMA_LOG_ERROR("failed to alloc memory for graph\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// set raw input
|
||||
{
|
||||
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
|
||||
float * data = (float *)malloc(ggml_nbytes(inp_raw));
|
||||
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
const int nx = imgs[i].nx;
|
||||
const int ny = imgs[i].ny;
|
||||
const int n = nx * ny;
|
||||
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
for (int k = 0; k < 3; k++) {
|
||||
for (int y = 0; y < ny; y++) {
|
||||
for (int x = 0; x < nx; x++) {
|
||||
data[(b * 3 * n) + k * n + y * nx + x] = imgs[b].buf[3 * (y * nx + x) + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
|
||||
free(data);
|
||||
}
|
||||
|
||||
if (model.class_embedding) {
|
||||
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
|
||||
|
||||
void* zero_mem = malloc(ggml_nbytes(embeddings));
|
||||
memset(zero_mem, 0, ggml_nbytes(embeddings));
|
||||
ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
|
||||
free(zero_mem);
|
||||
}
|
||||
|
||||
{
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
|
||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||
for (int i = 0; i < num_positions; i++) {
|
||||
positions_data[i] = i;
|
||||
}
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
}
|
||||
|
||||
{
|
||||
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||
for (int i = 0; i < num_patches; i++) {
|
||||
patches_data[i] = i + 1;
|
||||
}
|
||||
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||
free(patches_data);
|
||||
}
|
||||
|
||||
// compute
|
||||
ggml_backend_sched_graph_compute_async(ctx.sched, gf);
|
||||
|
||||
// the last node is the embedding tensor
|
||||
struct ggml_tensor * embeddings = ggml_graph_node(gf, -1);
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(ctx.sched, embeddings);
|
||||
|
||||
// copy the embeddings to the location passed by the user
|
||||
output.resize(clip_n_embd(ctx));
|
||||
ggml_backend_tensor_get_async(backend_embd, embeddings, output.data(), 0, ggml_nbytes(embeddings));
|
||||
|
||||
ggml_backend_sched_synchronize(ctx.sched);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int32_t clip_image_encode(clip_context & ctx, const clip_image_f32 & img, std::vector<float> & output) {
|
||||
clip_image_f32_batch imgs{img};
|
||||
return clip_image_batch_encode(ctx, imgs, output);
|
||||
}
|
||||
|
||||
int32_t clip_image_batch_encode(const clip_context & ctx, const clip_image_f32_batch & imgs, std::vector<float> & output) {
|
||||
int batch_size = imgs.size();
|
||||
static int32_t encode_image_with_clip(clip_context & ctx, const llama_img img, std::vector<float> & output_embd) {
|
||||
clip_image_u8 img_u8(img);
|
||||
clip_image_f32_batch img_res_v;
|
||||
auto & hparams = ctx.model->hparams;
|
||||
|
||||
if (!clip_image_preprocess(ctx, img_u8, img_res_v)) {
|
||||
LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__);
|
||||
return -2;
|
||||
}
|
||||
|
||||
switch (hparams.mm_patch_merge_type) {
|
||||
case MM_PATCH_MERGE_FLAT:
|
||||
{
|
||||
// flat / default llava-1.5 type embedding
|
||||
// n_output = clip_n_patches(ctx);
|
||||
int32_t encoded = clip_image_encode(ctx, img_res_v[0], output_embd);
|
||||
if (encoded != 0) {
|
||||
LLAMA_LOG_ERROR("Unable to encode image\n");
|
||||
return encoded;
|
||||
}
|
||||
} break;
|
||||
case MM_PATCH_MERGE_SPATIAL_UNPAD:
|
||||
{
|
||||
// TODO: support llava-1.6
|
||||
(void)0;
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported mm_patch_merge_type");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
// public API
|
||||
|
||||
int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch) {
|
||||
if (batch->n_imgs == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// TODO: batching is not working atm, should be fixed later
|
||||
const int n_embd = clip_n_embd(ctx);
|
||||
ctx.output.resize(n_embd * batch->n_imgs);
|
||||
ctx.n_output = batch->n_imgs;
|
||||
|
||||
for (int i = 0; i < batch->n_imgs; i++) {
|
||||
std::vector<float> output_single;
|
||||
int32_t status = encode_image_with_clip(ctx, *batch->imgs[i], output_single);
|
||||
if (status != 0) {
|
||||
return status;
|
||||
}
|
||||
// copy output embeddings to result
|
||||
for (int k = 0; k < n_embd; k++) {
|
||||
ctx.output[n_embd*i + k] = output_single[k];
|
||||
// if (k<10) printf("%f\n", output_single[k]);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
// for debugging
|
||||
#ifndef NDEBUG
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
// export clip_image_u8 to bmp file for debugging
|
||||
// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c
|
||||
inline int bmp_export(const clip_image_u8 &img, const std::string &location) {
|
||||
static int bmp_export(const clip_image_u8 &img, const std::string &location) {
|
||||
const uint32_t width = img.nx;
|
||||
const uint32_t height = img.ny;
|
||||
const std::vector<uint8_t> &buffer = img.buf;
|
||||
|
@ -445,7 +772,7 @@ inline int bmp_export(const clip_image_u8 &img, const std::string &location) {
|
|||
const uint32_t blueBitmask = (hasAlphaChannel) ? 0xFF000000 : 0;
|
||||
const uint32_t alphaBitmask = (hasAlphaChannel) ? 0x000000FF : 0;
|
||||
|
||||
//Writing the file header and information header to the file
|
||||
//Writing the file header and information header to the file
|
||||
std::vector<uint8_t> header(offset, 0);
|
||||
header[0] = signature[0];
|
||||
header[1] = signature[1];
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "ggml.h"
|
||||
|
||||
#include <vector>
|
||||
#include <array>
|
||||
|
||||
enum vision_arch {
|
||||
VISION_ARCH_LLAVA,
|
||||
|
@ -29,6 +30,7 @@ struct clip_hparams {
|
|||
uint32_t n_head;
|
||||
uint32_t n_layer;
|
||||
uint32_t max_pos_embd;
|
||||
bool use_gelu = false;
|
||||
|
||||
float eps;
|
||||
|
||||
|
@ -44,50 +46,50 @@ struct clip_hparams {
|
|||
|
||||
struct clip_layer {
|
||||
// attention
|
||||
struct ggml_tensor * k_w;
|
||||
struct ggml_tensor * k_b;
|
||||
struct ggml_tensor * q_w;
|
||||
struct ggml_tensor * q_b;
|
||||
struct ggml_tensor * v_w;
|
||||
struct ggml_tensor * v_b;
|
||||
struct ggml_tensor * k_w = NULL;
|
||||
struct ggml_tensor * k_b = NULL;
|
||||
struct ggml_tensor * q_w = NULL;
|
||||
struct ggml_tensor * q_b = NULL;
|
||||
struct ggml_tensor * v_w = NULL;
|
||||
struct ggml_tensor * v_b = NULL;
|
||||
|
||||
struct ggml_tensor * output_w;
|
||||
struct ggml_tensor * output_b;
|
||||
struct ggml_tensor * output_w = NULL;
|
||||
struct ggml_tensor * output_b = NULL;
|
||||
|
||||
// layernorm 1
|
||||
struct ggml_tensor * norm_in_w;
|
||||
struct ggml_tensor * norm_in_b;
|
||||
struct ggml_tensor * norm_in_w = NULL;
|
||||
struct ggml_tensor * norm_in_b = NULL;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * ffn_up_w;
|
||||
struct ggml_tensor * ffn_up_b;
|
||||
struct ggml_tensor * ffn_up_w = NULL;
|
||||
struct ggml_tensor * ffn_up_b = NULL;
|
||||
|
||||
struct ggml_tensor * ffn_down_w;
|
||||
struct ggml_tensor * ffn_down_b;
|
||||
struct ggml_tensor * ffn_down_w = NULL;
|
||||
struct ggml_tensor * ffn_down_b = NULL;
|
||||
|
||||
// layernorm 2
|
||||
struct ggml_tensor * norm_out_w;
|
||||
struct ggml_tensor * norm_out_b;
|
||||
struct ggml_tensor * norm_out_w = NULL;
|
||||
struct ggml_tensor * norm_out_b = NULL;
|
||||
};
|
||||
|
||||
struct clip_vision_model {
|
||||
struct clip_hparams hparams;
|
||||
|
||||
// embeddings
|
||||
struct ggml_tensor * class_embedding;
|
||||
struct ggml_tensor * patch_embeddings;
|
||||
struct ggml_tensor * patch_bias;
|
||||
struct ggml_tensor * position_embeddings;
|
||||
struct ggml_tensor * class_embedding = NULL;
|
||||
struct ggml_tensor * patch_embeddings = NULL;
|
||||
struct ggml_tensor * patch_bias = NULL;
|
||||
struct ggml_tensor * position_embeddings = NULL;
|
||||
|
||||
struct ggml_tensor * pre_norm_w;
|
||||
struct ggml_tensor * pre_norm_b;
|
||||
struct ggml_tensor * pre_norm_w = NULL;
|
||||
struct ggml_tensor * pre_norm_b = NULL;
|
||||
|
||||
std::vector<clip_layer> layers;
|
||||
|
||||
struct ggml_tensor * post_norm_w;
|
||||
struct ggml_tensor * post_norm_b;
|
||||
struct ggml_tensor * post_norm_w = NULL;
|
||||
struct ggml_tensor * post_norm_b = NULL;
|
||||
|
||||
struct ggml_tensor * projection;
|
||||
struct ggml_tensor * projection = NULL;
|
||||
|
||||
// LLaVA projection
|
||||
struct ggml_tensor * mm_a_w = NULL;
|
||||
|
@ -99,9 +101,15 @@ struct clip_vision_model {
|
|||
};
|
||||
|
||||
struct clip_context {
|
||||
struct ggml_context * ctx_ggml;
|
||||
clip_vision_model model;
|
||||
// memory buffers used to evaluate the model
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
ggml_backend_sched_t sched = nullptr;
|
||||
|
||||
int32_t n_output;
|
||||
float * output;
|
||||
const clip_vision_model * model;
|
||||
|
||||
// temporary output data
|
||||
int n_output;
|
||||
std::vector<float> output; // size == n_output * n_embd
|
||||
};
|
||||
|
||||
int32_t llama_vision_encode_internal(clip_context & ctx, llama_img_batch * batch);
|
||||
|
|
|
@ -2552,9 +2552,6 @@ struct llama_hparams {
|
|||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||
|
||||
bool has_vision = false;
|
||||
clip_hparams clip;
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
if (this->vocab_only != other.vocab_only) return true;
|
||||
if (this->n_vocab != other.n_vocab) return true;
|
||||
|
@ -3005,6 +3002,7 @@ struct llama_model {
|
|||
|
||||
std::vector<llama_layer> layers;
|
||||
|
||||
bool has_vision = false;
|
||||
clip_vision_model clip;
|
||||
|
||||
llama_split_mode split_mode;
|
||||
|
@ -3502,6 +3500,9 @@ struct llama_context {
|
|||
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
|
||||
// vision
|
||||
clip_context clip;
|
||||
};
|
||||
|
||||
struct llama_lora_weight {
|
||||
|
@ -6223,23 +6224,24 @@ static void llm_load_hparams(
|
|||
}
|
||||
|
||||
// vision model
|
||||
auto & vparams = model.clip.hparams;
|
||||
std::string vision_type;
|
||||
ml.get_key(LLM_KV_VISION_TYPE, vision_type, false);
|
||||
if (vision_type == "clip") {
|
||||
hparams.has_vision = true;
|
||||
model.has_vision = true;
|
||||
std::string proj_type;
|
||||
ml.get_key(LLM_KV_VISION_IMAGE_SIZE, hparams.clip.image_size, true);
|
||||
ml.get_key(LLM_KV_VISION_PATCH_SIZE, hparams.clip.patch_size, true);
|
||||
ml.get_key_or_arr(LLM_KV_VISION_IMAGE_MEAN, hparams.clip.image_mean, 3, true);
|
||||
ml.get_key_or_arr(LLM_KV_VISION_IMAGE_STD, hparams.clip.image_std, 3, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, hparams.clip.hidden_size, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, hparams.clip.n_layer, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, hparams.clip.n_intermediate, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, hparams.clip.n_head, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, hparams.clip.eps, true);
|
||||
ml.get_key(LLM_KV_VISION_IMAGE_SIZE, vparams.image_size, true);
|
||||
ml.get_key(LLM_KV_VISION_PATCH_SIZE, vparams.patch_size, true);
|
||||
ml.get_key_or_arr(LLM_KV_VISION_IMAGE_MEAN, vparams.image_mean, 3, true);
|
||||
ml.get_key_or_arr(LLM_KV_VISION_IMAGE_STD, vparams.image_std, 3, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, vparams.hidden_size, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, vparams.n_layer, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, vparams.n_intermediate, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, vparams.n_head, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, vparams.eps, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, proj_type, true);
|
||||
if (proj_type == "mlp") {
|
||||
hparams.clip.proj_type = CLIP_PROJECTOR_TYPE_MLP;
|
||||
vparams.proj_type = CLIP_PROJECTOR_TYPE_MLP;
|
||||
} else {
|
||||
throw std::runtime_error(format("unsupported clip projector type: %s", proj_type.c_str()));
|
||||
}
|
||||
|
@ -6247,7 +6249,7 @@ static void llm_load_hparams(
|
|||
ml.get_key(LLM_KV_VISION_CLIP_ARCHITECTURE, arch, true);
|
||||
for (auto & it : VISION_ARCH_NAMES) {
|
||||
if (arch == it.second) {
|
||||
hparams.clip.arch = it.first;
|
||||
vparams.arch = it.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -6256,10 +6258,10 @@ static void llm_load_hparams(
|
|||
}
|
||||
|
||||
// arch-specific CLIP hparams
|
||||
switch (hparams.clip.arch) {
|
||||
switch (vparams.arch) {
|
||||
case VISION_ARCH_LLAVA:
|
||||
{
|
||||
ml.get_key(LLM_KV_VISION_CLIP_MAX_POS_EMBD, hparams.clip.max_pos_embd, true);
|
||||
ml.get_key(LLM_KV_VISION_CLIP_MAX_POS_EMBD, vparams.max_pos_embd, true);
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
@ -8957,21 +8959,22 @@ static bool llm_load_tensors(
|
|||
}
|
||||
|
||||
// load tensors for vision model
|
||||
if (hparams.has_vision) {
|
||||
const int64_t n_layer = hparams.clip.n_layer;
|
||||
const int64_t n_embd = hparams.clip.hidden_size;
|
||||
const int64_t n_ff = hparams.clip.n_intermediate;
|
||||
const int64_t max_pos_embd = hparams.clip.max_pos_embd;
|
||||
auto & vparams = model.clip.hparams;
|
||||
if (model.has_vision) {
|
||||
const int64_t n_layer = vparams.n_layer;
|
||||
const int64_t n_embd = vparams.hidden_size;
|
||||
const int64_t n_ff = vparams.n_intermediate;
|
||||
const int64_t max_pos_embd = vparams.max_pos_embd;
|
||||
const int64_t n_channel = 3; // always RGB
|
||||
const int64_t patch_size = hparams.clip.patch_size;
|
||||
const auto tn = VISION_TN(hparams.clip.arch);
|
||||
const int64_t patch_size = vparams.patch_size;
|
||||
const auto tn = VISION_TN(vparams.arch);
|
||||
|
||||
ggml_context * ctx_vision = ctx_map.at(model.buft_input.buft); // TODO: make dedicated buft for vision
|
||||
auto ctx_for_layer = [&](int i) { return ctx_map.at(model.buft_layer[i].buft); };
|
||||
|
||||
model.clip.layers.resize(n_layer);
|
||||
|
||||
switch (hparams.clip.arch) {
|
||||
switch (vparams.arch) {
|
||||
case VISION_ARCH_LLAVA:
|
||||
{
|
||||
model.clip.mm_a_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_A, "weight"), {n_embd, n_ff});
|
||||
|
@ -19637,6 +19640,14 @@ struct llama_context * llama_new_context_with_model(
|
|||
}
|
||||
}
|
||||
|
||||
// initialize vision context
|
||||
if (model->has_vision) {
|
||||
ctx->clip.model = &model->clip;
|
||||
ctx->clip.sched = ctx->sched;
|
||||
const size_t max_nodes = llama_model_max_nodes(*model);
|
||||
ctx->clip.buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
@ -21780,6 +21791,30 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod
|
|||
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
|
||||
}
|
||||
|
||||
//
|
||||
// vision
|
||||
//
|
||||
|
||||
llama_img * llama_img_alloc(int width, int height) {
|
||||
llama_img * img = new llama_img();
|
||||
img->nx = width;
|
||||
img->ny = height;
|
||||
img->data = (unsigned char *)malloc(width*height*3);
|
||||
return img;
|
||||
}
|
||||
void llama_img_free(llama_img * img) {
|
||||
free(img->data);
|
||||
delete img;
|
||||
}
|
||||
|
||||
int32_t llama_vision_encode(struct llama_context * ctx, llama_img_batch * batch) {
|
||||
return llama_vision_encode_internal(ctx->clip, batch);
|
||||
}
|
||||
|
||||
float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx) {
|
||||
return ctx->clip.output.data();
|
||||
}
|
||||
|
||||
//
|
||||
// model split
|
||||
//
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue