Don't double count the sample time

This commit is contained in:
Howard Su 2023-07-05 09:10:54 +08:00
parent 7f0e9a775e
commit 7b007d4b5e

View file

@ -1896,10 +1896,10 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
return;
}
const int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax(ctx, candidates);
const int64_t t_start_sample_us = ggml_time_us();
// Compute the cumulative probabilities
float cum_sum = 0.0f;
size_t last_idx = candidates->size;
@ -1928,9 +1928,8 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array *
return;
}
const int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax(nullptr, candidates);
const int64_t t_start_sample_us = ggml_time_us();
// Compute the first and second derivatives
std::vector<float> first_derivatives(candidates->size - 1);
@ -1982,11 +1981,11 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
return;
}
const int64_t t_start_sample_us = ggml_time_us();
// Compute the softmax of logits and calculate entropy
llama_sample_softmax(nullptr, candidates);
const int64_t t_start_sample_us = ggml_time_us();
float entropy = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) {
entropy += -candidates->data[i].p * logf(candidates->data[i].p);
@ -2155,13 +2154,11 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_
if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
ctx->n_sample++;
}
return X;
}
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
assert(ctx);
int64_t t_start_sample_us;
t_start_sample_us = ggml_time_us();
@ -2176,13 +2173,14 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok
candidates->size = 1;
}
if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
// Normalize the probabilities of the remaining words
llama_sample_softmax(ctx, candidates);
// Sample the next word X from the remaining words
if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
llama_token X = llama_sample_token(ctx, candidates);
t_start_sample_us = ggml_time_us();