llama : avoid ggml_cast, use F32 query
This commit is contained in:
parent
40ea8cd1ac
commit
f9ca5dcbe8
6 changed files with 44 additions and 17 deletions
|
@ -2054,8 +2054,9 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
for (int64_t i = 0; i < L4; ++i) {
|
||||
// load heads from Q to shared memory
|
||||
for (int64_t j = sgitg; j < Q; j += nsg) {
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||
if (iq1 + j < ne01) {
|
||||
pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg];
|
||||
pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg];
|
||||
} else {
|
||||
pq4[j*T4 + N4*i + tiisg] = 0.0h;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue