add external llava API
This commit is contained in:
parent
0209d39526
commit
3c10d9f3de
2 changed files with 106 additions and 46 deletions
|
@ -2,6 +2,7 @@
|
||||||
#include "llava-utils.h"
|
#include "llava-utils.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "llava.h"
|
||||||
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
@ -34,27 +35,13 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
struct llava_context * llava_init(gpt_params * params) {
|
||||||
ggml_time_init();
|
|
||||||
|
|
||||||
gpt_params params;
|
const char * clip_path = params->mmproj.c_str();
|
||||||
|
const char * img_path = params->image.c_str();
|
||||||
|
|
||||||
if (!gpt_params_parse(argc, argv, params)) {
|
if (params->prompt.empty()) {
|
||||||
show_additional_info(argc, argv);
|
params->prompt = "describe the image in detail.";
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.mmproj.empty() || params.image.empty()) {
|
|
||||||
gpt_print_usage(argc, argv, params);
|
|
||||||
show_additional_info(argc, argv);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char * clip_path = params.mmproj.c_str();
|
|
||||||
const char * img_path = params.image.c_str();
|
|
||||||
|
|
||||||
if (params.prompt.empty()) {
|
|
||||||
params.prompt = "describe the image in detail.";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
||||||
|
@ -65,47 +52,48 @@ int main(int argc, char ** argv) {
|
||||||
if (!clip_image_load_from_file(img_path, &img)) {
|
if (!clip_image_load_from_file(img_path, &img)) {
|
||||||
fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
|
fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
|
||||||
clip_free(ctx_clip);
|
clip_free(ctx_clip);
|
||||||
return 1;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
||||||
if (!image_embd) {
|
if (!image_embd) {
|
||||||
fprintf(stderr, "Unable to allocate memory for image embeddings\n");
|
fprintf(stderr, "Unable to allocate memory for image embeddings\n");
|
||||||
return 1;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_img_embd;
|
int n_img_embd;
|
||||||
int n_img_pos;
|
int n_img_pos;
|
||||||
float t_img_enc_ms;
|
float t_img_enc_ms;
|
||||||
if (!encode_image_with_clip(ctx_clip, params.n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) {
|
if (!encode_image_with_clip(ctx_clip, params->n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) {
|
||||||
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
|
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
|
||||||
clip_free(ctx_clip);
|
clip_free(ctx_clip);
|
||||||
return 1;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
// we get the embeddings, free up the memory required for CLIP
|
// we get the embeddings, free up the memory required for CLIP
|
||||||
clip_free(ctx_clip);
|
clip_free(ctx_clip);
|
||||||
|
ctx_clip = NULL;
|
||||||
|
|
||||||
llama_backend_init(params.numa);
|
llama_backend_init(params->numa);
|
||||||
|
|
||||||
llama_model_params model_params = llama_model_default_params();
|
llama_model_params model_params = llama_model_default_params();
|
||||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
|
llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params);
|
||||||
if (model == NULL) {
|
if (model == NULL) {
|
||||||
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
||||||
return 1;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_context_params ctx_params = llama_context_default_params();
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
|
|
||||||
ctx_params.n_ctx = params.n_ctx < 2048 ? 2048 : params.n_ctx; // we need a longer context size to process image embeddings
|
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
|
||||||
ctx_params.n_threads = params.n_threads;
|
ctx_params.n_threads = params->n_threads;
|
||||||
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
ctx_params.n_threads_batch = params->n_threads_batch == -1 ? params->n_threads : params->n_threads_batch;
|
||||||
|
|
||||||
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
|
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
|
||||||
|
|
||||||
if (ctx_llama == NULL) {
|
if (ctx_llama == NULL) {
|
||||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||||
return 1;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
// make sure that the correct mmproj was used, i.e., compare apples to apples
|
// make sure that the correct mmproj was used, i.e., compare apples to apples
|
||||||
|
@ -118,28 +106,49 @@ int main(int argc, char ** argv) {
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
free(image_embd);
|
free(image_embd);
|
||||||
|
|
||||||
return 1;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
// process the prompt
|
{
|
||||||
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
|
printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / n_img_pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
|
||||||
|
|
||||||
|
ctx_llava->ctx_llama = ctx_llama;
|
||||||
|
ctx_llava->ctx_clip = ctx_clip;
|
||||||
|
ctx_llava->model = model;
|
||||||
|
ctx_llava->image_embd = image_embd;
|
||||||
|
ctx_llava->n_img_pos = n_img_pos;
|
||||||
|
return ctx_llava;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void llava_free(struct llava_context * ctx_llava) {
|
||||||
|
llama_free(ctx_llava->ctx_llama);
|
||||||
|
llama_free_model(ctx_llava->model);
|
||||||
|
llama_backend_free();
|
||||||
|
free(ctx_llava->image_embd);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llava_process_prompt(struct llava_context * ctx_llava, gpt_params * params, const char * prompt) {
|
||||||
int n_past = 0;
|
int n_past = 0;
|
||||||
|
|
||||||
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
|
||||||
|
|
||||||
// GG: are we sure that the should be a trailing whitespace at the end of this string?
|
// GG: are we sure that the should be a trailing whitespace at the end of this string?
|
||||||
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past);
|
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past);
|
||||||
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
|
eval_image_embd(ctx_llava->ctx_llama, ctx_llava->image_embd, ctx_llava->n_img_pos, params->n_batch, &n_past);
|
||||||
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
|
eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past);
|
||||||
eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past);
|
eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past);
|
||||||
|
|
||||||
// generate the response
|
// generate the response
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
for (int i = 0; i < max_tgt_len; i++) {
|
for (int i = 0; i < max_tgt_len; i++) {
|
||||||
const char * tmp = sample(ctx_llama, params, &n_past);
|
const char * tmp = sample(ctx_llava->ctx_llama, *params, &n_past);
|
||||||
if (strcmp(tmp, "</s>") == 0) break;
|
if (strcmp(tmp, "</s>") == 0) break;
|
||||||
|
|
||||||
printf("%s", tmp);
|
printf("%s", tmp);
|
||||||
|
@ -148,16 +157,36 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
{
|
}
|
||||||
printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / n_img_pos);
|
|
||||||
|
|
||||||
|
int main(int argc, char ** argv) {
|
||||||
|
ggml_time_init();
|
||||||
|
|
||||||
|
gpt_params params;
|
||||||
|
|
||||||
|
if (!gpt_params_parse(argc, argv, params)) {
|
||||||
|
show_additional_info(argc, argv);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
if (params.mmproj.empty() || params.image.empty()) {
|
||||||
|
gpt_print_usage(argc, argv, params);
|
||||||
|
show_additional_info(argc, argv);
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx_llama);
|
auto ctx_llava = llava_init(¶ms);
|
||||||
|
if (ctx_llava == NULL) {
|
||||||
|
fprintf(stderr, "%s: error: failed to init llava\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
llama_free(ctx_llama);
|
// process the prompt
|
||||||
llama_free_model(model);
|
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
|
||||||
llama_backend_free();
|
llava_process_prompt(ctx_llava, ¶ms, params.prompt.c_str());
|
||||||
free(image_embd);
|
|
||||||
|
|
||||||
|
llama_print_timings(ctx_llava->ctx_llama);
|
||||||
|
|
||||||
|
llava_free(ctx_llava);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
31
examples/llava/llava.h
Normal file
31
examples/llava/llava.h
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
#ifndef LLAVA_H
|
||||||
|
#define LLAVA_H
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
struct clip_ctx;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct llava_context {
|
||||||
|
struct clip_ctx * ctx_clip = NULL;
|
||||||
|
struct llama_context * ctx_llama = NULL;
|
||||||
|
struct llama_model * model = NULL;
|
||||||
|
|
||||||
|
int n_img_pos = 0;
|
||||||
|
float * image_embd = NULL;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llava_context * llava_init(gpt_params * params);
|
||||||
|
void llava_free(struct llava_context * ctx_llava);
|
||||||
|
|
||||||
|
void llava_process_prompt(struct llava_context * ctx_llava, gpt_params * params, const char * prompt);
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
Loading…
Add table
Add a link
Reference in a new issue