llama : adapt to F16 KQ_pos

This commit is contained in:
Georgi Gerganov 2024-02-19 13:10:24 +02:00
parent 31109ca00a
commit f249c997a8
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 13 additions and 8 deletions

View file

@ -6232,7 +6232,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? __half2float(slope*pos[col]) : 0.0f);
const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);