Use llama_ prefix for structs in global namespace.

This commit is contained in:
Matt Pulver 2023-08-25 09:36:34 -04:00
parent 93daad763d
commit fa33614b4d

View file

@ -4349,7 +4349,7 @@ struct llama_beam {
};
// A struct for calculating logit-related info.
struct logit_info {
struct llama_logit_info {
const float * const logits;
const int n_vocab;
const float max_l;
@ -4358,7 +4358,7 @@ struct logit_info {
float max_l;
float operator()(float sum, float l) const { return sum + std::exp(l - max_l); }
};
logit_info(llama_context * ctx)
llama_logit_info(llama_context * ctx)
: logits(llama_get_logits(ctx))
, n_vocab(llama_n_vocab(ctx))
, max_l(*std::max_element(logits, logits + n_vocab))
@ -4393,7 +4393,7 @@ struct logit_info {
}
};
struct beam_search {
struct llama_beam_search_data {
llama_context * ctx;
size_t n_beams;
int n_past;
@ -4408,7 +4408,7 @@ struct beam_search {
// Used to communicate to/from callback on beams state.
std::vector<llama_beam_view> beam_views;
beam_search(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads)
llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads)
: ctx(ctx)
, n_beams(n_beams)
, n_past(n_past)
@ -4452,7 +4452,7 @@ struct beam_search {
if (!beam.tokens.empty()) {
llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads);
}
logit_info logit_info(ctx);
llama_logit_info logit_info(ctx);
std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams);
size_t i=0;
if (next_beams.size() < n_beams) {
@ -4569,9 +4569,9 @@ void llama_beam_search(llama_context * ctx,
assert(ctx);
const int64_t t_start_sample_us = ggml_time_us();
beam_search beam_search(ctx, n_beams, n_past, n_predict, n_threads);
llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads);
beam_search.loop(callback, callback_data);
beam_search_data.loop(callback, callback_data);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
ctx->n_sample++;