Some cleanup

This commit is contained in:
M. Yusuf Sarıgöz 2023-10-09 13:38:48 +03:00
parent 8af7e2103c
commit 54495c9474
4 changed files with 202 additions and 224 deletions

View file

@ -89,7 +89,7 @@ static std::string format(const char * fmt, ...) {
// utilities to get data from a gguf file
//
int get_key_idx(const gguf_context * ctx, const char * key) {
static int get_key_idx(const gguf_context * ctx, const char * key) {
int i = gguf_find_key(ctx, key);
if (i == -1) {
fprintf(stderr, "key %s not found in file\n", key);
@ -99,19 +99,19 @@ int get_key_idx(const gguf_context * ctx, const char * key) {
return i;
}
const uint32_t get_u32(const gguf_context * ctx, std::string key) {
static const uint32_t get_u32(const gguf_context * ctx, std::string key) {
const int i = get_key_idx(ctx, key.c_str());
return gguf_get_val_u32(ctx, i);
}
const float get_f32(const gguf_context * ctx, std::string key) {
static const float get_f32(const gguf_context * ctx, std::string key) {
const int i = get_key_idx(ctx, key.c_str());
return gguf_get_val_f32(ctx, i);
}
struct ggml_tensor * get_tensor(struct ggml_context * ctx, std::string name) {
static struct ggml_tensor * get_tensor(struct ggml_context * ctx, std::string name) {
struct ggml_tensor * cur = ggml_get_tensor(ctx, name.c_str());
if (!cur) {
printf("unable to find tensor %s\n", name.c_str());
@ -121,7 +121,7 @@ struct ggml_tensor * get_tensor(struct ggml_context * ctx, std::string name) {
return cur;
}
std::string get_ftype(int ftype) {
static std::string get_ftype(int ftype) {
switch (ftype) {
case 0:
return "f32";
@ -231,20 +231,13 @@ struct clip_ctx {
int32_t ftype = 1;
struct ggml_context * ctx;
struct gguf_context * ctx_gguf;
//struct clip_buffer buf_compute;
// reusable buffer for `struct ggml_graph_plan.work_data`
std::vector<uint8_t> work_buffer;
// memory buffers used to evaluate the model
// memory buffers to evaluate the model
clip_buffer buf_compute;
clip_buffer buf_alloc;
ggml_allocr * alloc = NULL;
};
static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_image_f32_batch * imgs) {
if (!ctx->has_vision_encoder) {
@ -436,7 +429,8 @@ if (!ggml_allocr_is_measure(ctx->alloc)) {
embeddings = cur;
}
if (ctx->has_llava_projector) {
// llava projector
{
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
@ -457,8 +451,6 @@ if (!ggml_allocr_is_measure(ctx->alloc)) {
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_2_b, embeddings), embeddings);
ggml_set_name(embeddings, "check");
}
// build the graph
@ -551,6 +543,8 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
}
GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
GGML_ASSERT(new_clip->has_vision_encoder);
GGML_ASSERT(!new_clip->has_text_encoder);
idx = get_key_idx(ctx, KEY_USE_GELU);
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
@ -643,16 +637,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
vision_model.class_embedding = get_tensor(new_clip->ctx, TN_CLASS_EMBD);
vision_model.position_embeddings = get_tensor(new_clip->ctx, format(TN_POS_EMBD, "v"));
vision_model.pre_ln_w = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "weight"));
vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));if (new_clip->has_llava_projector) {
vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));
vision_model.mm_0_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 0, "weight"));
vision_model.mm_0_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 0, "bias"));
vision_model.mm_2_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 2, "weight"));
vision_model.mm_2_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 2, "bias"));
} else {
vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight"));
vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias"));
vision_model.projection = get_tensor(new_clip->ctx, TN_VIS_PROJ);
}
vision_model.layers.resize(hparams.n_layer);
for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = vision_model.layers[il];
@ -861,7 +851,7 @@ void clip_free(clip_ctx * ctx) {
delete ctx;
}
bool clip_image_encode(const clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec, const bool normalize) {
bool clip_image_encode(const clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
if (!ctx->has_vision_encoder) {
printf("This gguf file seems to have no vision encoder\n");
return false;
@ -870,37 +860,25 @@ bool clip_image_encode(const clip_ctx * ctx, const int n_threads, clip_image_f32
clip_image_f32_batch imgs{};
imgs.size = 1;
imgs.data = img;
return clip_image_batch_encode(ctx, n_threads, &imgs, vec, normalize);
return clip_image_batch_encode(ctx, n_threads, &imgs, vec);
}
bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec,
const bool normalize) {
bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) {
if (!ctx->has_vision_encoder) {
printf("This gguf file seems to have no vision encoder\n");
return false;
}
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
const int image_size = hparams.image_size;
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
const int num_positions = num_patches + 1;
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer;
const int n_intermediate = hparams.n_intermediate;
const int projection_dim = hparams.projection_dim;
const float eps = hparams.eps;
int batch_size = imgs->size;
if(ctx->has_llava_projector) {
GGML_ASSERT(batch_size == 1); // TODO: support multiple images
}
// reset alloc buffer to clean the memory from previous invocations
ggml_allocr_reset(ctx->alloc);
// build the inference graph
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
ggml_allocr_alloc_graph(ctx->alloc, gf);
@ -911,7 +889,10 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
ggml_graph_compute(gf, &plan);
// the last node is the embedding tensor
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1];
// copy the embeddings to the location passed by the user
memcpy(vec, ggml_get_data_f32(embeddings), ggml_nbytes(embeddings));
if (plan.work_size > 0) {
@ -921,7 +902,6 @@ struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1];
return true;
}
/*
bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) {
ggml_type type = GGML_TYPE_Q4_1;
@ -1106,6 +1086,9 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
return true;
}
*/
struct clip_vision_hparams * clip_get_vision_hparams(struct clip_ctx * ctx) { return &ctx->vision_model.hparams; }
size_t clip_embd_nbytes(struct clip_ctx * ctx) {
auto & params = ctx->vision_model.hparams;
return (params.image_size / params.patch_size) * (params.image_size / params.patch_size) * 4096 * sizeof(float);
}

View file

@ -9,17 +9,6 @@ struct clip_ctx;
extern "C" {
#endif
struct clip_text_hparams {
int32_t n_vocab;
int32_t num_positions;
int32_t hidden_size;
int32_t n_intermediate;
int32_t projection_dim;
int32_t n_head;
int32_t n_layer;
float eps;
};
struct clip_vision_hparams {
int32_t image_size;
int32_t patch_size;
@ -31,18 +20,11 @@ struct clip_vision_hparams {
float eps;
};
typedef int32_t clip_vocab_id;
struct clip_tokens {
clip_vocab_id * data;
size_t size;
};
struct clip_ctx * clip_model_load(const char * fname, const int verbosity);
void clip_free(struct clip_ctx * ctx);
struct clip_text_hparams * clip_get_text_hparams(struct clip_ctx * ctx);
struct clip_vision_hparams * clip_get_vision_hparams(struct clip_ctx * ctx);
size_t clip_embd_nbytes(struct clip_ctx * ctx);
// RGB uint8 image
struct clip_image_u8 {
@ -71,31 +53,16 @@ struct clip_image_f32_batch {
size_t size;
};
bool clip_tokenize(const struct clip_ctx * ctx, const char * text, struct clip_tokens * tokens);
struct clip_image_u8 * make_clip_image_u8();
struct clip_image_f32 * make_clip_image_f32();
bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res);
bool clip_text_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_tokens * tokens, float * vec,
const bool normalize);
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec,
const bool normalize);
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);
void clip_image_batch_preprocess(const struct clip_ctx * ctx, const int n_threads,
const struct clip_image_u8_batch * img_inputs, struct clip_image_f32_batch * imgs_resized);
bool clip_image_batch_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_image_f32_batch * imgs,
float * vec, const bool normalize);
// bool image_normalize(const clip_image_u8 *img, clip_image_f32 *res);
bool clip_compare_text_and_image(const struct clip_ctx * ctx, const int n_threads, const char * text,
const struct clip_image_u8 * image, float * score);
float clip_similarity_score(const float * vec1, const float * vec2, const int vec_dim);
bool softmax_with_sorting(float * arr, const int length, float * sorted_scores, int * indices);
bool clip_zero_shot_label_image(struct clip_ctx * ctx, const int n_threads, const struct clip_image_u8 * input_img,
const char ** labels, const size_t n_labels, float * scores, int * indices);
float * vec);
bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype);

