This commit is contained in:
Jed Fox 2023-03-23 16:15:55 -04:00
parent ab02a2441c
commit 1f9592baf3
No known key found for this signature in database
GPG key ID: 0B61D18EA54B47E1
2 changed files with 19 additions and 17 deletions

View file

@ -115,7 +115,7 @@ struct llama_context_params llama_context_default_params() {
/*.logits_all =*/ false,
/*.vocab_only =*/ false,
/*.progress_callback =*/ nullptr,
/*.progress_ctx =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
};
return result;
@ -132,8 +132,8 @@ static bool llama_model_load(
int n_parts,
ggml_type memory_type,
bool vocab_only,
llama_progress_handler progress_callback,
void *progress_ctx) {
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
const int64_t t_start_us = ggml_time_us();
@ -401,7 +401,7 @@ static bool llama_model_load(
std::vector<uint8_t> tmp;
if (progress_callback) {
progress_callback(0.0, progress_ctx);
progress_callback(0.0, progress_callback_user_data);
}
for (int i = 0; i < n_parts; ++i) {
@ -596,7 +596,7 @@ static bool llama_model_load(
if (++n_tensors % 8 == 0) {
if (progress_callback) {
double current_progress = (double(i) + (double(fin.tellg()) / double(file_size))) / double(n_parts);
progress_callback(current_progress, progress_ctx);
progress_callback(current_progress, progress_callback_user_data);
}
fprintf(stderr, ".");
fflush(stderr);
@ -616,7 +616,7 @@ static bool llama_model_load(
lctx.t_load_us = ggml_time_us() - t_start_us;
if (progress_callback) {
progress_callback(1.0, progress_ctx);
progress_callback(1.0, progress_callback_user_data);
}
return true;
@ -1431,7 +1431,7 @@ struct llama_context * llama_init_from_file(
ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.progress_callback, params.progress_ctx)) {
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) {
fprintf(stderr, "%s: failed to load model\n", __func__);
delete ctx;
return nullptr;

View file

@ -45,7 +45,7 @@ extern "C" {
} llama_token_data;
typedef void (*llama_progress_handler)(double progress, void *ctx);
typedef void (*llama_progress_callback)(double progress, void *ctx);
struct llama_context_params {
int n_ctx; // text context
@ -56,8 +56,10 @@ extern "C" {
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
llama_progress_handler progress_callback; // called with a progress value between 0 and 1, pass NULL to disable
void * progress_ctx; // context pointer passed to the progress callback
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
};
LLAMA_API struct llama_context_params llama_context_default_params();