llama : refactor sampling v2 (#9294)
- Add `struct llama_sampler` and `struct llama_sampler_i` - Add `llama_sampler_` API - Add `llama_sampler_chain_` API for chaining multiple samplers - Remove `LLAMA_API_INTERNAL` - Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`
This commit is contained in:
parent
947538acb8
commit
df270ef745
48 changed files with 3497 additions and 2914 deletions
|
@ -40,11 +40,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
|
|||
return true;
|
||||
}
|
||||
|
||||
static const char * sample(struct llama_sampling_context * ctx_sampling,
|
||||
static const char * sample(struct gpt_sampler * smpl,
|
||||
struct llama_context * ctx_llama,
|
||||
int * n_past) {
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
|
||||
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
|
||||
const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
static std::string ret;
|
||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||
ret = "</s>";
|
||||
|
@ -191,15 +191,15 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
|||
|
||||
LOG_TEE("\n");
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
|
||||
if (!ctx_sampling) {
|
||||
struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
|
||||
if (!smpl) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string response = "";
|
||||
for (int i = 0; i < max_tgt_len; i++) {
|
||||
const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
|
||||
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0) break;
|
||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||
|
@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
|||
fflush(stdout);
|
||||
}
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
gpt_sampler_free(smpl);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
|
@ -310,7 +310,7 @@ int main(int argc, char ** argv) {
|
|||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_print_timings(ctx_llava->ctx_llama);
|
||||
llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
|
@ -327,7 +327,7 @@ int main(int argc, char ** argv) {
|
|||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_print_timings(ctx_llava->ctx_llama);
|
||||
llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue