llama : avoid ggml_cast, use F32 query

This commit is contained in:
Georgi Gerganov 2024-01-25 17:46:07 +02:00
parent 40ea8cd1ac
commit f9ca5dcbe8
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 44 additions and 17 deletions

View file

@ -2054,8 +2054,9 @@ kernel void kernel_flash_attn_ext_f16(
for (int64_t i = 0; i < L4; ++i) {
// load heads from Q to shared memory
for (int64_t j = sgitg; j < Q; j += nsg) {
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
if (iq1 + j < ne01) {
pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg];
pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg];
} else {
pq4[j*T4 + N4*i + tiisg] = 0.0h;
}