llama : llama_perf + option to disable timings during decode (#9355)
* llama : llama_perf + option to disable timings during decode ggml-ci * common : add llama_arg * Update src/llama.cpp Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com> * perf : separate functions in the API ggml-ci * perf : safer pointer handling + naming update ggml-ci * minor : better local var name * perf : abort on invalid sampler pointer ggml-ci --------- Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
bd35cb0ae3
commit
0abc6a2c25
23 changed files with 135 additions and 91 deletions
|
@ -1669,3 +1669,37 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
|||
|
||||
return LLAMA_DEFAULT_SEED;
|
||||
}
|
||||
|
||||
// perf
|
||||
|
||||
struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
|
||||
struct llama_perf_sampler_data data = {};
|
||||
|
||||
if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
|
||||
GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
|
||||
}
|
||||
|
||||
const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
|
||||
|
||||
data.t_sample_ms = 1e-3 * ctx->t_sample_us;
|
||||
data.n_sample = std::max(0, ctx->n_sample);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
void llama_perf_sampler_print(const struct llama_sampler * chain) {
|
||||
const auto data = llama_perf_sampler(chain);
|
||||
|
||||
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
|
||||
}
|
||||
|
||||
void llama_perf_sampler_reset(struct llama_sampler * chain) {
|
||||
if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
|
||||
GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
|
||||
}
|
||||
|
||||
auto * ctx = (struct llama_sampler_chain *) chain->ctx;
|
||||
|
||||
ctx->t_sample_us = ctx->n_sample = 0;
|
||||
}
|
||||
|
|
103
src/llama.cpp
103
src/llama.cpp
|
@ -2486,6 +2486,7 @@ struct llama_cparams {
|
|||
bool causal_attn;
|
||||
bool offload_kqv;
|
||||
bool flash_attn;
|
||||
bool no_perf;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
|
@ -6661,8 +6662,6 @@ static bool llm_load_tensors(
|
|||
bool use_mlock,
|
||||
llama_progress_callback progress_callback,
|
||||
void * progress_callback_user_data) {
|
||||
model.t_start_us = ggml_time_us();
|
||||
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
model.split_mode = split_mode;
|
||||
|
@ -8593,14 +8592,13 @@ static bool llm_load_tensors(
|
|||
}
|
||||
}
|
||||
|
||||
// loading time will be recalculate after the first eval, so
|
||||
// we take page faults deferred by mmap() into consideration
|
||||
model.t_load_us = ggml_time_us() - model.t_start_us;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
|
||||
static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
|
||||
model.t_start_us = ggml_time_us();
|
||||
|
||||
try {
|
||||
llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
|
||||
|
||||
|
@ -8662,6 +8660,10 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
|||
return -1;
|
||||
}
|
||||
|
||||
// loading time will be recalculate after the first eval, so
|
||||
// we take page faults deferred by mmap() into consideration
|
||||
model.t_load_us = ggml_time_us() - model.t_start_us;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -17949,6 +17951,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.embeddings =*/ false,
|
||||
/*.offload_kqv =*/ true,
|
||||
/*.flash_attn =*/ false,
|
||||
/*.no_perf =*/ true,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
};
|
||||
|
@ -18159,6 +18162,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
cparams.embeddings = params.embeddings;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
|
@ -20077,10 +20081,14 @@ void llama_synchronize(struct llama_context * ctx) {
|
|||
|
||||
// add the evaluation to the stats
|
||||
if (ctx->n_queued_tokens == 1) {
|
||||
ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
|
||||
if (!ctx->cparams.no_perf) {
|
||||
ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
|
||||
}
|
||||
ctx->n_eval++;
|
||||
} else if (ctx->n_queued_tokens > 1) {
|
||||
ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
|
||||
if (!ctx->cparams.no_perf) {
|
||||
ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
|
||||
}
|
||||
ctx->n_p_eval += ctx->n_queued_tokens;
|
||||
}
|
||||
|
||||
|
@ -20688,65 +20696,40 @@ const char * llama_print_system_info(void) {
|
|||
return s.c_str();
|
||||
}
|
||||
|
||||
void llama_perf_print(const void * ctx, enum llama_perf_type type) {
|
||||
switch (type) {
|
||||
case LLAMA_PERF_TYPE_CONTEXT:
|
||||
{
|
||||
const auto * p = (const struct llama_context *) ctx;
|
||||
struct llama_perf_context_data llama_perf_context(const struct llama_context * ctx) {
|
||||
struct llama_perf_context_data data = {};
|
||||
|
||||
const double t_start_ms = 1e-3 * p->t_start_us;
|
||||
const double t_end_ms = 1.00 * ggml_time_ms();
|
||||
const double t_load_ms = 1e-3 * p->t_load_us;
|
||||
const double t_p_eval_ms = 1e-3 * p->t_p_eval_us;
|
||||
const double t_eval_ms = 1e-3 * p->t_eval_us;
|
||||
|
||||
const int32_t n_p_eval = std::max(0, p->n_p_eval);
|
||||
const int32_t n_eval = std::max(1, p->n_eval);
|
||||
|
||||
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, t_load_ms);
|
||||
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, t_p_eval_ms, n_p_eval, t_p_eval_ms / n_p_eval, 1e3 / t_p_eval_ms * n_p_eval);
|
||||
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, t_eval_ms, n_eval, t_eval_ms / n_eval, 1e3 / t_eval_ms * n_eval);
|
||||
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - t_start_ms), (n_p_eval + n_eval));
|
||||
} break;
|
||||
case LLAMA_PERF_TYPE_SAMPLER_CHAIN:
|
||||
{
|
||||
const auto * smpl = (const struct llama_sampler *) ctx;
|
||||
const auto * p = (const struct llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
const double t_sampler_ms = 1e-3 * p->t_sample_us;
|
||||
|
||||
const int32_t n_sampler = std::max(0, p->n_sample);
|
||||
|
||||
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, t_sampler_ms, n_sampler, t_sampler_ms / n_sampler, 1e3 / t_sampler_ms * n_sampler);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("invalid perf type");
|
||||
if (ctx == nullptr) {
|
||||
return data;
|
||||
}
|
||||
|
||||
data.t_start_ms = 1e-3 * ctx->t_start_us;
|
||||
data.t_load_ms = 1e-3 * ctx->t_load_us;
|
||||
data.t_p_eval_ms = 1e-3 * ctx->t_p_eval_us;
|
||||
data.t_eval_ms = 1e-3 * ctx->t_eval_us;
|
||||
data.n_p_eval = std::max(1, ctx->n_p_eval);
|
||||
data.n_eval = std::max(1, ctx->n_eval);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
void llama_perf_reset(void * ctx, enum llama_perf_type type) {
|
||||
switch (type) {
|
||||
case LLAMA_PERF_TYPE_CONTEXT:
|
||||
{
|
||||
auto * p = (struct llama_context *) ctx;
|
||||
void llama_perf_context_print(const struct llama_context * ctx) {
|
||||
const auto data = llama_perf_context(ctx);
|
||||
|
||||
p->t_start_us = ggml_time_us();
|
||||
p->t_eval_us = p->n_eval = 0;
|
||||
p->t_p_eval_us = p->n_p_eval = 0;
|
||||
} break;
|
||||
case LLAMA_PERF_TYPE_SAMPLER_CHAIN:
|
||||
{
|
||||
auto * smpl = (struct llama_sampler *) ctx;
|
||||
auto * p = (struct llama_sampler_chain *) smpl->ctx;
|
||||
const double t_end_ms = 1e-3 * ggml_time_us();
|
||||
|
||||
p->t_sample_us = p->n_sample = 0;
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("invalid perf type");
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
||||
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
||||
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
||||
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
||||
}
|
||||
|
||||
void llama_perf_context_reset(struct llama_context * ctx) {
|
||||
ctx->t_start_us = ggml_time_us();
|
||||
ctx->t_eval_us = ctx->n_eval = 0;
|
||||
ctx->t_p_eval_us = ctx->n_p_eval = 0;
|
||||
}
|
||||
|
||||
void llama_perf_dump_yaml(FILE * stream, const llama_context * ctx) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue