From 684da25926e5c505f725b4f10b5485b218fa1fc7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 10 Apr 2023 19:29:48 +0300 Subject: [PATCH 1/6] ggml : fix quantize_row_q4_1() ARM_NEON (close #876) --- ggml.c | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ggml.c b/ggml.c index 326b8e842..9616eb9fd 100644 --- a/ggml.c +++ b/ggml.c @@ -599,10 +599,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); - // absolute max - const float amax = MAX( - MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)), - MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3))); + const float amax = vmaxvq_f32(amaxv[0]); const float d = amax / ((1 << 3) - 1); const float id = d ? 1.0f/d : 0.0f; @@ -924,7 +921,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int float32x4_t minv[8]; float32x4_t maxv[8]; - for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); + for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l); for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]); for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]); @@ -947,7 +944,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int for (int l = 0; l < 8; l++) { const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id); - const int32x4_t vi = vcvtq_s32_f32(v); + const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest + const int32x4_t vi = vcvtq_s32_f32(vf); y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); From d9a239c4104c888eafda672c1e42c9bbc5084cb8 Mon Sep 17 00:00:00 2001 From: Marco Matthies <71844+marcom@users.noreply.github.com> Date: Mon, 10 Apr 2023 19:57:59 +0200 Subject: [PATCH 2/6] Simplify to include lower-case windows.h always, fix compile on mingw32 (#747) --- ggml.c | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/ggml.c b/ggml.c index 9616eb9fd..a817f8321 100644 --- a/ggml.c +++ b/ggml.c @@ -26,14 +26,9 @@ #define static_assert(cond, msg) struct global_scope_noop_trick #endif -#if defined _MSC_VER || defined(__MINGW32__) +#if defined(_WIN32) -#if !defined(__MINGW32__) -#include -#else -// ref: https://github.com/ggerganov/whisper.cpp/issues/168 #include -#endif typedef volatile LONG atomic_int; typedef atomic_int atomic_bool; From 9d634ef452d0fc24fcd49592952d13d0ab0f41b7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 10 Apr 2023 19:32:45 +0300 Subject: [PATCH 3/6] ggml : remove trailing whitespaces --- ggml.c | 80 ++++++++++++++++++++++++++++------------------------------ 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/ggml.c b/ggml.c index a817f8321..6db6fde8d 100644 --- a/ggml.c +++ b/ggml.c @@ -1944,7 +1944,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); - /* Prepare the constants we will need during execution */ + /* Prepare the constants we will need during execution */ const __m256i lowMask = _mm256_set1_epi8( 0xF ); const __m256i offset_8 = _mm256_set1_epi16( 8 ); @@ -1954,61 +1954,59 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest // Main loop for (int i = 0; i < nb; i+=UNROLL_COUNT) { - - // This loop will be unrolled by the compiler + // This loop will be unrolled by the compiler for (int u=0;u we now have a vector of 8 int_32t */ - __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q ); + /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */ + __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q ); - /* Convert to vectore of 8 int32_t to 8 floats */ - __m256 q = _mm256_cvtepi32_ps( xy_q ); + /* Convert to vectore of 8 int32_t to 8 floats */ + __m256 q = _mm256_cvtepi32_ps( xy_q ); - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps( scale, q, acc ); + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( scale, q, acc ); } - - } + } // Return horizontal sum of the acc vector __m128 res = _mm256_extractf128_ps( acc, 1 ); From c3ac702e5ee3533457e0489df4906ee112fe88e7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 10 Apr 2023 22:40:28 +0300 Subject: [PATCH 4/6] ggml : add ggml_cont() + optimize ggml_cpy() for contiguous dst --- ggml.c | 254 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-- ggml.h | 6 ++ 2 files changed, 252 insertions(+), 8 deletions(-) diff --git a/ggml.c b/ggml.c index 6db6fde8d..4f6420678 100644 --- a/ggml.c +++ b/ggml.c @@ -2609,6 +2609,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "SCALE", "CPY", + "CONT", "RESHAPE", "VIEW", "PERMUTE", @@ -2624,7 +2625,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "FLASH_FF", }; -static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); +static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2653,6 +2654,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "x*v", "x-\\>y", + "cont(x)", "reshape(x)", "view(x)", "permute(x)", @@ -2668,7 +2670,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_ff(x)", }; -static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); +static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -4301,6 +4303,41 @@ struct ggml_tensor * ggml_cpy_inplace( return ggml_cpy_impl(ctx, a, b, true); } +// ggml_cont + +struct ggml_tensor * ggml_cont_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_CONT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_cont( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cont_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_cont_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cont_impl(ctx, a, true); +} + // ggml_reshape struct ggml_tensor * ggml_reshape( @@ -4843,6 +4880,85 @@ static void ggml_compute_forward_dup_f16( // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy + if (ggml_is_contiguous(dst)) { + if (src0->nb[0] == sizeof(ggml_fp16_t)) { + if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + const size_t rs = ne00*nb00; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_ptr = (char *) dst->data + id*rs; + + memcpy(dst_ptr, src0_ptr, rs); + + id++; + } + } + } + } else if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } + return; + } + // dst counters int64_t i10 = 0; int64_t i11 = 0; @@ -4937,6 +5053,105 @@ static void ggml_compute_forward_dup_f32( return; } + if (src0->type == dst->type && + src0->ne[0] == dst->ne[0] && + src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (ggml_is_contiguous(dst)) { + // TODO: simplify + if (src0->nb[0] == sizeof(float)) { + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + const size_t rs = ne00*nb00; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_ptr = (char *) dst->data + id*rs; + + memcpy(dst_ptr, src0_ptr, rs); + + id++; + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } + + return; + } + // dst counters int64_t i10 = 0; int64_t i11 = 0; @@ -5057,14 +5272,18 @@ static void ggml_compute_forward_add_f32( GGML_ASSERT(nb00 == sizeof(float)); if (nb10 == sizeof(float)) { - const int j0 = (n/nth)*ith; - const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1); - - for (int j = j0; j < j1; j++) { + for (int j = ith; j < n; j += nth) { +#ifdef GGML_USE_ACCELERATE + vDSP_vadd( + (float *) ((char *) src0->data + j*nb01), 1, + (float *) ((char *) src1->data + j*nb11), 1, + (float *) ((char *) dst->data + j*nb1), 1, nc); +#else ggml_vec_add_f32(nc, (float *) ((char *) dst->data + j*nb1), (float *) ((char *) src0->data + j*nb01), (float *) ((char *) src1->data + j*nb11)); +#endif } } else { // src1 is not contiguous @@ -6812,6 +7031,15 @@ static void ggml_compute_forward_cpy( ggml_compute_forward_dup(params, src0, dst); } +// ggml_compute_forward_cont + +static void ggml_compute_forward_cont( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + ggml_compute_forward_dup(params, src0, dst); +} + // ggml_compute_forward_reshape static void ggml_compute_forward_reshape( @@ -8642,6 +8870,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_cpy(params, tensor->src0, tensor); } break; + case GGML_OP_CONT: + { + ggml_compute_forward_cont(params, tensor->src0, tensor); + } break; case GGML_OP_RESHAPE: { ggml_compute_forward_reshape(params, tensor->src0, tensor); @@ -8886,8 +9118,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src1->grad = ggml_add_impl(ctx, src1->grad, - // TODO: fix transpose, the node will break the graph connections - ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad), + ggml_mul_mat(ctx, + ggml_cont(ctx, ggml_transpose(ctx, src0)), + tensor->grad), inplace); } } break; @@ -8899,6 +9132,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_CONT: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_RESHAPE: { GGML_ASSERT(false); // TODO: not implemented @@ -9353,6 +9590,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) node->n_tasks = n_threads; } break; case GGML_OP_CPY: + case GGML_OP_CONT: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: diff --git a/ggml.h b/ggml.h index af16c647c..a5245a8ae 100644 --- a/ggml.h +++ b/ggml.h @@ -236,6 +236,7 @@ enum ggml_op { GGML_OP_SCALE, GGML_OP_CPY, + GGML_OP_CONT, GGML_OP_RESHAPE, GGML_OP_VIEW, GGML_OP_PERMUTE, @@ -525,6 +526,11 @@ struct ggml_tensor * ggml_cpy( struct ggml_tensor * a, struct ggml_tensor * b); +// make contiguous +struct ggml_tensor * ggml_cont( + struct ggml_context * ctx, + struct ggml_tensor * a); + // return view(a), b specifies the new shape // TODO: when we start computing gradient, make a copy instead of view struct ggml_tensor * ggml_reshape( From 461ba9e66ed3885f80680d71495e055580573c74 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 10 Apr 2023 23:20:01 +0300 Subject: [PATCH 5/6] ggml : fix WASM build --- ggml.c | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml.c b/ggml.c index 4f6420678..ada3bbbdc 100644 --- a/ggml.c +++ b/ggml.c @@ -2067,18 +2067,18 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest float sum1 = 0.0f; for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &px[i + 0]; - const block_q4_0 * restrict y0 = &py[i + 0]; - const block_q4_0 * restrict x1 = &px[i + 1]; - const block_q4_0 * restrict y1 = &py[i + 1]; + const block_q4_0 * restrict x0 = &x[i + 0]; + const block_q4_0 * restrict y0 = &y[i + 0]; + const block_q4_0 * restrict x1 = &x[i + 1]; + const block_q4_0 * restrict y1 = &y[i + 1]; const v128_t m4b = wasm_u8x16_splat(0xf); const v128_t s8b = wasm_i8x16_splat(0x8); - const v128_t v0_0 = wasm_v128_load(x0.qs); - const v128_t v0_1 = wasm_v128_load(y0.qs); - const v128_t v1_0 = wasm_v128_load(x1.qs); - const v128_t v1_1 = wasm_v128_load(y1.qs); + const v128_t v0_0 = wasm_v128_load(x0->qs); + const v128_t v0_1 = wasm_v128_load(y0->qs); + const v128_t v1_0 = wasm_v128_load(x1->qs); + const v128_t v1_1 = wasm_v128_load(y1->qs); // 4-bit -> 8-bit const v128_t v0_0l = wasm_v128_and(v0_0, m4b); From a0caa34b162449b5c13b8d604573053300ff54a1 Mon Sep 17 00:00:00 2001 From: qouoq Date: Tue, 11 Apr 2023 04:41:53 +0800 Subject: [PATCH 6/6] Add BAIR's Koala to supported models (#877) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 5ef4318eb..ef82855e4 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ New features will probably be added mostly through community contributions. - [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne) - [X] [Vicuna](https://github.com/ggerganov/llama.cpp/discussions/643#discussioncomment-5533894) +- [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/) **Bindings:**