From 465b8f4fc0ae54ddbbbe891ade101f9a17ef30f8 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Fri, 6 Oct 2023 07:33:10 -0600 Subject: [PATCH] Ensure kv cache head points to a valid slot in llama_decode internal Add some comments to prevent dumb people (like me) from getting confused. --- llama.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index ce4d68f38..bf640bc02 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1044,6 +1044,9 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; + // Note: The value of head isn't only used to optimize searching + // for a free KV slot. llama_decode_internal also uses it, so it + // cannot be freely changed after a slot has been allocated. uint32_t head = 0; uint32_t size = 0; @@ -1301,6 +1304,8 @@ static bool llama_kv_cache_init( // find an empty slot of size "n_tokens" in the cache // updates the cache head +// Note: On success, it's important that cache.head points +// to the first cell of the slot. static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, const struct llama_batch & batch) { @@ -4563,8 +4568,12 @@ static int llama_decode_internal( #endif // update the kv ring buffer - lctx.kv_self.head += n_tokens; lctx.kv_self.has_shift = false; + lctx.kv_self.head += n_tokens; + // Ensure kv cache head points to a valid index. + if (lctx.kv_self.head >= lctx.kv_self.size) { + lctx.kv_self.head = 0; + } #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes)