From 98ed16557432d7a5179c57eddcc3a08a7ae6d54d Mon Sep 17 00:00:00 2001 From: Robert Sung-wook Shin Date: Sat, 10 Jun 2023 01:24:40 +0900 Subject: [PATCH 1/6] OpenCL: Add release memory (#1741) * Add opencl release memory * Rename function name --- ggml-opencl.cpp | 9 +++++++++ ggml-opencl.h | 2 ++ llama.cpp | 6 +++++- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 81a975cf8..7b6daf4a8 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -662,6 +662,15 @@ static void ggml_cl_pool_free(cl_mem mem, size_t size) { clReleaseMemObject(mem); } +void ggml_cl_free_data(const struct ggml_tensor* tensor) { + if (tensor->backend != GGML_BACKEND_GPU) { + return; + } + + cl_mem mem = (cl_mem)tensor->data; + clReleaseMemObject(mem); +} + static cl_int ggml_cl_h2d_tensor_2d(cl_command_queue queue, cl_mem dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cl_event* ev) { cl_int err; const uint64_t ne0 = src->ne[0]; diff --git a/ggml-opencl.h b/ggml-opencl.h index c850bb8ad..bf95e5cd0 100644 --- a/ggml-opencl.h +++ b/ggml-opencl.h @@ -16,6 +16,8 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor void * ggml_cl_host_malloc(size_t size); void ggml_cl_host_free(void * ptr); +void ggml_cl_free_data(const struct ggml_tensor* tensor); + void ggml_cl_transform_tensor(struct ggml_tensor * tensor); void ggml_cl_load_data(const char * fname, struct ggml_tensor * tensor, size_t offset); diff --git a/llama.cpp b/llama.cpp index 16d6f6ef1..f40c5afa2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -210,7 +210,11 @@ struct llama_model { for (size_t i = 0; i < tensors_by_name.size(); ++i) { ggml_cuda_free_data(tensors_by_name[i].second); } -#endif // GGML_USE_CUBLAS +#elif defined(GGML_USE_CLBLAST) + for (size_t i = 0; i < tensors_by_name.size(); ++i) { + ggml_cl_free_data(tensors_by_name[i].second); + } +#endif } }; From 555275a693843273759230547001f9ae07fb537e Mon Sep 17 00:00:00 2001 From: rankaiyx Date: Sat, 10 Jun 2023 14:41:59 +0800 Subject: [PATCH 2/6] make : add SSSE3 compilation use case (#1659) --- Makefile | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Makefile b/Makefile index 39265164b..39ebfd048 100644 --- a/Makefile +++ b/Makefile @@ -107,6 +107,10 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686)) # Usage AVX-only #CFLAGS += -mfma -mf16c -mavx #CXXFLAGS += -mfma -mf16c -mavx + + # Usage SSSE3-only (Not is SSE3!) + #CFLAGS += -mssse3 + #CXXFLAGS += -mssse3 endif ifneq ($(filter ppc64%,$(UNAME_M)),) From ef3171d16241c18581d4d08374f0b9e396ade6b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xingchen=20Song=28=E5=AE=8B=E6=98=9F=E8=BE=B0=29?= Date: Sat, 10 Jun 2023 15:49:40 +0800 Subject: [PATCH 3/6] ggml : workaround for missing _mm256_setr_m128i in GCC < 8 (#1638) --- ggml.c | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ggml.c b/ggml.c index 567dbc1e1..9dc81fe08 100644 --- a/ggml.c +++ b/ggml.c @@ -492,6 +492,8 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); // quantization // +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) // multiply int8_t, add results pairwise twice static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { @@ -551,7 +553,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) { static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); - const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); const __m256i lowMask = _mm256_set1_epi8( 0xF ); return _mm256_and_si256(lowMask, bytes); } @@ -624,7 +626,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) { bytesh = _mm_or_si128(bytesh, bit_mask); bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); - return _mm256_set_m128i(bytesh, bytesl); + return MM256_SET_M128I(bytesh, bytesl); } // Unpack 32 4-bit fields into 32 bytes @@ -637,7 +639,7 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) const __m128i lowMask = _mm_set1_epi8(0xF); tmpl = _mm_and_si128(lowMask, tmpl); tmph = _mm_and_si128(lowMask, tmph); - return _mm256_set_m128i(tmph, tmpl); + return MM256_SET_M128I(tmph, tmpl); } // add int16_t pairwise and return as float vector @@ -645,7 +647,7 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { const __m128i ones = _mm_set1_epi16(1); const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); - const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl); + const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); return _mm256_cvtepi32_ps(summed_pairs); } @@ -2350,7 +2352,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const __m128i i32_1 = mul_sum_i8_pairs(bx, by); // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps(_mm256_set_m128i(i32_0, i32_1)); + __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); // Apply the scale, and accumulate acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); @@ -2826,7 +2828,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * __m128i bxh = _mm256_extractf128_si256(bx, 1); bxl = _mm_or_si128(bxl, bxhil); bxh = _mm_or_si128(bxh, bxhih); - bx = _mm256_set_m128i(bxh, bxl); + bx = MM256_SET_M128I(bxh, bxl); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); @@ -3082,7 +3084,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * __m128i bxh = _mm256_extractf128_si256(bx, 1); bxl = _mm_or_si128(bxl, bxhil); bxh = _mm_or_si128(bxh, bxhih); - bx = _mm256_set_m128i(bxh, bxl); + bx = MM256_SET_M128I(bxh, bxl); const __m256 dy = _mm256_set1_ps(y[i].d); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); From 4f0154b0bad775ac4651bf73b5c216eb43c45cdc Mon Sep 17 00:00:00 2001 From: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Date: Sat, 10 Jun 2023 01:59:17 -0600 Subject: [PATCH 4/6] llama : support requantizing models instead of only allowing quantization from 16/32bit (#1691) * Add support for quantizing already quantized models * Threaded dequantizing and f16 to f32 conversion * Clean up thread blocks with spares calculation a bit * Use std::runtime_error exceptions. --- examples/quantize/quantize.cpp | 57 ++++++++++++------ llama.cpp | 103 +++++++++++++++++++++++++++------ llama.h | 14 +++-- 3 files changed, 134 insertions(+), 40 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 947b40202..c6bf1b723 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -3,6 +3,7 @@ #include "llama.h" #include +#include #include #include @@ -53,27 +54,49 @@ bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::st // usage: // ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads] // +void usage(const char * executable) { + fprintf(stderr, "usage: %s [--help] [--allow-requantize] [--leave-output-tensor] model-f32.bin [model-quant.bin] type [nthreads]\n", executable); + fprintf(stderr, " --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); + fprintf(stderr, " --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); + fprintf(stderr, "Allowed quantization types:\n"); + for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) { + fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second); + } + exit(1); +} + int main(int argc, char ** argv) { if (argc < 3) { - fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]); - for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) { - fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second); + usage(argv[0]); + } + + llama_model_quantize_params params = llama_model_quantize_default_params(); + + int arg_idx = 1; + + for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { + if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { + params.quantize_output_tensor = false; + } else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) { + params.allow_requantize = true; + } else { + usage(argv[0]); } - return 1; + } + + if (argc - arg_idx < 3) { + usage(argv[0]); } llama_init_backend(); // parse command line arguments - const std::string fname_inp = argv[1]; + const std::string fname_inp = argv[arg_idx]; + arg_idx++; std::string fname_out; - int nthread; - llama_ftype ftype; - int arg_idx = 2; std::string ftype_str; - if (try_parse_ftype(argv[arg_idx], ftype, ftype_str)) { - // argv[2] is the ftype + if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) { std::string fpath; const size_t pos = fname_inp.find_last_of('/'); if (pos != std::string::npos) { @@ -84,7 +107,6 @@ int main(int argc, char ** argv) { arg_idx++; } else { - // argv[2] is the output path fname_out = argv[arg_idx]; arg_idx++; @@ -92,8 +114,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: missing ftype\n", __func__); return 1; } - // argv[3] is the ftype - if (!try_parse_ftype(argv[arg_idx], ftype, ftype_str)) { + if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) { fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]); return 1; } @@ -103,21 +124,19 @@ int main(int argc, char ** argv) { // parse nthreads if (argc > arg_idx) { try { - nthread = std::stoi(argv[arg_idx]); + params.nthread = std::stoi(argv[arg_idx]); } catch (const std::exception & e) { fprintf(stderr, "%s: invalid nthread '%s' (%s)\n", __func__, argv[arg_idx], e.what()); return 1; } - } else { - nthread = 0; } fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str()); - if (nthread > 0) { - fprintf(stderr, " using %d threads", nthread); + if (params.nthread > 0) { + fprintf(stderr, " using %d threads", params.nthread); } fprintf(stderr, "\n"); @@ -129,7 +148,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = llama_time_us(); - if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype, nthread)) { + if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ¶ms)) { fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); return 1; } diff --git a/llama.cpp b/llama.cpp index f40c5afa2..e100e2bc9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -886,6 +886,17 @@ struct llama_context_params llama_context_default_params() { return result; } +struct llama_model_quantize_params llama_model_quantize_default_params() { + struct llama_model_quantize_params result = { + /*.nthread =*/ 0, + /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, + /*.allow_requantize =*/ false, + /*.quantize_output_tensor =*/ true, + }; + + return result; +} + bool llama_mmap_supported() { return llama_mmap::SUPPORTED; } @@ -2231,9 +2242,70 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra // quantization // -static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, enum llama_ftype ftype, int nthread) { +static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llama_buffer & output, const int nelements, const int nthread) { + if (output.size < nelements * sizeof(float)) { + output.resize(nelements * sizeof(float)); + } + float * f32_output = (float *) output.addr; + + quantize_fns_t qtype; + if (ggml_is_quantized(tensor.type)) { + qtype = ggml_internal_get_quantize_fn(tensor.type); + if (qtype.dequantize_row_q == NULL) { + throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor.type))); + } + } else if (tensor.type != GGML_TYPE_F16) { + throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor.type))); + } + + if (nthread < 2) { + if (tensor.type == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor.data, f32_output, nelements); + } else if (ggml_is_quantized(tensor.type)) { + qtype.dequantize_row_q(tensor.data, f32_output, nelements); + } else { + LLAMA_ASSERT(false); // unreachable + } + return; + } + + auto block_size = tensor.type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor.type); + auto block_size_bytes = ggml_type_size(tensor.type); + + LLAMA_ASSERT(nelements % block_size == 0); + auto nblocks = nelements / block_size; + auto blocks_per_thread = nblocks / nthread; + auto spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count + + std::vector workers; + for (auto tnum = 0, in_buff_offs = 0, out_buff_offs = 0; tnum < nthread; tnum++) { + auto thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread + auto thr_elems = thr_blocks * block_size; // number of elements for this thread + auto thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread + + auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) { + if (typ == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels); + } else { + qtype.dequantize_row_q(inbuf, outbuf, nels); + } + }; + workers.push_back(std::thread(compute, tensor.type, tensor.data + in_buff_offs, f32_output + out_buff_offs, thr_elems)); + in_buff_offs += thr_block_bytes; + out_buff_offs += thr_elems; + } + for (auto & worker : workers) { + worker.join(); + } + +} + +static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type quantized_type; - switch (ftype) { + llama_ftype ftype = params->ftype; + int nthread = params->nthread; + + switch (params->ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; @@ -2259,7 +2331,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s std::unique_ptr model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false, /*vocab_only*/ false)); - llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), ftype); + llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), params->ftype); int n_attention_wv = 0; int n_feed_forward_w2 = 0; @@ -2301,9 +2373,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s quantize &= (tensor.ne.size() == 2); // uncomment this to keep the output layer in FP16 - //if (tensor.name == "output.weight") { - // quantize = false; - //} + if (!params->quantize_output_tensor && tensor.name == "output.weight") { + quantize = false; + } + quantize = quantize && quantized_type != tensor.type; enum ggml_type new_type; void * new_data; @@ -2346,17 +2419,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s float * f32_data; size_t nelements = tensor.ne.at(0) * tensor.ne.at(1); llama_buffer f32_conv_buf; + if (tensor.type == GGML_TYPE_F32) { f32_data = (float *) tensor.data; - } else if (tensor.type == GGML_TYPE_F16) { - f32_conv_buf.resize(nelements * sizeof(float)); - f32_data = (float *) f32_conv_buf.addr; - const auto * f16_data = (const ggml_fp16_t *) tensor.data; - for (size_t i = 0; i < nelements; i++) { - f32_data[i] = ggml_fp16_to_fp32(f16_data[i]); - } + } else if (ggml_is_quantized(tensor.type) && !params->allow_requantize) { + throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor.type))); } else { - throw std::runtime_error(format("type %s unsupported for integer quantization", ggml_type_name(tensor.type))); + llama_convert_tensor_internal(tensor, f32_conv_buf, nelements, nthread); + f32_data = (float *) f32_conv_buf.addr; } printf("quantizing .. "); @@ -2566,10 +2636,9 @@ void llama_free(struct llama_context * ctx) { int llama_model_quantize( const char * fname_inp, const char * fname_out, - enum llama_ftype ftype, - int nthread) { + const llama_model_quantize_params *params) { try { - llama_model_quantize_internal(fname_inp, fname_out, ftype, nthread); + llama_model_quantize_internal(fname_inp, fname_out, params); return 0; } catch (const std::exception & err) { fprintf(stderr, "%s: failed to quantize: %s\n", __func__, err.what()); diff --git a/llama.h b/llama.h index dc033b71d..7c7fd481c 100644 --- a/llama.h +++ b/llama.h @@ -115,7 +115,16 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors }; + // model quantization parameters + typedef struct llama_model_quantize_params { + int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + enum llama_ftype ftype; // quantize to this llama_ftype + bool allow_requantize; // allow quantizing non-f32/f16 tensors + bool quantize_output_tensor; // quantize output.weight + } llama_model_quantize_params; + LLAMA_API struct llama_context_params llama_context_default_params(); + LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(); LLAMA_API bool llama_mmap_supported(); LLAMA_API bool llama_mlock_supported(); @@ -137,14 +146,11 @@ extern "C" { // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); - // TODO: not great API - very likely to change // Returns 0 on success - // nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given LLAMA_API int llama_model_quantize( const char * fname_inp, const char * fname_out, - enum llama_ftype ftype, - int nthread); + const llama_model_quantize_params * params); // Apply a LoRA adapter to a loaded model // path_base_model is the path to a higher quality model to use as a base for From e9b66ee9829039d4ab54550d6222e42a0b31e52a Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Sat, 10 Jun 2023 11:28:11 +0300 Subject: [PATCH 5/6] metal : add Q4_1 implementation (#1785) 23.3 ms / token, so just ~1% slower than q4_0. Achieves 290 GB/s memory throughput. Co-authored-by: Iwan Kawrakow --- ggml-metal.m | 16 +++++- ggml-metal.metal | 123 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 1 deletion(-) diff --git a/ggml-metal.m b/ggml-metal.m index 5c9ecd76e..167ebd467 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -50,12 +50,14 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); + GGML_METAL_DECL_KERNEL(get_rows_q4_1); GGML_METAL_DECL_KERNEL(get_rows_q2_k); GGML_METAL_DECL_KERNEL(get_rows_q4_k); GGML_METAL_DECL_KERNEL(get_rows_q6_k); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); @@ -141,12 +143,14 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); + GGML_METAL_ADD_KERNEL(get_rows_q4_1); GGML_METAL_ADD_KERNEL(get_rows_q2_k); GGML_METAL_ADD_KERNEL(get_rows_q4_k); GGML_METAL_ADD_KERNEL(get_rows_q6_k); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); @@ -545,6 +549,15 @@ void ggml_metal_graph_compute( nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; + } break; case GGML_TYPE_Q2_K: { GGML_ASSERT(ne02 == 1); @@ -596,7 +609,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - if (src0t == GGML_TYPE_Q4_0) { + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q2_K) { @@ -623,6 +636,7 @@ void ggml_metal_graph_compute( switch (src0->type) { case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; diff --git a/ggml-metal.metal b/ggml-metal.metal index c94ef83f9..ccd36386b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -11,6 +11,13 @@ typedef struct { uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; +#define QK4_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; + static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) { const int qk = QK4_0; @@ -31,6 +38,27 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i } } +static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) { + const int qk = QK4_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const half d = x[i].d; + const half m = x[i].m; + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F); + const int x1 = (x[i].qs[j] >> 4); + + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; + } + } +} + kernel void kernel_add( device const float * src0, device const float * src1, @@ -212,6 +240,22 @@ kernel void kernel_get_rows_q4_0( (device float *) ((device char *) dst + i*nb1), ne00); } +kernel void kernel_get_rows_q4_1( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q4_1( + (device const block_q4_1 *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + kernel void kernel_rms_norm( device const void * src0, device float * dst, @@ -350,6 +394,85 @@ kernel void kernel_mul_mat_q4_0_f32( //} } +kernel void kernel_mul_mat_q4_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + const int nb = ne00/QK4_1; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb; + device const float * y = (device const float *) src1 + r1*ne10; + + const uint nth = tptg.x*tptg.y; + const uint ith = tptg.y*tpitg.x + tpitg.y; + + const int ix = tpitg.y/4; // 0 or 1 + const int iy = tpitg.y - 4*ix; // 0...3 + + const int first = 4 * iy; + + float sumf = 0; + + for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { + + const float d = (float)x[i].d; + const float m = (float)x[i].m; + + device const uint8_t * xl = x[i].qs + first; + device const float * yl = y + i * QK4_1 + first; + + float2 acc = {0.0f, 0.0f}; + + for (int j = 0; j < 4; ++j) { + + acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m); + acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m); + + } + + sumf += acc[0] + acc[1]; + } + + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + dst[r1*ne0 + r0] = sum[0]; + } +} + kernel void kernel_mul_mat_f16_f32( device const char * src0, device const char * src1, From 17c10acfb44ecb7af25e37fb67b9501cbc0034d2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 10 Jun 2023 12:06:45 +0300 Subject: [PATCH 6/6] ggml : force no_alloc == false when creating opt tensors (close #1699) This is needed to make operators like ggml_view() be able to store their parameters in the ggml context's memory and not get discarded when no_alloc is true --- ggml.c | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ggml.c b/ggml.c index 9dc81fe08..a13de5115 100644 --- a/ggml.c +++ b/ggml.c @@ -3721,6 +3721,7 @@ struct ggml_context { void * mem_buffer; bool mem_buffer_owned; bool no_alloc; + bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers int n_objects; @@ -4055,6 +4056,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size), /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, /*.no_alloc =*/ params.no_alloc, + /*.no_alloc_save =*/ params.no_alloc, /*.n_objects =*/ 0, /*.objects_begin =*/ NULL, /*.objects_end =*/ NULL, @@ -4132,11 +4134,18 @@ size_t ggml_get_mem_size(struct ggml_context * ctx) { // operators when using scratch buffers // TODO: implement a better way void ggml_scratch_save(struct ggml_context * ctx) { + // this is needed to allow opt tensors to store their data + // TODO: again, need to find a better way + ctx->no_alloc_save = ctx->no_alloc; + ctx->no_alloc = false; + ctx->scratch_save = ctx->scratch; ctx->scratch.data = NULL; } void ggml_scratch_load(struct ggml_context * ctx) { + ctx->no_alloc = ctx->no_alloc_save; + ctx->scratch = ctx->scratch_save; }