compress: remove sampling.cpp dependency
This commit is contained in:
parent
bec83989be
commit
da444fafd7
1 changed files with 13 additions and 16 deletions
|
@ -1,7 +1,6 @@
|
||||||
#include "arg.h"
|
#include "arg.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include "sampling.cpp"
|
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
|
@ -60,16 +59,15 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
|
||||||
|
|
||||||
std::vector<int> sample_ids;
|
std::vector<int> sample_ids;
|
||||||
|
|
||||||
smpl->set_logits(ctx, num_raw_tokens_header - 1);
|
gpt_sampler_sample(smpl, ctx, num_raw_tokens_header - 1, true);
|
||||||
for (int index = num_raw_tokens_header; index < inp.size(); index++)
|
for (int index = num_raw_tokens_header; index < inp.size(); index++)
|
||||||
{
|
{
|
||||||
auto &cur_p = smpl->cur_p; // initialized by set_logits
|
auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits
|
||||||
llama_sampler_apply(smpl->chain, &cur_p);
|
|
||||||
|
|
||||||
int match = -1;
|
int match = -1;
|
||||||
for (int i = 0; i < cur_p.size; i++)
|
for (int i = 0; i < cur_p->size; i++)
|
||||||
{
|
{
|
||||||
auto tok = cur_p.data[i];
|
auto tok = cur_p->data[i];
|
||||||
llama_token candidate = tok.id;
|
llama_token candidate = tok.id;
|
||||||
if (candidate == inp[index])
|
if (candidate == inp[index])
|
||||||
{
|
{
|
||||||
|
@ -91,7 +89,7 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
|
||||||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
smpl->set_logits(ctx, 0);
|
gpt_sampler_sample(smpl, ctx, 0, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// bit pack sample_ids
|
// bit pack sample_ids
|
||||||
|
@ -247,7 +245,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
smpl->set_logits(ctx, num_raw_tokens_header - 1);
|
gpt_sampler_sample(smpl, ctx, num_raw_tokens_header - 1, true);
|
||||||
|
|
||||||
int index = 0;
|
int index = 0;
|
||||||
int bit_index = (1 + num_raw_tokens_header * 4) * 8;
|
int bit_index = (1 + num_raw_tokens_header * 4) * 8;
|
||||||
|
@ -268,10 +266,9 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
||||||
sample_id |= (int)sample_ids_bitpacked[i + (bit_index / 8)];
|
sample_id |= (int)sample_ids_bitpacked[i + (bit_index / 8)];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto &cur_p = smpl->cur_p; // initialized by set_logits
|
auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits
|
||||||
llama_sampler_apply(smpl->chain, &cur_p);
|
|
||||||
|
|
||||||
auto token_id = cur_p.data[sample_id].id;
|
auto token_id = cur_p->data[sample_id].id;
|
||||||
|
|
||||||
out.push_back(token_id);
|
out.push_back(token_id);
|
||||||
|
|
||||||
|
@ -303,7 +300,8 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
||||||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
smpl->set_logits(ctx, 0);
|
gpt_sampler_sample(smpl, ctx, 0, true);
|
||||||
|
|
||||||
index++;
|
index++;
|
||||||
|
|
||||||
bit_index += 8 * (fixed_token_cost + bytesize);
|
bit_index += 8 * (fixed_token_cost + bytesize);
|
||||||
|
@ -328,10 +326,9 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
||||||
{
|
{
|
||||||
int sample_id = id;
|
int sample_id = id;
|
||||||
|
|
||||||
auto &cur_p = smpl->cur_p; // initialized by set_logits
|
auto cur_p = gpt_sampler_get_candidates(smpl); // initialized by set_logits
|
||||||
llama_sampler_apply(smpl->chain, &cur_p);
|
|
||||||
|
|
||||||
auto token_id = cur_p.data[sample_id].id;
|
auto token_id = cur_p->data[sample_id].id;
|
||||||
out.push_back(token_id);
|
out.push_back(token_id);
|
||||||
if (!inp.size() || token_id == inp[num_raw_tokens_header + index])
|
if (!inp.size() || token_id == inp[num_raw_tokens_header + index])
|
||||||
{
|
{
|
||||||
|
@ -350,7 +347,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
||||||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
smpl->set_logits(ctx, 0);
|
gpt_sampler_sample(smpl, ctx, 0, true);
|
||||||
}
|
}
|
||||||
index++;
|
index++;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue