Move llama_progress_handler into llama_context_params

This commit is contained in:
Jed Fox 2023-03-23 13:36:43 -04:00
parent e47924fd4b
commit ab02a2441c
No known key found for this signature in database
GPG key ID: 0B61D18EA54B47E1
3 changed files with 24 additions and 23 deletions

View file

@ -108,12 +108,14 @@ struct llama_context {
struct llama_context_params llama_context_default_params() { struct llama_context_params llama_context_default_params() {
struct llama_context_params result = { struct llama_context_params result = {
/*.n_ctx =*/ 512, /*.n_ctx =*/ 512,
/*.n_parts =*/ -1, /*.n_parts =*/ -1,
/*.seed =*/ 0, /*.seed =*/ 0,
/*.f16_kv =*/ false, /*.f16_kv =*/ false,
/*.logits_all =*/ false, /*.logits_all =*/ false,
/*.vocab_only =*/ false, /*.vocab_only =*/ false,
/*.progress_callback =*/ nullptr,
/*.progress_ctx =*/ nullptr,
}; };
return result; return result;
@ -130,7 +132,8 @@ static bool llama_model_load(
int n_parts, int n_parts,
ggml_type memory_type, ggml_type memory_type,
bool vocab_only, bool vocab_only,
llama_progress_handler progress) { llama_progress_handler progress_callback,
void *progress_ctx) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
@ -397,8 +400,8 @@ static bool llama_model_load(
std::vector<uint8_t> tmp; std::vector<uint8_t> tmp;
if (progress.handler) { if (progress_callback) {
progress.handler(0, progress.ctx); progress_callback(0.0, progress_ctx);
} }
for (int i = 0; i < n_parts; ++i) { for (int i = 0; i < n_parts; ++i) {
@ -591,9 +594,9 @@ static bool llama_model_load(
//fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); //fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
if (++n_tensors % 8 == 0) { if (++n_tensors % 8 == 0) {
if (progress.handler) { if (progress_callback) {
double current_progress = (double(i) + (double(fin.tellg()) / double(file_size))) / double(n_parts); double current_progress = (double(i) + (double(fin.tellg()) / double(file_size))) / double(n_parts);
progress.handler(current_progress, progress.ctx); progress_callback(current_progress, progress_ctx);
} }
fprintf(stderr, "."); fprintf(stderr, ".");
fflush(stderr); fflush(stderr);
@ -612,8 +615,8 @@ static bool llama_model_load(
lctx.t_load_us = ggml_time_us() - t_start_us; lctx.t_load_us = ggml_time_us() - t_start_us;
if (progress.handler) { if (progress_callback) {
progress.handler(1, progress.ctx); progress_callback(1.0, progress_ctx);
} }
return true; return true;
@ -1414,8 +1417,7 @@ bool llama_model_quantize_internal(const std::string & fname_inp, const std::str
struct llama_context * llama_init_from_file( struct llama_context * llama_init_from_file(
const char * path_model, const char * path_model,
struct llama_context_params params, struct llama_context_params params) {
llama_progress_handler progress) {
ggml_time_init(); ggml_time_init();
llama_context * ctx = new llama_context; llama_context * ctx = new llama_context;
@ -1429,7 +1431,7 @@ struct llama_context * llama_init_from_file(
ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, progress)) { if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.progress_callback, params.progress_ctx)) {
fprintf(stderr, "%s: failed to load model\n", __func__); fprintf(stderr, "%s: failed to load model\n", __func__);
delete ctx; delete ctx;
return nullptr; return nullptr;

11
llama.h
View file

@ -45,6 +45,8 @@ extern "C" {
} llama_token_data; } llama_token_data;
typedef void (*llama_progress_handler)(double progress, void *ctx);
struct llama_context_params { struct llama_context_params {
int n_ctx; // text context int n_ctx; // text context
int n_parts; // -1 for default int n_parts; // -1 for default
@ -53,11 +55,9 @@ extern "C" {
bool f16_kv; // use fp16 for KV cache bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights bool vocab_only; // only load the vocabulary, no weights
};
struct llama_progress_handler { llama_progress_handler progress_callback; // called with a progress value between 0 and 1, pass NULL to disable
void (*handler)(double progress, void *ctx); void * progress_ctx; // context pointer passed to the progress callback
void *ctx;
}; };
LLAMA_API struct llama_context_params llama_context_default_params(); LLAMA_API struct llama_context_params llama_context_default_params();
@ -67,8 +67,7 @@ extern "C" {
// Return NULL on failure // Return NULL on failure
LLAMA_API struct llama_context * llama_init_from_file( LLAMA_API struct llama_context * llama_init_from_file(
const char * path_model, const char * path_model,
struct llama_context_params params, struct llama_context_params params);
struct llama_progress_handler progress);
// Frees all allocated memory // Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx); LLAMA_API void llama_free(struct llama_context * ctx);

View file

@ -32,7 +32,7 @@ int main(int argc, char **argv) {
lparams.vocab_only = true; lparams.vocab_only = true;
ctx = llama_init_from_file(fname.c_str(), lparams, {NULL, NULL}); ctx = llama_init_from_file(fname.c_str(), lparams);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());