compress: remove sampling.cpp dependency

This commit is contained in:
Stéphane du Hamel 2024-09-25 11:56:47 +02:00
parent bec83989be
commit da444fafd7

View file

@ -1,7 +1,6 @@
#include "arg.h"
#include "common.h"
#include "sampling.h"
#include "sampling.cpp"
#include "log.h"
#include "llama.h"
@ -60,16 +59,15 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
std::vector<int> sample_ids;
smpl->set_logits(ctx, num_raw_tokens_header - 1);
gpt_sampler_sample(smpl, ctx, num_raw_tokens_header - 1, true);
for (int index = num_raw_tokens_header; index < inp.size(); index++)
{
auto &cur_p = smpl->cur_p; // initialized by set_logits
llama_sampler_apply(smpl->chain, &cur_p);
auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits
int match = -1;
for (int i = 0; i < cur_p.size; i++)
for (int i = 0; i < cur_p->size; i++)
{
auto tok = cur_p.data[i];
auto tok = cur_p->data[i];
llama_token candidate = tok.id;
if (candidate == inp[index])
{
@ -91,7 +89,7 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
exit(1);
}
smpl->set_logits(ctx, 0);
gpt_sampler_sample(smpl, ctx, 0, true);
}
// bit pack sample_ids
@ -247,7 +245,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
exit(1);
}
smpl->set_logits(ctx, num_raw_tokens_header - 1);
gpt_sampler_sample(smpl, ctx, num_raw_tokens_header - 1, true);
int index = 0;
int bit_index = (1 + num_raw_tokens_header * 4) * 8;
@ -268,10 +266,9 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
sample_id |= (int)sample_ids_bitpacked[i + (bit_index / 8)];
}
auto &cur_p = smpl->cur_p; // initialized by set_logits
llama_sampler_apply(smpl->chain, &cur_p);
auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits
auto token_id = cur_p.data[sample_id].id;
auto token_id = cur_p->data[sample_id].id;
out.push_back(token_id);
@ -303,7 +300,8 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
exit(1);
}
smpl->set_logits(ctx, 0);
gpt_sampler_sample(smpl, ctx, 0, true);
index++;
bit_index += 8 * (fixed_token_cost + bytesize);
@ -328,10 +326,9 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
{
int sample_id = id;
auto &cur_p = smpl->cur_p; // initialized by set_logits
llama_sampler_apply(smpl->chain, &cur_p);
auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits
auto token_id = cur_p.data[sample_id].id;
auto token_id = cur_p->data[sample_id].id;
out.push_back(token_id);
if (!inp.size() || token_id == inp[num_raw_tokens_header + index])
{
@ -350,7 +347,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
exit(1);
}
smpl->set_logits(ctx, 0);
gpt_sampler_sample(smpl, ctx, 0, true);
}
index++;