mamba : support state saving and restoring
This commit is contained in:
parent
b83fbc9287
commit
eefb794bd7
1 changed files with 25 additions and 4 deletions
29
llama.cpp
29
llama.cpp
|
@ -13344,8 +13344,8 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||||
|
|
||||||
const size_t kv_buf_size = kv_self.total_size();
|
const size_t kv_buf_size = kv_self.total_size();
|
||||||
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
|
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
|
||||||
|
@ -13366,6 +13366,17 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
||||||
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
|
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
|
||||||
data_ctx->write(tmp_buf.data(), tmp_buf.size());
|
data_ctx->write(tmp_buf.data(), tmp_buf.size());
|
||||||
|
|
||||||
|
if (kv_self.recurrent) {
|
||||||
|
// v is contiguous for recurrent models
|
||||||
|
// TODO: use other tensors for state models than k and v
|
||||||
|
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
|
||||||
|
|
||||||
|
tmp_buf.resize(v_size);
|
||||||
|
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size());
|
||||||
|
data_ctx->write(tmp_buf.data(), tmp_buf.size());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// v is not contiguous, copy row by row
|
// v is not contiguous, copy row by row
|
||||||
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
|
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
|
||||||
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
|
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
|
||||||
|
@ -13456,8 +13467,8 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||||
|
|
||||||
size_t kv_buf_size;
|
size_t kv_buf_size;
|
||||||
uint32_t kv_head;
|
uint32_t kv_head;
|
||||||
|
@ -13478,6 +13489,16 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
|
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
|
||||||
inp += k_size;
|
inp += k_size;
|
||||||
|
|
||||||
|
if (kv_self.recurrent) {
|
||||||
|
// v is contiguous for recurrent models
|
||||||
|
// TODO: use other tensors for state models than k and v
|
||||||
|
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
|
||||||
|
|
||||||
|
ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size);
|
||||||
|
inp += v_size;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// v is not contiguous, copy row by row
|
// v is not contiguous, copy row by row
|
||||||
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
|
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
|
||||||
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
|
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue