From bd5b24e8b6705a8bfd0e45b508712b1b0dc622aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Tue, 24 Sep 2024 23:52:09 +0200 Subject: [PATCH] compress: cleanup --- examples/compress/compress.cpp | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/examples/compress/compress.cpp b/examples/compress/compress.cpp index b800ab645..736a5bf6a 100644 --- a/examples/compress/compress.cpp +++ b/examples/compress/compress.cpp @@ -37,6 +37,8 @@ int msB_log256(int x) const int block_header_size = 2; const int fixed_token_cost = 1; +int total_pad = 0; + std::vector encode(llama_context *ctx, std::vector inp, gpt_sampler *smpl, int num_raw_tokens_header) { @@ -62,7 +64,6 @@ std::vector encode(llama_context *ctx, std::vector inp, gp for (int index = num_raw_tokens_header; index < inp.size(); index++) { auto &cur_p = smpl->cur_p; // initialized by set_logits - // llama_sampler_apply(smpl->grmr, &cur_p); llama_sampler_apply(smpl->chain, &cur_p); int match = -1; @@ -121,12 +122,10 @@ std::vector encode(llama_context *ctx, std::vector inp, gp int sample_id = sample_ids[i]; uint8_t PAD = (8 - bit_offset % 8) % 8; uint8_t bytesize = (uint8_t)msB_log256(sample_id); - // LOG("pos: %d, bs: %d\n",sample_id, bytesize); // Big number, better save as token if (sample_id > PAD + (block_header_size + fixed_token_cost + bytesize) * 8) { - // LOG("End block\n"); // Close current block (0b1010 is block marker) if (was_block) { @@ -151,21 +150,18 @@ std::vector encode(llama_context *ctx, std::vector inp, gp } } bit_offset += PAD; + total_pad += PAD; if (bit_offset % 8) { LOG_ERR("Unreachable"); exit(-1); } - // LOG("\n%d",bit_offset/8); // 0b0101 is token marker - sample_ids_bitpacked.push_back(0b01010000 | bytesize); // put token bytes into sample_ids_bitpacked - // LOG("\n%d -> ",sample_id); for (int j = 0; j < bytesize; j++) { sample_ids_bitpacked.push_back(sample_id & 0xff); - LOG("%02x ", sample_id & 0xff); sample_id >>= 8; } if (sample_id) @@ -217,6 +213,7 @@ std::vector encode(llama_context *ctx, std::vector inp, gp int block_size = (bit_offset + PAD) / 8 - block_start; // endianness: big endian sample_ids_bitpacked[block_start + 1] = block_size & 0xff; + total_pad+=PAD; } llama_batch_free(batch); return sample_ids_bitpacked; @@ -245,7 +242,6 @@ std::vector decode(llama_context *ctx, gpt_sampler *smpl, std::vect auto token_str = llama_token_to_piece(ctx, token); LOG("%s", token_str.c_str()); } - LOG("\u001b[0m\u001b[37m"); if (llama_decode(ctx, batch)) { LOG_ERR("%s: llama_decode() failed\n", __func__); @@ -275,6 +271,7 @@ std::vector decode(llama_context *ctx, gpt_sampler *smpl, std::vect auto &cur_p = smpl->cur_p; // initialized by set_logits llama_sampler_apply(smpl->chain, &cur_p); + auto token_id = cur_p.data[sample_id].id; out.push_back(token_id); @@ -288,12 +285,10 @@ std::vector decode(llama_context *ctx, gpt_sampler *smpl, std::vect // print in red LOG("\u001b[31m%s", llama_token_to_piece(ctx, token_id).c_str()); LOG("\nExpected: %s", llama_token_to_piece(ctx, inp[num_raw_tokens_header + index]).c_str()); - // LOG("\n%d", num_raw_tokens_header + index); LOG("\n, Id: %d != %d", token_id, inp[num_raw_tokens_header + index]); LOG("\nPos: %d, bs:%d", sample_id, bytesize); // print sample_id bytes in hex - // LOG("\n %02x %02x", sample_ids_bitpacked[bit_index / 8], sample_ids_bitpacked[bit_index / 8 + 1]); LOG("\n"); for (int i = bytesize; i > 0; i--) { @@ -335,8 +330,8 @@ std::vector decode(llama_context *ctx, gpt_sampler *smpl, std::vect int sample_id = id; auto &cur_p = smpl->cur_p; // initialized by set_logits - // llama_sampler_apply(smpl->grmr, &cur_p); llama_sampler_apply(smpl->chain, &cur_p); + auto token_id = cur_p.data[sample_id].id; out.push_back(token_id); if (!inp.size() || token_id == inp[num_raw_tokens_header + index]) @@ -363,7 +358,6 @@ std::vector decode(llama_context *ctx, gpt_sampler *smpl, std::vect id = 0; } } - // LOG("\n(%d+%d)/8= %d\n",bit_index,PAD,(bit_index+PAD)/8); bit_index += PAD; } } @@ -554,10 +548,12 @@ int main(int argc, char **argv) if(!params.no_perf){ LOG("\nInput: %d characters (%d tokens)", params.prompt.length(), inp.size()); - float compressed_byte_per_token = (float)sample_ids_bitpacked.size() / (float)inp.size(); + float compressed_bits_per_token = 8 * (float)sample_ids_bitpacked.size() / (float)inp.size(); float compressed_bits_per_char = 8 * (float)sample_ids_bitpacked.size() / (float)params.prompt.length(); - LOG("\n%d compressed bytes,(%04f bytes per token, %04f bits per character)\n", (int)sample_ids_bitpacked.size(), compressed_byte_per_token, compressed_bits_per_char); + LOG("\n%d compressed bytes,(%04f bits per token, %04f bits per character)\n", (int)sample_ids_bitpacked.size(), compressed_bits_per_token, compressed_bits_per_char); + LOG("\n%d padding bits, (%04f bits per character without padding)", total_pad, compressed_bits_per_char - total_pad/(float)params.prompt.length()); + LOG("\nPPL (over)estimation: %04f (%04f with padding)", exp2(compressed_bits_per_token-total_pad/(float)inp.size()),exp2(compressed_bits_per_token)); } //maybe this needs to be changed if(params.out_file != "imatrix.dat"){ @@ -630,7 +626,7 @@ int main(int argc, char **argv) ofs.write((char*)&out_str[0], out_str.size()); ofs.close(); } - + llama_free(ctx); llama_free_model(model);