diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 133e219f0..f2d643e4e 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -153,9 +153,9 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou } static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); + GGML_ASSERT(ncols % WARP_SIZE == 0 || ncols < WARP_SIZE); if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_dims(min(ncols, WARP_SIZE), 1, 1); rms_norm_f32<<>>(x, dst, ncols, eps); } else { const dim3 block_dims(1024, 1, 1); diff --git a/src/llama.cpp b/src/llama.cpp index aeea54cff..b1bcbbbcf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9119,9 +9119,9 @@ static struct ggml_tensor * llm_build_mamba( // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers if (ssm_dt_b_c_rms) { - dt = ggml_rms_norm(ctx, dt, norm_rms_eps); - B = ggml_rms_norm(ctx, B, norm_rms_eps); - C = ggml_rms_norm(ctx, C, norm_rms_eps); + dt = ggml_rms_norm(ctx, ggml_cont(ctx, dt), norm_rms_eps); + B = ggml_rms_norm(ctx, ggml_cont(ctx, B), norm_rms_eps); + C = ggml_rms_norm(ctx, ggml_cont(ctx, C), norm_rms_eps); } // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}