Ensure kv cache head points to a valid slot in llama_decode internal
Add some comments to prevent dumb people (like me) from getting confused.
This commit is contained in:
parent
3144563db1
commit
465b8f4fc0
1 changed files with 10 additions and 1 deletions
11
llama.cpp
11
llama.cpp
|
@ -1044,6 +1044,9 @@ struct llama_kv_cell {
|
|||
struct llama_kv_cache {
|
||||
bool has_shift = false;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_internal also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
|
||||
|
@ -1301,6 +1304,8 @@ static bool llama_kv_cache_init(
|
|||
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
// updates the cache head
|
||||
// Note: On success, it's important that cache.head points
|
||||
// to the first cell of the slot.
|
||||
static bool llama_kv_cache_find_slot(
|
||||
struct llama_kv_cache & cache,
|
||||
const struct llama_batch & batch) {
|
||||
|
@ -4563,8 +4568,12 @@ static int llama_decode_internal(
|
|||
#endif
|
||||
|
||||
// update the kv ring buffer
|
||||
lctx.kv_self.head += n_tokens;
|
||||
lctx.kv_self.has_shift = false;
|
||||
lctx.kv_self.head += n_tokens;
|
||||
// Ensure kv cache head points to a valid index.
|
||||
if (lctx.kv_self.head >= lctx.kv_self.size) {
|
||||
lctx.kv_self.head = 0;
|
||||
}
|
||||
|
||||
#ifdef GGML_PERF
|
||||
// print timing information per ggml operation (for debugging purposes)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue