diff --git a/ggml-metal.m b/ggml-metal.m index 7b161c69d..7b6762e6d 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2177,7 +2177,7 @@ static bool ggml_metal_graph_compute( case GGML_OP_FLASH_ATTN_EXT: { GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F32); struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; @@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&scale length:sizeof( float) atIndex:27]; // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9ab9e16c3..c9e4dcfe9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2054,8 +2054,9 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t i = 0; i < L4; ++i) { // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); if (iq1 + j < ne01) { - pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; + pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg]; } else { pq4[j*T4 + N4*i + tiisg] = 0.0h; } diff --git a/ggml.c b/ggml.c index 10df03c9c..5e515c03f 100644 --- a/ggml.c +++ b/ggml.c @@ -4178,6 +4178,8 @@ struct ggml_tensor * ggml_mul_mat( void ggml_mul_mat_set_prec( struct ggml_tensor * a, enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + const int32_t prec_i32 = (int32_t) prec; ggml_set_op_params_i32(a, 0, prec_i32); @@ -5781,6 +5783,16 @@ struct ggml_tensor * ggml_flash_attn_ext( return result; } +void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -13347,7 +13359,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(ne2 == N); GGML_ASSERT(P >= 0); - GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); @@ -13408,6 +13420,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( float M = -INFINITY; float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); memset(V16, 0, D*sizeof(ggml_fp16_t)); @@ -13433,10 +13446,19 @@ static void ggml_compute_forward_flash_attn_ext_f16( float s; + // convert Q to F16 in V32 + { + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + + for (int64_t d = 0; d < D; ++d) { + Q16[d] = GGML_FP32_TO_FP16(pq[d]); + } + } + ggml_vec_dot_f16(D, &s, (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + Q16); s = s*scale + mv; @@ -13488,13 +13510,14 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (q->type) { - case GGML_TYPE_F16: + switch (dst->op_params[1]) { + case GGML_PREC_DEFAULT: { ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); } break; default: { + // TODO: implement F32 precision GGML_ASSERT(false); } break; } diff --git a/ggml.h b/ggml.h index 7bca02f2a..e2f74412f 100644 --- a/ggml.h +++ b/ggml.h @@ -1633,6 +1633,10 @@ extern "C" { struct ggml_tensor * mask, float scale); + GGML_API void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 4e6c9f9cc..550caced4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4368,7 +4368,8 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 41ddfcca5..db1244876 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1386,26 +1386,24 @@ struct test_leaky_relu : public test_case { // GGML_OP_FLASH_ATTN_EXT struct test_flash_attn_ext : public test_case { - const ggml_type typeq; const int64_t hs; // head size const int64_t nh; // num heads const int64_t kv; // kv size const int64_t nb; // batch size std::string vars() override { - return VARS_TO_STR5(typeq, hs, nh, kv, nb); + return VARS_TO_STR4(hs, nh, kv, nb); } double max_nmse_err() override { return 5e-5; } - test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, - int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) - : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); @@ -1680,9 +1678,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8)); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7)); - test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer