compress: format

This commit is contained in:
Stéphane du Hamel 2024-09-25 01:26:39 +02:00
parent b9a32f464f
commit bec83989be

View file

@ -78,7 +78,8 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
break; break;
} }
} }
if(match<0){ if (match < 0)
{
LOG_ERR("\n couldn't match %s", llama_token_to_piece(ctx, inp[index]).c_str()); LOG_ERR("\n couldn't match %s", llama_token_to_piece(ctx, inp[index]).c_str());
exit(1); exit(1);
} }
@ -133,14 +134,13 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
int block_size = (bit_offset + PAD) / 8 - block_start; int block_size = (bit_offset + PAD) / 8 - block_start;
if (block_size >= 256) if (block_size >= 256)
{ {
// TODO: handle more than 256 bytes of block data // TODO: handle more than 256 bytes of block data
// (maybe allow multiple blocks in a row) // (maybe allow multiple blocks in a row)
LOG_ERR("Block too big %d >= 256", block_size); LOG_ERR("Block too big %d >= 256", block_size);
exit(-1); exit(-1);
} }
sample_ids_bitpacked[block_start + 1] = block_size & 0xff; sample_ids_bitpacked[block_start + 1] = block_size & 0xff;
// put last bytes // put last bytes
if (PAD) if (PAD)
{ {
@ -212,7 +212,7 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
int block_size = (bit_offset + PAD) / 8 - block_start; int block_size = (bit_offset + PAD) / 8 - block_start;
// endianness: big endian // endianness: big endian
sample_ids_bitpacked[block_start + 1] = block_size & 0xff; sample_ids_bitpacked[block_start + 1] = block_size & 0xff;
total_pad+=PAD; total_pad += PAD;
} }
llama_batch_free(batch); llama_batch_free(batch);
return sample_ids_bitpacked; return sample_ids_bitpacked;
@ -330,7 +330,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
auto &cur_p = smpl->cur_p; // initialized by set_logits auto &cur_p = smpl->cur_p; // initialized by set_logits
llama_sampler_apply(smpl->chain, &cur_p); llama_sampler_apply(smpl->chain, &cur_p);
auto token_id = cur_p.data[sample_id].id; auto token_id = cur_p.data[sample_id].id;
out.push_back(token_id); out.push_back(token_id);
if (!inp.size() || token_id == inp[num_raw_tokens_header + index]) if (!inp.size() || token_id == inp[num_raw_tokens_header + index])
@ -482,7 +482,7 @@ int main(int argc, char **argv)
params.sparams.top_p = 1; params.sparams.top_p = 1;
params.sparams.top_k = -1; params.sparams.top_k = -1;
// Avoid temp=0 because greedy sampling breaks stuff // Avoid temp=0 because greedy sampling breaks stuff
params.sparams.temp = 1.; params.sparams.temp = 1.;
gpt_init(); gpt_init();
@ -544,38 +544,43 @@ int main(int argc, char **argv)
auto t_enc_end = ggml_time_us(); auto t_enc_end = ggml_time_us();
LOG("\n"); LOG("\n");
if(!params.no_perf){ if (!params.no_perf)
{
LOG("\nInput: %d characters (%d tokens)", params.prompt.length(), inp.size()); LOG("\nInput: %d characters (%d tokens)", params.prompt.length(), inp.size());
float compressed_bits_per_token = 8 * (float)sample_ids_bitpacked.size() / (float)inp.size(); float compressed_bits_per_token = 8 * (float)sample_ids_bitpacked.size() / (float)inp.size();
float compressed_bits_per_char = 8 * (float)sample_ids_bitpacked.size() / (float)params.prompt.length(); float compressed_bits_per_char = 8 * (float)sample_ids_bitpacked.size() / (float)params.prompt.length();
LOG("\n%d compressed bytes,(%04f bits per token, %04f bits per character)\n", (int)sample_ids_bitpacked.size(), compressed_bits_per_token, compressed_bits_per_char); LOG("\n%d compressed bytes,(%04f bits per token, %04f bits per character)\n", (int)sample_ids_bitpacked.size(), compressed_bits_per_token, compressed_bits_per_char);
LOG("\n%d padding bits, (%04f bits per character without padding)", total_pad, compressed_bits_per_char - total_pad/(float)params.prompt.length()); LOG("\n%d padding bits, (%04f bits per character without padding)", total_pad, compressed_bits_per_char - total_pad / (float)params.prompt.length());
LOG("\nPPL (over)estimation: %04f (%04f with padding)", exp2(compressed_bits_per_token-total_pad/(float)inp.size()),exp2(compressed_bits_per_token)); LOG("\nPPL (over)estimation: %04f (%04f with padding)", exp2(compressed_bits_per_token - total_pad / (float)inp.size()), exp2(compressed_bits_per_token));
} }
//maybe this needs to be changed // maybe this needs to be changed
if(params.out_file != "imatrix.dat"){ if (params.out_file != "imatrix.dat")
{
// dump uint8array to bin file // dump uint8array to bin file
std::ofstream ofs(params.out_file.c_str(), std::ios::binary); std::ofstream ofs(params.out_file.c_str(), std::ios::binary);
ofs.write((char*)&sample_ids_bitpacked[0], sample_ids_bitpacked.size()); ofs.write((char *)&sample_ids_bitpacked[0], sample_ids_bitpacked.size());
ofs.close(); ofs.close();
}else{ }
else
{
LOG("\n------------\n"); LOG("\n------------\n");
//print as hex to stdout // print as hex to stdout
for (int i = 0; i < sample_ids_bitpacked.size(); i++){ for (int i = 0; i < sample_ids_bitpacked.size(); i++)
{
LOG("%02X ", sample_ids_bitpacked[i]); LOG("%02X ", sample_ids_bitpacked[i]);
} }
} }
} }
else if (params.compress_mode == 2) else if (params.compress_mode == 2)
{ {
//decompress mode // decompress mode
// load sample_ids_bitpacked from params.prompt_file // load sample_ids_bitpacked from params.prompt_file
std::ifstream ifs(params.prompt_file.c_str(), std::ios::binary); std::ifstream ifs(params.prompt_file.c_str(), std::ios::binary);
if (!ifs) { if (!ifs)
{
LOG_ERR("%s: failed to open file\n", __func__); LOG_ERR("%s: failed to open file\n", __func__);
return -1; return -1;
} }
@ -588,14 +593,16 @@ int main(int argc, char **argv)
std::vector<uint8_t> sample_ids_bitpacked(fileSize); std::vector<uint8_t> sample_ids_bitpacked(fileSize);
// Read the ifs into the vector // Read the ifs into the vector
if (!ifs.read(reinterpret_cast<char*>(sample_ids_bitpacked.data()), fileSize)) { if (!ifs.read(reinterpret_cast<char *>(sample_ids_bitpacked.data()), fileSize))
{
LOG_ERR("%s: failed to read file\n", __func__); LOG_ERR("%s: failed to read file\n", __func__);
return -1; return -1;
} }
ifs.close(); ifs.close();
//Debug: print as hex // Debug: print as hex
for (int i = 0; i < sample_ids_bitpacked.size(); i++){ for (int i = 0; i < sample_ids_bitpacked.size(); i++)
{
LOG("%02X ", sample_ids_bitpacked[i]); LOG("%02X ", sample_ids_bitpacked[i]);
} }
LOG("\n"); LOG("\n");
@ -612,23 +619,22 @@ int main(int argc, char **argv)
std::vector<llama_token> out = decode(ctx, smpl, sample_ids_bitpacked); std::vector<llama_token> out = decode(ctx, smpl, sample_ids_bitpacked);
gpt_sampler_free(smpl); gpt_sampler_free(smpl);
auto t_dec_end = ggml_time_us(); auto t_dec_end = ggml_time_us();
//maybe this needs to be changed // maybe this needs to be changed
if(params.out_file != "imatrix.dat"){ if (params.out_file != "imatrix.dat")
{
// dump as string to file // dump as string to file
std::string out_str = ::llama_detokenize(ctx, out); std::string out_str = ::llama_detokenize(ctx, out);
std::ofstream ofs(params.out_file.c_str(), std::ios::binary); std::ofstream ofs(params.out_file.c_str(), std::ios::binary);
ofs.write((char*)&out_str[0], out_str.size()); ofs.write((char *)&out_str[0], out_str.size());
ofs.close(); ofs.close();
} }
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);
} }
llama_backend_free(); llama_backend_free();