server : fix crash when prompt exceeds context size (#3996)
This commit is contained in:
		
							parent
							
								
									34b0a08207
								
							
						
					
					
						commit
						d96ca7ded7
					
				
					 1 changed files with 29 additions and 29 deletions
				
			
		|  | @ -1557,6 +1557,35 @@ struct llama_server_context | |||
| 
 | ||||
|                     slot.num_prompt_tokens = prompt_tokens.size(); | ||||
| 
 | ||||
|                     if (slot.params.n_keep < 0) | ||||
|                     { | ||||
|                         slot.params.n_keep = slot.num_prompt_tokens; | ||||
|                     } | ||||
|                     slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); | ||||
| 
 | ||||
|                     // if input prompt is too big, truncate it
 | ||||
|                     if (slot.num_prompt_tokens >= slot.n_ctx) | ||||
|                     { | ||||
|                         const int n_left = slot.n_ctx - slot.params.n_keep; | ||||
|                         const int n_block_size = n_left / 2; | ||||
|                         const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; | ||||
| 
 | ||||
|                         std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep); | ||||
|                         new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); | ||||
| 
 | ||||
|                         LOG_VERBOSE("input truncated", { | ||||
|                             {"n_ctx",  slot.n_ctx}, | ||||
|                             {"n_keep", slot.params.n_keep}, | ||||
|                             {"n_left", n_left}, | ||||
|                             {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, | ||||
|                         }); | ||||
|                         slot.truncated = true; | ||||
|                         prompt_tokens = new_tokens; | ||||
| 
 | ||||
|                         slot.num_prompt_tokens = prompt_tokens.size(); | ||||
|                         GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx); | ||||
|                     } | ||||
| 
 | ||||
|                     if (!slot.params.cache_prompt) | ||||
|                     { | ||||
|                         llama_sampling_reset(slot.ctx_sampling); | ||||
|  | @ -1566,35 +1595,6 @@ struct llama_server_context | |||
|                     } | ||||
|                     else | ||||
|                     { | ||||
|                         if (slot.params.n_keep < 0) | ||||
|                         { | ||||
|                             slot.params.n_keep = slot.num_prompt_tokens; | ||||
|                         } | ||||
|                         slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); | ||||
| 
 | ||||
|                         // if input prompt is too big, truncate it
 | ||||
|                         if (slot.num_prompt_tokens >= slot.n_ctx) | ||||
|                         { | ||||
|                             const int n_left = slot.n_ctx - slot.params.n_keep; | ||||
|                             const int n_block_size = n_left / 2; | ||||
|                             const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; | ||||
| 
 | ||||
|                             std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep); | ||||
|                             new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); | ||||
| 
 | ||||
|                             LOG_VERBOSE("input truncated", { | ||||
|                                                             {"n_ctx",  slot.n_ctx}, | ||||
|                                                             {"n_keep", slot.params.n_keep}, | ||||
|                                                             {"n_left", n_left}, | ||||
|                                                             {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, | ||||
|                                                         }); | ||||
|                             slot.truncated = true; | ||||
|                             prompt_tokens = new_tokens; | ||||
| 
 | ||||
|                             slot.num_prompt_tokens = prompt_tokens.size(); | ||||
|                             GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx); | ||||
|                         } | ||||
| 
 | ||||
|                         // push the prompt into the sampling context (do not apply grammar)
 | ||||
|                         for (auto &token : prompt_tokens) | ||||
|                         { | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue