Merge branch 'master' into gg/flash-attn

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-29 17:19:25 +03:00
commit a1616e9f72
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
82 changed files with 3896 additions and 1063 deletions

View file

@ -38,7 +38,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations:
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
float max_val = -INFINITY;
@ -50,8 +50,8 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
break;
}
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;
const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
@ -119,7 +119,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
return;
}
const int idst = rowx*ncols + col;
const int64_t idst = (int64_t)rowx*ncols + col;
dst[idst] = vals[col] * inv_sum;
}
}