Extend llama_kv_cache_seq_rm to allow matching any sequence (#3843)
* Extend llama_kv_cache_seq_rm to allow matichng any sequence * Replace llama_kv_cache_tokens_rm with llama_kv_cache_clear Use llama_kv_cache_clear for cache clearing Change calls to llama_kv_cache_tokens_rm that want to delete by position to use llama_kv_cache_seq_rm functionality
This commit is contained in:
		
							parent
							
								
									2046eb4345
								
							
						
					
					
						commit
						6e08281e58
					
				
					 8 changed files with 30 additions and 32 deletions
				
			
		|  | @ -889,7 +889,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par | ||||||
| 
 | 
 | ||||||
|         std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), }; |         std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), }; | ||||||
|         llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); |         llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); | ||||||
|         llama_kv_cache_tokens_rm(lctx, -1, -1); |         llama_kv_cache_clear(lctx); | ||||||
|         llama_reset_timings(lctx); |         llama_reset_timings(lctx); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -185,7 +185,7 @@ int main(int argc, char ** argv) { | ||||||
| 
 | 
 | ||||||
|                 const auto t_pp_start = ggml_time_us(); |                 const auto t_pp_start = ggml_time_us(); | ||||||
| 
 | 
 | ||||||
|                 llama_kv_cache_tokens_rm(ctx, -1, -1); |                 llama_kv_cache_clear(ctx); | ||||||
| 
 | 
 | ||||||
|                 if (!decode_helper(ctx, batch, ctx_params.n_batch)) { |                 if (!decode_helper(ctx, batch, ctx_params.n_batch)) { | ||||||
|                     LOG_TEE("%s: llama_decode() failed\n", __func__); |                     LOG_TEE("%s: llama_decode() failed\n", __func__); | ||||||
|  |  | ||||||
|  | @ -1037,7 +1037,7 @@ int main(int argc, char ** argv) { | ||||||
| 
 | 
 | ||||||
|         test t(inst, lmodel, ctx); |         test t(inst, lmodel, ctx); | ||||||
| 
 | 
 | ||||||
|         llama_kv_cache_tokens_rm(ctx, -1, -1); |         llama_kv_cache_clear(ctx); | ||||||
| 
 | 
 | ||||||
|         // warmup run
 |         // warmup run
 | ||||||
|         if (t.n_prompt > 0) { |         if (t.n_prompt > 0) { | ||||||
|  | @ -1048,7 +1048,7 @@ int main(int argc, char ** argv) { | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         for (int i = 0; i < params.reps; i++) { |         for (int i = 0; i < params.reps; i++) { | ||||||
|             llama_kv_cache_tokens_rm(ctx, -1, -1); |             llama_kv_cache_clear(ctx); | ||||||
| 
 | 
 | ||||||
|             uint64_t t_start = get_time_ns(); |             uint64_t t_start = get_time_ns(); | ||||||
|             if (t.n_prompt > 0) { |             if (t.n_prompt > 0) { | ||||||
|  |  | ||||||
|  | @ -298,7 +298,7 @@ int main(int argc, char ** argv) { | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // remove any "future" tokens that we might have inherited from the previous session
 |         // remove any "future" tokens that we might have inherited from the previous session
 | ||||||
|         llama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1); |         llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     LOGLN( |     LOGLN( | ||||||
|  |  | ||||||
|  | @ -210,7 +210,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & | ||||||
|         const auto t_start = std::chrono::high_resolution_clock::now(); |         const auto t_start = std::chrono::high_resolution_clock::now(); | ||||||
| 
 | 
 | ||||||
|         // clear the KV cache
 |         // clear the KV cache
 | ||||||
|         llama_kv_cache_tokens_rm(ctx, -1, -1); |         llama_kv_cache_clear(ctx); | ||||||
| 
 | 
 | ||||||
|         for (int j = 0; j < num_batches; ++j) { |         for (int j = 0; j < num_batches; ++j) { | ||||||
|             const int batch_start = start + j * n_batch; |             const int batch_start = start + j * n_batch; | ||||||
|  | @ -339,7 +339,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par | ||||||
|         const auto t_start = std::chrono::high_resolution_clock::now(); |         const auto t_start = std::chrono::high_resolution_clock::now(); | ||||||
| 
 | 
 | ||||||
|         // clear the KV cache
 |         // clear the KV cache
 | ||||||
|         llama_kv_cache_tokens_rm(ctx, -1, -1); |         llama_kv_cache_clear(ctx); | ||||||
| 
 | 
 | ||||||
|         for (int j = 0; j < num_batches; ++j) { |         for (int j = 0; j < num_batches; ++j) { | ||||||
|             const int batch_start = start + j * n_batch; |             const int batch_start = start + j * n_batch; | ||||||
|  | @ -573,7 +573,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // clear the KV cache
 |         // clear the KV cache
 | ||||||
|         llama_kv_cache_tokens_rm(ctx, -1, -1); |         llama_kv_cache_clear(ctx); | ||||||
| 
 | 
 | ||||||
|         auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab); |         auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab); | ||||||
|         if (logits.empty()) { |         if (logits.empty()) { | ||||||
|  |  | ||||||
|  | @ -857,7 +857,7 @@ struct llama_server_context | ||||||
| 
 | 
 | ||||||
|     void kv_cache_clear() { |     void kv_cache_clear() { | ||||||
|         // clear the entire KV cache
 |         // clear the entire KV cache
 | ||||||
|         llama_kv_cache_tokens_rm(ctx, -1, -1); |         llama_kv_cache_clear(ctx); | ||||||
|         clean_kv_cache = false; |         clean_kv_cache = false; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										27
									
								
								llama.cpp
									
										
									
									
									
								
							
							
						
						
									
										27
									
								
								llama.cpp
									
										
									
									
									
								
							|  | @ -1466,17 +1466,12 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { | ||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) { | static void llama_kv_cache_clear(struct llama_kv_cache & cache) { | ||||||
|     if (c0 < 0) c0 = 0; |     for (int32_t i = 0; i < cache.size; ++i) { | ||||||
|     if (c1 < 0) c1 = cache.size; |  | ||||||
| 
 |  | ||||||
|     for (int32_t i = c0; i < c1; ++i) { |  | ||||||
|         cache.cells[i].pos = -1; |         cache.cells[i].pos = -1; | ||||||
|         cache.cells[i].seq_id.clear(); |         cache.cells[i].seq_id.clear(); | ||||||
|     } |     } | ||||||
| 
 |     cache.head = 0; | ||||||
|     // Searching for a free slot can start here since we know it will be empty.
 |  | ||||||
|     cache.head = uint32_t(c0); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| static void llama_kv_cache_seq_rm( | static void llama_kv_cache_seq_rm( | ||||||
|  | @ -1490,8 +1485,14 @@ static void llama_kv_cache_seq_rm( | ||||||
|     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); |     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||||||
| 
 | 
 | ||||||
|     for (uint32_t i = 0; i < cache.size; ++i) { |     for (uint32_t i = 0; i < cache.size; ++i) { | ||||||
|         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { |         if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { | ||||||
|  |             if (seq_id < 0) { | ||||||
|  |                 cache.cells[i].seq_id.clear(); | ||||||
|  |             } else if (cache.cells[i].has_seq_id(seq_id)) { | ||||||
|                 cache.cells[i].seq_id.erase(seq_id); |                 cache.cells[i].seq_id.erase(seq_id); | ||||||
|  |             } else { | ||||||
|  |                 continue; | ||||||
|  |             } | ||||||
|             if (cache.cells[i].seq_id.empty()) { |             if (cache.cells[i].seq_id.empty()) { | ||||||
|                 cache.cells[i].pos = -1; |                 cache.cells[i].pos = -1; | ||||||
|                 if (new_head == cache.size) new_head = i; |                 if (new_head == cache.size) new_head = i; | ||||||
|  | @ -9207,8 +9208,8 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) { | ||||||
|     return ctx->kv_self.head; |     return ctx->kv_self.head; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) { | void llama_kv_cache_clear(struct llama_context * ctx) { | ||||||
|     llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1); |     llama_kv_cache_clear(ctx->kv_self); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { | void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { | ||||||
|  | @ -9654,7 +9655,7 @@ int llama_eval( | ||||||
|                  llama_token * tokens, |                  llama_token * tokens, | ||||||
|                      int32_t   n_tokens, |                      int32_t   n_tokens, | ||||||
|                          int   n_past) { |                          int   n_past) { | ||||||
|     llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); |     llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); | ||||||
| 
 | 
 | ||||||
|     const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0)); |     const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0)); | ||||||
|     if (ret < 0) { |     if (ret < 0) { | ||||||
|  | @ -9669,7 +9670,7 @@ int llama_eval_embd( | ||||||
|                            float * embd, |                            float * embd, | ||||||
|                          int32_t   n_tokens, |                          int32_t   n_tokens, | ||||||
|                              int   n_past) { |                              int   n_past) { | ||||||
|     llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); |     llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); | ||||||
| 
 | 
 | ||||||
|     llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; |     llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										11
									
								
								llama.h
									
										
									
									
									
								
							
							
						
						
									
										11
									
								
								llama.h
									
										
									
									
									
								
							|  | @ -334,15 +334,12 @@ extern "C" { | ||||||
|     LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), |     LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), | ||||||
|             "avoid using this, it will be removed in the future, instead - count the tokens in user code"); |             "avoid using this, it will be removed in the future, instead - count the tokens in user code"); | ||||||
| 
 | 
 | ||||||
|     // Remove all tokens data of cells in [c0, c1)
 |     // Clear the KV cache
 | ||||||
|     // c0 < 0 : [0,  c1]
 |     LLAMA_API void llama_kv_cache_clear( | ||||||
|     // c1 < 0 : [c0, inf)
 |             struct llama_context * ctx); | ||||||
|     LLAMA_API void llama_kv_cache_tokens_rm( |  | ||||||
|             struct llama_context * ctx, |  | ||||||
|                          int32_t   c0, |  | ||||||
|                          int32_t   c1); |  | ||||||
| 
 | 
 | ||||||
|     // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
 |     // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
 | ||||||
|  |     // seq_id < 0 : match any sequence
 | ||||||
|     // p0 < 0     : [0,  p1]
 |     // p0 < 0     : [0,  p1]
 | ||||||
|     // p1 < 0     : [p0, inf)
 |     // p1 < 0     : [p0, inf)
 | ||||||
|     LLAMA_API void llama_kv_cache_seq_rm( |     LLAMA_API void llama_kv_cache_seq_rm( | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue