mamba : fix self-overlapping view depth stride

This commit is contained in:
Francis Couture-Harpin 2024-01-31 08:47:53 -05:00
parent e9cc45ecae
commit 81b57bb375

View file

@ -7960,7 +7960,7 @@ struct llm_build_context {
const size_t conv_x_nb1 = (d_conv - 1 + n_tok) * ggml_element_size(conv_x); const size_t conv_x_nb1 = (d_conv - 1 + n_tok) * ggml_element_size(conv_x);
conv_x = ggml_set_2d(ctx0, conv_x, conv_state, conv_x_nb1, 0); conv_x = ggml_set_2d(ctx0, conv_x, conv_state, conv_x_nb1, 0);
// unfortunately, making x contiguous is necessary because ggml_set expects nb0 == sizeof(float) // making x contiguous is necessary because ggml_set expects it
conv_x = ggml_set_2d(ctx0, conv_x, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_x_nb1, (d_conv - 1)*ggml_element_size(conv_x)); conv_x = ggml_set_2d(ctx0, conv_x, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_x_nb1, (d_conv - 1)*ggml_element_size(conv_x));
// store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state // store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state
@ -7969,9 +7969,10 @@ struct llm_build_context {
ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tok*ggml_element_size(conv_x)), ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tok*ggml_element_size(conv_x)),
ggml_view_tensor(ctx0, kv_self.k_l[il]))); ggml_view_tensor(ctx0, kv_self.k_l[il])));
// prepare convolution for all tokens in the batch with a self-overlapping view // prepare convolution for all tokens in the batch with a self-overlapping view,
// shifting by one column each ... depth? ... with a window of d_conv columns.
// {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok} // {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok}
conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, -(d_conv - 1)*d_inner*ggml_element_size(conv_x), 0); conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, 1*ggml_element_size(conv_x), 0);
// perform convolution // perform convolution
// => {1, d_inner, n_tok} // => {1, d_inner, n_tok}