llama : fix copy-paste errors, add TODO

This commit is contained in:
Georgi Gerganov 2024-04-25 19:45:36 +03:00
parent c225609f10
commit bab346ba69
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -16781,13 +16781,14 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
} }
} }
// TODO: simplify, reduce copy-paste
if (!kv_self.v_trans) { if (!kv_self.v_trans) {
for (int il = 0; il < (int)n_layer; ++il) { for (int il = 0; il < (int)n_layer; ++il) {
// Write key type // Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
data_ctx.write(&v_type_i, sizeof(v_type_i)); data_ctx.write(&v_type_i, sizeof(v_type_i));
// Write row size of key // Write row size of value
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
data_ctx.write(&v_size_row, sizeof(v_size_row)); data_ctx.write(&v_size_row, sizeof(v_size_row));
@ -16947,32 +16948,33 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
} }
} }
// TODO: simplify, reduce copy-paste
if (!kv_self.v_trans) { if (!kv_self.v_trans) {
for (int il = 0; il < (int)n_layer; ++il) { for (int il = 0; il < (int)n_layer; ++il) {
// Read type of key // Read type of value
int32_t v_type_i_ref; int32_t v_type_i_ref;
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
inp += sizeof(v_type_i_ref); inp += sizeof(v_type_i_ref);
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) { if (v_type_i != v_type_i_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return 0; return 0;
} }
// Read row size of key // Read row size of value
size_t v_size_row_ref; size_t v_size_row_ref;
memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref)); memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
inp += sizeof(v_size_row_ref); inp += sizeof(v_size_row_ref);
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
if (v_size_row != v_size_row_ref) { if (v_size_row != v_size_row_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
return 0; return 0;
} }
if (cell_count) { if (cell_count) {
// Read and set the keys for the whole cell range // Read and set the values for the whole cell range
ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row); ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
inp += cell_count * v_size_row; inp += cell_count * v_size_row;
} }