From d52dd501f084ce9a8c1d883dedcf14eeda0ac5a4 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 2 Mar 2024 21:39:28 -0500 Subject: [PATCH] ggml : in ggml_ssm_scan, use a threshold for soft_plus This is how the official Mamba implementation does it, and it's also what torch.nn.Softplus does. --- ggml.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 981a2302a..9b5d0302b 100644 --- a/ggml.c +++ b/ggml.c @@ -14904,7 +14904,8 @@ static void ggml_compute_forward_ssm_scan_f32( // d_inner for (int i1 = 0; i1 < ir; ++i1) { - float dt_soft_plus = log1pf(expf(dt[i1])); + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; float x_dt = x[i1] * dt_soft_plus; float sumf = 0.0f; // d_state