llava : code formatting, rename files, fix compile warnings

This commit is contained in:
Georgi Gerganov 2023-10-12 15:35:44 +03:00
parent 346e3c1605
commit 4bc5c9c5d5
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
9 changed files with 90 additions and 79 deletions

View file

@ -627,8 +627,8 @@ convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggm
llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
llava: examples/llava/llava.cpp examples/llava/llava-utils.h examples/llava/clip.cpp examples/llava/clip.h examples/llava/stb_image.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) llava: examples/llava/llava.cpp examples/llava/llava-utils.h examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -Wno-cast-qual
baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS) baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

View file

@ -3,8 +3,9 @@ add_library(${TARGET} clip.cpp clip.h)
install(TARGETS ${TARGET} LIBRARY) install(TARGETS ${TARGET} LIBRARY)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11) target_compile_features(${TARGET} PRIVATE cxx_std_11)
target_compile_options(${TARGET} PRIVATE -Wno-cast-qual) # stb_image.h
if(TARGET BUILD_INFO) if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO) add_dependencies(${TARGET} BUILD_INFO)
endif() endif()
set(TARGET llava) set(TARGET llava)
@ -13,5 +14,5 @@ install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama clip ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common llama clip ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11) target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO) if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO) add_dependencies(${TARGET} BUILD_INFO)
endif() endif()

View file

