llama : fix copy-paste errors, add TODO
This commit is contained in:
parent
c225609f10
commit
bab346ba69
1 changed files with 9 additions and 7 deletions
16
llama.cpp
16
llama.cpp
|
@ -16781,13 +16781,14 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: simplify, reduce copy-paste
|
||||||
if (!kv_self.v_trans) {
|
if (!kv_self.v_trans) {
|
||||||
for (int il = 0; il < (int)n_layer; ++il) {
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
// Write key type
|
// Write value type
|
||||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||||
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
||||||
|
|
||||||
// Write row size of key
|
// Write row size of value
|
||||||
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
|
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
|
||||||
data_ctx.write(&v_size_row, sizeof(v_size_row));
|
data_ctx.write(&v_size_row, sizeof(v_size_row));
|
||||||
|
|
||||||
|
@ -16947,32 +16948,33 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: simplify, reduce copy-paste
|
||||||
if (!kv_self.v_trans) {
|
if (!kv_self.v_trans) {
|
||||||
for (int il = 0; il < (int)n_layer; ++il) {
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
// Read type of key
|
// Read type of value
|
||||||
int32_t v_type_i_ref;
|
int32_t v_type_i_ref;
|
||||||
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
||||||
inp += sizeof(v_type_i_ref);
|
inp += sizeof(v_type_i_ref);
|
||||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||||
if (v_type_i != v_type_i_ref) {
|
if (v_type_i != v_type_i_ref) {
|
||||||
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||||
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read row size of key
|
// Read row size of value
|
||||||
size_t v_size_row_ref;
|
size_t v_size_row_ref;
|
||||||
memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
|
memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
|
||||||
inp += sizeof(v_size_row_ref);
|
inp += sizeof(v_size_row_ref);
|
||||||
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
|
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
|
||||||
if (v_size_row != v_size_row_ref) {
|
if (v_size_row != v_size_row_ref) {
|
||||||
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||||
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
|
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cell_count) {
|
if (cell_count) {
|
||||||
// Read and set the keys for the whole cell range
|
// Read and set the values for the whole cell range
|
||||||
ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
|
ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
|
||||||
inp += cell_count * v_size_row;
|
inp += cell_count * v_size_row;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue