mamba : support state saving and restoring

This commit is contained in:
Francis Couture-Harpin 2024-03-03 13:55:48 -05:00
parent b83fbc9287
commit eefb794bd7

View file

@ -13344,8 +13344,8 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
const size_t kv_buf_size = kv_self.total_size();
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
@ -13366,6 +13366,17 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
data_ctx->write(tmp_buf.data(), tmp_buf.size());
if (kv_self.recurrent) {
// v is contiguous for recurrent models
// TODO: use other tensors for state models than k and v
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
tmp_buf.resize(v_size);
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size());
data_ctx->write(tmp_buf.data(), tmp_buf.size());
continue;
}
// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
@ -13456,8 +13467,8 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
size_t kv_buf_size;
uint32_t kv_head;
@ -13478,6 +13489,16 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
inp += k_size;
if (kv_self.recurrent) {
// v is contiguous for recurrent models
// TODO: use other tensors for state models than k and v
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size);
inp += v_size;
continue;
}
// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);