Use llama_ prefix for structs in global namespace.

This commit is contained in:
Matt Pulver 2023-08-25 09:36:34 -04:00
parent 93daad763d
commit fa33614b4d

View file

@ -4349,7 +4349,7 @@ struct llama_beam {
}; };
// A struct for calculating logit-related info. // A struct for calculating logit-related info.
struct logit_info { struct llama_logit_info {
const float * const logits; const float * const logits;
const int n_vocab; const int n_vocab;
const float max_l; const float max_l;
@ -4358,7 +4358,7 @@ struct logit_info {
float max_l; float max_l;
float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } float operator()(float sum, float l) const { return sum + std::exp(l - max_l); }
}; };
logit_info(llama_context * ctx) llama_logit_info(llama_context * ctx)
: logits(llama_get_logits(ctx)) : logits(llama_get_logits(ctx))
, n_vocab(llama_n_vocab(ctx)) , n_vocab(llama_n_vocab(ctx))
, max_l(*std::max_element(logits, logits + n_vocab)) , max_l(*std::max_element(logits, logits + n_vocab))
@ -4393,7 +4393,7 @@ struct logit_info {
} }
}; };
struct beam_search { struct llama_beam_search_data {
llama_context * ctx; llama_context * ctx;
size_t n_beams; size_t n_beams;
int n_past; int n_past;
@ -4408,7 +4408,7 @@ struct beam_search {
// Used to communicate to/from callback on beams state. // Used to communicate to/from callback on beams state.
std::vector<llama_beam_view> beam_views; std::vector<llama_beam_view> beam_views;
beam_search(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads) llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads)
: ctx(ctx) : ctx(ctx)
, n_beams(n_beams) , n_beams(n_beams)
, n_past(n_past) , n_past(n_past)
@ -4452,7 +4452,7 @@ struct beam_search {
if (!beam.tokens.empty()) { if (!beam.tokens.empty()) {
llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads); llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads);
} }
logit_info logit_info(ctx); llama_logit_info logit_info(ctx);
std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams); std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams);
size_t i=0; size_t i=0;
if (next_beams.size() < n_beams) { if (next_beams.size() < n_beams) {
@ -4569,9 +4569,9 @@ void llama_beam_search(llama_context * ctx,
assert(ctx); assert(ctx);
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
beam_search beam_search(ctx, n_beams, n_past, n_predict, n_threads); llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads);
beam_search.loop(callback, callback_data); beam_search_data.loop(callback, callback_data);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
ctx->n_sample++; ctx->n_sample++;