From d2f92662002d2d04a8d91620ac4e7965d9850c67 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 19 Apr 2023 20:20:44 +0200 Subject: [PATCH] Multi-threading quantization. Not much gain for simple quantizations, bit it will be important for quantizations that require more CPU cycles. --- examples/quantize/quantize.cpp | 7 ++-- ggml.c | 27 +++++++++++++++ ggml.h | 2 ++ llama.cpp | 61 ++++++++++++++++++++++++---------- llama.h | 4 ++- 5 files changed, 79 insertions(+), 22 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 59cb67440..9882e6036 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -10,8 +10,8 @@ int main(int argc, char ** argv) { ggml_time_init(); - if (argc != 4) { - fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); + if (argc < 4) { + fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type [nthread]\n", argv[0]); fprintf(stderr, " type = %d - q4_0\n", LLAMA_FTYPE_MOSTLY_Q4_0); fprintf(stderr, " type = %d - q4_1\n", LLAMA_FTYPE_MOSTLY_Q4_1); fprintf(stderr, " type = %d - q4_2\n", LLAMA_FTYPE_MOSTLY_Q4_2); @@ -29,6 +29,7 @@ int main(int argc, char ** argv) { const std::string fname_out = argv[2]; const enum llama_ftype ftype = (enum llama_ftype)atoi(argv[3]); + int nthread = argc > 4 ? atoi(argv[4]) : 0; const int64_t t_main_start_us = ggml_time_us(); @@ -38,7 +39,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = ggml_time_us(); - if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype)) { + if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype, nthread)) { fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); return 1; } diff --git a/ggml.c b/ggml.c index 431cdb9c9..ad1356505 100644 --- a/ggml.c +++ b/ggml.c @@ -11870,6 +11870,33 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * return (n/QK4_2*sizeof(block_q4_2)); } +size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) { + size_t result = 0; + switch (type) { + case GGML_TYPE_Q4_0: + { + assert (start % QK4_0 == 0); + block_q4_0 * block = (block_q4_0*)dst + start / QK4_0; + result = ggml_quantize_q4_0(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q4_1: + { + assert (start % QK4_1 == 0); + block_q4_1 * block = (block_q4_1*)dst + start / QK4_1; + result = ggml_quantize_q4_1(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q4_2: + { + assert (start % QK4_2 == 0); + block_q4_2 * block = (block_q4_2*)dst + start / QK4_2; + result = ggml_quantize_q4_2(src + start, block, n, n, hist); + } break; + default: + assert(false); + } + return result; +} + //////////////////////////////////////////////////////////////////////////////// int ggml_cpu_has_avx(void) { diff --git a/ggml.h b/ggml.h index 570147fc2..b50ca510a 100644 --- a/ggml.h +++ b/ggml.h @@ -809,6 +809,8 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); + // // system info // diff --git a/llama.cpp b/llama.cpp index 3ff5dc1e1..f576345cb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -24,6 +24,9 @@ #include #include #include +#include +#include +#include #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -1569,7 +1572,7 @@ static llama_vocab::id llama_sample_top_p_top_k( // quantization // -static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, enum llama_ftype ftype) { +static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, enum llama_ftype ftype, int nthread) { ggml_type quantized_type; switch (ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break; @@ -1578,6 +1581,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s default: throw format("invalid output file type %d\n", ftype); }; + if (nthread <= 0) nthread = std::thread::hardware_concurrency(); + std::unique_ptr model_loader(new llama_model_loader(fname_inp.c_str(), /*use_mmap*/ false, /*vocab_only*/ false)); llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), ftype); @@ -1586,6 +1591,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s size_t total_size_new = 0; std::vector hist_all(1 << 4, 0); + std::vector workers; + std::mutex mutex; + size_t idx = 0; for (llama_load_tensor & tensor : model_loader->tensors_map.tensors) { llama_buffer read_data; @@ -1639,21 +1647,37 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_data = work.addr; std::vector hist_cur(1 << 4, 0); - switch (new_type) { - case GGML_TYPE_Q4_0: - { - new_size = ggml_quantize_q4_0(f32_data, new_data, nelements, (int) tensor.ne.at(0), hist_cur.data()); - } break; - case GGML_TYPE_Q4_1: - { - new_size = ggml_quantize_q4_1(f32_data, new_data, nelements, (int) tensor.ne.at(0), hist_cur.data()); - } break; - case GGML_TYPE_Q4_2: - { - new_size = ggml_quantize_q4_2(f32_data, new_data, nelements, (int) tensor.ne.at(0), hist_cur.data()); - } break; - default: - LLAMA_ASSERT(false); + int chunk_size = 32 * 512; + int nchunk = (nelements + chunk_size - 1)/chunk_size; + int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; + if (nthread_use < 2) { + new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data()); + } else { + size_t counter = 0; + new_size = 0; + auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, nelements, chunk_size] () { + std::vector local_hist; + size_t local_size = 0; + while (true) { + std::unique_lock lock(mutex); + size_t first = counter; counter += chunk_size; + if (first >= nelements) { + if (!local_hist.empty()) { + for (int j=0; j %8.2f MB | hist: ", tensor.size/1024.0/1024.0, new_size/1024.0/1024.0); @@ -1775,9 +1799,10 @@ void llama_free(struct llama_context * ctx) { int llama_model_quantize( const char * fname_inp, const char * fname_out, - enum llama_ftype ftype) { + enum llama_ftype ftype, + int nthread) { try { - llama_model_quantize_internal(fname_inp, fname_out, ftype); + llama_model_quantize_internal(fname_inp, fname_out, ftype, nthread); return 0; } catch (const std::string & err) { fprintf(stderr, "%s: failed to quantize: %s\n", __func__, err.c_str()); diff --git a/llama.h b/llama.h index 208b03d18..280680ab4 100644 --- a/llama.h +++ b/llama.h @@ -92,10 +92,12 @@ extern "C" { // TODO: not great API - very likely to change // Returns 0 on success + // nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given LLAMA_API int llama_model_quantize( const char * fname_inp, const char * fname_out, - enum llama_ftype ftype); + enum llama_ftype ftype, + int nthread); // Apply a LoRA adapter to a loaded model // path_base_model is the path to a higher quality model to use as a base for