separate vision ctx and llm ctx
This commit is contained in:
parent
ff77b15845
commit
fa55281759
7 changed files with 139 additions and 35 deletions
|
@ -120,6 +120,14 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
llama_vision_context_params vparams = llama_vision_context_default_params();
|
||||
vparams.n_threads = llama_n_threads(ctx);
|
||||
llama_vision_context * vctx = llama_vision_init_from_model(model, vparams);
|
||||
if (!vctx) {
|
||||
LOG_ERR("model does not have vision encoder\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
|
||||
|
||||
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||
|
@ -136,12 +144,12 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
llama_vision_bitmap * img = load_image_from_file(img_path);
|
||||
LOG_INF("loaded image %s, size = %d x %d\n", img_path, img->nx, img->ny);
|
||||
img_tokens = llama_vision_tokenize(ctx, img);
|
||||
img_tokens = llama_vision_tokenize(vctx, img);
|
||||
if (!img_tokens) {
|
||||
LOG_ERR("failed to create image tokens\n");
|
||||
return 1;
|
||||
}
|
||||
if (llama_vision_encode(ctx, img_tokens)) {
|
||||
if (llama_vision_encode(vctx, img_tokens)) {
|
||||
LOG_ERR("failed to encode image\n");
|
||||
return 1;
|
||||
}
|
||||
|
@ -163,7 +171,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
} else {
|
||||
auto * img_embd = llama_vision_get_output_tensor(ctx);
|
||||
auto * img_embd = llama_vision_get_output_tensor(vctx);
|
||||
// std::vector<float> output_debug(ggml_nelements(img_embd));
|
||||
// ggml_backend_tensor_get(img_embd, output_debug.data(), 0, ggml_nbytes(img_embd));
|
||||
// for (int row = 0; row < 10; row++) {
|
||||
|
|
|
@ -229,6 +229,8 @@ extern "C" {
|
|||
bool sorted;
|
||||
} llama_token_data_array;
|
||||
|
||||
struct llama_vision_context;
|
||||
|
||||
// Structure represents the basic input unit of vision model
|
||||
// This can be a processed image or slices of images under the hood
|
||||
struct llama_vision_tokens;
|
||||
|
@ -365,6 +367,10 @@ extern "C" {
|
|||
void * abort_callback_data;
|
||||
};
|
||||
|
||||
struct llama_vision_context_params {
|
||||
int32_t n_threads;
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
typedef struct llama_model_quantize_params {
|
||||
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||
|
@ -402,6 +408,7 @@ extern "C" {
|
|||
// TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
|
||||
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||
LLAMA_API struct llama_vision_context_params llama_vision_context_default_params(void);
|
||||
LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
|
||||
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
||||
|
||||
|
@ -1297,20 +1304,30 @@ extern "C" {
|
|||
// Vision API
|
||||
//
|
||||
|
||||
// Vision context
|
||||
LLAMA_API struct llama_vision_context * llama_vision_init_from_model(
|
||||
const struct llama_model * model,
|
||||
struct llama_vision_context_params params);
|
||||
LLAMA_API void llama_vision_free(struct llama_vision_context * ctx);
|
||||
|
||||
// Container for RGB bitmap
|
||||
LLAMA_API struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny);
|
||||
LLAMA_API void llama_vision_bitmap_free(struct llama_vision_bitmap * bmp);
|
||||
|
||||
// Create image tokens from the RGB bitmap
|
||||
LLAMA_API struct llama_vision_tokens * llama_vision_tokenize(struct llama_context * ctx, llama_vision_bitmap * bmp);
|
||||
LLAMA_API struct llama_vision_tokens * llama_vision_tokenize(
|
||||
struct llama_vision_context * ctx,
|
||||
struct llama_vision_bitmap * bmp);
|
||||
LLAMA_API void llama_vision_tokens_free(struct llama_vision_tokens * img_tokens);
|
||||
|
||||
// User must reserve N number of tokens in tokenized text prompt for each image
|
||||
// LLAMA_API int32_t llama_vision_get_n_tokens(const llama_vision_img_tokens * img_tokens);
|
||||
|
||||
// Encode patches into embeddings
|
||||
LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, struct llama_vision_tokens * img_tokens);
|
||||
LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(struct llama_context * ctx);
|
||||
LLAMA_API int32_t llama_vision_encode(
|
||||
struct llama_vision_context * ctx,
|
||||
struct llama_vision_tokens * img_tokens);
|
||||
LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(struct llama_vision_context * ctx);
|
||||
|
||||
//
|
||||
// Model split
|
||||
|
|
|
@ -1576,8 +1576,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||
{LLM_TENSOR_V_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_V_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_V_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_V_PRE_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_V_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_V_PRE_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_V_POST_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_V_RESMPL_POS_EMBD_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_V_RESMPL_ATTN_Q, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_V_RESMPL_ATTN_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
|
||||
|
|
|
@ -108,9 +108,6 @@ struct llama_context {
|
|||
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
|
||||
// vision
|
||||
llama_vision_context vctx;
|
||||
};
|
||||
|
||||
// TODO: make these methods of llama_context
|
||||
|
|
|
@ -982,7 +982,7 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
|
|||
}
|
||||
|
||||
// alloc memory for graph
|
||||
bool ok = ggml_backend_sched_alloc_graph(ctx.sched, gf);
|
||||
bool ok = ggml_backend_sched_alloc_graph(ctx.sched.get(), gf);
|
||||
if (!ok) {
|
||||
LLAMA_LOG_ERROR("failed to alloc memory for graph\n");
|
||||
return -1;
|
||||
|
@ -1064,7 +1064,7 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
|
|||
// compute
|
||||
LLAMA_LOG_DEBUG("%s: compute start\n", __func__);
|
||||
int64_t t_start = ggml_time_ms();
|
||||
ggml_backend_sched_graph_compute(ctx.sched, gf);
|
||||
ggml_backend_sched_graph_compute(ctx.sched.get(), gf);
|
||||
|
||||
// the last node is the embedding tensor
|
||||
struct ggml_tensor * output_node = ggml_graph_node(gf, -1);
|
||||
|
@ -1091,6 +1091,92 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
|
|||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
// public API
|
||||
|
||||
struct llama_vision_context_params llama_vision_context_default_params() {
|
||||
return {
|
||||
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
|
||||
};
|
||||
}
|
||||
|
||||
struct llama_vision_context * llama_vision_init_from_model(const struct llama_model * model, struct llama_vision_context_params params) {
|
||||
if (!model->has_vision) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
llama_vision_context * ctx = new llama_vision_context;
|
||||
ctx->model = &model->vit;
|
||||
|
||||
// TODO: this looks ugly, mostly copied from llama.cpp, refactor it in the future
|
||||
|
||||
// init backends
|
||||
{
|
||||
// add CPU backend
|
||||
ctx->backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
if (ctx->backend_cpu == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
|
||||
llama_vision_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ctx->backends.emplace_back(ctx->backend_cpu);
|
||||
|
||||
// create a list of the set_n_threads functions in the backends
|
||||
for (auto & backend : ctx->backends) {
|
||||
ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
|
||||
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
||||
if (reg) {
|
||||
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
||||
ggml_backend_set_n_threads_fn(backend.get(), params.n_threads);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// scheduler and compute buffers
|
||||
{
|
||||
// buffer types used for the compute buffer of each backend
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
std::vector<ggml_backend_t> backend_ptrs;
|
||||
for (auto & backend : ctx->backends) {
|
||||
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
||||
auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
||||
if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model->devices.empty()) {
|
||||
// use the host buffer of the first device CPU for faster transfer of the intermediate state
|
||||
auto * dev = model->devices[0];
|
||||
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
|
||||
if (host_buft) {
|
||||
buft = host_buft;
|
||||
}
|
||||
}
|
||||
backend_buft.push_back(buft);
|
||||
backend_ptrs.push_back(backend.get());
|
||||
}
|
||||
|
||||
const size_t max_nodes = model->max_nodes();
|
||||
|
||||
// buffer used to store the computation graph and the tensor meta data
|
||||
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
||||
|
||||
// TODO: support pipeline_parallel
|
||||
const bool pipeline_parallel = false;
|
||||
|
||||
ctx->sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
|
||||
|
||||
if (pipeline_parallel) {
|
||||
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched.get()));
|
||||
}
|
||||
}
|
||||
|
||||
const size_t max_nodes = VISION_GRAPH_MAX_NODE; // TODO: make it dynamic
|
||||
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void llama_vision_free(struct llama_vision_context * ctx) {
|
||||
if (ctx->ctx_ggml) {
|
||||
ggml_free(ctx->ctx_ggml);
|
||||
}
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny) {
|
||||
llama_vision_bitmap * bmp = new llama_vision_bitmap;
|
||||
bmp->nx = nx;
|
||||
|
@ -1105,16 +1191,15 @@ void llama_vision_bitmap_free(llama_vision_bitmap * bmp) {
|
|||
}
|
||||
|
||||
struct llama_vision_tokens * llama_vision_tokenize(
|
||||
struct llama_context * ctx,
|
||||
llama_vision_bitmap * bmp) {
|
||||
llama_vision_context & vctx = ctx->vctx;
|
||||
switch (vctx.model->hparams.arch) {
|
||||
struct llama_vision_context * ctx,
|
||||
struct llama_vision_bitmap * bmp) {
|
||||
switch (ctx->model->hparams.arch) {
|
||||
case LLM_ARCH_VISION_LLAVA:
|
||||
case LLM_ARCH_VISION_MOBILEVLM:
|
||||
case LLM_ARCH_VISION_IDEFICS3:
|
||||
return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp));
|
||||
return new llama_vision_tokens(llama_vision_processor_llava(*ctx).tokenize(*bmp));
|
||||
case LLM_ARCH_VISION_MINICPMV:
|
||||
return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp));
|
||||
return new llama_vision_tokens(llama_vision_processor_llava(*ctx).tokenize(*bmp));
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported arch");
|
||||
}
|
||||
|
@ -1124,19 +1209,18 @@ void llama_vision_tokens_free(llama_vision_tokens * p) {
|
|||
delete p;
|
||||
}
|
||||
|
||||
int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_tokens * p) {
|
||||
int32_t llama_vision_encode(struct llama_vision_context * ctx, struct llama_vision_tokens * p) {
|
||||
if (p->buf.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: nothing to encode\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
llama_vision_context & vctx = ctx->vctx;
|
||||
auto & hparams = vctx.model->hparams;
|
||||
auto & hparams = ctx->model->hparams;
|
||||
switch (hparams.mm_patch_merge_type) {
|
||||
case MM_PATCH_MERGE_FLAT:
|
||||
{
|
||||
// flat / default llava-1.5 type embedding
|
||||
int32_t encoded = llama_vision_encode_impl(vctx, *p);
|
||||
int32_t encoded = llama_vision_encode_impl(*ctx, *p);
|
||||
if (encoded != 0) {
|
||||
LLAMA_LOG_ERROR("Unable to encode image\n");
|
||||
return encoded;
|
||||
|
@ -1154,8 +1238,8 @@ int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_tokens * p)
|
|||
return 0;
|
||||
}
|
||||
|
||||
struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx) {
|
||||
return ctx->vctx.output;
|
||||
struct ggml_tensor * llama_vision_get_output_tensor(struct llama_vision_context * ctx) {
|
||||
return ctx->output;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpp.h"
|
||||
#include "llama.h"
|
||||
#include "llama-arch.h"
|
||||
|
||||
|
@ -142,12 +143,14 @@ struct llama_vision_model {
|
|||
struct llama_vision_context {
|
||||
// memory buffers used to evaluate the model
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
ggml_backend_sched_t sched = nullptr;
|
||||
struct ggml_context * ctx_ggml = nullptr;
|
||||
ggml_backend_sched_ptr sched;
|
||||
std::vector<ggml_backend_ptr> backends;
|
||||
ggml_backend_t backend_cpu;
|
||||
|
||||
const llama_vision_model * model;
|
||||
|
||||
// temporary output data, to be picked up by llama_decode()
|
||||
struct ggml_context * ctx_ggml = nullptr;
|
||||
struct ggml_tensor * output;
|
||||
};
|
||||
|
||||
|
|
|
@ -8460,7 +8460,9 @@ static int llama_prepare_sbatch(
|
|||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
GGML_ASSERT((batch.token && !batch.embd && !batch.embd_tensor)
|
||||
|| (!batch.token && batch.embd && !batch.embd_tensor)
|
||||
|| (!batch.token && !batch.embd && batch.embd_tensor)); // NOLINT
|
||||
if (batch.token) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
|
||||
|
@ -9893,13 +9895,6 @@ struct llama_context * llama_init_from_model(
|
|||
}
|
||||
}
|
||||
|
||||
if (model->has_vision) {
|
||||
ctx->vctx.model = &model->vit;
|
||||
ctx->vctx.sched = ctx->sched.get();
|
||||
const size_t max_nodes = VISION_GRAPH_MAX_NODE; // TODO: make it dynamic
|
||||
ctx->vctx.buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue