From da444fafd76331240a4668114cd034622bc6c97c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Wed, 25 Sep 2024 11:56:47 +0200 Subject: [PATCH] compress: remove sampling.cpp dependency --- examples/compress/compress.cpp | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/examples/compress/compress.cpp b/examples/compress/compress.cpp index bd2756afa..62981ec19 100644 --- a/examples/compress/compress.cpp +++ b/examples/compress/compress.cpp @@ -1,7 +1,6 @@ #include "arg.h" #include "common.h" #include "sampling.h" -#include "sampling.cpp" #include "log.h" #include "llama.h" @@ -60,16 +59,15 @@ std::vector encode(llama_context *ctx, std::vector inp, gp std::vector sample_ids; - smpl->set_logits(ctx, num_raw_tokens_header - 1); + gpt_sampler_sample(smpl, ctx, num_raw_tokens_header - 1, true); for (int index = num_raw_tokens_header; index < inp.size(); index++) { - auto &cur_p = smpl->cur_p; // initialized by set_logits - llama_sampler_apply(smpl->chain, &cur_p); + auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits int match = -1; - for (int i = 0; i < cur_p.size; i++) + for (int i = 0; i < cur_p->size; i++) { - auto tok = cur_p.data[i]; + auto tok = cur_p->data[i]; llama_token candidate = tok.id; if (candidate == inp[index]) { @@ -91,7 +89,7 @@ std::vector encode(llama_context *ctx, std::vector inp, gp LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); exit(1); } - smpl->set_logits(ctx, 0); + gpt_sampler_sample(smpl, ctx, 0, true); } // bit pack sample_ids @@ -247,7 +245,7 @@ std::vector decode(llama_context *ctx, gpt_sampler *smpl, std::vect exit(1); } - smpl->set_logits(ctx, num_raw_tokens_header - 1); + gpt_sampler_sample(smpl, ctx, num_raw_tokens_header - 1, true); int index = 0; int bit_index = (1 + num_raw_tokens_header * 4) * 8; @@ -268,10 +266,9 @@ std::vector decode(llama_context *ctx, gpt_sampler *smpl, std::vect sample_id |= (int)sample_ids_bitpacked[i + (bit_index / 8)]; } - auto &cur_p = smpl->cur_p; // initialized by set_logits - llama_sampler_apply(smpl->chain, &cur_p); + auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits - auto token_id = cur_p.data[sample_id].id; + auto token_id = cur_p->data[sample_id].id; out.push_back(token_id); @@ -303,7 +300,8 @@ std::vector decode(llama_context *ctx, gpt_sampler *smpl, std::vect LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); exit(1); } - smpl->set_logits(ctx, 0); + gpt_sampler_sample(smpl, ctx, 0, true); + index++; bit_index += 8 * (fixed_token_cost + bytesize); @@ -328,10 +326,9 @@ std::vector decode(llama_context *ctx, gpt_sampler *smpl, std::vect { int sample_id = id; - auto &cur_p = smpl->cur_p; // initialized by set_logits - llama_sampler_apply(smpl->chain, &cur_p); + auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits - auto token_id = cur_p.data[sample_id].id; + auto token_id = cur_p->data[sample_id].id; out.push_back(token_id); if (!inp.size() || token_id == inp[num_raw_tokens_header + index]) { @@ -350,7 +347,7 @@ std::vector decode(llama_context *ctx, gpt_sampler *smpl, std::vect LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); exit(1); } - smpl->set_logits(ctx, 0); + gpt_sampler_sample(smpl, ctx, 0, true); } index++;