From a1c004ef2e056cdeffcd47aaac196883bb123a3a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 17:42:55 +0200 Subject: [PATCH 01/57] ggml : add ggml_flash_attn_ext API --- ggml-metal.m | 50 +++++++ ggml-metal.metal | 29 ++++ ggml.c | 298 ++++++++++++++++++++++++++++++++++++- ggml.h | 9 ++ llama.cpp | 80 +++++----- tests/test-backend-ops.cpp | 28 ++++ 6 files changed, 456 insertions(+), 38 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 912ddc83f..6d88d5c36 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -147,6 +147,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -511,6 +512,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -665,6 +667,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_PAD: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -2161,6 +2164,53 @@ static bool ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + + size_t offs_src2 = 0; + size_t offs_src3 = 0; + + id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; + + // TODO: extend if necessary + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/ggml-metal.metal b/ggml-metal.metal index 029578dc5..b79a1ba56 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,35 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +kernel void kernel_flash_attn_ext_f16( + device const half * q, + device const half * k, + device const half * v, + device const half * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & scale, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // TODO: implement +} + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/ggml.c b/ggml.c index cbf2d4bdd..e01d938ce 100644 --- a/ggml.c +++ b/ggml.c @@ -1650,6 +1650,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "LEAKY_RELU", "FLASH_ATTN", + "FLASH_ATTN_EXT", "FLASH_FF", "FLASH_ATTN_BACK", "WIN_PART", @@ -1674,7 +1675,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1736,6 +1737,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "leaky_relu(x)", "flash_attn(x)", + "flash_attn_ext(x)", "flash_ff(x)", "flash_attn_back(x)", "win_part(x)", @@ -1760,7 +1762,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5678,6 +5680,46 @@ struct ggml_tensor * ggml_flash_attn( return result; } +// ggml_flash_attn_ext + +struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + } + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -13212,6 +13254,251 @@ static void ggml_compute_forward_flash_attn( } } +// ggml_compute_forward_flash_attn_ext + +static void ggml_compute_forward_flash_attn_ext_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } else { + for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (mask) { + const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); + ggml_vec_acc_f32(M, S, mp); + } + + // softmax + // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. + // dont forget to set their S values to zero + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); + sump[j] += (ggml_float)val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); + + for (int64_t i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). + if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { + for (int64_t ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16(nev0, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } else { + for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16_unroll(nev0, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } + } +} + +static void ggml_compute_forward_flash_attn_ext( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_ff static void ggml_compute_forward_flash_ff_f16( @@ -14717,6 +15004,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm const bool masked = t != 0; ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_FLASH_FF: { ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor); @@ -15713,6 +16004,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { @@ -16438,6 +16730,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = n_threads; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; @@ -16769,6 +17062,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); diff --git a/ggml.h b/ggml.h index de8162b81..d76fe9d5c 100644 --- a/ggml.h +++ b/ggml.h @@ -452,6 +452,7 @@ extern "C" { GGML_OP_LEAKY_RELU, GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_WIN_PART, @@ -1619,6 +1620,14 @@ extern "C" { struct ggml_tensor * v, bool masked); + GGML_API struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index d28382f7d..cec23c23f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4205,38 +4205,6 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } - - if (max_alibi_bias > 0.0f) { - // temporary branch until we figure out how to handle ggml_alibi through ggml_add - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); - - if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } - - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); - - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); - cb(kq, "kq_soft_max_ext", il); - } - // split cached v into n_head heads struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -4246,8 +4214,49 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + // TODO: determine if we can use flash attention + const bool supports_flash_attn = true; + + struct ggml_tensor * kqv; + + if (supports_flash_attn) { + kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); + + if (max_alibi_bias > 0.0f) { + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } + + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); + + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + cb(kq, "kq_soft_max_ext", il); + } + + kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); + } struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); @@ -9490,8 +9499,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, - cparams.n_ctx, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 55ce14e0d..5693c2197 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1384,6 +1384,32 @@ struct test_leaky_relu : public test_case { } }; +// GGML_OP_FLASH_ATTN_EXT +struct test_flash_attn_ext : public test_case { + const ggml_type typeq; + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nt; // tokens + + std::string vars() override { + return VARS_TO_STR5(typeq, hs, nh, kv, nt); + } + + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, + int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) + : typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); + return out; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1650,6 +1676,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8)); + #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024)); From fa7ebcca993ec0d47f6ed6a47a8d5ac4f7407262 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jan 2024 20:06:26 +0200 Subject: [PATCH 02/57] ggml : fix GQA support in ggml_flash_attn_ext --- ggml-metal.metal | 8 ++++---- ggml.c | 23 +++++++++++++++-------- llama.cpp | 4 ++++ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index b79a1ba56..28847794c 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const half * mask, + device const half * q, + device const half * k, + device const half * v, + device const float * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, diff --git a/ggml.c b/ggml.c index e01d938ce..9cf4784ce 100644 --- a/ggml.c +++ b/ggml.c @@ -13307,6 +13307,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + if (params->type == GGML_TASK_INIT) { return; } @@ -13347,8 +13354,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { for (int64_t ic = 0; ic < nek1; ++ic) { // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; const int ik1 = ic; // S indices @@ -13362,8 +13369,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( } else { for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; const int ik1 = ic; // S indices @@ -13452,8 +13459,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int i3 = iq3; // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; ggml_vec_dot_f16(nev0, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), @@ -13468,8 +13475,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int i3 = iq3; // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; ggml_vec_dot_f16_unroll(nev0, nbv1, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), diff --git a/llama.cpp b/llama.cpp index cec23c23f..d4bebe520 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4220,6 +4220,10 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kqv; if (supports_flash_attn) { + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); From a9681febd65cbd3f372badc5f4a4d8bc1336d2d9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 12:26:49 +0200 Subject: [PATCH 03/57] ggml : online attention (CPU) --- ggml-metal.m | 8 +- ggml-metal.metal | 3 +- ggml.c | 263 ++++++++++++++++++------------------- ggml.h | 5 + llama.cpp | 136 +++++++++++-------- tests/test-backend-ops.cpp | 14 +- 6 files changed, 231 insertions(+), 198 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 6d88d5c36..4d85dd3dd 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2207,9 +2207,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + const int nwarps = 4; + + // each warp needs n_embd_head elements + GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; + const int nth = MIN(1024, ne0); - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 28847794c..a1e1755a3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1981,7 +1981,8 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, - constant float & scale, + constant float & scale, + threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { diff --git a/ggml.c b/ggml.c index 9cf4784ce..e64a328fa 100644 --- a/ggml.c +++ b/ggml.c @@ -817,7 +817,7 @@ do { \ #if defined(__F16C__) // the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { @@ -1323,6 +1323,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + // xs and vs are byte strides of x and v inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { @@ -1407,6 +1438,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } +inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } @@ -5704,8 +5764,9 @@ struct ggml_tensor * ggml_flash_attn_ext( is_node = true; } - //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne); float params[] = { scale }; ggml_set_op_params(result, params, sizeof(params)); @@ -13281,12 +13342,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t D = neq0; const int64_t N = neq1; const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); + GGML_ASSERT(ne2 == N); GGML_ASSERT(P >= 0); GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); @@ -13295,11 +13353,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(neq0 == D); GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); + GGML_ASSERT(nev0 == D); GGML_ASSERT(neq1 == N); GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); + GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -13339,151 +13397,87 @@ static void ggml_compute_forward_flash_attn_ext_f16( //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices const int iq3 = ir/(neq2*neq1); const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + float S = 0.0f; + float M = -INFINITY; - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } + float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); - if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { - for (int64_t ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - const int ik1 = ic; + memset(V16, 0, D*sizeof(ggml_fp16_t)); - // S indices - const int i1 = ik1; + const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL; - ggml_vec_dot_f16(neq0, - S + i1, - (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } else { - for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - const int ik1 = ic; + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; - // S indices - const int i1 = ik1; + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; - ggml_vec_dot_f16_unroll(neq0, nbk1, - S + i1, - ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } - - // scale - ggml_vec_scale_f32(nek1, S, scale); - - if (mask) { - const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); - ggml_vec_acc_f32(M, S, mp); - } - - // softmax - // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. - // dont forget to set their S values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); - - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; - - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - float * SS = S + i; - - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { - SS[j] = 0.0f; - } else { - ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); - sump[j] += (ggml_float)val; - SS[j] = val; - } - } - } - - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? mp[ic] : 0.0f; + if (mv == -INFINITY) { + continue; } - assert(sum > 0.0); + float s; - sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); + ggml_vec_dot_f16(D, + &s, + (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); + s = s*scale + mv; + + const float Mold = M; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, V16, ms); + } else { + vs = expf(s - M); } -#endif + + const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + + // V += v*expf(s - M) + ggml_vec_mad_f16(D, V16, v16, vs); + + S = S*ms + vs; } - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); - - for (int64_t i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); + // V /= S + for (int64_t d = 0; d < D; ++d) { + V32[d] = GGML_FP16_TO_FP32(V16[d])/S; } - // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). - if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { - for (int64_t ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); - ggml_vec_dot_f16(nev0, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } else { - for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; - - ggml_vec_dot_f16_unroll(nev0, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); } } @@ -17069,7 +17063,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; case GGML_OP_FLASH_ATTN: - case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); @@ -17081,6 +17074,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + } break; case GGML_OP_FLASH_FF: { if (node->src[1]->type == GGML_TYPE_F32) { diff --git a/ggml.h b/ggml.h index d76fe9d5c..7bca02f2a 100644 --- a/ggml.h +++ b/ggml.h @@ -1620,6 +1620,11 @@ extern "C" { struct ggml_tensor * v, bool masked); + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch, 1, 1] + // res: [n_embd, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index f0a63afef..4e6c9f9cc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -95,6 +95,8 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 8 +#define LLAMA_FLASH_ATTN + // // logging // @@ -4167,23 +4169,34 @@ static void llm_build_kv_store( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - // compute the transposed [n_tokens, n_embd] V matrix - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed - cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + +#if defined(LLAMA_FLASH_ATTN) + // NOTE: the V cache is not transposed when using FLASH attention !! + struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); + cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + + GGML_UNUSED(n_ctx); +#else + // compute the transposed [n_tokens, n_embd] V matrix + //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed + cb(v_cur_t, "v_cur_t", il); + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); - cb(v_cache_view, "v_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); +#endif } static struct ggml_tensor * llm_build_norm( @@ -4343,7 +4356,60 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - // split cached v into n_head heads + struct ggml_tensor * cur; + +#if defined(LLAMA_FLASH_ATTN) + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + 0); + cb(v, "v", il); + + cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); + //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); + + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); +#else + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); + + if (max_alibi_bias > 0.0f) { + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } + + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); + + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + cb(kq, "kq_soft_max_ext", il); + } + + // split cached v into n_head heads (transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], n_kv, n_embd_head_v, n_head_kv, @@ -4352,59 +4418,15 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - // TODO: determine if we can use flash attention - const bool supports_flash_attn = true; - - struct ggml_tensor * kqv; - - if (supports_flash_attn) { - //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); - //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); - //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); - //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); - kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); - } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } - - if (max_alibi_bias > 0.0f) { - // temporary branch until we figure out how to handle ggml_alibi through ggml_add - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); - - if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } - - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); - - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); - cb(kq, "kq_soft_max_ext", il); - } - - kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); - } + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); cb(cur, "kqv_merged_cont", il); +#endif cur = ggml_mul_mat(ctx, wo, cur); if (wo_b) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5693c2197..a56c0d6c5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1390,21 +1390,21 @@ struct test_flash_attn_ext : public test_case { const int64_t hs; // head size const int64_t nh; // num heads const int64_t kv; // kv size - const int64_t nt; // tokens + const int64_t nb; // batch size std::string vars() override { - return VARS_TO_STR5(typeq, hs, nh, kv, nt); + return VARS_TO_STR5(typeq, hs, nh, kv, nb); } test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, - int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) - : typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} + int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); + ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); return out; } From 1173f49c3bbe30810af4aeb77219eba7e05f658d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 17:32:28 +0200 Subject: [PATCH 04/57] metal : initial implementation --- ggml-metal.m | 69 +++++++++++++------ ggml-metal.metal | 138 ++++++++++++++++++++++++++++++++++--- ggml.c | 2 +- tests/test-backend-ops.cpp | 4 ++ 4 files changed, 180 insertions(+), 33 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 4d85dd3dd..556c53482 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -278,6 +278,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { NSURL * libURL = [NSURL fileURLWithPath:libPath]; GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } else { GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); @@ -316,13 +320,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { //[options setFastMathEnabled:false]; ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } } - - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } } // print MTL GPU family: @@ -396,6 +399,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ + GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ return NULL; \ @@ -2171,12 +2177,28 @@ static bool ggml_metal_graph_compute( struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + GGML_ASSERT(ggml_are_same_shape(src1, src2)); + size_t offs_src2 = 0; size_t offs_src3 = 0; - id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + GGML_ASSERT(src2); + id id_src2 = ggml_metal_get_buffer(ctx, src2, &offs_src2); + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + float scale; memcpy(&scale, dst->op_params, sizeof(float)); @@ -2197,25 +2219,28 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; + [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int nwarps = 4; + const int nwarps = 1; - // each warp needs n_embd_head elements - GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; + GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index a1e1755a3..5986bcb42 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const float * mask, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1973,20 +1973,138 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, constant float & scale, threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - // TODO: implement + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; + + if (iq1 >= ne01) { + return; + } + + const int64_t D = ne00; + + // TODO: can we move this to the stack? + threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + + // initialize with zeros + for (int64_t d = 0; d < D; ++d) { + V16[d] = 0.0h; + } + + threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + + half S = 0.0h; + half M = -INFINITY; + + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; + + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; + + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; + + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; + + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; + + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; + + // load Q to shared memory + for (int64_t d = 0; d < D; ++d) { + pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + } + + for (int64_t ic = 0; ic < ne11; ++ic) { + const half mv = mp ? mp[ic] : 0.0h; + if (mv == -INFINITY) { + continue; + } + + half s = 0.0f; + + //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t d = 0; d < D; ++d) { + s += pk[d] * pq[d]; + } + + s = s*scale + mv; + + const half Mold = M; + + half ms = 1.0f; + half vs = 1.0f; + + if (s > M) { + M = s; + ms = exp(Mold - M); + + // V = V*exp(Mold - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] *= ms; + } + } else { + vs = exp(s - M); + } + + device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + + // V += v*exp(s - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] += pv[d] * vs; + } + + S = S*ms + vs; + } + + for (int64_t d = 0; d < D; ++d) { + V16[d] /= S; + } + + // dst indices + const int64_t i1 = iq1; + const int64_t i2 = iq2; + const int64_t i3 = iq3; + + for (int64_t d = 0; d < D; ++d) { + dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + } } kernel void kernel_cpy_f16_f16( diff --git a/ggml.c b/ggml.c index e64a328fa..10df03c9c 100644 --- a/ggml.c +++ b/ggml.c @@ -13419,8 +13419,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ik2 = iq2 / rk2; // v indices - const int iv2 = iq2 / rv2; const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; // online softmax / attention // loop over n_kv and n_head_kv diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a56c0d6c5..51a33c662 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1396,6 +1396,10 @@ struct test_flash_attn_ext : public test_case { return VARS_TO_STR5(typeq, hs, nh, kv, nb); } + double max_nmse_err() override { + return 5e-4; + } + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} From 528da7515ef874ab1188ab8f691c36d3e9e0cb20 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 11:13:24 +0200 Subject: [PATCH 05/57] metal : f16 precision --- ggml-metal.m | 6 ++++-- ggml-metal.metal | 40 ++++++++++++++++++++++------------------ 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 556c53482..e67a7c4ef 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2237,8 +2237,10 @@ static bool ggml_metal_graph_compute( const int nwarps = 1; - GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; + const size_t shalf = sizeof(float)/2; + + GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 5986bcb42..e4e89b5b3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1988,7 +1988,7 @@ kernel void kernel_flash_attn_ext_f16( constant int64_t & ne2, constant int64_t & ne3, constant float & scale, - threadgroup float * shared [[threadgroup(0)]], + threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]], @@ -2003,16 +2003,17 @@ kernel void kernel_flash_attn_ext_f16( } const int64_t D = ne00; + const int64_t D4 = D/4; // TODO: can we move this to the stack? - threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); // initialize with zeros - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] = 0.0h; } - threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); half S = 0.0h; half M = -INFINITY; @@ -2045,8 +2046,8 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv3 = iq3 / rv3; // load Q to shared memory - for (int64_t d = 0; d < D; ++d) { - pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + for (int64_t d = 0; d < D4; ++d) { + pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; } for (int64_t ic = 0; ic < ne11; ++ic) { @@ -2055,15 +2056,16 @@ kernel void kernel_flash_attn_ext_f16( continue; } - half s = 0.0f; + half4 s4 = 0.0f; - //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); - device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t d = 0; d < D; ++d) { - s += pk[d] * pq[d]; + for (int64_t d = 0; d < D4; ++d) { + s4 += pk4[d] * pq4[d]; } + half s = s4.x + s4.y + s4.z + s4.w; + s = s*scale + mv; const half Mold = M; @@ -2076,24 +2078,24 @@ kernel void kernel_flash_attn_ext_f16( ms = exp(Mold - M); // V = V*exp(Mold - M) - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] *= ms; } } else { vs = exp(s - M); } - device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); // V += v*exp(s - M) - for (int64_t d = 0; d < D; ++d) { - V16[d] += pv[d] * vs; + for (int64_t d = 0; d < D4; ++d) { + V16[d] += pv4[d] * vs; } S = S*ms + vs; } - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < D4; ++d) { V16[d] /= S; } @@ -2102,8 +2104,10 @@ kernel void kernel_flash_attn_ext_f16( const int64_t i2 = iq2; const int64_t i3 = iq3; - for (int64_t d = 0; d < D; ++d) { - dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + device float4 * dst4 = (device float4 *) dst; + + for (int64_t d = 0; d < D4; ++d) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; } } From 52ae085750afd37affc4ed18fe092d92c9ccdc5f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 11:38:17 +0200 Subject: [PATCH 06/57] metal : reduce branches --- ggml-metal.metal | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index e4e89b5b3..f3a7efafa 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2056,40 +2056,26 @@ kernel void kernel_flash_attn_ext_f16( continue; } - half4 s4 = 0.0f; + device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); - device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + half4 s4 = 0.0h; for (int64_t d = 0; d < D4; ++d) { s4 += pk4[d] * pq4[d]; } - half s = s4.x + s4.y + s4.z + s4.w; - - s = s*scale + mv; + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; const half Mold = M; - half ms = 1.0f; - half vs = 1.0f; + M = max(M, s); - if (s > M) { - M = s; - ms = exp(Mold - M); + const half ms = exp(Mold - M); + const half vs = exp(s - M); - // V = V*exp(Mold - M) - for (int64_t d = 0; d < D4; ++d) { - V16[d] *= ms; - } - } else { - vs = exp(s - M); - } - - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); - - // V += v*exp(s - M) for (int64_t d = 0; d < D4; ++d) { - V16[d] += pv4[d] * vs; + V16[d] = V16[d]*ms + pv4[d]*vs; } S = S*ms + vs; From b97325800a7727244e737715fa7b5e2bc41afb21 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 12:01:55 +0200 Subject: [PATCH 07/57] metal : specialize for head size --- ggml-metal.m | 259 +++++++++++++++++++++++++---------------------- ggml-metal.metal | 42 +++++++- 2 files changed, 179 insertions(+), 122 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e67a7c4ef..046643146 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -147,7 +147,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -412,125 +414,127 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } return ctx; @@ -2172,6 +2176,7 @@ static bool ggml_metal_graph_compute( } break; case GGML_OP_FLASH_ATTN_EXT: { + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(src0->type == GGML_TYPE_F16); struct ggml_tensor * src2 = gf->nodes[i]->src[2]; @@ -2202,7 +2207,19 @@ static bool ggml_metal_graph_compute( float scale; memcpy(&scale, dst->op_params, sizeof(float)); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; + id pipeline = nil; + + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } // TODO: extend if necessary [encoder setComputePipelineState:pipeline]; diff --git a/ggml-metal.metal b/ggml-metal.metal index f3a7efafa..d97952f2b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,43 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]); + +template // head size kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2002,7 +2039,6 @@ kernel void kernel_flash_attn_ext_f16( return; } - const int64_t D = ne00; const int64_t D4 = D/4; // TODO: can we move this to the stack? @@ -2097,6 +2133,10 @@ kernel void kernel_flash_attn_ext_f16( } } +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, From 8cde449b8be4e481db2a8790d9320c743b3ed65e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 12:23:22 +0200 Subject: [PATCH 08/57] wip : 8 rows per simd group --- ggml-metal.m | 10 +-- ggml-metal.metal | 175 ++++++++++++++++++++++++++++++++++++----------- 2 files changed, 140 insertions(+), 45 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 046643146..0b1119c4e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int nwarps = 1; + const int64_t nwarps = 2; - const size_t shalf = sizeof(float)/2; + const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); - GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index d97952f2b..789b19bad 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2031,33 +2031,20 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]; - const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; + //const int64_t iq3 = tgpig[2]; + //const int64_t iq2 = tgpig[1]; + //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - if (iq1 >= ne01) { + const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups + + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq1 = tgpig[0]; + + if (iq2 >= ne02) { return; } - const int64_t D4 = D/4; - - // TODO: can we move this to the stack? - threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); - - // initialize with zeros - for (int64_t d = 0; d < D4; ++d) { - V16[d] = 0.0h; - } - - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); - - half S = 0.0h; - half M = -INFINITY; - - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; - // assume K and V are same shape const int64_t ne22 = ne12; const int64_t ne23 = ne13; @@ -2081,11 +2068,97 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv2 = iq2 / rv2; const int64_t iv3 = iq3 / rv3; - // load Q to shared memory - for (int64_t d = 0; d < D4; ++d) { - pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + +// const int64_t D4 = D/4; +// +// // TODO: can we move this to the stack? +// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); +// +// // initialize with zeros +// for (int64_t d = 0; d < D4; ++d) { +// +// } +// +// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); +// +// // load Q to shared memory +// for (int64_t d = 0; d < D4; ++d) { +// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; +// } +// +// half S = 0.0h; +// half M = -INFINITY; +// +// for (int64_t ic = 0; ic < ne11; ++ic) { +// const half mv = mp ? mp[ic] : 0.0h; +// if (mv == -INFINITY) { +// continue; +// } +// +// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); +// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); +// +// half4 s4 = 0.0h; +// +// for (int64_t d = 0; d < D4; ++d) { +// s4 += pk4[d] * pq4[d]; +// } +// +// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; +// +// const half Mold = M; +// +// M = max(M, s); +// +// const half ms = exp(Mold - M); +// const half vs = exp(s - M); +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] = V16[d]*ms + pv4[d]*vs; +// } +// +// S = S*ms + vs; +// } +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] /= S; +// } +// +// // dst indices +// const int64_t i1 = iq1; +// const int64_t i2 = iq2; +// const int64_t i3 = iq3; +// +// device float4 * dst4 = (device float4 *) dst; +// +// for (int64_t d = 0; d < D4; ++d) { +// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; +// } + + const int64_t D4 = D/4; + + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + + const uint tiih = tiisg%4; // thread index in head + const uint hiisg = tiisg/4; // head index in simdgroup + + // load 8 heads from Q to shared memory + for (int64_t i = 0; i < D4/4; ++i) { + pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; + ps4[hiisg*D4 + 4*i + tiih] = 0.0h; } + simdgroup_barrier(mem_flags::mem_threadgroup); + + half S = 0.0h; + half M = -INFINITY; + for (int64_t ic = 0; ic < ne11; ++ic) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { @@ -2097,30 +2170,52 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t d = 0; d < D4; ++d) { - s4 += pk4[d] * pq4[d]; + for (int64_t i = 0; i < D4/4; ++i) { + s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + ss4[hiisg*4 + tiih] = s4; - const half Mold = M; + simdgroup_barrier(mem_flags::mem_threadgroup); - M = max(M, s); + if (tiih == 0) { + s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; - const half ms = exp(Mold - M); - const half vs = exp(s - M); + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; - for (int64_t d = 0; d < D4; ++d) { - V16[d] = V16[d]*ms + pv4[d]*vs; + const half Mold = M; + + M = max(M, s); + + const half ms = exp(Mold - M); + const half vs = exp(s - M); + + S = S*ms + vs; + + ss[2*hiisg + 0] = ms; + ss[2*hiisg + 1] = vs; } - S = S*ms + vs; + simdgroup_barrier(mem_flags::mem_threadgroup); + + const half ms = ss[2*hiisg + 0]; + const half vs = ss[2*hiisg + 1]; + + for (int64_t i = 0; i < D4/4; ++i) { + ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; + } } - for (int64_t d = 0; d < D4; ++d) { - V16[d] /= S; + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (tiih == 0) { + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] /= S; + } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // dst indices const int64_t i1 = iq1; const int64_t i2 = iq2; @@ -2128,8 +2223,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t d = 0; d < D4; ++d) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; + for (int64_t i = 0; i < D4/4; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; } } From f31955f5d12da67f35aa459996a171975fdf269b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 18:01:28 +0200 Subject: [PATCH 09/57] wip : 4 rows per simd group --- ggml-metal.m | 6 +++--- ggml-metal.metal | 39 +++++++++++++++++++++------------------ 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 0b1119c4e..abb96d6ec 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 2; + const int64_t nwarps = 4; - const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); + const size_t smem = nwarps*(2*4*ne00 + 128)*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 4*nwarps - 1)/(4*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 789b19bad..6fdd7fdad 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2038,7 +2038,7 @@ kernel void kernel_flash_attn_ext_f16( const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq2 = tgpig[1]*(4*nsg) + 4*sgitg + tiisg/8; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2140,18 +2140,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 4*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 2*4*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*4*D + 128) + 2*4*D); - const uint tiih = tiisg%4; // thread index in head - const uint hiisg = tiisg/4; // head index in simdgroup + const uint tiih = tiisg%8; // thread index in head + const uint hiisg = tiisg/8; // head index in simdgroup // load 8 heads from Q to shared memory - for (int64_t i = 0; i < D4/4; ++i) { - pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; - ps4[hiisg*D4 + 4*i + tiih] = 0.0h; + for (int64_t i = 0; i < D4/8; ++i) { + pq4[hiisg*D4 + 8*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[8*i + tiih]; + ps4[hiisg*D4 + 8*i + tiih] = 0.0h; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2170,16 +2170,18 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t i = 0; i < D4/4; ++i) { - s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; +#pragma unroll(D4/8) + for (int64_t i = 0; i < D4/8; ++i) { + s4 += pk4[8*i + tiih] * pq4[hiisg*D4 + 8*i + tiih]; } - ss4[hiisg*4 + tiih] = s4; + ss4[hiisg*8 + tiih] = s4; simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; + s4 = ss4[8*hiisg + 0] + ss4[8*hiisg + 1] + ss4[8*hiisg + 2] + ss4[8*hiisg + 3] + + ss4[8*hiisg + 4] + ss4[8*hiisg + 5] + ss4[8*hiisg + 6] + ss4[8*hiisg + 7]; half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; @@ -2201,8 +2203,9 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; - for (int64_t i = 0; i < D4/4; ++i) { - ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; +#pragma unroll(D4/8) + for (int64_t i = 0; i < D4/8; ++i) { + ps4[hiisg*D4 + 8*i + tiih] = ps4[hiisg*D4 + 8*i + tiih]*ms + pv4[8*i + tiih]*vs; } } @@ -2223,8 +2226,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/4; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; + for (int64_t i = 0; i < D4/8; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 8*i + tiih] = (float4) ps4[hiisg*D4 + 8*i + tiih]; } } From a4b6341c7b2a1977c29e79b17a0e5de3e31a5420 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 18:24:13 +0200 Subject: [PATCH 10/57] wip : template for rows per warp --- ggml-metal.m | 7 ++++--- ggml-metal.metal | 54 +++++++++++++++++++++++++----------------------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index abb96d6ec..d521df43a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 4; + const int64_t nwarps = 8; + const int64_t nhpw = 4; // heads per warp - const size_t smem = nwarps*(2*4*ne00 + 128)*(sizeof(float)/2); + const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 4*nwarps - 1)/(4*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 6fdd7fdad..c9876c103 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size +template // head size, rows per warp kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2036,9 +2036,10 @@ kernel void kernel_flash_attn_ext_f16( //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups + const uint tph = N_SIMDWIDTH/R; // threads per head const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(4*nsg) + 4*sgitg + tiisg/8; + const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2140,18 +2141,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) ); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 4*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*4*D + 128) + 2*4*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*4*D + 128) + 2*4*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D); - const uint tiih = tiisg%8; // thread index in head - const uint hiisg = tiisg/8; // head index in simdgroup + const uint tiih = tiisg%tph; // thread index in head + const uint hiisg = tiisg/tph; // head index in simdgroup - // load 8 heads from Q to shared memory - for (int64_t i = 0; i < D4/8; ++i) { - pq4[hiisg*D4 + 8*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[8*i + tiih]; - ps4[hiisg*D4 + 8*i + tiih] = 0.0h; + // load R heads from Q to shared memory + for (int64_t i = 0; i < D4/tph; ++i) { + pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2170,18 +2171,20 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; -#pragma unroll(D4/8) - for (int64_t i = 0; i < D4/8; ++i) { - s4 += pk4[8*i + tiih] * pq4[hiisg*D4 + 8*i + tiih]; + for (int64_t i = 0; i < D4/tph; ++i) { + s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } - ss4[hiisg*8 + tiih] = s4; + ss4[hiisg*tph + tiih] = s4; simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = ss4[8*hiisg + 0] + ss4[8*hiisg + 1] + ss4[8*hiisg + 2] + ss4[8*hiisg + 3] + - ss4[8*hiisg + 4] + ss4[8*hiisg + 5] + ss4[8*hiisg + 6] + ss4[8*hiisg + 7]; + s4 = 0.0h; + + for (int64_t i = 0; i < tph; ++i) { + s4 += ss4[hiisg*tph + i]; + } half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; @@ -2203,9 +2206,8 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; -#pragma unroll(D4/8) - for (int64_t i = 0; i < D4/8; ++i) { - ps4[hiisg*D4 + 8*i + tiih] = ps4[hiisg*D4 + 8*i + tiih]*ms + pv4[8*i + tiih]*vs; + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; } } @@ -2226,14 +2228,14 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/8; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 8*i + tiih] = (float4) ps4[hiisg*D4 + 8*i + tiih]; + for (int64_t i = 0; i < D4/tph; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; kernel void kernel_cpy_f16_f16( device const half * src0, From 77d08f3272c62900b40d110bf0de7f4466675c71 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 21:04:15 +0200 Subject: [PATCH 11/57] metal : parallelize across KV size --- ggml-metal.m | 8 +-- ggml-metal.metal | 137 +++++++++++++++++------------------------------ 2 files changed, 52 insertions(+), 93 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index d521df43a..a60dd779a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,15 +2252,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 8; - const int64_t nhpw = 4; // heads per warp + const int64_t nwarps = 16; + const int64_t nhptg = 4; // heads per threadgroup - const size_t smem = nwarps*(2*nhpw*ne00 + 128)*(sizeof(float)/2); + const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhpw*nwarps - 1)/(nhpw*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index c9876c103..539e26c91 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, rows per warp +template // head size, rows per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,15 +2031,11 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - //const int64_t iq3 = tgpig[2]; - //const int64_t iq2 = tgpig[1]; - //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - - const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups - const uint tph = N_SIMDWIDTH/R; // threads per head + const uint nsg = ntg.y; // number of simdgroups + const uint tph = N_SIMDWIDTH/R; // threads per head const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*(R*nsg) + R*sgitg + tiisg/tph; + const int64_t iq2 = tgpig[1]*R + tiisg/tph; const int64_t iq1 = tgpig[0]; if (iq2 >= ne02) { @@ -2073,94 +2069,30 @@ kernel void kernel_flash_attn_ext_f16( device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; -// const int64_t D4 = D/4; -// -// // TODO: can we move this to the stack? -// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); -// -// // initialize with zeros -// for (int64_t d = 0; d < D4; ++d) { -// -// } -// -// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); -// -// // load Q to shared memory -// for (int64_t d = 0; d < D4; ++d) { -// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; -// } -// -// half S = 0.0h; -// half M = -INFINITY; -// -// for (int64_t ic = 0; ic < ne11; ++ic) { -// const half mv = mp ? mp[ic] : 0.0h; -// if (mv == -INFINITY) { -// continue; -// } -// -// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); -// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); -// -// half4 s4 = 0.0h; -// -// for (int64_t d = 0; d < D4; ++d) { -// s4 += pk4[d] * pq4[d]; -// } -// -// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; -// -// const half Mold = M; -// -// M = max(M, s); -// -// const half ms = exp(Mold - M); -// const half vs = exp(s - M); -// -// for (int64_t d = 0; d < D4; ++d) { -// V16[d] = V16[d]*ms + pv4[d]*vs; -// } -// -// S = S*ms + vs; -// } -// -// for (int64_t d = 0; d < D4; ++d) { -// V16[d] /= S; -// } -// -// // dst indices -// const int64_t i1 = iq1; -// const int64_t i2 = iq2; -// const int64_t i3 = iq3; -// -// device float4 * dst4 = (device float4 *) dst; -// -// for (int64_t d = 0; d < D4; ++d) { -// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; -// } - const int64_t D4 = D/4; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 0*R*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 1*R*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*R*D + 128) + 2*R*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*R*D + 128) + 2*R*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D); const uint tiih = tiisg%tph; // thread index in head const uint hiisg = tiisg/tph; // head index in simdgroup // load R heads from Q to shared memory for (int64_t i = 0; i < D4/tph; ++i) { - pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + if (sgitg == 0) { + pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + } + ps4[hiisg*D4 + tph*i + tiih] = 0.0h; } - simdgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); half S = 0.0h; half M = -INFINITY; - for (int64_t ic = 0; ic < ne11; ++ic) { + for (int64_t ic = sgitg; ic < ne11; ic += nsg) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { continue; @@ -2175,18 +2107,18 @@ kernel void kernel_flash_attn_ext_f16( s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } - ss4[hiisg*tph + tiih] = s4; + ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); simdgroup_barrier(mem_flags::mem_threadgroup); if (tiih == 0) { - s4 = 0.0h; + half s = 0.0h; for (int64_t i = 0; i < tph; ++i) { - s4 += ss4[hiisg*tph + i]; + s += ss[hiisg*tph + i]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + s = s*scale + mv; const half Mold = M; @@ -2211,9 +2143,34 @@ kernel void kernel_flash_attn_ext_f16( } } - simdgroup_barrier(mem_flags::mem_threadgroup); - if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // reduce the warps + if (sgitg == 0 && tiih == 0) { + for (int64_t sg = 1; sg < nsg; ++sg) { + const half S0 = S; + const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + + const half M0 = M; + const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; + } + } + for (int64_t i = 0; i < D4; ++i) { ps4[hiisg*D4 + i] /= S; } @@ -2228,8 +2185,10 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t i = 0; i < D4/tph; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + if (sgitg == 0) { + for (int64_t i = 0; i < D4/tph; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + } } } From 17720fad669eed6171ddf17184da5bab50adeb72 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 21 Jan 2024 22:44:41 +0200 Subject: [PATCH 12/57] metal : parallel reduce across heads --- ggml-metal.m | 4 ++-- ggml-metal.metal | 32 ++++++++++++++++++++------------ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a60dd779a..fdfb50d3d 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,8 +2252,8 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 16; - const int64_t nhptg = 4; // heads per threadgroup + const int64_t nwarps = 32; + const int64_t nhptg = 2; // heads per threadgroup const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index 539e26c91..919119c8d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2103,6 +2103,7 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; +#pragma unroll for (int64_t i = 0; i < D4/tph; ++i) { s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; } @@ -2114,17 +2115,18 @@ kernel void kernel_flash_attn_ext_f16( if (tiih == 0) { half s = 0.0h; +#pragma unroll for (int64_t i = 0; i < tph; ++i) { s += ss[hiisg*tph + i]; } s = s*scale + mv; - const half Mold = M; + const half m = M; M = max(M, s); - const half ms = exp(Mold - M); + const half ms = exp(m - M); const half vs = exp(s - M); S = S*ms + vs; @@ -2138,6 +2140,7 @@ kernel void kernel_flash_attn_ext_f16( const half ms = ss[2*hiisg + 0]; const half vs = ss[2*hiisg + 1]; +#pragma unroll for (int64_t i = 0; i < D4/tph; ++i) { ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; } @@ -2151,12 +2154,12 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps - if (sgitg == 0 && tiih == 0) { + if (sgitg == 0) { for (int64_t sg = 1; sg < nsg; ++sg) { - const half S0 = S; + const half S0 = ss[ 2*hiisg + 0]; const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; - const half M0 = M; + const half M0 = ss[ 2*hiisg + 1]; const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; M = max(M0, M1); @@ -2166,13 +2169,18 @@ kernel void kernel_flash_attn_ext_f16( S = S0*ms0 + S1*ms1; - for (int64_t i = 0; i < D4; ++i) { - ps4[hiisg*D4 + i] = ps4[hiisg*D4 + i]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + i]*ms1; + if (tiih == 0) { + ss[2*hiisg + 0] = S; + ss[2*hiisg + 1] = M; + } + + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; } } - for (int64_t i = 0; i < D4; ++i) { - ps4[hiisg*D4 + i] /= S; + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; } } @@ -2192,9 +2200,9 @@ kernel void kernel_flash_attn_ext_f16( } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>; kernel void kernel_cpy_f16_f16( device const half * src0, From 1446a12b29f422a0c0040e62c16715a3fb7ce1cb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 23 Jan 2024 18:27:54 +0200 Subject: [PATCH 13/57] metal : efficient flash_attn_f16 implementation --- ggml-metal.m | 14 +- ggml-metal.metal | 293 +++++++++++++++++++++++-------------- tests/test-backend-ops.cpp | 6 +- 3 files changed, 195 insertions(+), 118 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index fdfb50d3d..7b161c69d 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2183,6 +2183,7 @@ static bool ggml_metal_graph_compute( struct ggml_tensor * src3 = gf->nodes[i]->src[3]; GGML_ASSERT(ggml_are_same_shape(src1, src2)); + GGML_ASSERT(src3); size_t offs_src2 = 0; size_t offs_src3 = 0; @@ -2252,15 +2253,20 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 32; - const int64_t nhptg = 2; // heads per threadgroup + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps) - const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2); + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) + //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); + const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 919119c8d..9b6ceec4e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, rows per threadgroup +template // head size, heads per threadgroup, queries per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,178 +2031,247 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups - const uint tph = N_SIMDWIDTH/R; // threads per head + const uint nsg = ntg.y; // number of simdgroups const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]*R + tiisg/tph; - const int64_t iq1 = tgpig[0]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*Q; if (iq2 >= ne02) { return; } - // assume K and V are same shape - const int64_t ne22 = ne12; - const int64_t ne23 = ne13; - - const uint64_t nb21 = nb11; - const uint64_t nb22 = nb12; - const uint64_t nb23 = nb13; - - // broadcast - const int64_t rk2 = ne02/ne12; - const int64_t rk3 = ne03/ne13; - - const int64_t rv2 = ne02/ne22; - const int64_t rv3 = ne03/ne23; - - // k indices - const int64_t ik2 = iq2 / rk2; - const int64_t ik3 = iq3 / rk3; - - // v indices - const int64_t iv2 = iq2 / rv2; - const int64_t iv3 = iq3 / rv3; - - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; - const int64_t D4 = D/4; + const int64_t N4 = N_SIMDWIDTH; + const int64_t L4 = (D4 + N4 - 1)/N4; + const int64_t D8 = D/8; - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D); + const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half + const int64_t T4 = T/4; // shared memory size per query in half4 - const uint tiih = tiisg%tph; // thread index in head - const uint hiisg = tiisg/tph; // head index in simdgroup + threadgroup half * pq = (threadgroup half *) (shared + 0*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); + threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); - // load R heads from Q to shared memory - for (int64_t i = 0; i < D4/tph; ++i) { - if (sgitg == 0) { - pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; + for (int64_t i = 0; i < L4; ++i) { + // load heads from Q to shared memory + for (int64_t j = sgitg; j < Q; j += nsg) { + if (iq1 + j < ne01) { + pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; + } else { + pq4[j*T4 + N4*i + tiisg] = 0.0h; + } } - ps4[hiisg*D4 + tph*i + tiih] = 0.0h; + // zero out shared memory + for (int64_t j = 0; j < Q; ++j) { + ps4[j*T4 + N4*i + tiisg] = 0.0h; + } + } + + if (tiisg < C) { + for (int64_t j = 0; j < Q; ++j) { + ss[j*T + 0 + tiisg] = 0.0h; + } } threadgroup_barrier(mem_flags::mem_threadgroup); - half S = 0.0h; - half M = -INFINITY; + { + half S[Q] = { 0.0h }; + half M[Q] = { -INFINITY }; - for (int64_t ic = sgitg; ic < ne11; ic += nsg) { - const half mv = mp ? mp[ic] : 0.0h; - if (mv == -INFINITY) { - continue; + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; + + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; + + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; + + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; + + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; + + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; + + simdgroup_half8x8 mq[D8]; + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mq[i], pq + i*8, T); } - device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + // TODO: this can be improved + device const float * mp[Q]; - half4 s4 = 0.0h; + { + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; -#pragma unroll - for (int64_t i = 0; i < D4/tph; ++i) { - s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih]; + for (int64_t j = 0; j < Q; ++j) { + if (iq1 + j < ne01) { + mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31); + } else { + mp[j] = nullptr; + } + } } - ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); + for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { + // skip -INF blocks + // TODO: double-check this + { + float smc = -INFINITY; - simdgroup_barrier(mem_flags::mem_threadgroup); + for (int64_t j = 0; j < Q; ++j) { + const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY; + smc = simd_max(max(smc, mc)); + } - if (tiih == 0) { - half s = 0.0h; - -#pragma unroll - for (int64_t i = 0; i < tph; ++i) { - s += ss[hiisg*tph + i]; + if (smc == -INFINITY) { + continue; + } } - s = s*scale + mv; + // Q*K^T + { + simdgroup_half8x8 mk; - const half m = M; + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); - M = max(M, s); + device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - const half ms = exp(m - M); - const half vs = exp(s - M); + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mk, pk + i*8, nb11/2, 0, true); - S = S*ms + vs; + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } - ss[2*hiisg + 0] = ms; - ss[2*hiisg + 1] = vs; + simdgroup_store(mqk, ss + 8*cc, T, 0, false); + } + } + + // online softmax + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = tiisg; + + const half s = ss[j*T + p]*scale + (mp[j][iic + p]); + + half m = M[j]; + + M[j] = simd_max(max(M[j], s)); + + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + + S[j] = S[j]*ms + simd_sum(vs); + + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] *= ms; + } + + ss[j*T + p] = vs; + } + + // (Q*K^T)*V + { + simdgroup_half8x8 mv; + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_half8x8 mp[C/8]; + simdgroup_half8x8 mqkv; + + simdgroup_load(mqkv, ps + i*8, T, 0, false); + + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); + } + + for (int cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + simdgroup_load(mv, pv + i*8, nb21/2, 0, false); + + simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); + } + + simdgroup_store(mqkv, ps + i*8, T, 0, false); + } + } } - simdgroup_barrier(mem_flags::mem_threadgroup); - - const half ms = ss[2*hiisg + 0]; - const half vs = ss[2*hiisg + 1]; - -#pragma unroll - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs; + for (int64_t j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } } } - if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; - } - threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps + // TODO: try parallel reduce if (sgitg == 0) { + half S = { 0.0h }; + half M = { -INFINITY }; + for (int64_t sg = 1; sg < nsg; ++sg) { - const half S0 = ss[ 2*hiisg + 0]; - const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - const half M0 = ss[ 2*hiisg + 1]; - const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - M = max(M0, M1); + M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); - S = S0*ms0 + S1*ms1; + S = S0*ms0 + S1*ms1; - if (tiih == 0) { - ss[2*hiisg + 0] = S; - ss[2*hiisg + 1] = M; + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + } + + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + } } - - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1; - } - } - - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S; } } simdgroup_barrier(mem_flags::mem_threadgroup); - // dst indices - const int64_t i1 = iq1; - const int64_t i2 = iq2; - const int64_t i3 = iq3; - device float4 * dst4 = (device float4 *) dst; if (sgitg == 0) { - for (int64_t i = 0; i < D4/tph; ++i) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih]; + for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (int64_t i = 0; i < L4; ++i) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S; + } } } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; kernel void kernel_cpy_f16_f16( device const half * src0, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 51a33c662..41ddfcca5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1397,7 +1397,7 @@ struct test_flash_attn_ext : public test_case { } double max_nmse_err() override { - return 5e-4; + return 5e-5; } test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, @@ -1680,7 +1680,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7)); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From d917746ddb053b73e868fd6e1854ac17b62bd863 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 15:00:49 +0200 Subject: [PATCH 14/57] metal : avoid redundant loads of the attention --- ggml-metal.metal | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9b6ceec4e..785a60e50 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2184,20 +2184,22 @@ kernel void kernel_flash_attn_ext_f16( ss[j*T + p] = vs; } + simdgroup_barrier(mem_flags::mem_none); + // (Q*K^T)*V { simdgroup_half8x8 mv; + simdgroup_half8x8 mp[C/8]; + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); + } + for (int64_t i = 0; i < D8; ++i) { - simdgroup_half8x8 mp[C/8]; simdgroup_half8x8 mqkv; simdgroup_load(mqkv, ps + i*8, T, 0, false); - for (int cc = 0; cc < C/8; ++cc) { - simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); - } - for (int cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); From 432ad04ffaa445a3837b92dce1c03513009ab4ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 15:47:52 +0200 Subject: [PATCH 15/57] metal : scale and mask in matrix form --- ggml-metal.metal | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 785a60e50..ae8f5caea 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2127,6 +2127,9 @@ kernel void kernel_flash_attn_ext_f16( } } + // prepare diagonal scale matrix + simdgroup_half8x8 mscale(scale); + for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { // skip -INF blocks // TODO: double-check this @@ -2153,11 +2156,16 @@ kernel void kernel_flash_attn_ext_f16( device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mk, pk + i*8, nb11/2, 0, true); + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } + // mqk = mqk*scale + mask + simdgroup_float8x8 mm; + simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + simdgroup_store(mqk, ss + 8*cc, T, 0, false); } } @@ -2166,7 +2174,8 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; - const half s = ss[j*T + p]*scale + (mp[j][iic + p]); + //const half s = ss[j*T + p]*scale + (mp[j][iic + p]); + const half s = ss[j*T + p]; half m = M[j]; @@ -2203,7 +2212,7 @@ kernel void kernel_flash_attn_ext_f16( for (int cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); - simdgroup_load(mv, pv + i*8, nb21/2, 0, false); + simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); } From 40ea8cd1aca61294e1987bcb1051317827f1b145 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 16:31:39 +0200 Subject: [PATCH 16/57] metal : fix comment --- ggml-metal.metal | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index ae8f5caea..9ab9e16c3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, heads per threadgroup, queries per threadgroup +template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2031,16 +2031,12 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint nsg = ntg.y; // number of simdgroups + const uint nsg = ntg.y; // number of simdgroups const int64_t iq3 = tgpig[2]; const int64_t iq2 = tgpig[1]; const int64_t iq1 = tgpig[0]*Q; - if (iq2 >= ne02) { - return; - } - const int64_t D4 = D/4; const int64_t N4 = N_SIMDWIDTH; const int64_t L4 = (D4 + N4 - 1)/N4; From f9ca5dcbe86a10cfa873814d5f754b7c9108f339 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 17:46:07 +0200 Subject: [PATCH 17/57] llama : avoid ggml_cast, use F32 query --- ggml-metal.m | 4 ++-- ggml-metal.metal | 3 ++- ggml.c | 31 +++++++++++++++++++++++++++---- ggml.h | 4 ++++ llama.cpp | 3 ++- tests/test-backend-ops.cpp | 16 +++++++--------- 6 files changed, 44 insertions(+), 17 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 7b161c69d..7b6762e6d 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2177,7 +2177,7 @@ static bool ggml_metal_graph_compute( case GGML_OP_FLASH_ATTN_EXT: { GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F32); struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; @@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9ab9e16c3..c9e4dcfe9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2054,8 +2054,9 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t i = 0; i < L4; ++i) { // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; + pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg]; } else { pq4[j*T4 + N4*i + tiisg] = 0.0h; } diff --git a/ggml.c b/ggml.c index 10df03c9c..5e515c03f 100644 --- a/ggml.c +++ b/ggml.c @@ -4178,6 +4178,8 @@ struct ggml_tensor * ggml_mul_mat( void ggml_mul_mat_set_prec( struct ggml_tensor * a, enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + const int32_t prec_i32 = (int32_t) prec; ggml_set_op_params_i32(a, 0, prec_i32); @@ -5781,6 +5783,16 @@ struct ggml_tensor * ggml_flash_attn_ext( return result; } +void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -13347,7 +13359,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(ne2 == N); GGML_ASSERT(P >= 0); - GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); @@ -13408,6 +13420,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( float M = -INFINITY; float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); memset(V16, 0, D*sizeof(ggml_fp16_t)); @@ -13433,10 +13446,19 @@ static void ggml_compute_forward_flash_attn_ext_f16( float s; + // convert Q to F16 in V32 + { + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + + for (int64_t d = 0; d < D; ++d) { + Q16[d] = GGML_FP32_TO_FP16(pq[d]); + } + } + ggml_vec_dot_f16(D, &s, (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + Q16); s = s*scale + mv; @@ -13488,13 +13510,14 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (q->type) { - case GGML_TYPE_F16: + switch (dst->op_params[1]) { + case GGML_PREC_DEFAULT: { ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); } break; default: { + // TODO: implement F32 precision GGML_ASSERT(false); } break; } diff --git a/ggml.h b/ggml.h index 7bca02f2a..e2f74412f 100644 --- a/ggml.h +++ b/ggml.h @@ -1633,6 +1633,10 @@ extern "C" { struct ggml_tensor * mask, float scale); + GGML_API void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 4e6c9f9cc..550caced4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4368,7 +4368,8 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 41ddfcca5..db1244876 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1386,26 +1386,24 @@ struct test_leaky_relu : public test_case { // GGML_OP_FLASH_ATTN_EXT struct test_flash_attn_ext : public test_case { - const ggml_type typeq; const int64_t hs; // head size const int64_t nh; // num heads const int64_t kv; // kv size const int64_t nb; // batch size std::string vars() override { - return VARS_TO_STR5(typeq, hs, nh, kv, nb); + return VARS_TO_STR4(hs, nh, kv, nb); } double max_nmse_err() override { return 5e-5; } - test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, - int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) - : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); @@ -1680,9 +1678,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8)); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7)); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 6fea843b246409a3c4b26156745a89e4ba01029b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 17:59:41 +0200 Subject: [PATCH 18/57] metal : add parallel reduce version (disabled) --- ggml-metal.m | 2 +- ggml-metal.metal | 42 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 7b6762e6d..cf7880c82 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) diff --git a/ggml-metal.metal b/ggml-metal.metal index c9e4dcfe9..6eb2825df 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2230,7 +2230,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps - // TODO: try parallel reduce +#if 1 if (sgitg == 0) { half S = { 0.0h }; half M = { -INFINITY }; @@ -2261,6 +2261,46 @@ kernel void kernel_flash_attn_ext_f16( } } } +#else + // parallel reduce + // NOTE: this is significantly slower than the serial version above, likely due to the small number of warps + { + half S = { 0.0h }; + half M = { -INFINITY }; + + for (int64_t sg = nsg/2; sg > 0; sg /= 2) { + if (sgitg >= sg) { + continue; + } + + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*(D + 1*C) + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*(D + 1*C) + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + } + + for (int64_t i = 0; i < L4; ++i) { + ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } +#endif simdgroup_barrier(mem_flags::mem_threadgroup); From 77f6976a87f6d034cf0f7a77e14a011da7901911 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 13:15:00 +0200 Subject: [PATCH 19/57] metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments --- ggml-metal.m | 12 +-- ggml-metal.metal | 272 ++++++++++++++++++++++------------------------- 2 files changed, 136 insertions(+), 148 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index eabc16f41..a7e126bff 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2213,14 +2213,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps) - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) + const int64_t ncpsg = 32; // cache values per simdgroup - //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); - const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + + const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); diff --git a/ggml-metal.metal b/ggml-metal.metal index 6eb2825df..b564f014d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,6 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); +// ref: https://arxiv.org/pdf/2307.08691.pdf template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, @@ -2038,39 +2039,45 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iq1 = tgpig[0]*Q; const int64_t D4 = D/4; - const int64_t N4 = N_SIMDWIDTH; - const int64_t L4 = (D4 + N4 - 1)/N4; const int64_t D8 = D/8; + const int64_t NW = N_SIMDWIDTH; + const int64_t L4 = (D4 + NW - 1)/NW; + const int64_t SH = (C + Q); // shared memory per simdgroup in (half) - const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half - const int64_t T4 = T/4; // shared memory size per query in half4 + const int64_t T = D + nsg*SH; // shared memory size per query in (half) + const int64_t T4 = T/4; // shared memory size per query in (half4) - threadgroup half * pq = (threadgroup half *) (shared + 0*D); - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); - threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; for (int64_t i = 0; i < L4; ++i) { // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg]; + sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg]; } else { - pq4[j*T4 + N4*i + tiisg] = 0.0h; + sq4[j*T4 + NW*i + tiisg] = 0.0h; } } - - // zero out shared memory - for (int64_t j = 0; j < Q; ++j) { - ps4[j*T4 + N4*i + tiisg] = 0.0h; - } } + // zero out lo + for (int64_t i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); + } + + // zero out shared memory SH if (tiisg < C) { for (int64_t j = 0; j < Q; ++j) { - ss[j*T + 0 + tiisg] = 0.0h; + ss[j*T + tiisg] = 0.0h; + if (tiisg < Q) { + ss[j*T + C + tiisg] = 0.0h; + } } } @@ -2103,46 +2110,24 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv2 = iq2 / rv2; const int64_t iv3 = iq3 / rv3; + // load the queries from shared memory into local memory simdgroup_half8x8 mq[D8]; for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mq[i], pq + i*8, T); + simdgroup_load(mq[i], sq + i*8, T); } - // TODO: this can be improved - device const float * mp[Q]; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - { - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - for (int64_t j = 0; j < Q; ++j) { - if (iq1 + j < ne01) { - mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31); - } else { - mp[j] = nullptr; - } - } - } + // pointer to the mask + device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); // prepare diagonal scale matrix simdgroup_half8x8 mscale(scale); - for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { - // skip -INF blocks - // TODO: double-check this - { - float smc = -INFINITY; - - for (int64_t j = 0; j < Q; ++j) { - const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY; - smc = simd_max(max(smc, mc)); - } - - if (smc == -INFINITY) { - continue; - } - } - + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { // Q*K^T { simdgroup_half8x8 mk; @@ -2150,7 +2135,7 @@ kernel void kernel_flash_attn_ext_f16( for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); - device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); @@ -2160,65 +2145,77 @@ kernel void kernel_flash_attn_ext_f16( // mqk = mqk*scale + mask simdgroup_float8x8 mm; - simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false); simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); simdgroup_store(mqk, ss + 8*cc, T, 0, false); } } + // used to detect blocks full of -INF + half smax = -INFINITY; + // online softmax for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; - //const half s = ss[j*T + p]*scale + (mp[j][iic + p]); const half s = ss[j*T + p]; - half m = M[j]; - + smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); + const half m = M[j]; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); S[j] = S[j]*ms + simd_sum(vs); - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] *= ms; + // create an 8x8 diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; } + // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } - simdgroup_barrier(mem_flags::mem_none); + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } - // (Q*K^T)*V + // O = diag(ms)*O + { + simdgroup_half8x8 mm; + + simdgroup_load(mm, ss + C, T, 0, false); + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_multiply(lo[i], mm, lo[i]); + } + } + + // O = O + (Q*K^T)*V { simdgroup_half8x8 mv; - simdgroup_half8x8 mp[C/8]; for (int cc = 0; cc < C/8; ++cc) { - simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); - } + simdgroup_half8x8 mp; + simdgroup_load(mp, ss + 8*cc, T, 0, false); - for (int64_t i = 0; i < D8; ++i) { - simdgroup_half8x8 mqkv; - - simdgroup_load(mqkv, ps + i*8, T, 0, false); - - for (int cc = 0; cc < C/8; ++cc) { - device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + for (int64_t i = 0; i < D8; ++i) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); - simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); + simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); } - - simdgroup_store(mqkv, ps + i*8, T, 0, false); } } } + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int64_t j = 0; j < Q; ++j) { if (tiisg == 0) { ss[j*T + 0] = S[j]; @@ -2227,91 +2224,82 @@ kernel void kernel_flash_attn_ext_f16( } } - threadgroup_barrier(mem_flags::mem_threadgroup); + // reduce the warps sequentially + for (int64_t sg = 1; sg < nsg; ++sg) { + half S = { 0.0h }; + half M = { -INFINITY }; - // reduce the warps -#if 1 + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each simdgroup stores its output to shared memory, reusing sq4 + if (sgitg == sg) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_half8x8 ms0; + simdgroup_half8x8 ms1; + + simdgroup_load(ms0, ss + C, T, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, T, 0, false); + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } + } + } + + // store result to shared memory (reuse sq4) if (sgitg == 0) { - half S = { 0.0h }; - half M = { -INFINITY }; - - for (int64_t sg = 1; sg < nsg; ++sg) { - for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - - M = max(M0, M1); - - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); - - S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; - } - - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; - } - } + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); } } -#else - // parallel reduce - // NOTE: this is significantly slower than the serial version above, likely due to the small number of warps - { - half S = { 0.0h }; - half M = { -INFINITY }; - - for (int64_t sg = nsg/2; sg > 0; sg /= 2) { - if (sgitg >= sg) { - continue; - } - - for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(D + 1*C) + 0]; - - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(D + 1*C) + 1]; - - M = max(M0, M1); - - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); - - S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; - } - - for (int64_t i = 0; i < L4; ++i) { - ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } - } -#endif - - simdgroup_barrier(mem_flags::mem_threadgroup); device float4 * dst4 = (device float4 *) dst; + // final rescale with 1/S and store to global memory if (sgitg == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S; + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S; } } } From ecc466a460abc7ad73df3b22a3e0957170bcf7b9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 15:42:57 +0200 Subject: [PATCH 20/57] metal : add tests, fix scaling, support C > 32 --- ggml-metal.m | 6 ++-- ggml-metal.metal | 64 ++++++++++++++++++++------------------ tests/test-backend-ops.cpp | 14 ++++++--- 3 files changed, 47 insertions(+), 37 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a7e126bff..484ef8939 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2213,12 +2213,12 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8) + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32) // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index b564f014d..7b604eb61 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2041,7 +2041,6 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; const int64_t D8 = D/8; const int64_t NW = N_SIMDWIDTH; - const int64_t L4 = (D4 + NW - 1)/NW; const int64_t SH = (C + Q); // shared memory per simdgroup in (half) const int64_t T = D + nsg*SH; // shared memory size per query in (half) @@ -2054,14 +2053,15 @@ kernel void kernel_flash_attn_ext_f16( // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[D8]; - for (int64_t i = 0; i < L4; ++i) { - // load heads from Q to shared memory - for (int64_t j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + // load heads from Q to shared memory + for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = tiisg; i < D4; i += NW) { if (iq1 + j < ne01) { - sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg]; + sq4[j*T4 + i] = (half4) q4[i]; } else { - sq4[j*T4 + NW*i + tiisg] = 0.0h; + sq4[j*T4 + i] = 0.0h; } } } @@ -2072,12 +2072,9 @@ kernel void kernel_flash_attn_ext_f16( } // zero out shared memory SH - if (tiisg < C) { - for (int64_t j = 0; j < Q; ++j) { - ss[j*T + tiisg] = 0.0h; - if (tiisg < Q) { - ss[j*T + C + tiisg] = 0.0h; - } + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = tiisg; i < SH; i += NW) { + ss[j*T + i] = 0.0h; } } @@ -2157,27 +2154,34 @@ kernel void kernel_flash_attn_ext_f16( // online softmax for (int64_t j = 0; j < Q; ++j) { - const int64_t p = tiisg; - - const half s = ss[j*T + p]; - - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); - const half m = M[j]; - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; - S[j] = S[j]*ms + simd_sum(vs); + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + } + + const half ms = exp(m - M[j]); + + S[j] = S[j]*ms; // create an 8x8 diagonal matrix for rescaling the output - if (p == j) { + if (tiisg == j) { ss[j*T + C + j] = ms; } - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = exp(s - M[j]); + + S[j] = S[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } } // skip -INF blocks @@ -2231,7 +2235,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); - // each simdgroup stores its output to shared memory, reusing sq4 + // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { for (int64_t i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false); @@ -2284,7 +2288,7 @@ kernel void kernel_flash_attn_ext_f16( } } - // store result to shared memory (reuse sq4) + // store result to shared memory (reuse sq) if (sgitg == 0) { for (int64_t i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false); @@ -2298,8 +2302,8 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S; + for (int64_t i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4c98bef7c..4093a52f2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1395,7 +1395,7 @@ struct test_flash_attn_ext : public test_case { } double max_nmse_err() override { - return 5e-5; + return 5e-4; } test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) @@ -1677,9 +1677,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 3a428a10973a751af72b55b9ef396de9c305c6ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 17:47:22 +0200 Subject: [PATCH 21/57] metal : improve precision --- ggml-metal.metal | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 7b604eb61..b6b5fd997 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2120,7 +2120,7 @@ kernel void kernel_flash_attn_ext_f16( device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); // prepare diagonal scale matrix - simdgroup_half8x8 mscale(scale); + simdgroup_float8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2163,7 +2163,7 @@ kernel void kernel_flash_attn_ext_f16( M[j] = simd_max(max(M[j], s)); } - const half ms = exp(m - M[j]); + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); S[j] = S[j]*ms; @@ -2175,7 +2175,7 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = exp(s - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); S[j] = S[j] + simd_sum(vs); @@ -2255,8 +2255,8 @@ kernel void kernel_flash_attn_ext_f16( M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); + const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); S = S0*ms0 + S1*ms1; From 8612864108760897261d0d10101f68355899b03f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 18:10:16 +0200 Subject: [PATCH 22/57] ggml : fix f16 mad --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 6bba840d9..fc0886aec 100644 --- a/ggml.c +++ b/ggml.c @@ -1344,12 +1344,12 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const // leftovers for (int i = np; i < n; ++i) { - y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } #else // scalar for (int i = 0; i < n; ++i) { - y[i] += GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i])*v); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); } #endif } From 134c81c78dfdeaca988ea2505cc6f0c0aec2d243 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 22:23:40 +0200 Subject: [PATCH 23/57] metal : minor --- ggml-metal.metal | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index b6b5fd997..ad6a4a318 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2127,15 +2127,14 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { // Q*K^T { - simdgroup_half8x8 mk; - for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } @@ -2192,7 +2191,6 @@ kernel void kernel_flash_attn_ext_f16( // O = diag(ms)*O { simdgroup_half8x8 mm; - simdgroup_load(mm, ss + C, T, 0, false); for (int64_t i = 0; i < D8; ++i) { @@ -2202,8 +2200,6 @@ kernel void kernel_flash_attn_ext_f16( // O = O + (Q*K^T)*V { - simdgroup_half8x8 mv; - for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mp; simdgroup_load(mp, ss + 8*cc, T, 0, false); @@ -2211,6 +2207,7 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t i = 0; i < D8; ++i) { device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + simdgroup_half8x8 mv; simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); From 1db22d7032fd55a612e400164cb70ad238bbc055 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 23:08:31 +0200 Subject: [PATCH 24/57] metal : support Q > 8 --- examples/batched-bench/batched-bench.cpp | 2 +- ggml-metal.m | 7 ++- ggml-metal.metal | 80 +++++++++++++++--------- 3 files changed, 55 insertions(+), 34 deletions(-) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 7924db267..4992b57f6 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -104,7 +104,7 @@ int main(int argc, char ** argv) { ctx_params.seed = 1234; ctx_params.n_ctx = n_kv_max; - ctx_params.n_batch = 512; + ctx_params.n_batch = 2048; ctx_params.mul_mat_q = mmq; ctx_params.n_threads = params.n_threads; diff --git a/ggml-metal.m b/ggml-metal.m index ef799ef57..a0dd1d0df 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2206,8 +2206,11 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8) - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32) + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) diff --git a/ggml-metal.metal b/ggml-metal.metal index ad6a4a318..08c000cc4 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2040,6 +2040,7 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; const int64_t D8 = D/8; + const int64_t Q8 = Q/8; const int64_t NW = N_SIMDWIDTH; const int64_t SH = (C + Q); // shared memory per simdgroup in (half) @@ -2051,7 +2052,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[D8]; + simdgroup_half8x8 lo[Q8][D8]; // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { @@ -2067,8 +2068,10 @@ kernel void kernel_flash_attn_ext_f16( } // zero out lo - for (int64_t i = 0; i < D8; ++i) { - lo[i] = make_filled_simdgroup_matrix(0.0h); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + lo[j][i] = make_filled_simdgroup_matrix(0.0h); + } } // zero out shared memory SH @@ -2108,10 +2111,12 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv3 = iq3 / rv3; // load the queries from shared memory into local memory - simdgroup_half8x8 mq[D8]; + simdgroup_half8x8 mq[Q8][D8]; - for (int64_t i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, T); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); + } } const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; @@ -2128,7 +2133,10 @@ kernel void kernel_flash_attn_ext_f16( // Q*K^T { for (int cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); + simdgroup_half8x8 mqk[Q8]; + for (int64_t j = 0; j < Q8; ++j) { + mqk[j] = make_filled_simdgroup_matrix(0.h); + } device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2136,15 +2144,19 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_half8x8 mk; simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); + } } // mqk = mqk*scale + mask - simdgroup_float8x8 mm; - simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false); - simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_float8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); - simdgroup_store(mqk, ss + 8*cc, T, 0, false); + simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); + } } } @@ -2166,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16( S[j] = S[j]*ms; - // create an 8x8 diagonal matrix for rescaling the output + // create a QxQ diagonal matrix for rescaling the output if (tiisg == j) { ss[j*T + C + j] = ms; } @@ -2189,28 +2201,30 @@ kernel void kernel_flash_attn_ext_f16( } // O = diag(ms)*O - { + for (int64_t j = 0; j < Q8; ++j) { simdgroup_half8x8 mm; - simdgroup_load(mm, ss + C, T, 0, false); + simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_multiply(lo[i], mm, lo[i]); + simdgroup_multiply(lo[j][i], mm, lo[j][i]); } } // O = O + (Q*K^T)*V { for (int cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mp; - simdgroup_load(mp, ss + 8*cc, T, 0, false); + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); for (int64_t i = 0; i < D8; ++i) { - device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - simdgroup_half8x8 mv; - simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false); + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_half8x8 mv; + simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); - simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]); + simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); + } } } } @@ -2234,8 +2248,10 @@ kernel void kernel_flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (int64_t i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + } } } @@ -2267,19 +2283,19 @@ kernel void kernel_flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { + for (int64_t j = 0; j < Q8; ++j) { simdgroup_half8x8 t; simdgroup_half8x8 ms0; simdgroup_half8x8 ms1; - simdgroup_load(ms0, ss + C, T, 0, false); - simdgroup_load(ms1, ss + C + sg*SH, T, 0, false); + simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); + simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); for (int64_t i = 0; i < D8; ++i) { - simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); simdgroup_multiply(t, ms1, t); - simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); } } } @@ -2287,8 +2303,10 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (int64_t i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + } } } From 4794821a31d5778b3398b8375d29fa63a539c8c4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 16:44:55 +0200 Subject: [PATCH 25/57] tests : add ATTN tests --- tests/test-backend-ops.cpp | 70 +++++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c18ff07ea..0ce498e9e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1418,6 +1418,48 @@ struct test_flash_attn_ext : public test_case { } }; +// Attention +struct test_attn : public test_case { + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nb; // batch size + + std::string op_desc(ggml_tensor * t) override { + return "ATTN"; + + GGML_UNUSED(t); + } + + std::string vars() override { + return VARS_TO_STR4(hs, nh, kv, nb); + } + + double max_nmse_err() override { + return 5e-4; + } + + test_attn(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + + struct ggml_tensor * cur; + + cur = ggml_mul_mat (ctx, k, q); + cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs)); + cur = ggml_mul_mat (ctx, v, cur); + cur = ggml_permute (ctx, cur, 0, 2, 1, 3); + cur = ggml_cont_2d (ctx, cur, hs*nh, nb); + + return cur; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1684,15 +1726,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1)); + test_cases.emplace_back(new test_attn(64, 32, 512, 8)); + test_cases.emplace_back(new test_attn(64, 32, 512, 7)); + test_cases.emplace_back(new test_attn(64, 32, 512, 1)); + test_cases.emplace_back(new test_attn(80, 32, 512, 8)); + test_cases.emplace_back(new test_attn(80, 32, 512, 7)); + test_cases.emplace_back(new test_attn(80, 32, 512, 1)); + test_cases.emplace_back(new test_attn(128, 32, 512, 8)); + test_cases.emplace_back(new test_attn(128, 32, 512, 7)); + test_cases.emplace_back(new test_attn(128, 32, 512, 1)); + + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From abeaf0d90ee82096a0aba20785f1e37bd1f3aa41 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 18:12:24 +0200 Subject: [PATCH 26/57] metal : disable buffer allocation logs --- ggml-metal.m | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a0dd1d0df..a637f0487 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2421,10 +2421,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe UNUSED(buft); } -static void ggml_backend_metal_log_allocated_size(id device) { +static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)", + __func__, + size_aligned / 1024.0 / 1024.0, device.currentAllocatedSize / 1024.0 / 1024.0, device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); @@ -2434,10 +2437,15 @@ static void ggml_backend_metal_log_allocated_size(id device) { GGML_METAL_LOG_INFO("\n"); } } else { - GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); } +#endif #endif UNUSED(device); + UNUSED(size_aligned); } GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -2471,8 +2479,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff return NULL; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); - ggml_backend_metal_log_allocated_size(device); + ggml_backend_metal_log_allocated_size(device, size_aligned); return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); } @@ -2549,7 +2556,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + ggml_backend_metal_log_allocated_size(device, size_aligned); ++ctx->n_buffers; } else { @@ -2572,7 +2579,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i); + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + if (i + size_step < size) { GGML_METAL_LOG_INFO("\n"); } @@ -2581,8 +2589,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, } } - ggml_backend_metal_log_allocated_size(device); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size); } From c6c1132e5e6658b3c209433ed5ef75067ef31a2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 18:22:28 +0200 Subject: [PATCH 27/57] tests : more --- ggml-metal.m | 9 +++++++++ ggml-metal.metal | 3 +++ ggml.c | 5 ----- tests/test-backend-ops.cpp | 29 ++++++++++------------------- 4 files changed, 22 insertions(+), 24 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a637f0487..4b5fd0bb8 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -137,7 +137,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -505,7 +508,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -2166,7 +2172,10 @@ static bool ggml_metal_graph_compute( switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; default: { GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); diff --git a/ggml-metal.metal b/ggml-metal.metal index 08c000cc4..be059d78f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2326,7 +2326,10 @@ kernel void kernel_flash_attn_ext_f16( template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; kernel void kernel_cpy_f16_f16( device const half * src0, diff --git a/ggml.c b/ggml.c index e8a5fcfa4..57271a1ad 100644 --- a/ggml.c +++ b/ggml.c @@ -13554,11 +13554,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t D = neq0; const int64_t N = neq1; - const int64_t P = nek1 - N; GGML_ASSERT(ne0 == D); GGML_ASSERT(ne2 == N); - GGML_ASSERT(P >= 0); GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); @@ -13569,7 +13567,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nev0 == D); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted @@ -13608,8 +13605,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( float scale = 1.0f; memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0ce498e9e..f57e8ab1a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1726,25 +1726,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_attn(64, 32, 512, 8)); - test_cases.emplace_back(new test_attn(64, 32, 512, 7)); - test_cases.emplace_back(new test_attn(64, 32, 512, 1)); - test_cases.emplace_back(new test_attn(80, 32, 512, 8)); - test_cases.emplace_back(new test_attn(80, 32, 512, 7)); - test_cases.emplace_back(new test_attn(80, 32, 512, 1)); - test_cases.emplace_back(new test_attn(128, 32, 512, 8)); - test_cases.emplace_back(new test_attn(128, 32, 512, 7)); - test_cases.emplace_back(new test_attn(128, 32, 512, 1)); - - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1)); + for (int hs : { 64, 80, 96, 112, 128, 256, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, 2048, 4096, }) { + for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 5fcb9c1c5af108056c8ad51fc1995de9d7707d2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 19:46:22 +0200 Subject: [PATCH 28/57] metal : faster inner loop for C == 32 --- ggml-metal.metal | 63 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index be059d78f..db4c7cfde 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2048,8 +2048,8 @@ kernel void kernel_flash_attn_ext_f16( const int64_t T4 = T/4; // shared memory size per query in (half4) threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention - threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[Q8][D8]; @@ -2164,35 +2164,60 @@ kernel void kernel_flash_attn_ext_f16( half smax = -INFINITY; // online softmax - for (int64_t j = 0; j < Q; ++j) { - const half m = M[j]; + if (C == 32) { + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = tiisg; - for (int64_t p = tiisg; p < C; p += NW) { + const half m = M[j]; const half s = ss[j*T + p]; smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - } - - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - - S[j] = S[j]*ms; - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*T + C + j] = ms; - } - - for (int64_t p = tiisg; p < C; p += NW) { - const half s = ss[j*T + p]; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j] + simd_sum(vs); + S[j] = S[j]*ms + simd_sum(vs); + + // create a QxQ diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; + } // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } + } else { + for (int64_t j = 0; j < Q; ++j) { + const half m = M[j]; + + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + } + + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + + S[j] = S[j]*ms; + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg == j) { + ss[j*T + C + j] = ms; + } + + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + + S[j] = S[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + } } // skip -INF blocks From d073e4f93337560e552f0d3de4b6b07bf13ef3f5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 21:45:32 +0200 Subject: [PATCH 29/57] metal : fix array initialization --- ggml-metal.metal | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index db4c7cfde..41f6169de 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2084,8 +2084,8 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); { - half S[Q] = { 0.0h }; - half M[Q] = { -INFINITY }; + half S[Q] = { [0 ... Q-1] = 0.0h }; + half M[Q] = { [0 ... Q-1] = -INFINITY }; // assume K and V are same shape const int64_t ne22 = ne12; From 78df5527e4e9eafb181200384fbed80c8116042e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 21:46:49 +0200 Subject: [PATCH 30/57] tests : ifdef --- tests/test-backend-ops.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f57e8ab1a..07182c6d8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1726,6 +1726,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); +#if 0 for (int hs : { 64, 80, 96, 112, 128, 256, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { @@ -1736,6 +1737,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } +#else + for (int hs : { 128, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, 512 }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } +#endif #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer From 8ad92dc1ec9aa6549c68900daa7ab93b57fa3ae5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 31 Jan 2024 19:17:16 +0200 Subject: [PATCH 31/57] ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext --- ggml-cuda.cu | 20 +++++++++---------- ggml-metal.m | 6 ++++++ ggml-metal.metal | 40 ++++++++++++++++++-------------------- ggml.c | 13 +++++++++---- ggml.h | 12 +++++++----- llama.cpp | 40 ++++++++++++++++++++++---------------- tests/test-backend-ops.cpp | 10 +++++----- 7 files changed, 79 insertions(+), 62 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e56595742..c57a031e4 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5917,7 +5917,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int } template -static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +static __global__ void soft_max_f16(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; @@ -5952,12 +5952,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds if (need_check && col_data + 0 >= ncols_data) { val.x = -INFINITY; } else { - val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f); + val.x = x[ix + 0]*scale + (y ? __half2float(y[iy + 0]) : 0.0f); } if (need_check && col_data + WARP_SIZE >= ncols_data) { val.y = -INFINITY; } else { - val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); + val.y = x[ix + WARP_SIZE]*scale + (y ? __half2float(y[iy + WARP_SIZE]) : 0.0f); } if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) { vals[col_smem] = val; @@ -6047,7 +6047,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds } template -static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +static __global__ void soft_max_f32(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -6077,7 +6077,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (y ? y[iy] : 0.0f); + const float val = x[ix]*scale + (y ? __half2float(y[iy]) : 0.0f); vals[col] = val; max_val = max(max_val, val); } @@ -7585,7 +7585,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } -static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f16_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -7628,7 +7628,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con } } -static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -9060,7 +9060,7 @@ static void ggml_cuda_op_soft_max( GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -9080,9 +9080,9 @@ static void ggml_cuda_op_soft_max( #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX if (use_f16_soft_max) { - soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + soft_max_f16_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); } else { - soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + soft_max_f32_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); } (void) dst; diff --git a/ggml-metal.m b/ggml-metal.m index 15e5568f9..e00069624 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1187,6 +1187,8 @@ static bool ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); + int nth = 32; // SIMD width id pipeline = nil; @@ -2213,6 +2215,10 @@ static bool ggml_metal_graph_compute( id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); diff --git a/ggml-metal.metal b/ggml-metal.metal index b2e40715d..04c1aaf9c 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -349,9 +349,9 @@ kernel void kernel_sum_rows( } kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -366,9 +366,9 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max float lmax = -INFINITY; @@ -435,14 +435,14 @@ kernel void kernel_soft_max( } kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, - threadgroup float * buf [[threadgroup(0)]], + threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], @@ -452,15 +452,15 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; // parallel max float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -486,7 +486,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -2144,13 +2144,11 @@ kernel void kernel_flash_attn_ext_f16( } } - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - // pointer to the mask - device const float * mp = (device const float *) (mask + (ir%ne31)*nb31); + device const half * mp = (device const half *) (mask + iq1*nb31); // prepare diagonal scale matrix - simdgroup_float8x8 mscale(scale); + simdgroup_half8x8 mscale(scale); // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -2176,8 +2174,8 @@ kernel void kernel_flash_attn_ext_f16( // mqk = mqk*scale + mask for (int64_t j = 0; j < Q8; ++j) { - simdgroup_float8x8 mm; - simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false); + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); diff --git a/ggml.c b/ggml.c index 466a8cdec..59a4c05a1 100644 --- a/ggml.c +++ b/ggml.c @@ -5085,6 +5085,7 @@ static struct ggml_tensor * ggml_soft_max_impl( bool inplace) { GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); @@ -5854,6 +5855,8 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); } @@ -11552,12 +11555,14 @@ static void ggml_compute_forward_soft_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); if (mp) { - ggml_vec_acc_f32(nc, wp, mp); + for (int i = 0; i < nc; ++i) { + wp[i] += GGML_FP16_TO_FP32(mp[i]); + } } #ifndef NDEBUG @@ -13760,7 +13765,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( memset(V16, 0, D*sizeof(ggml_fp16_t)); - const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL; + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; // k indices const int ik3 = iq3 / rk3; @@ -13774,7 +13779,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? mp[ic] : 0.0f; + const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; } diff --git a/ggml.h b/ggml.h index a83ff8035..74ce1abd4 100644 --- a/ggml.h +++ b/ggml.h @@ -1646,11 +1646,13 @@ extern "C" { struct ggml_tensor * v, bool masked); - // q: [n_embd, n_batch, n_head, 1] - // k: [n_embd, n_kv, n_head_kv, 1] - // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! - // mask: [n_kv, n_batch, 1, 1] - // res: [n_embd, n_head, n_batch, 1] !! permuted !! +#define GGML_KQ_MASK_PAD 32 + + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd, n_head, n_batch, 1] !! permuted !! GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 1f8ecc19b..fe2583966 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4721,7 +4721,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -4905,7 +4905,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5026,7 +5026,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5148,7 +5148,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); @@ -5245,7 +5245,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); if (do_rope_shift) { @@ -5448,7 +5448,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -5538,7 +5538,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); inpL = llm_build_norm(ctx0, inpL, hparams, @@ -5631,7 +5631,7 @@ struct llm_build_context { cb(inpL, "inp_embd", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -5731,7 +5731,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5854,7 +5854,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -5968,7 +5968,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6089,7 +6089,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6211,7 +6211,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6318,7 +6318,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); @@ -6416,7 +6416,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -6524,7 +6524,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); // shift the entire K-cache if needed @@ -10250,7 +10250,10 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; - cparams.n_batch = params.n_batch; + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -10430,6 +10433,9 @@ struct llama_context * llama_new_context_with_model( ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); + // zero-out the input buffer to prevent NaNs in padded tensors + ggml_backend_buffer_clear(ctx->buf_input, 0); + LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(ctx->buf_input), ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0f31c00f9..b1b30b91c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1101,7 +1101,7 @@ struct test_soft_max : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * b = nullptr; - if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); } + if (mask) { b = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]); } ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale); return out; } @@ -1472,7 +1472,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); return out; } @@ -1506,7 +1506,7 @@ struct test_attn : public test_case { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed - ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, 1); struct ggml_tensor * cur; @@ -1793,7 +1793,7 @@ struct test_llama : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); @@ -1915,7 +1915,7 @@ struct test_falcon : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); From 910b15bb4006409fe24b41da171cc562cdb1f3a4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 16:41:02 +0200 Subject: [PATCH 32/57] ggml : fix ggml_soft_max mask requirement --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 59a4c05a1..ebd9c6b34 100644 --- a/ggml.c +++ b/ggml.c @@ -5089,7 +5089,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } bool is_node = false; From 2e460137490a4e002a60a60aed052e90179bb65b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 16:47:20 +0200 Subject: [PATCH 33/57] cuda : fix soft_max to use correct mask size --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c57a031e4..15fc6154f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -9064,7 +9064,7 @@ static void ggml_cuda_op_soft_max( const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; + const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded! float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float)); From 5a19a9f6d0899becbc71a19454a27c0225edddf7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 19:47:11 +0200 Subject: [PATCH 34/57] cuda : add flash_attn kernel (wip) --- ggml-cuda.cu | 735 ++++++++++++++++++++++++++++++++++++++++++++++++++- llama.cpp | 3 +- 2 files changed, 735 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 15fc6154f..60d228a61 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -108,6 +108,7 @@ #include #include #include +#include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED @@ -655,6 +656,19 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } +static __device__ __forceinline__ half warp_reduce_sum(half x) { +#if __CUDA_ARCH__ >= CC_VOLTA +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hadd(__shfl_xor_sync(0xffffffff, x, mask, 32), x); + } + return x; +#else + (void) x; + NO_DEVICE_CODE; +#endif +} + static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -676,6 +690,18 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } +static __device__ __forceinline__ half warp_reduce_max(half x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + (void) x; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -989,6 +1015,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr if (lane_id == 0) { s_sum[warp_id] = tmp; } + __syncthreads(); tmp = s_sum[lane_id]; tmp = warp_reduce_sum(tmp); @@ -6249,6 +6276,528 @@ static __global__ void pool2d_nchw_kernel( o_ptr[cur_oh * ow + cur_ow] = res; } +#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256 + +template +static __global__ void flash_attn_f32( + const float* __restrict__ q, + const float* __restrict__ k, + const float* __restrict__ v, + float* __restrict__ kqv, + float kq_scale, + int head_dim, int seq_len, int num_heads) { + const int head = blockIdx.x / seq_len; + const int head_size = head_dim * seq_len; + const int s = blockIdx.x % seq_len; + + extern __shared__ char flash_attn_shmem_f32[]; + float* S = (float*)flash_attn_shmem_f32; + float* warp_data = (float*)(flash_attn_shmem_f32 + seq_len * sizeof(float)); + + // QK^T + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + const int key_offset = is * head_dim + head * head_size; + const int query_offset = s * head_dim + head * head_size; + + float tmp = 0.0f; + for(int d = 0; d < head_dim; d++) { + tmp += k[key_offset + d] * q[query_offset + d]; + } + S[is] = tmp * kq_scale; + } + __syncthreads(); + + float max_val = -INFINITY; + // get the max + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + max_val = fmaxf(max_val , S[is]); + } + + max_val = warp_reduce_max(max_val); + + { // get max from all threads + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + warp_data[warp_id] = max_val; + } + __syncthreads(); + max_val = warp_data[lane_id]; + max_val = warp_reduce_max(max_val); + } + + // softmax(QK^T) + float sum = 0.0f; + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + float tmp = expf(S[is] - max_val); + sum += tmp; + S[is] = tmp; + } + __syncthreads(); + + sum = warp_reduce_sum(sum); + { // softmax sum partials + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + warp_data[warp_id] = sum; + } + __syncthreads(); + sum = warp_data[lane_id]; + sum = warp_reduce_sum(sum); + } + + float inv_sum = 1.0f / sum; + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + S[is] *= inv_sum; + } + __syncthreads(); + + // softmax(QK^T)V + #pragma unroll + for (int d0 = threadIdx.x; d0 < k_head_dim; d0 += block_size) { + const int d = threadIdx.x + d0; + if(d >= head_dim) { + break; + } + const int dst_index = d + s * head_dim + head * head_size; + const int value_offset = d * seq_len + head * head_size; + + float temp = 0.0f; + #pragma unroll + for(int ic = 0; ic < k_seq_len;ic++) { + if(ic >= seq_len) { + break; + } + temp += v[value_offset + ic] * S[ic]; + } + kqv[dst_index] = temp; + } +} + +#if __CUDA_ARCH__ >= CC_VOLTA +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_acc; +#endif + +// based on metal version +template // D head size, Q queries per block, C cache items per block +static __global__ void flash_attn_ext_f16( + const char* __restrict__ q, + const char* __restrict__ k, + const char* __restrict__ v, + const char* __restrict__ mask, + float* __restrict__ dst, + float scale, + int ne00, + int ne01, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + int ne31, + int nb31, + int nb01, + int nb02, + int nb03, + int nb11, + int nb12, + int nb13, + int ne0, + int ne1, + int ne2, + int ne3) { +#if __CUDA_ARCH__ >= CC_VOLTA + const int warp_id = threadIdx.y; + const int lane_id = threadIdx.x; + + const int num_warps = blockDim.y; // number of warps + const int iq3 = blockIdx.z; + const int iq2 = blockIdx.y; + const int iq1 = blockIdx.x * Q; + + const int D2 = D/2; + const int D16 = D/16; + const int Q16 = Q/16; + const int NW = WARP_SIZE; + const int SH = (C + Q); // shared memory per simdgroup in (half) + + const int T = D + num_warps*SH; // shared memory size per query in (half) + const int T2 = T/2; // shared memory size per query in (half2) + + extern __shared__ half __flash_attn_f16_shmem[]; + // pq + half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data + half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 + half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + + half16x16_acc zr; + half16x16_acc lo[Q16][D16]; + + // load heads from Q to shared memory + for (int64_t j = warp_id; j < Q; j += num_warps) { + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = lane_id; i < D2; i += NW) { + if (iq1 + j < ne01) { + sq2[j*T2 + i] = __float22half2_rn(q2[i]); + } else { + sq2[j*T2 + i] = make_half2(0.0, 0.0); + } + } + } + + nvcuda::wmma::fill_fragment(zr, 0.0); + + // zero out lo + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + } + } + + // zero out shared memory SH + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = lane_id; i < SH; i += NW) { + ss[j*T + i] = 0.0; + } + } + + __syncthreads(); + + { + half S[Q]; + half M[Q]; + + for(int i = 0; i < Q; i++) { + S[i] = __float2half(0.0f); + M[i] = __float2half(-INFINITY); + } + + // assume K and V are same shape + const int ne22 = ne12; + const int ne23 = ne13; + + const int nb21 = nb11; + const int nb22 = nb12; + const int nb23 = nb13; + + // broadcast + const int rk2 = ne02/ne12; + const int rk3 = ne03/ne13; + + const int rv2 = ne02/ne22; + const int rv3 = ne03/ne23; + + // k indices + const int ik2 = iq2 / rk2; + const int ik3 = iq3 / rk3; + + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half16x16_a mq[Q16][D16]; + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); + } + } + + // pointer to the mask + const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + + // prepare diagonal scale matrix + half16x16_b mscale; + for (int i = 0; i < 16; ++i) { + ss[i*T + i] = __float2half(scale); + } + nvcuda::wmma::load_matrix_sync(mscale, ss, T); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { + // Q*K^T + { + for (int cc = 0; cc < C/16; ++cc) { + half16x16_acc mqk[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::fill_fragment(mqk[j], 0); + } + + const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t i = 0; i < D16; ++i) { + half16x16_bT mk; // transposed key + nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); + + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); + } + } + + // mqk = mqk*scale + mask + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a mqka; + half16x16_acc mm; + if(mp) { + nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); + } + + // convert accumulator to matrix_a + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); + + nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + } + } + } + + // used to detect blocks full of -INF + half smax = __float2half(-INFINITY); + + // online softmax + if (C == 32) { + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = lane_id; + + const half m = M[j]; + const half s = ss[j*T + p]; + + smax = warp_reduce_max(__hmax(smax, s)); + M[j] = warp_reduce_max(__hmax(M[j], s)); + + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + + S[j] = S[j]*ms + warp_reduce_sum(vs); + + // create a QxQ diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = ms; + } + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + } else { + for (int64_t j = 0; j < Q; ++j) { + const half m = M[j]; + + for (int64_t p = lane_id; p < C; p += NW) { + const half s = ss[j*T + p]; + + smax = __hmax(smax, s); + M[j] = __hmax(M[j], s); + } + + smax = warp_reduce_max(smax); + M[j] = warp_reduce_max(M[j]); + + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + } + + // local sum + half ls = 0.0f; + + for (int64_t p = lane_id; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + + ls += vs; + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + + S[j] = S[j]*ms + warp_reduce_sum(ls); + } + } + + // skip -INF blocks + if (__hisinf(smax)) { + continue; + } + + // O = diag(ms)*O + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a mm; + half16x16_b lob; + + nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + + for (int64_t i = 0; i < D16; ++i) { + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); + + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); + } + + // restore zeros + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); + } + + // O = O + (Q*K^T)*V + { + for (int cc = 0; cc < C/16; ++cc) { + const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + half16x16_b mk[D16]; + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); + } + + half16x16_a mv[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); + } + + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); + } + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (int64_t j = 0; j < Q; ++j) { + if (lane_id == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (int64_t sg = 1; sg < num_warps; ++sg) { + half S = __float2half(0.0f); + half M = __float2half(-INFINITY); + + __syncthreads(); + + // each simdgroup stores its output to shared memory, reusing sq + if (warp_id == sg) { + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + } + } + } + + __syncthreads(); + + // the first simdgroup accumulates the results from the other simdgroups + if (warp_id == 0) { + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; + + M = __hmax(M0, M1); + + const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (lane_id == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a ms0; + half16x16_a ms1; + half16x16_b t; + half16x16_acc t2; + + nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); + + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(t2, 0.0); + nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); + nvcuda::wmma::mma_sync(t2, ms1, t, t2); + + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); + + nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); + } + } + } + } + + // store result to shared memory (reuse sq) + if (warp_id == 0) { + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + } + } + } + + // final rescale with 1/S and store to global memory + if (warp_id == 0) { + for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (int64_t i = lane_id; i < D; i += NW) { + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); + } + } + } +#else + NO_DEVICE_CODE; +#endif +} + template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -7682,6 +8231,13 @@ static void im2col_cuda(const float* x, T* dst, im2col_kernel<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); } +static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) { + int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float); + int num_blocks = num_heads * seq_len; + flash_attn_f32<<>>( + q, k, v, dst, kq_scale, d_head, seq_len, num_heads); +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 256 @@ -8659,7 +9215,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec( src1_dfloat = src1_dfloat_a.alloc(ne00); ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, ne00, 1, sizeof(float), 0, 0, - ne00, 1, sizeof(half), 0, 0, stream); + ne00, 1, sizeof(half), 0, 0, 0, 0, 0, 0, stream); } #else const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion @@ -10284,6 +10840,170 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s } } +inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV) { + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F32); + GGML_ASSERT(V->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); + GGML_ASSERT(K->backend == GGML_BACKEND_GPU); + GGML_ASSERT(V->backend == GGML_BACKEND_GPU); + GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + + const int64_t d_head = Q->ne[0]; + const int64_t sequence_length = Q->ne[1]; + const int64_t num_heads = Q->ne[2]; + + GGML_ASSERT(Q->ne[0] == d_head); + GGML_ASSERT(K->ne[0] == d_head); + GGML_ASSERT(V->ne[1] == d_head); + + GGML_ASSERT(Q->ne[1] == sequence_length); + GGML_ASSERT(K->ne[1] == sequence_length); + GGML_ASSERT(V->ne[0] == sequence_length); + + GGML_ASSERT(Q->ne[2] == num_heads); + GGML_ASSERT(K->ne[2] == num_heads); + GGML_ASSERT(V->ne[2] == num_heads); + + float KQ_scale = 1.0f / sqrtf((float)d_head); + + flash_attn_f32_cuda( + (float *) src0_extra->data_device[g_main_device], // Query + (float *) src1_extra->data_device[g_main_device], // Key + (float *) src2_extra->data_device[g_main_device], // Value + (float *) dst_extra->data_device[g_main_device], // dst + KQ_scale, d_head, sequence_length, num_heads, main_stream); +} + + +inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F16); + GGML_ASSERT(V->type == GGML_TYPE_F16); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); + GGML_ASSERT(K->backend == GGML_BACKEND_GPU); + GGML_ASSERT(V->backend == GGML_BACKEND_GPU); + GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; + ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + +#define NQPB 16 +#define NCPW 128 + + const int nqpb = NQPB; // queries per block + const int ncpw = NCPW; // cache values per warp (does not work for other values) + + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? + // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why + const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 2; + + dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); + dim3 block_dim(32, nwarps, 1); + + const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + + switch (Q->ne[0]) + { + case 16: + flash_attn_ext_f16<16, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 64: + flash_attn_ext_f16<64, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 80: + flash_attn_ext_f16<80, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 128: + flash_attn_ext_f16<128, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + break; + } +} + static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } @@ -10573,6 +11293,10 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st case GGML_OP_ARGSORT: func = ggml_cuda_argsort; break; + case GGML_OP_FLASH_ATTN: + break; + case GGML_OP_FLASH_ATTN_EXT: + break; default: return false; } @@ -10587,7 +11311,13 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return true; } - func(tensor->src[0], tensor->src[1], tensor); + if(tensor->op == GGML_OP_FLASH_ATTN) { + ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) { + ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } else { + func(tensor->src[0], tensor->src[1], tensor); + } return true; } @@ -11403,6 +12133,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; default: return false; diff --git a/llama.cpp b/llama.cpp index fe2583966..2330efff5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6881,7 +6881,8 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + // note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext) + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(128, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); //kv_self.n = llama_kv_cache_cell_max(kv_self); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); From 56e45a239e1d5a871009aa162b7ba99c93c40b62 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 20:16:32 +0200 Subject: [PATCH 35/57] metal : optimize softmax for C > 32 --- ggml-metal.metal | 16 +++++++++++----- tests/test-backend-ops.cpp | 9 +++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 04c1aaf9c..3d5d762d1 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2217,29 +2217,35 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); + smax = max(smax, s); + M[j] = max(M[j], s); } - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + smax = simd_max(smax); + M[j] = simd_max(M[j]); - S[j] = S[j]*ms; + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (tiisg == j) { ss[j*T + C + j] = ms; } + // local sum + half ls = 0.0h; + for (int64_t p = tiisg; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j] + simd_sum(vs); + ls += vs; // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } + + S[j] = S[j]*ms + simd_sum(ls); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b1b30b91c..2ab535406 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,9 +572,18 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; +#if 1 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } +#else + int n_nodes = gf->n_nodes; + for (int i = 1; i < n_runs; i++) { + for (int j = 0; j < n_nodes; j++) { + gf->nodes[gf->n_nodes++] = gf->nodes[j]; + } + } +#endif // calculate memory size_t mem = n_runs * op_size(out); From cda5a60a41c669f233a943e1182cd6625f61a924 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 20:53:29 +0200 Subject: [PATCH 36/57] metal : optimize softmax --- ggml-metal.m | 5 +++-- ggml-metal.metal | 34 +++++++++++++++++++--------------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e00069624..2bbb6d17a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2285,8 +2285,9 @@ static bool ggml_metal_graph_compute( const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) diff --git a/ggml-metal.metal b/ggml-metal.metal index 3d5d762d1..d9a536ae8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2188,6 +2188,8 @@ kernel void kernel_flash_attn_ext_f16( // online softmax if (C == 32) { + half ms[Q]; + for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; @@ -2197,20 +2199,22 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j]*ms + simd_sum(vs); - - // create a QxQ diagonal matrix for rescaling the output - if (p == j) { - ss[j*T + C + j] = ms; - } + S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; + } } else { + half ms[Q]; + for (int64_t j = 0; j < Q; ++j) { const half m = M[j]; @@ -2224,12 +2228,7 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(smax); M[j] = simd_max(M[j]); - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*T + C + j] = ms; - } + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); // local sum half ls = 0.0h; @@ -2245,7 +2244,12 @@ kernel void kernel_flash_attn_ext_f16( ss[j*T + p] = vs; } - S[j] = S[j]*ms + simd_sum(ls); + S[j] = S[j]*ms[j] + simd_sum(ls); + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; } } From c6769b942229a9e634965e6215651b8d4cf02403 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 21:24:26 +0200 Subject: [PATCH 37/57] tests : minor fix --- tests/test-backend-ops.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2ab535406..727f2ea06 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -578,6 +578,7 @@ struct test_case { } #else int n_nodes = gf->n_nodes; + n_runs = 1000; for (int i = 1; i < n_runs; i++) { for (int j = 0; j < n_nodes; j++) { gf->nodes[gf->n_nodes++] = gf->nodes[j]; From db1f3c482e256398330d44ad22b498ca2cd03161 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Feb 2024 22:08:37 +0200 Subject: [PATCH 38/57] cuda : avoid zeroing fragments --- ggml-cuda.cu | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 098b55e07..7130209e7 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6443,11 +6443,11 @@ static __global__ void flash_attn_ext_f16( const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; - const int D2 = D/2; + const int D2 = D/2; const int D16 = D/16; const int Q16 = Q/16; - const int NW = WARP_SIZE; - const int SH = (C + Q); // shared memory per simdgroup in (half) + const int NW = WARP_SIZE; + const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) @@ -6665,8 +6665,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); - nvcuda::wmma::fill_fragment(lo[j][i], 0.0); - nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); } // restore zeros @@ -6760,9 +6759,8 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); - nvcuda::wmma::mma_sync(t2, ms1, t, t2); + nvcuda::wmma::mma_sync(t2, ms1, t, zr); // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); From 12eaa22628740e388789081ccba93159c1b0b412 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Feb 2024 11:55:38 +0200 Subject: [PATCH 39/57] tests : update dims --- ggml-cuda.cu | 180 ++++++++++++++++++++++--------------- tests/test-backend-ops.cpp | 6 +- 2 files changed, 110 insertions(+), 76 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7130209e7..2c050c0c4 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6568,7 +6568,8 @@ static __global__ void flash_attn_ext_f16( for (int64_t j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; - if(mp) { + + if (mp) { nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); } @@ -10927,78 +10928,111 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); - switch (Q->ne[0]) - { - case 16: - flash_attn_ext_f16<16, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<64, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 80: - flash_attn_ext_f16<80, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 128: - flash_attn_ext_f16<128, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - default: - break; + switch (Q->ne[0]) { + case 64: + flash_attn_ext_f16<64, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 80: + flash_attn_ext_f16<80, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 96: + flash_attn_ext_f16<96, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 112: + flash_attn_ext_f16<112, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 128: + flash_attn_ext_f16<128, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 256: + flash_attn_ext_f16<256, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + break; } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 727f2ea06..9feb5e1fe 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,7 +572,7 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; -#if 1 +#if 0 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } @@ -2209,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); -#if 0 - for (int hs : { 64, 80, 96, 112, 128, 256, }) { +#if 1 + for (int hs : { 64, 80, 128, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From b68a112204c58e2bed334273753211c15acc48e2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Feb 2024 15:12:28 +0200 Subject: [PATCH 40/57] cuda : fix __hisinf() result check --- ggml-cuda.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2c050c0c4..0136fbf28 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6597,8 +6597,8 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(__hmax(smax, s)); M[j] = warp_reduce_max(__hmax(M[j], s)); - const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); - const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); S[j] = S[j]*ms + warp_reduce_sum(vs); @@ -6624,7 +6624,7 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(smax); M[j] = warp_reduce_max(M[j]); - const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { @@ -6637,7 +6637,7 @@ static __global__ void flash_attn_ext_f16( for (int64_t p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); ls += vs; @@ -6650,7 +6650,7 @@ static __global__ void flash_attn_ext_f16( } // skip -INF blocks - if (__hisinf(smax)) { + if (__hisinf(smax) == -1) { continue; } @@ -6735,8 +6735,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); - const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); + const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); S = S0*ms0 + S1*ms1; From b150abe83e6f0f8a0cf552d7fc0d8fe9f0f78569 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 13:17:47 +0200 Subject: [PATCH 41/57] cuda : avoid warp_reduce for smax --- ggml-cuda.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0136fbf28..c3f24242b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6621,7 +6621,6 @@ static __global__ void flash_attn_ext_f16( M[j] = __hmax(M[j], s); } - smax = warp_reduce_max(smax); M[j] = warp_reduce_max(M[j]); const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); @@ -6649,6 +6648,8 @@ static __global__ void flash_attn_ext_f16( } } + smax = warp_reduce_max(smax); + // skip -INF blocks if (__hisinf(smax) == -1) { continue; From 7c34655b366e14d43f7fc9fa104a9ca7b8f60580 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 13:39:46 +0200 Subject: [PATCH 42/57] cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) --- ggml-cuda.cu | 70 ++++++++++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c3f24242b..558ffb8ac 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,10 +6462,10 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory - for (int64_t j = warp_id; j < Q; j += num_warps) { + for (int j = warp_id; j < Q; j += num_warps) { const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int64_t i = lane_id; i < D2; i += NW) { + for (int i = lane_id; i < D2; i += NW) { if (iq1 + j < ne01) { sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { @@ -6477,15 +6477,15 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(zr, 0.0); // zero out lo - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::fill_fragment(lo[j][i], 0.0); } } // zero out shared memory SH - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = lane_id; i < SH; i += NW) { + for (int j = 0; j < Q; ++j) { + for (int i = lane_id; i < SH; i += NW) { ss[j*T + i] = 0.0; } } @@ -6526,8 +6526,8 @@ static __global__ void flash_attn_ext_f16( // load the queries from shared memory into local memory half16x16_a mq[Q16][D16]; - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); } } @@ -6544,28 +6544,28 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { + for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) { // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { half16x16_acc mqk[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::fill_fragment(mqk[j], 0); } const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { half16x16_bT mk; // transposed key nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); } } // mqk = mqk*scale + mask - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; @@ -6588,8 +6588,8 @@ static __global__ void flash_attn_ext_f16( // online softmax if (C == 32) { - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = lane_id; + for (int j = 0; j < Q; ++j) { + const int p = lane_id; const half m = M[j]; const half s = ss[j*T + p]; @@ -6611,10 +6611,10 @@ static __global__ void flash_attn_ext_f16( ss[j*T + p] = vs; } } else { - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int64_t p = lane_id; p < C; p += NW) { + for (int p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; smax = __hmax(smax, s); @@ -6633,7 +6633,7 @@ static __global__ void flash_attn_ext_f16( // local sum half ls = 0.0f; - for (int64_t p = lane_id; p < C; p += NW) { + for (int p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); @@ -6656,13 +6656,13 @@ static __global__ void flash_attn_ext_f16( } // O = diag(ms)*O - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mm; half16x16_b lob; nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); @@ -6680,17 +6680,17 @@ static __global__ void flash_attn_ext_f16( const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); half16x16_b mk[D16]; - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); } half16x16_a mv[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); } - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); } } @@ -6699,7 +6699,7 @@ static __global__ void flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { if (lane_id == 0) { ss[j*T + 0] = S[j]; ss[j*T + 1] = M[j]; @@ -6708,7 +6708,7 @@ static __global__ void flash_attn_ext_f16( } // reduce the warps sequentially - for (int64_t sg = 1; sg < num_warps; ++sg) { + for (int sg = 1; sg < num_warps; ++sg) { half S = __float2half(0.0f); half M = __float2half(-INFINITY); @@ -6716,8 +6716,8 @@ static __global__ void flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (warp_id == sg) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6727,7 +6727,7 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; @@ -6751,7 +6751,7 @@ static __global__ void flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a ms0; half16x16_a ms1; half16x16_b t; @@ -6760,7 +6760,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, zr); @@ -6776,8 +6776,8 @@ static __global__ void flash_attn_ext_f16( // store result to shared memory (reuse sq) if (warp_id == 0) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6785,10 +6785,10 @@ static __global__ void flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (warp_id == 0) { - for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = lane_id; i < D; i += NW) { + for (int i = lane_id; i < D; i += NW) { dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } From 1f8a5924823aecaa6ab1d5c2ac70ddde1d6c27d0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:01:32 +0200 Subject: [PATCH 43/57] cuda : make loops use the same loop values Thanks Johannes again for the tip --- ggml-cuda.cu | 43 +++++++++++++++++++++++++++++++------- tests/test-backend-ops.cpp | 2 +- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 558ffb8ac..a3a6c6455 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,10 +6462,20 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory - for (int j = warp_id; j < Q; j += num_warps) { + for (int j0 = 0; j0 < Q; j0 += num_warps) { + const int j = j0 + warp_id; + if (j >= Q) { + break; + } + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int i = lane_id; i < D2; i += NW) { + for (int i0 = 0; i0 < D2; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D2) { + break; + } + if (iq1 + j < ne01) { sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { @@ -6485,7 +6495,12 @@ static __global__ void flash_attn_ext_f16( // zero out shared memory SH for (int j = 0; j < Q; ++j) { - for (int i = lane_id; i < SH; i += NW) { + for (int i0 = 0; i0 < SH; i0 += NW) { + const int i = i0 + lane_id; + if (i >= SH) { + break; + } + ss[j*T + i] = 0.0; } } @@ -6544,7 +6559,12 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) { + for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { + const int ic = ic0 + warp_id*C; + if (ic >= ne11) { + break; + } + // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { @@ -6614,7 +6634,9 @@ static __global__ void flash_attn_ext_f16( for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int p = lane_id; p < C; p += NW) { + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; + const half s = ss[j*T + p]; smax = __hmax(smax, s); @@ -6633,7 +6655,9 @@ static __global__ void flash_attn_ext_f16( // local sum half ls = 0.0f; - for (int p = lane_id; p < C; p += NW) { + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; + const half s = ss[j*T + p]; const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); @@ -6788,7 +6812,12 @@ static __global__ void flash_attn_ext_f16( for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int i = lane_id; i < D; i += NW) { + for (int i0 = 0; i0 < D; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D) { + break; + } + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9feb5e1fe..e4076b49c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2210,7 +2210,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_leaky_relu()); #if 1 - for (int hs : { 64, 80, 128, }) { + for (int hs : { 128, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { From 92472ea22ca3eed7f65114b0e6b7de1585930759 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:10:01 +0200 Subject: [PATCH 44/57] cuda : unroll some of the loops --- ggml-cuda.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a3a6c6455..deda4cc70 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,6 +6462,7 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory +#pragma unroll for (int j0 = 0; j0 < Q; j0 += num_warps) { const int j = j0 + warp_id; if (j >= Q) { @@ -6470,6 +6471,7 @@ static __global__ void flash_attn_ext_f16( const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); +#pragma unroll for (int i0 = 0; i0 < D2; i0 += NW) { const int i = i0 + lane_id; if (i >= D2) { From c51f27c0dbd70fe8eda6182d61371d6a2dea6fb9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:27:36 +0200 Subject: [PATCH 45/57] cuda : avoid __hisinf branches --- ggml-cuda.cu | 89 +++++++++++++++++++--------------------------------- 1 file changed, 32 insertions(+), 57 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index deda4cc70..4d1fb008c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6513,9 +6513,9 @@ static __global__ void flash_attn_ext_f16( half S[Q]; half M[Q]; - for(int i = 0; i < Q; i++) { + for (int i = 0; i < Q; ++i) { S[i] = __float2half(0.0f); - M[i] = __float2half(-INFINITY); + M[i] = CUDART_MIN_DENORM_FP16; } // assume K and V are same shape @@ -6609,69 +6609,44 @@ static __global__ void flash_attn_ext_f16( half smax = __float2half(-INFINITY); // online softmax - if (C == 32) { - for (int j = 0; j < Q; ++j) { - const int p = lane_id; + for (int j = 0; j < Q; ++j) { + const half m = M[j]; + + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; - const half m = M[j]; const half s = ss[j*T + p]; - smax = warp_reduce_max(__hmax(smax, s)); - M[j] = warp_reduce_max(__hmax(M[j], s)); + smax = __hmax(smax, s); + M[j] = __hmax(M[j], s); + } - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); + M[j] = warp_reduce_max(M[j]); - S[j] = S[j]*ms + warp_reduce_sum(vs); + const half ms = hexp(m - M[j]); - // create a QxQ diagonal matrix for rescaling the output - if (p == j) { - ss[j*T + C + j] = ms; - } + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + } + + // local sum + half ls = 0.0f; + + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; + + const half s = ss[j*T + p]; + + const half vs = hexp(s - M[j]); + + ls += vs; // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } - } else { - for (int j = 0; j < Q; ++j) { - const half m = M[j]; - for (int p0 = 0; p0 < C; p0 += NW) { - const int p = p0 + lane_id; - - const half s = ss[j*T + p]; - - smax = __hmax(smax, s); - M[j] = __hmax(M[j], s); - } - - M[j] = warp_reduce_max(M[j]); - - const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - } - - // local sum - half ls = 0.0f; - - for (int p0 = 0; p0 < C; p0 += NW) { - const int p = p0 + lane_id; - - const half s = ss[j*T + p]; - - const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); - - ls += vs; - - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; - } - - S[j] = S[j]*ms + warp_reduce_sum(ls); - } + S[j] = S[j]*ms + warp_reduce_sum(ls); } smax = warp_reduce_max(smax); @@ -6736,7 +6711,7 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int sg = 1; sg < num_warps; ++sg) { half S = __float2half(0.0f); - half M = __float2half(-INFINITY); + half M = CUDART_MIN_DENORM_FP16; __syncthreads(); @@ -6762,8 +6737,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); - const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); + const half ms0 = hexp(M0 - M); + const half ms1 = hexp(M1 - M); S = S0*ms0 + S1*ms1; From b958151e3f66e17a9bc5131e446a50c5529b4b81 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:00:25 +0200 Subject: [PATCH 46/57] cuda : use half2 in softmax --- ggml-cuda.cu | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4d1fb008c..1fed9d23e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6451,12 +6451,14 @@ static __global__ void flash_attn_ext_f16( const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) + const int C2 = C/2; extern __shared__ half __flash_attn_f16_shmem[]; // pq half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 half16x16_acc zr; half16x16_acc lo[Q16][D16]; @@ -6606,19 +6608,19 @@ static __global__ void flash_attn_ext_f16( } // used to detect blocks full of -INF - half smax = __float2half(-INFINITY); + half2 smax = make_half2(-INFINITY, -INFINITY); // online softmax for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int p0 = 0; p0 < C; p0 += NW) { + for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; - const half s = ss[j*T + p]; + const half2 s = ss2[j*T2 + p]; - smax = __hmax(smax, s); - M[j] = __hmax(M[j], s); + smax = __hmax2(smax, s); + M[j] = __hmax(M[j], __hmax(s.x, s.y)); } M[j] = warp_reduce_max(M[j]); @@ -6631,28 +6633,31 @@ static __global__ void flash_attn_ext_f16( } // local sum - half ls = 0.0f; + half2 ls = make_half2(0.0f, 0.0f); + half2 M2 = make_half2(M[j], M[j]); - for (int p0 = 0; p0 < C; p0 += NW) { + for (int p0 = 0; p0 < C2; p0 += NW) { const int p = p0 + lane_id; - const half s = ss[j*T + p]; + const half2 s = ss2[j*T2 + p]; - const half vs = hexp(s - M[j]); + const half2 vs = h2exp(s - M2); ls += vs; // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + ss2[j*T2 + p] = vs; } - S[j] = S[j]*ms + warp_reduce_sum(ls); + ls = warp_reduce_sum(ls); + + S[j] = S[j]*ms + ls.x + ls.y; } smax = warp_reduce_max(smax); // skip -INF blocks - if (__hisinf(smax) == -1) { + if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) { continue; } From a7b471569bdf4e09e97b2d02c27989b8cb801861 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:17:49 +0200 Subject: [PATCH 47/57] cuda : switch to 1 warp for bs > 16 --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1fed9d23e..c98b551b3 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -10933,7 +10933,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 2; + const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); From 3b1c4e76739031bee3028748e0cd288c148f77b4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:36:05 +0200 Subject: [PATCH 48/57] cuda : speed-up reduce part of the kernel --- ggml-cuda.cu | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c98b551b3..67541a61e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6715,9 +6715,6 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int sg = 1; sg < num_warps; ++sg) { - half S = __float2half(0.0f); - half M = CUDART_MIN_DENORM_FP16; - __syncthreads(); // each simdgroup stores its output to shared memory, reusing sq @@ -6733,27 +6730,25 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { - for (int j = 0; j < Q; ++j) { + for (int j = lane_id; j < Q; j += NW) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; const half M0 = ss[j*T + 1]; const half M1 = ss[j*T + sg*SH + 1]; - M = __hmax(M0, M1); + const half M = __hmax(M0, M1); const half ms0 = hexp(M0 - M); const half ms1 = hexp(M1 - M); - S = S0*ms0 + S1*ms1; + const half S = S0*ms0 + S1*ms1; - if (lane_id == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; + ss[j*T + 0] = S; + ss[j*T + 1] = M; - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; - } + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 @@ -10931,6 +10926,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nqpb = NQPB; // queries per block const int ncpw = NCPW; // cache values per warp (does not work for other values) + GGML_ASSERT(NQPB <= 32); + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; From 5b263dd83a5f906eddd10bc044051d7571097043 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 16:12:20 +0200 Subject: [PATCH 49/57] cuda : unroll Q*K^T loop --- ggml-cuda.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 67541a61e..dbd482239 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6571,6 +6571,7 @@ static __global__ void flash_attn_ext_f16( // Q*K^T { +#pragma unroll for (int cc = 0; cc < C/16; ++cc) { half16x16_acc mqk[Q16]; for (int j = 0; j < Q16; ++j) { From e04ff391819e1875beed3e30d9e7592db45e0e62 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 16:57:46 +0200 Subject: [PATCH 50/57] cuda : fix -INF block check --- ggml-cuda.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dbd482239..e51ddc08f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6658,7 +6658,7 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(smax); // skip -INF blocks - if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) { + if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) { continue; } @@ -6676,8 +6676,10 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); } + } - // restore zeros + // restore zeros + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); } From cfd9732b2e45a442f4f7261ac0b50ec6e0862ab2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 18:31:55 +0200 Subject: [PATCH 51/57] cuda : simplify softmax --- ggml-cuda.cu | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e51ddc08f..25f810cbe 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6512,11 +6512,10 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); { - half S[Q]; + half S = __float2half(0.0f); half M[Q]; for (int i = 0; i < Q; ++i) { - S[i] = __float2half(0.0f); M[i] = CUDART_MIN_DENORM_FP16; } @@ -6626,13 +6625,6 @@ static __global__ void flash_attn_ext_f16( M[j] = warp_reduce_max(M[j]); - const half ms = hexp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; - } - // local sum half2 ls = make_half2(0.0f, 0.0f); half2 M2 = make_half2(M[j], M[j]); @@ -6652,7 +6644,14 @@ static __global__ void flash_attn_ext_f16( ls = warp_reduce_sum(ls); - S[j] = S[j]*ms + ls.x + ls.y; + const half ms = hexp(m - M[j]); + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + + S = S*ms + ls.x + ls.y; + } } smax = warp_reduce_max(smax); @@ -6709,8 +6708,8 @@ static __global__ void flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int j = 0; j < Q; ++j) { - if (lane_id == 0) { - ss[j*T + 0] = S[j]; + if (lane_id == j) { + ss[j*T + 0] = S; ss[j*T + 1] = M[j]; } } From ef68fac2a8b51e2237234e3d7c6120cade457ce8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 18:36:58 +0200 Subject: [PATCH 52/57] cuda : fix matrix names --- ggml-cuda.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 25f810cbe..d9ab2bd09 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6687,19 +6687,19 @@ static __global__ void flash_attn_ext_f16( for (int cc = 0; cc < C/16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); - half16x16_b mk[D16]; + half16x16_b mv[D16]; for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); + nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); } - half16x16_a mv[Q16]; + half16x16_a ms[Q16]; for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); + nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); } for (int j = 0; j < Q16; ++j) { for (int i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); + nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); } } } From 1846e92a904ef17a55a3f7e7c2e837f35db2ce7d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 4 Feb 2024 09:57:58 +0200 Subject: [PATCH 53/57] cuda : minor --- ggml-cuda.cu | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d9ab2bd09..713a6a89a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6399,10 +6399,10 @@ static __global__ void flash_attn_f32( } #if __CUDA_ARCH__ >= CC_VOLTA -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_bT; -typedef nvcuda::wmma::fragment half16x16_acc; +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_acc; #endif // based on metal version @@ -6443,15 +6443,17 @@ static __global__ void flash_attn_ext_f16( const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; - const int D2 = D/2; const int D16 = D/16; const int Q16 = Q/16; + const int C16 = C/16; + const int NW = WARP_SIZE; const int SH = (C + Q); // shared memory per simdgroup in (half) const int T = D + num_warps*SH; // shared memory size per query in (half) const int T2 = T/2; // shared memory size per query in (half2) const int C2 = C/2; + const int D2 = D/2; extern __shared__ half __flash_attn_f16_shmem[]; // pq @@ -6571,7 +6573,7 @@ static __global__ void flash_attn_ext_f16( // Q*K^T { #pragma unroll - for (int cc = 0; cc < C/16; ++cc) { + for (int cc = 0; cc < C16; ++cc) { half16x16_acc mqk[Q16]; for (int j = 0; j < Q16; ++j) { nvcuda::wmma::fill_fragment(mqk[j], 0); @@ -6684,7 +6686,7 @@ static __global__ void flash_attn_ext_f16( // O = O + (Q*K^T)*V { - for (int cc = 0; cc < C/16; ++cc) { + for (int cc = 0; cc < C16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); half16x16_b mv[D16]; @@ -6707,11 +6709,9 @@ static __global__ void flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int j = 0; j < Q; ++j) { - if (lane_id == j) { - ss[j*T + 0] = S; - ss[j*T + 1] = M[j]; - } + if (lane_id < Q) { + ss[lane_id*T + 0] = S; + ss[lane_id*T + 1] = M[lane_id]; } } @@ -10939,6 +10939,10 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + // increase shared memory limit to 96KB + //const size_t shmem_max = 96*1024; + //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + switch (Q->ne[0]) { case 64: flash_attn_ext_f16<64, NQPB, NCPW> @@ -11045,6 +11049,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * default: break; } + + CUDA_CHECK(cudaGetLastError()); } static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { From f249c997a8b3f9b129fe825bebd609a362e9ab9c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 19 Feb 2024 13:10:24 +0200 Subject: [PATCH 54/57] llama : adapt to F16 KQ_pos --- ggml-cuda.cu | 2 +- ggml.c | 2 +- llama.cpp | 15 ++++++++++----- tests/test-backend-ops.cpp | 2 +- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2c8af51a6..5c6159a83 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6232,7 +6232,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? __half2float(slope*pos[col]) : 0.0f); + const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); diff --git a/ggml.c b/ggml.c index efc570db6..9a2ae6264 100644 --- a/ggml.c +++ b/ggml.c @@ -5192,7 +5192,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } if (pos) { diff --git a/llama.cpp b/llama.cpp index 2359ed10a..5aa3a508d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -102,7 +102,7 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 8 -#define LLAMA_FLASH_ATTN +//#define LLAMA_FLASH_ATTN // // logging @@ -4831,6 +4831,11 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * cur; #if defined(LLAMA_FLASH_ATTN) + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); + + GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); + // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -5260,7 +5265,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); // shift the entire K-cache if needed @@ -5804,7 +5809,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); for (int il = 0; il < n_layer; ++il) { @@ -6043,7 +6048,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); inpL = llm_build_norm(ctx0, inpL, hparams, @@ -6140,7 +6145,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0); + struct ggml_tensor * KQ_pos = ggml_cast(ctx0, ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0), GGML_TYPE_F16); cb(KQ_pos, "KQ_pos", -1); for (int il = 0; il < n_layer; ++il) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 912223def..278c57299 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1505,7 +1505,7 @@ struct test_attn : public test_case { struct ggml_tensor * cur; cur = ggml_mul_mat (ctx, k, q); - cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs)); + cur = ggml_soft_max_ext(ctx, cur, mask, nullptr, 1.0f/sqrtf(hs), 0.0f); cur = ggml_mul_mat (ctx, v, cur); cur = ggml_permute (ctx, cur, 0, 2, 1, 3); cur = ggml_cont_2d (ctx, cur, hs*nh, nb); From 6aefd11204199c9bd520b8991bab4085cb6fc977 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Mar 2024 13:50:54 +0200 Subject: [PATCH 55/57] llama : adapt new models to F16 KQ_mask --- llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 1a099adcb..f2b224caf 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7362,7 +7362,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -7489,7 +7489,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { @@ -7724,7 +7724,7 @@ struct llm_build_context { cb(inp_pos, "inp_pos", -1); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16); cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { From 58c7f6167c0f6540f0da0386fc65d940e1a16ea5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Mar 2024 20:44:57 +0200 Subject: [PATCH 56/57] ggml : fix F16 store (ARM NEON) --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 09ed9da34..5715b78ec 100644 --- a/ggml.c +++ b/ggml.c @@ -874,7 +874,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F16x8_FMA #define GGML_F16_VEC_ADD GGML_F16x8_ADD #define GGML_F16_VEC_MUL GGML_F16x8_MUL @@ -900,7 +900,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL From 3a468e6f9f0c7dff9ed78b0f7a5af069da420606 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Mar 2024 17:12:17 +0200 Subject: [PATCH 57/57] llama : fix type of KQ_mask and KQ_pos --- llama.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 77da94960..b80080daf 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5810,20 +5810,20 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { if (causal) { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } else { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); - return lctx.inp_KQ_mask; + return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16); } struct ggml_tensor * build_inp_KQ_pos() { - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F16, n_kv); + lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); cb(lctx.inp_KQ_pos, "KQ_pos", -1); ggml_set_input(lctx.inp_KQ_pos); - return lctx.inp_KQ_pos; + return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16); } struct ggml_tensor * build_inp_mean() {