View file

@ -0,0 +1,141 @@
// this one and clip lib will be eventually merged to a single lib, let's keep it this way for now
#include <stdio.h>
#include <stdlib.h>
#include <vector>
#include "common.h"
#include "llama.h"
bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
for (int i = 0; i < N; i += n_batch) {
int n_eval = N - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
*n_past += n_eval;
}
return true;
}
bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) {
int N = (int) tokens.size();
for (int i = 0; i < N; i += n_batch) {
int n_eval = (int) tokens.size() - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
*n_past += n_eval;
}
return true;
}
bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
std::vector<llama_token> tokens;
tokens.push_back(id);
return eval_tokens(ctx_llama, tokens, 1, n_past);
}
bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past){
std::string str2 = str;
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, true);
eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
return true;
}
llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
// out of user input, sample next token
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
// const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
// const float repeat_penalty = params.repeat_penalty;
// const float alpha_presence = params.presence_penalty;
// const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
// const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
{
auto logits = llama_get_logits(ctx_llama);
auto n_vocab = llama_n_vocab(llama_get_model(ctx_llama));
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// TODO: Apply penalties
// float nl_logit = logits[llama_token_nl(ctx)];
// auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
// llama_sample_repetition_penalty(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, repeat_penalty);
// llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, alpha_frequency, alpha_presence);
// if (!penalize_nl) {
// logits[llama_token_nl(ctx)] = nl_logit;
// }
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx_llama, &candidates_p);
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temp(ctx_llama, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temp(ctx_llama, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx_llama, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx_llama, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
llama_sample_temp(ctx_llama, &candidates_p, temp);
id = llama_sample_token(ctx_llama, &candidates_p);
}
}
}
return id;
}
const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) {
int id = sample_id(ctx_llama, params);
static std::string ret;
if (id == llama_token_eos(ctx_llama)) {
ret = "</s>";
} else {
ret = llama_token_to_piece(ctx_llama, id);
}
eval_id(ctx_llama, id, n_past);
return ret.c_str();
}

View file

@ -3,149 +3,17 @@
#include <vector>
#include "clip.h"
#include "llava-utils.h"
#include "common.h"
#include "llama.h"
static bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
for (int i = 0; i < N; i += n_batch) {
int n_eval = N - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
*n_past += n_eval;
}
return true;
}
static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int N, int * n_past) {
int n_batch = N;
for (int i = 0; i < (int) tokens.size(); i += n_batch) {
int n_eval = (int) tokens.size() - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
*n_past += n_eval;
}
return true;
}
static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
std::vector<llama_token> tokens;
tokens.push_back(id);
return eval_tokens(ctx_llama, tokens, 1, n_past);
}
static bool eval_string(struct llama_context * ctx_llama, const char* str, int N, int * n_past){
std::string str2 = str;
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, true);
eval_tokens(ctx_llama, embd_inp, N, n_past);
return true;
}
static llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
// out of user input, sample next token
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
// const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
// const float repeat_penalty = params.repeat_penalty;
// const float alpha_presence = params.presence_penalty;
// const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
// const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
{
auto logits = llama_get_logits(ctx_llama);
auto n_vocab = llama_n_vocab(llama_get_model(ctx_llama));
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// TODO: Apply penalties
// float nl_logit = logits[llama_token_nl(ctx)];
// auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
// llama_sample_repetition_penalty(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, repeat_penalty);
// llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, alpha_frequency, alpha_presence);
// if (!penalize_nl) {
// logits[llama_token_nl(ctx)] = nl_logit;
// }
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx_llama, &candidates_p);
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temp(ctx_llama, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temp(ctx_llama, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx_llama, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx_llama, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
llama_sample_temp(ctx_llama, &candidates_p, temp);
id = llama_sample_token(ctx_llama, &candidates_p);
}
}
}
return id;
}
const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) {
int id = sample_id(ctx_llama, params);
static std::string ret;
if (id == llama_token_eos(ctx_llama)) {
ret = "</s>";
} else {
ret = llama_token_to_piece(ctx_llama, id);
}
eval_id(ctx_llama, id, n_past);
return ret.c_str();
}
int main(int argc, char ** argv) {
gpt_params params;
if (argc < 4) {
printf("usage: %s <path/to/llava-v1.5/ggml-model-f16.gguf> <path/to/llava-v1.5/llava-encoder-f16.gguf> <path/to/an/image.jpg> [a text prompt]\n", argv[0]);
printf("usage: %s <path/to/llava-v1.5/ggml-model-q5_k.gguf> <path/to/llava-v1.5/mmproj-model-f16.gguf> <path/to/an/image.jpg> [a text prompt]\n", argv[0]);
return 1;
}
params.model = argv[1];
@ -160,13 +28,28 @@ int main(int argc, char ** argv) {
params.prompt = "describe the image in detail.";
}
auto ctx_clip = clip_model_load(clip_path, 1);
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
// load and preprocess the iamge
clip_image_u8 img;
clip_image_f32 img_res;
clip_image_load_from_file(img_path, &img);
clip_image_preprocess(ctx_clip, &img, &img_res);
float * vec = (float *)malloc(4096 * 576 * sizeof(float));
clip_image_encode(ctx_clip, params.n_threads, &img_res, vec, false);
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
if (!image_embd) {
fprintf(stderr, "Unable to allocate memory for CLIP embeddings\n");
return 1;
}
if (!clip_image_encode(ctx_clip, params.n_threads, &img_res, image_embd)) {
fprintf(stderr, "Unable to encode image\n");
return 1;
}
// we get the embeddings, free up the memory required for CLIP
clip_free(ctx_clip);
llama_backend_init(params.numa);
@ -191,13 +74,17 @@ int main(int argc, char ** argv) {
return 1;
}
// process the prompt
// llava chat format is "user: <image embeddings>\n<textual prompt>\nassistant:"
int n_past = 0;
int max_tgt_len = 256;
eval_string(ctx_llama, "user: ", params.n_batch, &n_past);
eval_image_embd(ctx_llama, vec, 576, params.n_batch, &n_past);
eval_image_embd(ctx_llama, image_embd, /*n_pos_image=*/ 576, params.n_batch, &n_past);
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
eval_string(ctx_llama, "\nassistant:", params.n_batch, &n_past);
printf("n_past = %d\n", n_past);
// generate the response
const char* tmp;
for (int i=0; i<max_tgt_len; i++) {
@ -213,7 +100,7 @@ printf("n_past = %d\n", n_past);
llama_free(ctx_llama);
llama_free_model(model);
llama_backend_free();
free(vec);
free(image_embd);
return 0;
}