@ -7,6 +7,7 @@ and [13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
models are available. models are available.
After API is confirmed, more models will be supported / uploaded. After API is confirmed, more models will be supported / uploaded.
## Usage ## Usage
Build with cmake or run `make llava` to build it. Build with cmake or run `make llava` to build it.
@ -28,16 +29,16 @@ git clone https://huggingface.co/liuhaotian/llava-v1.5-7b
git clone https://huggingface.co/openai/clip-vit-large-patch14-336 git clone https://huggingface.co/openai/clip-vit-large-patch14-336
``` ```
2. Use `llava_surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents: 2. Use `llava-surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents:
```sh ```sh
python ./examples/llava/llava_surgery.py -m ../llava-v1.5-7b python ./examples/llava/llava-surgery.py -m ../llava-v1.5-7b
``` ```
3. Use `convert_image_encoder_to_gguf.py` to convert the LLaVA image encoder to GGUF: 3. Use `convert-image-encoder-to-gguf.py` to convert the LLaVA image encoder to GGUF:
```sh ```sh
python ./examples/llava/convert_image_encoder_to_gguf -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b python ./examples/llava/convert-image-encoder-to-gguf -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b
``` ```
4. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF: 4. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF:

View file

@ -97,19 +97,19 @@ static int get_key_idx(const gguf_context * ctx, const char * key) {
return i; return i;
} }
static const uint32_t get_u32(const gguf_context * ctx, std::string key) { static uint32_t get_u32(const gguf_context * ctx, const std::string & key) {
const int i = get_key_idx(ctx, key.c_str()); const int i = get_key_idx(ctx, key.c_str());
return gguf_get_val_u32(ctx, i); return gguf_get_val_u32(ctx, i);
} }
static const float get_f32(const gguf_context * ctx, std::string key) { static float get_f32(const gguf_context * ctx, const std::string & key) {
const int i = get_key_idx(ctx, key.c_str()); const int i = get_key_idx(ctx, key.c_str());
return gguf_get_val_f32(ctx, i); return gguf_get_val_f32(ctx, i);
} }
static struct ggml_tensor * get_tensor(struct ggml_context * ctx, std::string name) { static struct ggml_tensor * get_tensor(struct ggml_context * ctx, const std::string & name) {
struct ggml_tensor * cur = ggml_get_tensor(ctx, name.c_str()); struct ggml_tensor * cur = ggml_get_tensor(ctx, name.c_str());
if (!cur) { if (!cur) {
printf("unable to find tensor %s\n", name.c_str()); printf("unable to find tensor %s\n", name.c_str());
@ -123,25 +123,18 @@ static std::string get_ftype(int ftype) {
switch (ftype) { switch (ftype) {
case 0: case 0:
return "f32"; return "f32";
break;
case 1: case 1:
return "f16"; return "f16";
break;
case 2: case 2:
return "q4_0"; return "q4_0";
break;
case 3: case 3:
return "q4_1"; return "q4_1";
break;
case 6: case 6:
return "q5_0"; return "q5_0";
break;
case 7: case 7:
return "q5_1"; return "q5_1";
break;
case 8: case 8:
return "q8_0"; return "q8_0";
break;
default: default:
throw std::runtime_error(format("Unrecognized file type: %d\n", ftype)); throw std::runtime_error(format("Unrecognized file type: %d\n", ftype));
} }
@ -237,7 +230,6 @@ struct clip_ctx {
}; };
static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_image_f32_batch * imgs) { static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_image_f32_batch * imgs) {
if (!ctx->has_vision_encoder) { if (!ctx->has_vision_encoder) {
printf("This gguf file seems to have no vision encoder\n"); printf("This gguf file seems to have no vision encoder\n");
return nullptr; return nullptr;
@ -254,15 +246,15 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
const int n_head = hparams.n_head; const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head; const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer; const int n_layer = hparams.n_layer;
const int n_intermediate = hparams.n_intermediate; //const int n_intermediate = hparams.n_intermediate;
const int projection_dim = hparams.projection_dim; //const int projection_dim = hparams.projection_dim;
const float eps = hparams.eps; const float eps = hparams.eps;
int batch_size = imgs->size; int batch_size = imgs->size;
if(ctx->has_llava_projector) { if(ctx->has_llava_projector) {
GGML_ASSERT(batch_size == 1); GGML_ASSERT(batch_size == 1);
} }
auto & buf_compute = ctx->buf_compute; const auto & buf_compute = ctx->buf_compute;
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size, /*.mem_size =*/ buf_compute.size,
@ -281,9 +273,9 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
if (!ggml_allocr_is_measure(ctx->alloc)) { if (!ggml_allocr_is_measure(ctx->alloc)) {
float * data = (float *)ggml_get_data(inp_raw); float * data = (float *)ggml_get_data(inp_raw);
for (int b = 0; b < imgs->size; b++) { for (size_t i = 0; i < imgs->size; i++) {
const int nx = imgs->data[b].nx; const int nx = imgs->data[i].nx;
const int ny = imgs->data[b].ny; const int ny = imgs->data[i].ny;
GGML_ASSERT(nx == image_size && ny == image_size); GGML_ASSERT(nx == image_size && ny == image_size);
const int n = nx * ny; const int n = nx * ny;
@ -339,17 +331,17 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
ggml_repeat(ctx0, model.pre_ln_b, embeddings)); ggml_repeat(ctx0, model.pre_ln_b, embeddings));
} }
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(ctx->alloc, KQ_scale); ggml_allocr_alloc(ctx->alloc, KQ_scale);
if (!ggml_allocr_is_measure(ctx->alloc)) { if (!ggml_allocr_is_measure(ctx->alloc)) {
ggml_set_f32(KQ_scale, 1.0f / sqrt((float)d_head)); ggml_set_f32(KQ_scale, 1.0f / sqrt((float)d_head));
} }
// loop over layers // loop over layers
for (int il = 0; il < n_layer - 1; il++) { for (int il = 0; il < n_layer - 1; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
const size_t nb_q_w = model.layers[il].q_w->nb[0]; //const size_t nb_q_w = model.layers[il].q_w->nb[0];
// layernorm1 // layernorm1
{ {
@ -730,7 +722,7 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip
uint8_t bc[3] = {122, 116, 104}; // bakground color in RGB from LLaVA uint8_t bc[3] = {122, 116, 104}; // bakground color in RGB from LLaVA
// fill with background color // fill with background color
for (int i = 0; i < temp.size; i++) { for (size_t i = 0; i < temp.size; i++) {
temp.data[i] = bc[i % 3]; temp.data[i] = bc[i % 3];
} }
@ -963,7 +955,7 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
if (conv_buf.size() < n_elms) { if (conv_buf.size() < n_elms) {
conv_buf.resize(n_elms); conv_buf.resize(n_elms);
} }
for (int j = 0; j < n_elms; ++j) { for (size_t j = 0; j < n_elms; ++j) {
conv_buf[j] = ggml_fp16_to_fp32(((ggml_fp16_t *)cur->data)[j]); conv_buf[j] = ggml_fp16_to_fp32(((ggml_fp16_t *)cur->data)[j]);
} }
f32_data = (float *)conv_buf.data(); f32_data = (float *)conv_buf.data();
@ -981,28 +973,28 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
std::vector<int64_t> hist_cur(1 << 4, 0); std::vector<int64_t> hist_cur(1 << 4, 0);
switch (new_type) { switch (new_type) {
case GGML_TYPE_Q4_0: { case GGML_TYPE_Q4_0: {
new_size = ggml_quantize_q4_0(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data()); new_size = ggml_quantize_q4_0(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data());
} break; } break;
case GGML_TYPE_Q4_1: { case GGML_TYPE_Q4_1: {
new_size = ggml_quantize_q4_1(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data()); new_size = ggml_quantize_q4_1(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data());
} break; } break;
case GGML_TYPE_Q5_0: { case GGML_TYPE_Q5_0: {
new_size = ggml_quantize_q5_0(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data()); new_size = ggml_quantize_q5_0(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data());
} break; } break;
case GGML_TYPE_Q5_1: { case GGML_TYPE_Q5_1: {
new_size = ggml_quantize_q5_1(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data()); new_size = ggml_quantize_q5_1(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data());
} break; } break;
case GGML_TYPE_Q8_0: { case GGML_TYPE_Q8_0: {
new_size = ggml_quantize_q8_0(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data()); new_size = ggml_quantize_q8_0(f32_data, new_data, n_elms, cur->ne[0], hist_cur.data());
} break; } break;
default: { default: {
fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, new_type); fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, new_type);
return false; return false;
} }
} }
for (int j = 0; j < hist_cur.size(); ++j) { for (size_t j = 0; j < hist_cur.size(); ++j) {
hist_all[j] += hist_cur[j]; hist_all[j] += hist_cur[j];
} }
} else { } else {
@ -1017,7 +1009,7 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size); gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size);
fout.write((const char *)new_data, new_size); fout.write((const char *)new_data, new_size);
size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size; size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size;
for (int j = 0; j < pad; ++j) { for (size_t j = 0; j < pad; ++j) {
fout.put(0); fout.put(0);
} }

View file

@ -1,12 +1,15 @@
#pragma once
// this one and clip lib will be eventually merged to a single lib, let's keep it this way for now // this one and clip lib will be eventually merged to a single lib, let's keep it this way for now
#include <stdio.h>
#include <stdlib.h>
#include <vector>
#include "common.h" #include "common.h"
#include "llama.h" #include "llama.h"
bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) { #include <cstdio>
#include <cstdlib>
#include <vector>
inline bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama)); int n_embd = llama_n_embd(llama_get_model(ctx_llama));
for (int i = 0; i < N; i += n_batch) { for (int i = 0; i < N; i += n_batch) {
@ -24,7 +27,7 @@ bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch
return true; return true;
} }
bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) { inline bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) {
int N = (int) tokens.size(); int N = (int) tokens.size();
for (int i = 0; i < N; i += n_batch) { for (int i = 0; i < N; i += n_batch) {
int n_eval = (int) tokens.size() - i; int n_eval = (int) tokens.size() - i;
@ -40,20 +43,21 @@ bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> toke
return true; return true;
} }
bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) { inline bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
tokens.push_back(id); tokens.push_back(id);
return eval_tokens(ctx_llama, tokens, 1, n_past); return eval_tokens(ctx_llama, tokens, 1, n_past);
} }
bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past){ inline bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past){
std::string str2 = str; std::string str2 = str;
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, true); std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, true);
eval_tokens(ctx_llama, embd_inp, n_batch, n_past); eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
return true; return true;
} }
llama_token sample_id(llama_context * ctx_llama, gpt_params & params) { // TODO: use common/sampling.h
inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
// out of user input, sample next token // out of user input, sample next token
const float temp = params.sampling_params.temp; const float temp = params.sampling_params.temp;
const int32_t top_k = params.sampling_params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : params.sampling_params.top_k; const int32_t top_k = params.sampling_params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : params.sampling_params.top_k;
@ -128,7 +132,7 @@ llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
return id; return id;
} }
const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) { inline const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) {
int id = sample_id(ctx_llama, params); int id = sample_id(ctx_llama, params);
static std::string ret; static std::string ret;
if (id == llama_token_eos(ctx_llama)) { if (id == llama_token_eos(ctx_llama)) {

View file

@ -1,13 +1,13 @@
#include <stdio.h>
#include <stdlib.h>
#include <vector>
#include "clip.h" #include "clip.h"
#include "llava-utils.h" #include "llava-utils.h"
#include "common.h" #include "common.h"
#include "llama.h" #include "llama.h"
static void show_additional_info(int argc, char ** argv) { #include <cstdio>
#include <cstdlib>
#include <vector>
static void show_additional_info(int /*argc*/, char ** argv) {
printf("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); printf("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n"); printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
} }
@ -40,6 +40,7 @@ int main(int argc, char ** argv) {
// load and preprocess the image // load and preprocess the image
clip_image_u8 img; clip_image_u8 img;
clip_image_f32 img_res; clip_image_f32 img_res;
if (!clip_image_load_from_file(img_path, &img)) { if (!clip_image_load_from_file(img_path, &img)) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path); fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
@ -54,8 +55,9 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
int n_img_pos = clip_n_patches(ctx_clip); int n_img_pos = clip_n_patches(ctx_clip);
int n_img_embd = clip_n_mmproj_embd(ctx_clip); int n_img_embd = clip_n_mmproj_embd(ctx_clip);
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
if (!image_embd) { if (!image_embd) {
@ -84,11 +86,13 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = params.n_ctx < 2048 ? 2048 : params.n_ctx; // we need a longer context size to process image embeddings ctx_params.n_ctx = params.n_ctx < 2048 ? 2048 : params.n_ctx; // we need a longer context size to process image embeddings
ctx_params.n_threads = params.n_threads; ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
if (ctx_llama == NULL) { if (ctx_llama == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
@ -111,26 +115,35 @@ int main(int argc, char ** argv) {
// process the prompt // process the prompt
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:" // llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
int n_past = 0; int n_past = 0;
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
// GG: are we sure that the should be a trailing whitespace at the end of this string?
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past); eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past);
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past); eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past); eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past); eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past);
// generate the response // generate the response
const char* tmp; printf("\n");
for (int i=0; i < max_tgt_len; i++) {
tmp = sample(ctx_llama, params, &n_past); for (int i = 0; i < max_tgt_len; i++) {
if (strcmp(tmp, "</s>")==0) break; const char * tmp = sample(ctx_llama, params, &n_past);
if (strcmp(tmp, "</s>") == 0) break;
printf("%s", tmp); printf("%s", tmp);
fflush(stdout); fflush(stdout);
} }
printf("\n"); printf("\n");
const float img_enc_duration = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; {
printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, img_enc_duration, img_enc_duration / n_img_pos); const float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / n_img_pos);
}
llama_print_timings(ctx_llama); llama_print_timings(ctx_llama);