make rms_norm_eps a parameter (#2374)
* make rms_norm_eps a parameter * add rms_norm_eps to command line * fix baby llama, test-grad0 * use scientific notation for eps param in the help ggml-ci
This commit is contained in:
parent
b3f138d058
commit
41c674161f
11 changed files with 89 additions and 56 deletions
13
ggml-cuda.cu
13
ggml-cuda.cu
|
@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
|||
}
|
||||
}
|
||||
|
||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
|
||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const float eps = 1e-6f;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
||||
|
@ -2122,10 +2120,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
|
|||
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
}
|
||||
|
||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
}
|
||||
|
||||
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
|
||||
|
@ -2876,8 +2874,11 @@ inline void ggml_cuda_op_rms_norm(
|
|||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t i01_diff = i01_high - i01_low;
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
// compute
|
||||
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
|
||||
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue