compress: cleanup
This commit is contained in:
parent
1146007610
commit
bd5b24e8b6
1 changed files with 11 additions and 15 deletions
|
@ -37,6 +37,8 @@ int msB_log256(int x)
|
||||||
const int block_header_size = 2;
|
const int block_header_size = 2;
|
||||||
const int fixed_token_cost = 1;
|
const int fixed_token_cost = 1;
|
||||||
|
|
||||||
|
int total_pad = 0;
|
||||||
|
|
||||||
std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gpt_sampler *smpl, int num_raw_tokens_header)
|
std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gpt_sampler *smpl, int num_raw_tokens_header)
|
||||||
{
|
{
|
||||||
|
|
||||||
|
@ -62,7 +64,6 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
|
||||||
for (int index = num_raw_tokens_header; index < inp.size(); index++)
|
for (int index = num_raw_tokens_header; index < inp.size(); index++)
|
||||||
{
|
{
|
||||||
auto &cur_p = smpl->cur_p; // initialized by set_logits
|
auto &cur_p = smpl->cur_p; // initialized by set_logits
|
||||||
// llama_sampler_apply(smpl->grmr, &cur_p);
|
|
||||||
llama_sampler_apply(smpl->chain, &cur_p);
|
llama_sampler_apply(smpl->chain, &cur_p);
|
||||||
|
|
||||||
int match = -1;
|
int match = -1;
|
||||||
|
@ -121,12 +122,10 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
|
||||||
int sample_id = sample_ids[i];
|
int sample_id = sample_ids[i];
|
||||||
uint8_t PAD = (8 - bit_offset % 8) % 8;
|
uint8_t PAD = (8 - bit_offset % 8) % 8;
|
||||||
uint8_t bytesize = (uint8_t)msB_log256(sample_id);
|
uint8_t bytesize = (uint8_t)msB_log256(sample_id);
|
||||||
// LOG("pos: %d, bs: %d\n",sample_id, bytesize);
|
|
||||||
|
|
||||||
// Big number, better save as token
|
// Big number, better save as token
|
||||||
if (sample_id > PAD + (block_header_size + fixed_token_cost + bytesize) * 8)
|
if (sample_id > PAD + (block_header_size + fixed_token_cost + bytesize) * 8)
|
||||||
{
|
{
|
||||||
// LOG("End block\n");
|
|
||||||
// Close current block (0b1010 is block marker)
|
// Close current block (0b1010 is block marker)
|
||||||
if (was_block)
|
if (was_block)
|
||||||
{
|
{
|
||||||
|
@ -151,21 +150,18 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bit_offset += PAD;
|
bit_offset += PAD;
|
||||||
|
total_pad += PAD;
|
||||||
if (bit_offset % 8)
|
if (bit_offset % 8)
|
||||||
{
|
{
|
||||||
LOG_ERR("Unreachable");
|
LOG_ERR("Unreachable");
|
||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
// LOG("\n%d",bit_offset/8);
|
|
||||||
// 0b0101 is token marker
|
// 0b0101 is token marker
|
||||||
|
|
||||||
sample_ids_bitpacked.push_back(0b01010000 | bytesize);
|
sample_ids_bitpacked.push_back(0b01010000 | bytesize);
|
||||||
// put token bytes into sample_ids_bitpacked
|
// put token bytes into sample_ids_bitpacked
|
||||||
// LOG("\n%d -> ",sample_id);
|
|
||||||
for (int j = 0; j < bytesize; j++)
|
for (int j = 0; j < bytesize; j++)
|
||||||
{
|
{
|
||||||
sample_ids_bitpacked.push_back(sample_id & 0xff);
|
sample_ids_bitpacked.push_back(sample_id & 0xff);
|
||||||
LOG("%02x ", sample_id & 0xff);
|
|
||||||
sample_id >>= 8;
|
sample_id >>= 8;
|
||||||
}
|
}
|
||||||
if (sample_id)
|
if (sample_id)
|
||||||
|
@ -217,6 +213,7 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
|
||||||
int block_size = (bit_offset + PAD) / 8 - block_start;
|
int block_size = (bit_offset + PAD) / 8 - block_start;
|
||||||
// endianness: big endian
|
// endianness: big endian
|
||||||
sample_ids_bitpacked[block_start + 1] = block_size & 0xff;
|
sample_ids_bitpacked[block_start + 1] = block_size & 0xff;
|
||||||
|
total_pad+=PAD;
|
||||||
}
|
}
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
return sample_ids_bitpacked;
|
return sample_ids_bitpacked;
|
||||||
|
@ -245,7 +242,6 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
||||||
auto token_str = llama_token_to_piece(ctx, token);
|
auto token_str = llama_token_to_piece(ctx, token);
|
||||||
LOG("%s", token_str.c_str());
|
LOG("%s", token_str.c_str());
|
||||||
}
|
}
|
||||||
LOG("\u001b[0m\u001b[37m");
|
|
||||||
if (llama_decode(ctx, batch))
|
if (llama_decode(ctx, batch))
|
||||||
{
|
{
|
||||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||||
|
@ -275,6 +271,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
||||||
|
|
||||||
auto &cur_p = smpl->cur_p; // initialized by set_logits
|
auto &cur_p = smpl->cur_p; // initialized by set_logits
|
||||||
llama_sampler_apply(smpl->chain, &cur_p);
|
llama_sampler_apply(smpl->chain, &cur_p);
|
||||||
|
|
||||||
auto token_id = cur_p.data[sample_id].id;
|
auto token_id = cur_p.data[sample_id].id;
|
||||||
|
|
||||||
out.push_back(token_id);
|
out.push_back(token_id);
|
||||||
|
@ -288,12 +285,10 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
||||||
// print in red
|
// print in red
|
||||||
LOG("\u001b[31m%s", llama_token_to_piece(ctx, token_id).c_str());
|
LOG("\u001b[31m%s", llama_token_to_piece(ctx, token_id).c_str());
|
||||||
LOG("\nExpected: %s", llama_token_to_piece(ctx, inp[num_raw_tokens_header + index]).c_str());
|
LOG("\nExpected: %s", llama_token_to_piece(ctx, inp[num_raw_tokens_header + index]).c_str());
|
||||||
// LOG("\n%d", num_raw_tokens_header + index);
|
|
||||||
LOG("\n, Id: %d != %d", token_id, inp[num_raw_tokens_header + index]);
|
LOG("\n, Id: %d != %d", token_id, inp[num_raw_tokens_header + index]);
|
||||||
LOG("\nPos: %d, bs:%d", sample_id, bytesize);
|
LOG("\nPos: %d, bs:%d", sample_id, bytesize);
|
||||||
|
|
||||||
// print sample_id bytes in hex
|
// print sample_id bytes in hex
|
||||||
// LOG("\n %02x %02x", sample_ids_bitpacked[bit_index / 8], sample_ids_bitpacked[bit_index / 8 + 1]);
|
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
for (int i = bytesize; i > 0; i--)
|
for (int i = bytesize; i > 0; i--)
|
||||||
{
|
{
|
||||||
|
@ -335,8 +330,8 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
||||||
int sample_id = id;
|
int sample_id = id;
|
||||||
|
|
||||||
auto &cur_p = smpl->cur_p; // initialized by set_logits
|
auto &cur_p = smpl->cur_p; // initialized by set_logits
|
||||||
// llama_sampler_apply(smpl->grmr, &cur_p);
|
|
||||||
llama_sampler_apply(smpl->chain, &cur_p);
|
llama_sampler_apply(smpl->chain, &cur_p);
|
||||||
|
|
||||||
auto token_id = cur_p.data[sample_id].id;
|
auto token_id = cur_p.data[sample_id].id;
|
||||||
out.push_back(token_id);
|
out.push_back(token_id);
|
||||||
if (!inp.size() || token_id == inp[num_raw_tokens_header + index])
|
if (!inp.size() || token_id == inp[num_raw_tokens_header + index])
|
||||||
|
@ -363,7 +358,6 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
|
||||||
id = 0;
|
id = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// LOG("\n(%d+%d)/8= %d\n",bit_index,PAD,(bit_index+PAD)/8);
|
|
||||||
bit_index += PAD;
|
bit_index += PAD;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -554,10 +548,12 @@ int main(int argc, char **argv)
|
||||||
if(!params.no_perf){
|
if(!params.no_perf){
|
||||||
LOG("\nInput: %d characters (%d tokens)", params.prompt.length(), inp.size());
|
LOG("\nInput: %d characters (%d tokens)", params.prompt.length(), inp.size());
|
||||||
|
|
||||||
float compressed_byte_per_token = (float)sample_ids_bitpacked.size() / (float)inp.size();
|
float compressed_bits_per_token = 8 * (float)sample_ids_bitpacked.size() / (float)inp.size();
|
||||||
float compressed_bits_per_char = 8 * (float)sample_ids_bitpacked.size() / (float)params.prompt.length();
|
float compressed_bits_per_char = 8 * (float)sample_ids_bitpacked.size() / (float)params.prompt.length();
|
||||||
|
|
||||||
LOG("\n%d compressed bytes,(%04f bytes per token, %04f bits per character)\n", (int)sample_ids_bitpacked.size(), compressed_byte_per_token, compressed_bits_per_char);
|
LOG("\n%d compressed bytes,(%04f bits per token, %04f bits per character)\n", (int)sample_ids_bitpacked.size(), compressed_bits_per_token, compressed_bits_per_char);
|
||||||
|
LOG("\n%d padding bits, (%04f bits per character without padding)", total_pad, compressed_bits_per_char - total_pad/(float)params.prompt.length());
|
||||||
|
LOG("\nPPL (over)estimation: %04f (%04f with padding)", exp2(compressed_bits_per_token-total_pad/(float)inp.size()),exp2(compressed_bits_per_token));
|
||||||
}
|
}
|
||||||
//maybe this needs to be changed
|
//maybe this needs to be changed
|
||||||
if(params.out_file != "imatrix.dat"){
|
if(params.out_file != "imatrix.dat"){
|
||||||
|
@ -630,7 +626,7 @@ int main(int argc, char **argv)
|
||||||
ofs.write((char*)&out_str[0], out_str.size());
|
ofs.write((char*)&out_str[0], out_str.size());
|
||||||
ofs.close();
|
ofs.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue