ggml : in ggml_ssm_scan, use a threshold for soft_plus
This is how the official Mamba implementation does it, and it's also what torch.nn.Softplus does.
This commit is contained in:
parent
1af1000f10
commit
d52dd501f0
1 changed files with 2 additions and 1 deletions
3
ggml.c
3
ggml.c
|
@ -14904,7 +14904,8 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
|
||||
// d_inner
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
float dt_soft_plus = log1pf(expf(dt[i1]));
|
||||
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
||||
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
||||
float x_dt = x[i1] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
// d_state
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue