Merge branch 'master' into gg/flash-attn
ggml-ci
This commit is contained in:
commit
a1616e9f72
82 changed files with 3896 additions and 1063 deletions
|
@ -38,7 +38,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
|
|||
extern __shared__ float data_soft_max_f32[];
|
||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||
// shared memory buffer to cache values between iterations:
|
||||
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
|
||||
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
|
||||
|
||||
float max_val = -INFINITY;
|
||||
|
||||
|
@ -50,8 +50,8 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
|
|||
break;
|
||||
}
|
||||
|
||||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
const int64_t ix = (int64_t)rowx*ncols + col;
|
||||
const int64_t iy = (int64_t)rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
|
||||
|
||||
|
@ -119,7 +119,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
|
|||
return;
|
||||
}
|
||||
|
||||
const int idst = rowx*ncols + col;
|
||||
const int64_t idst = (int64_t)rowx*ncols + col;
|
||||
dst[idst] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue