llama : some llama_kv_cell simplifications

This commit is contained in:
Georgi Gerganov 2024-02-25 10:59:52 +02:00
parent 032ff85706
commit 715a343343
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -1709,6 +1709,14 @@ struct llama_kv_cell {
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const llama_kv_cell & other) const {
return seq_id == other.seq_id;
}
};
// ring-buffer of cached KV data
@ -2101,7 +2109,7 @@ static bool llama_kv_cache_find_slot(
// find how many cells are currently in use
static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
for (uint32_t i = cache.size - 1; i > 0; --i) {
if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
if (cache.cells[i].pos >= 0 && !cache.cells[i].is_empty()) {
return i + 1;
}
}
@ -2137,7 +2145,7 @@ static void llama_kv_cache_seq_rm(
} else {
continue;
}
if (cache.cells[i].seq_id.empty()) {
if (cache.cells[i].is_empty()) {
// keep count of the number of used cells
if (cache.cells[i].pos >= 0) cache.used--;
@ -2206,10 +2214,14 @@ static void llama_kv_cache_seq_add(
cache.cells[i].delta += delta;
if (cache.cells[i].pos < 0) {
if (!cache.cells[i].seq_id.empty()) cache.used--;
if (!cache.cells[i].is_empty()) {
cache.used--;
}
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
if (new_head == cache.size) {
new_head = i;
}
}
}
}
@ -11618,8 +11630,7 @@ struct llama_context * llama_new_context_with_model(
}
ctx->backends.push_back(ctx->backend_cpu);
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v,
cparams.n_ctx, cparams.offload_kqv)) {
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
@ -12203,10 +12214,10 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const auto & hparams = ctx->model.hparams;
const auto & cparams = ctx->cparams;
const auto n_layer = hparams.n_layer;
const auto n_embd_k_gqa = hparams.n_embd_k_gqa();
const auto n_embd_v_gqa = hparams.n_embd_v_gqa();
const auto n_ctx = cparams.n_ctx;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const uint32_t n_ctx = cparams.n_ctx;
const size_t kv_buf_size = kv_self.total_size();
const uint32_t kv_head = kv_self.head;
@ -12221,14 +12232,16 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
if (kv_buf_size) {
std::vector<uint8_t> tmp_buf;
for (int il = 0; il < (int) n_layer; ++il) {
size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
tmp_buf.resize(k_size);
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
data_ctx->write(tmp_buf.data(), tmp_buf.size());
// v is not contiguous, copy row by row
size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
tmp_buf.resize(v_row_size);
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), ir*v_row_stride, tmp_buf.size());
@ -12315,10 +12328,10 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
const auto & hparams = ctx->model.hparams;
const auto & cparams = ctx->cparams;
const int n_layer = hparams.n_layer;
const int n_embd_k_gqa = hparams.n_embd_k_gqa();
const int n_embd_v_gqa = hparams.n_embd_v_gqa();
const int n_ctx = cparams.n_ctx;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const uint32_t n_ctx = cparams.n_ctx;
size_t kv_buf_size;
uint32_t kv_head;
@ -12334,13 +12347,15 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
GGML_ASSERT(kv_self.total_size() == kv_buf_size);
for (int il = 0; il < (int) n_layer; ++il) {
size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
inp += k_size;
// v is not contiguous, copy row by row
size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
inp += v_row_size;