ggml : in ggml_ssm_scan, use a threshold for soft_plus

This is how the official Mamba implementation does it,
and it's also what torch.nn.Softplus does.
This commit is contained in:
Francis Couture-Harpin 2024-03-02 21:39:28 -05:00
parent 1af1000f10
commit d52dd501f0

3
ggml.c
View file

@ -14904,7 +14904,8 @@ static void ggml_compute_forward_ssm_scan_f32(
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
float dt_soft_plus = log1pf(expf(dt[i1]));
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
float x_dt = x[i1] * dt_soft_plus;
float sumf = 0.0f;
// d_state