fix bug where base64 string was not removed from the prompt
This commit is contained in:
parent
f21af512cd
commit
708928c649
4 changed files with 86 additions and 82 deletions
|
@ -7,6 +7,13 @@
|
|||
#include "llava.h"
|
||||
#include "llava-utils.h"
|
||||
|
||||
struct llava_context {
|
||||
struct clip_ctx * ctx_clip = NULL;
|
||||
struct llama_context * ctx_llama = NULL;
|
||||
struct llama_model * model = NULL;
|
||||
};
|
||||
|
||||
|
||||
|
||||
static void show_additional_info(int /*argc*/, char ** argv) {
|
||||
printf("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
|
||||
|
@ -25,7 +32,7 @@ static bool load_image(llava_context * ctx_llava, gpt_params * params, float **i
|
|||
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
|
||||
return false;
|
||||
}
|
||||
prompt = remove_image_from_prompt(prompt);
|
||||
params->prompt = remove_image_from_prompt(prompt);
|
||||
} else {
|
||||
if (!clip_image_load_from_file(params->image.c_str(), &img)) {
|
||||
fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str());
|
||||
|
@ -49,8 +56,7 @@ static void process_prompt(struct llava_context * ctx_llava, float * image_embd,
|
|||
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
|
||||
// GG: are we sure that the should be a trailing whitespace at the end of this string?
|
||||
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past);
|
||||
printf("embedding image, n_img_pos is %d\n", n_img_pos);
|
||||
eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past);
|
||||
llava_eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past);
|
||||
eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past);
|
||||
eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past);
|
||||
|
||||
|
@ -70,6 +76,61 @@ static void process_prompt(struct llava_context * ctx_llava, float * image_embd,
|
|||
|
||||
}
|
||||
|
||||
|
||||
static struct llava_context * llava_init(gpt_params * params) {
|
||||
|
||||
const char * clip_path = params->mmproj.c_str();
|
||||
|
||||
auto prompt = params->prompt;
|
||||
if (prompt.empty()) {
|
||||
prompt = "describe the image in detail.";
|
||||
}
|
||||
|
||||
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
||||
|
||||
llama_backend_init(params->numa);
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params);
|
||||
if (model == NULL) {
|
||||
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
|
||||
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
|
||||
ctx_params.n_threads = params->n_threads;
|
||||
ctx_params.n_threads_batch = params->n_threads_batch == -1 ? params->n_threads : params->n_threads_batch;
|
||||
|
||||
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
if (ctx_llama == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
|
||||
|
||||
ctx_llava->ctx_llama = ctx_llama;
|
||||
ctx_llava->ctx_clip = ctx_clip;
|
||||
ctx_llava->model = model;
|
||||
return ctx_llava;
|
||||
}
|
||||
|
||||
|
||||
static void llava_free(struct llava_context * ctx_llava) {
|
||||
if (ctx_llava->ctx_clip) {
|
||||
clip_free(ctx_llava->ctx_clip);
|
||||
ctx_llava->ctx_clip = NULL;
|
||||
}
|
||||
|
||||
llama_free(ctx_llava->ctx_llama);
|
||||
llama_free_model(ctx_llava->model);
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
|
|
|
@ -11,24 +11,6 @@
|
|||
#include <cstdlib>
|
||||
#include <vector>
|
||||
|
||||
inline bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) {
|
||||
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||
|
||||
for (int i = 0; i < N; i += n_batch) {
|
||||
int n_eval = N - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
*n_past += n_eval;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) {
|
||||
int N = (int) tokens.size();
|
||||
for (int i = 0; i < N; i += n_batch) {
|
||||
|
@ -37,7 +19,7 @@ inline bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
|||
n_eval = n_batch;
|
||||
}
|
||||
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
fprintf(stderr, "%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
}
|
||||
*n_past += n_eval;
|
||||
|
@ -194,6 +176,6 @@ inline std::string remove_image_from_prompt(const std::string& prompt, const cha
|
|||
return prompt;
|
||||
}
|
||||
auto pre = prompt.substr(0, begin);
|
||||
auto post = prompt.substr(end+1);
|
||||
auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END));
|
||||
return pre + replacement + post;
|
||||
}
|
||||
|
|
|
@ -68,56 +68,21 @@ bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip,
|
|||
}
|
||||
|
||||
|
||||
struct llava_context * llava_init(gpt_params * params) {
|
||||
|
||||
const char * clip_path = params->mmproj.c_str();
|
||||
bool llava_eval_image_embd(llama_context * ctx_llama, float * image_embd, int n_image_pos, int n_batch, int * n_past) {
|
||||
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||
|
||||
auto prompt = params->prompt;
|
||||
if (prompt.empty()) {
|
||||
prompt = "describe the image in detail.";
|
||||
for (int i = 0; i < n_image_pos; i += n_batch) {
|
||||
int n_eval = n_image_pos - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
|
||||
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
||||
|
||||
llama_backend_init(params->numa);
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params);
|
||||
if (model == NULL) {
|
||||
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
||||
return NULL;
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (image_embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
|
||||
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
|
||||
ctx_params.n_threads = params->n_threads;
|
||||
ctx_params.n_threads_batch = params->n_threads_batch == -1 ? params->n_threads : params->n_threads_batch;
|
||||
|
||||
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
if (ctx_llama == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
return NULL;
|
||||
*n_past += n_eval;
|
||||
}
|
||||
|
||||
auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
|
||||
|
||||
ctx_llava->ctx_llama = ctx_llama;
|
||||
ctx_llava->ctx_clip = ctx_clip;
|
||||
ctx_llava->model = model;
|
||||
return ctx_llava;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
void llava_free(struct llava_context * ctx_llava) {
|
||||
if (ctx_llava->ctx_clip) {
|
||||
clip_free(ctx_llava->ctx_clip);
|
||||
ctx_llava->ctx_clip = NULL;
|
||||
}
|
||||
|
||||
llama_free(ctx_llava->ctx_llama);
|
||||
llama_free_model(ctx_llava->model);
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
|
|
|
@ -10,18 +10,14 @@ struct clip_ctx;
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct llava_context {
|
||||
struct clip_ctx * ctx_clip = NULL;
|
||||
struct llama_context * ctx_llama = NULL;
|
||||
struct llama_model * model = NULL;
|
||||
};
|
||||
|
||||
struct llava_context * llava_init(gpt_params * params);
|
||||
void llava_free(struct llava_context * ctx_llava);
|
||||
|
||||
/** build a llava image embedding from the passed-in clip image `img`. result is returned as image_embd_out, size n_image_pos_out */
|
||||
/** using ctx_clip, build a llava image embedding from the passed-in image `img` (see clip.h for methods to load img).
|
||||
* result is returned as image_embd_out, size n_image_pos_out */
|
||||
LLAMA_API bool llava_build_img_embed(const struct llama_context * ctx_llama, struct clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_image_pos_out);
|
||||
|
||||
/** write the image represented by image_embd (size n_image_pos) into the llama context with batch size n_batch,
|
||||
* starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
|
||||
LLAMA_API bool llava_eval_image_embd(struct llama_context * ctx_llama, float * image_embd, int n_image_pos, int n_batch, int * n_past);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue