fix bug where base64 string was not removed from the prompt

This commit is contained in:
Damian Stewart 2023-10-14 18:24:55 +02:00
parent f21af512cd
commit 708928c649
4 changed files with 86 additions and 82 deletions

View file

@ -7,6 +7,13 @@
#include "llava.h" #include "llava.h"
#include "llava-utils.h" #include "llava-utils.h"
struct llava_context {
struct clip_ctx * ctx_clip = NULL;
struct llama_context * ctx_llama = NULL;
struct llama_model * model = NULL;
};
static void show_additional_info(int /*argc*/, char ** argv) { static void show_additional_info(int /*argc*/, char ** argv) {
printf("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); printf("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
@ -25,7 +32,7 @@ static bool load_image(llava_context * ctx_llava, gpt_params * params, float **i
fprintf(stderr, "%s: can't load image from prompt\n", __func__); fprintf(stderr, "%s: can't load image from prompt\n", __func__);
return false; return false;
} }
prompt = remove_image_from_prompt(prompt); params->prompt = remove_image_from_prompt(prompt);
} else { } else {
if (!clip_image_load_from_file(params->image.c_str(), &img)) { if (!clip_image_load_from_file(params->image.c_str(), &img)) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str()); fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str());
@ -49,8 +56,7 @@ static void process_prompt(struct llava_context * ctx_llava, float * image_embd,
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:" // llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
// GG: are we sure that the should be a trailing whitespace at the end of this string? // GG: are we sure that the should be a trailing whitespace at the end of this string?
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past); eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past);
printf("embedding image, n_img_pos is %d\n", n_img_pos); llava_eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past);
eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past); eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past); eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past);
@ -70,6 +76,61 @@ static void process_prompt(struct llava_context * ctx_llava, float * image_embd,
} }
static struct llava_context * llava_init(gpt_params * params) {
const char * clip_path = params->mmproj.c_str();
auto prompt = params->prompt;
if (prompt.empty()) {
prompt = "describe the image in detail.";
}
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
llama_backend_init(params->numa);
llama_model_params model_params = llama_model_default_params();
llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return NULL;
}
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
ctx_params.n_threads = params->n_threads;
ctx_params.n_threads_batch = params->n_threads_batch == -1 ? params->n_threads : params->n_threads_batch;
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
if (ctx_llama == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return NULL;
}
auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
ctx_llava->ctx_llama = ctx_llama;
ctx_llava->ctx_clip = ctx_clip;
ctx_llava->model = model;
return ctx_llava;
}
static void llava_free(struct llava_context * ctx_llava) {
if (ctx_llava->ctx_clip) {
clip_free(ctx_llava->ctx_clip);
ctx_llava->ctx_clip = NULL;
}
llama_free(ctx_llava->ctx_llama);
llama_free_model(ctx_llava->model);
llama_backend_free();
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_time_init(); ggml_time_init();

View file

@ -11,24 +11,6 @@
#include <cstdlib> #include <cstdlib>
#include <vector> #include <vector>
inline bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
for (int i = 0; i < N; i += n_batch) {
int n_eval = N - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
*n_past += n_eval;
}
return true;
}
inline bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) { inline bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) {
int N = (int) tokens.size(); int N = (int) tokens.size();
for (int i = 0; i < N; i += n_batch) { for (int i = 0; i < N; i += n_batch) {
@ -37,7 +19,7 @@ inline bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
n_eval = n_batch; n_eval = n_batch;
} }
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false; return false;
} }
*n_past += n_eval; *n_past += n_eval;
@ -194,6 +176,6 @@ inline std::string remove_image_from_prompt(const std::string& prompt, const cha
return prompt; return prompt;
} }
auto pre = prompt.substr(0, begin); auto pre = prompt.substr(0, begin);
auto post = prompt.substr(end+1); auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END));
return pre + replacement + post; return pre + replacement + post;
} }

View file

@ -68,56 +68,21 @@ bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip,
} }
struct llava_context * llava_init(gpt_params * params) {
const char * clip_path = params->mmproj.c_str(); bool llava_eval_image_embd(llama_context * ctx_llama, float * image_embd, int n_image_pos, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
auto prompt = params->prompt; for (int i = 0; i < n_image_pos; i += n_batch) {
if (prompt.empty()) { int n_eval = n_image_pos - i;
prompt = "describe the image in detail."; if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (image_embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
*n_past += n_eval;
} }
return true;
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
llama_backend_init(params->numa);
llama_model_params model_params = llama_model_default_params();
llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return NULL;
}
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
ctx_params.n_threads = params->n_threads;
ctx_params.n_threads_batch = params->n_threads_batch == -1 ? params->n_threads : params->n_threads_batch;
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
if (ctx_llama == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return NULL;
}
auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
ctx_llava->ctx_llama = ctx_llama;
ctx_llava->ctx_clip = ctx_clip;
ctx_llava->model = model;
return ctx_llava;
} }
void llava_free(struct llava_context * ctx_llava) {
if (ctx_llava->ctx_clip) {
clip_free(ctx_llava->ctx_clip);
ctx_llava->ctx_clip = NULL;
}
llama_free(ctx_llava->ctx_llama);
llama_free_model(ctx_llava->model);
llama_backend_free();
}

View file

@ -10,18 +10,14 @@ struct clip_ctx;
extern "C" { extern "C" {
#endif #endif
struct llava_context { /** using ctx_clip, build a llava image embedding from the passed-in image `img` (see clip.h for methods to load img).
struct clip_ctx * ctx_clip = NULL; * result is returned as image_embd_out, size n_image_pos_out */
struct llama_context * ctx_llama = NULL;
struct llama_model * model = NULL;
};
struct llava_context * llava_init(gpt_params * params);
void llava_free(struct llava_context * ctx_llava);
/** build a llava image embedding from the passed-in clip image `img`. result is returned as image_embd_out, size n_image_pos_out */
LLAMA_API bool llava_build_img_embed(const struct llama_context * ctx_llama, struct clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_image_pos_out); LLAMA_API bool llava_build_img_embed(const struct llama_context * ctx_llama, struct clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_image_pos_out);
/** write the image represented by image_embd (size n_image_pos) into the llama context with batch size n_batch,
* starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
LLAMA_API bool llava_eval_image_embd(struct llama_context * ctx_llama, float * image_embd, int n_image_pos, int n_batch, int * n_past);
#ifdef __cplusplus #ifdef __cplusplus
} }