diff --git a/ggml.c b/ggml.c index 981a2302a..9b5d0302b 100644 --- a/ggml.c +++ b/ggml.c @@ -14904,7 +14904,8 @@ static void ggml_compute_forward_ssm_scan_f32( // d_inner for (int i1 = 0; i1 < ir; ++i1) { - float dt_soft_plus = log1pf(expf(dt[i1])); + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; float x_dt = x[i1] * dt_soft_plus; float sumf = 0.0f; // d_state