llama : adapt to F16 KQ_pos
This commit is contained in:
parent
31109ca00a
commit
f249c997a8
4 changed files with 13 additions and 8 deletions
|
@ -6232,7 +6232,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha
|
|||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? __half2float(slope*pos[col]) : 0.0f);
|
||||
const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f);
|
||||
|
||||
vals[col] = val;
|
||||
max_val = max(max_val, val);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue