compress: cleanup

This commit is contained in:
Stéphane du Hamel 2024-09-24 23:52:09 +02:00
parent 1146007610
commit bd5b24e8b6

View file

@ -37,6 +37,8 @@ int msB_log256(int x)
const int block_header_size = 2;
const int fixed_token_cost = 1;
int total_pad = 0;
std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gpt_sampler *smpl, int num_raw_tokens_header)
{
@ -62,7 +64,6 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
for (int index = num_raw_tokens_header; index < inp.size(); index++)
{
auto &cur_p = smpl->cur_p; // initialized by set_logits
// llama_sampler_apply(smpl->grmr, &cur_p);
llama_sampler_apply(smpl->chain, &cur_p);
int match = -1;
@ -121,12 +122,10 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
int sample_id = sample_ids[i];
uint8_t PAD = (8 - bit_offset % 8) % 8;
uint8_t bytesize = (uint8_t)msB_log256(sample_id);
// LOG("pos: %d, bs: %d\n",sample_id, bytesize);
// Big number, better save as token
if (sample_id > PAD + (block_header_size + fixed_token_cost + bytesize) * 8)
{
// LOG("End block\n");
// Close current block (0b1010 is block marker)
if (was_block)
{
@ -151,21 +150,18 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
}
}
bit_offset += PAD;
total_pad += PAD;
if (bit_offset % 8)
{
LOG_ERR("Unreachable");
exit(-1);
}
// LOG("\n%d",bit_offset/8);
// 0b0101 is token marker
sample_ids_bitpacked.push_back(0b01010000 | bytesize);
// put token bytes into sample_ids_bitpacked
// LOG("\n%d -> ",sample_id);
for (int j = 0; j < bytesize; j++)
{
sample_ids_bitpacked.push_back(sample_id & 0xff);
LOG("%02x ", sample_id & 0xff);
sample_id >>= 8;
}
if (sample_id)
@ -217,6 +213,7 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
int block_size = (bit_offset + PAD) / 8 - block_start;
// endianness: big endian
sample_ids_bitpacked[block_start + 1] = block_size & 0xff;
total_pad+=PAD;
}
llama_batch_free(batch);
return sample_ids_bitpacked;
@ -245,7 +242,6 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
auto token_str = llama_token_to_piece(ctx, token);
LOG("%s", token_str.c_str());
}
LOG("\u001b[0m\u001b[37m");
if (llama_decode(ctx, batch))
{
LOG_ERR("%s: llama_decode() failed\n", __func__);
@ -275,6 +271,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
auto &cur_p = smpl->cur_p; // initialized by set_logits
llama_sampler_apply(smpl->chain, &cur_p);
auto token_id = cur_p.data[sample_id].id;
out.push_back(token_id);
@ -288,12 +285,10 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
// print in red
LOG("\u001b[31m%s", llama_token_to_piece(ctx, token_id).c_str());
LOG("\nExpected: %s", llama_token_to_piece(ctx, inp[num_raw_tokens_header + index]).c_str());
// LOG("\n%d", num_raw_tokens_header + index);
LOG("\n, Id: %d != %d", token_id, inp[num_raw_tokens_header + index]);
LOG("\nPos: %d, bs:%d", sample_id, bytesize);
// print sample_id bytes in hex
// LOG("\n %02x %02x", sample_ids_bitpacked[bit_index / 8], sample_ids_bitpacked[bit_index / 8 + 1]);
LOG("\n");
for (int i = bytesize; i > 0; i--)
{
@ -335,8 +330,8 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
int sample_id = id;
auto &cur_p = smpl->cur_p; // initialized by set_logits
// llama_sampler_apply(smpl->grmr, &cur_p);
llama_sampler_apply(smpl->chain, &cur_p);
auto token_id = cur_p.data[sample_id].id;
out.push_back(token_id);
if (!inp.size() || token_id == inp[num_raw_tokens_header + index])
@ -363,7 +358,6 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
id = 0;
}
}
// LOG("\n(%d+%d)/8= %d\n",bit_index,PAD,(bit_index+PAD)/8);
bit_index += PAD;
}
}
@ -554,10 +548,12 @@ int main(int argc, char **argv)
if(!params.no_perf){
LOG("\nInput: %d characters (%d tokens)", params.prompt.length(), inp.size());
float compressed_byte_per_token = (float)sample_ids_bitpacked.size() / (float)inp.size();
float compressed_bits_per_token = 8 * (float)sample_ids_bitpacked.size() / (float)inp.size();
float compressed_bits_per_char = 8 * (float)sample_ids_bitpacked.size() / (float)params.prompt.length();
LOG("\n%d compressed bytes,(%04f bytes per token, %04f bits per character)\n", (int)sample_ids_bitpacked.size(), compressed_byte_per_token, compressed_bits_per_char);
LOG("\n%d compressed bytes,(%04f bits per token, %04f bits per character)\n", (int)sample_ids_bitpacked.size(), compressed_bits_per_token, compressed_bits_per_char);
LOG("\n%d padding bits, (%04f bits per character without padding)", total_pad, compressed_bits_per_char - total_pad/(float)params.prompt.length());
LOG("\nPPL (over)estimation: %04f (%04f with padding)", exp2(compressed_bits_per_token-total_pad/(float)inp.size()),exp2(compressed_bits_per_token));
}
//maybe this needs to be changed
if(params.out_file != "imatrix.dat"){
@ -630,7 +626,7 @@ int main(int argc, char **argv)
ofs.write((char*)&out_str[0], out_str.size());
ofs.close();
}
llama_free(ctx);
llama_free_model(model);