compress: remove sampling.cpp dependency
This commit is contained in:
parent
bec83989be
commit
da444fafd7
1 changed files with 13 additions and 16 deletions
|
@ -1,7 +1,6 @@
|
|||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "sampling.cpp"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
|
@ -60,16 +59,15 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
|
|||
|
||||
std::vector<int> sample_ids;
|
||||
|
||||
smpl->set_logits(ctx, num_raw_tokens_header - 1);
|
||||
gpt_sampler_sample(smpl, ctx, num_raw_tokens_header - 1, true);
|
||||
for (int index = num_raw_tokens_header; index < inp.size(); index++)
|
||||
{
|
||||
auto &cur_p = smpl->cur_p; // initialized by set_logits
|
||||
llama_sampler_apply(smpl->chain, &cur_p);
|
||||
auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits
|
||||
|
||||
int match = -1;
|
||||
for (int i = 0; i < cur_p.size; i++)
|
||||
for (int i = 0; i < cur_p->size; i++)
|
||||
{
|
||||
auto tok = cur_p.data[i];
|
||||
auto tok = cur_p->data[i];
|
||||
llama_token candidate = tok.id;
|
||||
if (candidate == inp[index])
|
||||
{
|
||||
|
@ -91,7 +89,7 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
|
|||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||
exit(1);
|
||||
}
|
||||
smpl->set_logits(ctx, 0);
|
||||
gpt_sampler_sample(smpl, ctx, 0, true);
|
||||
}
|
||||
|
||||
// bit pack sample_ids
|
||||
|
@ -247,7 +245,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
|||
exit(1);
|
||||
}
|
||||
|
||||
smpl->set_logits(ctx, num_raw_tokens_header - 1);
|
||||
gpt_sampler_sample(smpl, ctx, num_raw_tokens_header - 1, true);
|
||||
|
||||
int index = 0;
|
||||
int bit_index = (1 + num_raw_tokens_header * 4) * 8;
|
||||
|
@ -268,10 +266,9 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
|||
sample_id |= (int)sample_ids_bitpacked[i + (bit_index / 8)];
|
||||
}
|
||||
|
||||
auto &cur_p = smpl->cur_p; // initialized by set_logits
|
||||
llama_sampler_apply(smpl->chain, &cur_p);
|
||||
auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits
|
||||
|
||||
auto token_id = cur_p.data[sample_id].id;
|
||||
auto token_id = cur_p->data[sample_id].id;
|
||||
|
||||
out.push_back(token_id);
|
||||
|
||||
|
@ -303,7 +300,8 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
|||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||
exit(1);
|
||||
}
|
||||
smpl->set_logits(ctx, 0);
|
||||
gpt_sampler_sample(smpl, ctx, 0, true);
|
||||
|
||||
index++;
|
||||
|
||||
bit_index += 8 * (fixed_token_cost + bytesize);
|
||||
|
@ -328,10 +326,9 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
|||
{
|
||||
int sample_id = id;
|
||||
|
||||
auto &cur_p = smpl->cur_p; // initialized by set_logits
|
||||
llama_sampler_apply(smpl->chain, &cur_p);
|
||||
auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits
|
||||
|
||||
auto token_id = cur_p.data[sample_id].id;
|
||||
auto token_id = cur_p->data[sample_id].id;
|
||||
out.push_back(token_id);
|
||||
if (!inp.size() || token_id == inp[num_raw_tokens_header + index])
|
||||
{
|
||||
|
@ -350,7 +347,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
|||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||
exit(1);
|
||||
}
|
||||
smpl->set_logits(ctx, 0);
|
||||
gpt_sampler_sample(smpl, ctx, 0, true);
|
||||
}
|
||||
index++;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue