llama : support save/load state with FA enabled

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-25 18:18:13 +03:00
parent cb3547ac46
commit 1fd5bc3d5e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 12 additions and 7 deletions

View file

@ -518,6 +518,7 @@ function gg_run_open_llama_7b_v2 {
(time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log (time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
(time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state --model -fa ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
function check_ppl { function check_ppl {
qnt="$1" qnt="$1"

View file

@ -2036,8 +2036,8 @@ struct llama_kv_cache {
bool has_shift = false; bool has_shift = false;
bool do_defrag = false; bool do_defrag = false;
bool do_copy = false; bool do_copy = false;
// with recurrent state models, a cell can hold the state for more than one past token bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool recurrent = false; bool v_trans = true; // the value tensor is transposed
// Note: The value of head isn't only used to optimize searching // Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it // for a free KV slot. llama_decode_internal also uses it, so it
@ -2335,11 +2335,14 @@ struct llama_context {
static bool llama_kv_cache_init( static bool llama_kv_cache_init(
struct llama_kv_cache & cache, struct llama_kv_cache & cache,
const llama_model & model, const llama_context * ctx,
ggml_type type_k, ggml_type type_k,
ggml_type type_v, ggml_type type_v,
uint32_t kv_size, uint32_t kv_size,
bool offload) { bool offload) {
const llama_model & model = ctx->model;
const llama_cparams & cparams = ctx->cparams;
const struct llama_hparams & hparams = model.hparams; const struct llama_hparams & hparams = model.hparams;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
@ -2350,6 +2353,7 @@ static bool llama_kv_cache_init(
// TODO: find a nicer way to add other recurrent model architectures // TODO: find a nicer way to add other recurrent model architectures
cache.recurrent = model.arch == LLM_ARCH_MAMBA; cache.recurrent = model.arch == LLM_ARCH_MAMBA;
cache.v_trans = !cparams.flash_attn;
// TODO: support mixed reccurent Transformer architectues // TODO: support mixed reccurent Transformer architectues
// NOTE: (!a || b) is a logical implication (a -> b) // NOTE: (!a || b) is a logical implication (a -> b)
@ -15550,7 +15554,7 @@ struct llama_context * llama_new_context_with_model(
} }
ctx->backends.push_back(ctx->backend_cpu); ctx->backends.push_back(ctx->backend_cpu);
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) { if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx); llama_free(ctx);
return nullptr; return nullptr;
@ -16330,7 +16334,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size()); ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
data_ctx->write(tmp_buf.data(), tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size());
if (kv_self.recurrent) { if (kv_self.recurrent || !kv_self.v_trans) {
// v is contiguous for recurrent models // v is contiguous for recurrent models
// TODO: use other tensors for state models than k and v // TODO: use other tensors for state models than k and v
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
@ -16486,7 +16490,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
inp += k_size; inp += k_size;
if (kv_self.recurrent) { if (kv_self.recurrent || !kv_self.v_trans) {
// v is contiguous for recurrent models // v is contiguous for recurrent models
// TODO: use other tensors for state models than k and v // TODO: use other tensors for state models than k and v
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);