Fix failed assertions while running Falcon Mamba
This commit is contained in:
parent
061e520075
commit
fae826fb56
2 changed files with 5 additions and 5 deletions
|
@ -153,9 +153,9 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
|
|||
}
|
||||
|
||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0 || ncols < WARP_SIZE);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_dims(min(ncols, WARP_SIZE), 1, 1);
|
||||
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
|
|
|
@ -9119,9 +9119,9 @@ static struct ggml_tensor * llm_build_mamba(
|
|||
|
||||
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
|
||||
if (ssm_dt_b_c_rms) {
|
||||
dt = ggml_rms_norm(ctx, dt, norm_rms_eps);
|
||||
B = ggml_rms_norm(ctx, B, norm_rms_eps);
|
||||
C = ggml_rms_norm(ctx, C, norm_rms_eps);
|
||||
dt = ggml_rms_norm(ctx, ggml_cont(ctx, dt), norm_rms_eps);
|
||||
B = ggml_rms_norm(ctx, ggml_cont(ctx, B), norm_rms_eps);
|
||||
C = ggml_rms_norm(ctx, ggml_cont(ctx, C), norm_rms_eps);
|
||||
}
|
||||
|
||||
// {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue