From bec83989bed431266ac4d26535b722a6361fade0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Wed, 25 Sep 2024 01:26:39 +0200 Subject: [PATCH] compress: format --- examples/compress/compress.cpp | 60 +++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/examples/compress/compress.cpp b/examples/compress/compress.cpp index a0f79005f..bd2756afa 100644 --- a/examples/compress/compress.cpp +++ b/examples/compress/compress.cpp @@ -78,7 +78,8 @@ std::vector encode(llama_context *ctx, std::vector inp, gp break; } } - if(match<0){ + if (match < 0) + { LOG_ERR("\n couldn't match %s", llama_token_to_piece(ctx, inp[index]).c_str()); exit(1); } @@ -133,14 +134,13 @@ std::vector encode(llama_context *ctx, std::vector inp, gp int block_size = (bit_offset + PAD) / 8 - block_start; if (block_size >= 256) { - // TODO: handle more than 256 bytes of block data + // TODO: handle more than 256 bytes of block data // (maybe allow multiple blocks in a row) LOG_ERR("Block too big %d >= 256", block_size); exit(-1); } sample_ids_bitpacked[block_start + 1] = block_size & 0xff; - // put last bytes if (PAD) { @@ -212,7 +212,7 @@ std::vector encode(llama_context *ctx, std::vector inp, gp int block_size = (bit_offset + PAD) / 8 - block_start; // endianness: big endian sample_ids_bitpacked[block_start + 1] = block_size & 0xff; - total_pad+=PAD; + total_pad += PAD; } llama_batch_free(batch); return sample_ids_bitpacked; @@ -330,7 +330,7 @@ std::vector decode(llama_context *ctx, gpt_sampler *smpl, std::vect auto &cur_p = smpl->cur_p; // initialized by set_logits llama_sampler_apply(smpl->chain, &cur_p); - + auto token_id = cur_p.data[sample_id].id; out.push_back(token_id); if (!inp.size() || token_id == inp[num_raw_tokens_header + index]) @@ -482,7 +482,7 @@ int main(int argc, char **argv) params.sparams.top_p = 1; params.sparams.top_k = -1; // Avoid temp=0 because greedy sampling breaks stuff - params.sparams.temp = 1.; + params.sparams.temp = 1.; gpt_init(); @@ -544,38 +544,43 @@ int main(int argc, char **argv) auto t_enc_end = ggml_time_us(); LOG("\n"); - if(!params.no_perf){ + if (!params.no_perf) + { LOG("\nInput: %d characters (%d tokens)", params.prompt.length(), inp.size()); float compressed_bits_per_token = 8 * (float)sample_ids_bitpacked.size() / (float)inp.size(); float compressed_bits_per_char = 8 * (float)sample_ids_bitpacked.size() / (float)params.prompt.length(); LOG("\n%d compressed bytes,(%04f bits per token, %04f bits per character)\n", (int)sample_ids_bitpacked.size(), compressed_bits_per_token, compressed_bits_per_char); - LOG("\n%d padding bits, (%04f bits per character without padding)", total_pad, compressed_bits_per_char - total_pad/(float)params.prompt.length()); - LOG("\nPPL (over)estimation: %04f (%04f with padding)", exp2(compressed_bits_per_token-total_pad/(float)inp.size()),exp2(compressed_bits_per_token)); + LOG("\n%d padding bits, (%04f bits per character without padding)", total_pad, compressed_bits_per_char - total_pad / (float)params.prompt.length()); + LOG("\nPPL (over)estimation: %04f (%04f with padding)", exp2(compressed_bits_per_token - total_pad / (float)inp.size()), exp2(compressed_bits_per_token)); } - //maybe this needs to be changed - if(params.out_file != "imatrix.dat"){ + // maybe this needs to be changed + if (params.out_file != "imatrix.dat") + { // dump uint8array to bin file std::ofstream ofs(params.out_file.c_str(), std::ios::binary); - ofs.write((char*)&sample_ids_bitpacked[0], sample_ids_bitpacked.size()); + ofs.write((char *)&sample_ids_bitpacked[0], sample_ids_bitpacked.size()); ofs.close(); - }else{ + } + else + { LOG("\n------------\n"); - //print as hex to stdout - for (int i = 0; i < sample_ids_bitpacked.size(); i++){ + // print as hex to stdout + for (int i = 0; i < sample_ids_bitpacked.size(); i++) + { LOG("%02X ", sample_ids_bitpacked[i]); } } - } else if (params.compress_mode == 2) { - //decompress mode - // load sample_ids_bitpacked from params.prompt_file + // decompress mode + // load sample_ids_bitpacked from params.prompt_file std::ifstream ifs(params.prompt_file.c_str(), std::ios::binary); - if (!ifs) { + if (!ifs) + { LOG_ERR("%s: failed to open file\n", __func__); return -1; } @@ -588,14 +593,16 @@ int main(int argc, char **argv) std::vector sample_ids_bitpacked(fileSize); // Read the ifs into the vector - if (!ifs.read(reinterpret_cast(sample_ids_bitpacked.data()), fileSize)) { + if (!ifs.read(reinterpret_cast(sample_ids_bitpacked.data()), fileSize)) + { LOG_ERR("%s: failed to read file\n", __func__); return -1; } ifs.close(); - //Debug: print as hex - for (int i = 0; i < sample_ids_bitpacked.size(); i++){ + // Debug: print as hex + for (int i = 0; i < sample_ids_bitpacked.size(); i++) + { LOG("%02X ", sample_ids_bitpacked[i]); } LOG("\n"); @@ -612,23 +619,22 @@ int main(int argc, char **argv) std::vector out = decode(ctx, smpl, sample_ids_bitpacked); - gpt_sampler_free(smpl); auto t_dec_end = ggml_time_us(); - //maybe this needs to be changed - if(params.out_file != "imatrix.dat"){ + // maybe this needs to be changed + if (params.out_file != "imatrix.dat") + { // dump as string to file std::string out_str = ::llama_detokenize(ctx, out); std::ofstream ofs(params.out_file.c_str(), std::ios::binary); - ofs.write((char*)&out_str[0], out_str.size()); + ofs.write((char *)&out_str[0], out_str.size()); ofs.close(); } llama_free(ctx); llama_free_model(model); - } llama_backend_free();