diff --git a/llama.cpp b/llama.cpp index 81b2e9e82..0060fafed 100644 --- a/llama.cpp +++ b/llama.cpp @@ -108,12 +108,14 @@ struct llama_context { struct llama_context_params llama_context_default_params() { struct llama_context_params result = { - /*.n_ctx =*/ 512, - /*.n_parts =*/ -1, - /*.seed =*/ 0, - /*.f16_kv =*/ false, - /*.logits_all =*/ false, - /*.vocab_only =*/ false, + /*.n_ctx =*/ 512, + /*.n_parts =*/ -1, + /*.seed =*/ 0, + /*.f16_kv =*/ false, + /*.logits_all =*/ false, + /*.vocab_only =*/ false, + /*.progress_callback =*/ nullptr, + /*.progress_ctx =*/ nullptr, }; return result; @@ -130,7 +132,8 @@ static bool llama_model_load( int n_parts, ggml_type memory_type, bool vocab_only, - llama_progress_handler progress) { + llama_progress_handler progress_callback, + void *progress_ctx) { fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); const int64_t t_start_us = ggml_time_us(); @@ -397,8 +400,8 @@ static bool llama_model_load( std::vector tmp; - if (progress.handler) { - progress.handler(0, progress.ctx); + if (progress_callback) { + progress_callback(0.0, progress_ctx); } for (int i = 0; i < n_parts; ++i) { @@ -591,9 +594,9 @@ static bool llama_model_load( //fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); if (++n_tensors % 8 == 0) { - if (progress.handler) { + if (progress_callback) { double current_progress = (double(i) + (double(fin.tellg()) / double(file_size))) / double(n_parts); - progress.handler(current_progress, progress.ctx); + progress_callback(current_progress, progress_ctx); } fprintf(stderr, "."); fflush(stderr); @@ -612,8 +615,8 @@ static bool llama_model_load( lctx.t_load_us = ggml_time_us() - t_start_us; - if (progress.handler) { - progress.handler(1, progress.ctx); + if (progress_callback) { + progress_callback(1.0, progress_ctx); } return true; @@ -1414,8 +1417,7 @@ bool llama_model_quantize_internal(const std::string & fname_inp, const std::str struct llama_context * llama_init_from_file( const char * path_model, - struct llama_context_params params, - llama_progress_handler progress) { + struct llama_context_params params) { ggml_time_init(); llama_context * ctx = new llama_context; @@ -1429,7 +1431,7 @@ struct llama_context * llama_init_from_file( ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, progress)) { + if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.progress_callback, params.progress_ctx)) { fprintf(stderr, "%s: failed to load model\n", __func__); delete ctx; return nullptr; diff --git a/llama.h b/llama.h index e8bae5181..0e1b43ee3 100644 --- a/llama.h +++ b/llama.h @@ -45,6 +45,8 @@ extern "C" { } llama_token_data; + typedef void (*llama_progress_handler)(double progress, void *ctx); + struct llama_context_params { int n_ctx; // text context int n_parts; // -1 for default @@ -53,11 +55,9 @@ extern "C" { bool f16_kv; // use fp16 for KV cache bool logits_all; // the llama_eval() call computes all logits, not just the last one bool vocab_only; // only load the vocabulary, no weights - }; - struct llama_progress_handler { - void (*handler)(double progress, void *ctx); - void *ctx; + llama_progress_handler progress_callback; // called with a progress value between 0 and 1, pass NULL to disable + void * progress_ctx; // context pointer passed to the progress callback }; LLAMA_API struct llama_context_params llama_context_default_params(); @@ -67,8 +67,7 @@ extern "C" { // Return NULL on failure LLAMA_API struct llama_context * llama_init_from_file( const char * path_model, - struct llama_context_params params, - struct llama_progress_handler progress); + struct llama_context_params params); // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 96e4d0cef..49bc232b6 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -32,7 +32,7 @@ int main(int argc, char **argv) { lparams.vocab_only = true; - ctx = llama_init_from_file(fname.c_str(), lparams, {NULL, NULL}); + ctx = llama_init_from_file(fname.c_str(), lparams); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());