diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 34a8e0f66..9a5414787 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -15198,10 +15198,10 @@ static void ggml_compute_forward_flash_attn_ext_f16( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - enum ggml_type const kq_vec_dot_type = type_traits[k->type].vec_dot_type; - ggml_from_float_t const kq_from_float = type_traits[kq_vec_dot_type].from_float; - ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; - ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type; + ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float; + ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; + ggml_to_float_t const v_to_float = type_traits[v->type].to_float; // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { @@ -15238,7 +15238,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iv2 = iq2 / rv2; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - kq_from_float(pq, Q_q, D); + q_to_vec_dot(pq, Q_q, D); // online softmax / attention // loop over n_kv and n_head_kv