speculative : bug fixes
This commit is contained in:
		
							parent
							
								
									0e89203b51
								
							
						
					
					
						commit
						4e82b2ea3f
					
				
					 1 changed files with 10 additions and 17 deletions
				
			
		|  | @ -37,8 +37,8 @@ int main(int argc, char ** argv) { | |||
|     const int n_seq_dft = params.n_parallel; | ||||
| 
 | ||||
|     // TODO: make this configurable
 | ||||
|     const float p_accept = 0.4f; | ||||
|     const float p_split  = 0.3f; | ||||
|     const float p_accept = 0.80f; | ||||
|     const float p_split  = 0.10f; | ||||
| 
 | ||||
| #ifndef LOG_DISABLE_LOGS | ||||
|     log_set_target(log_filename_generator("speculative", "log")); | ||||
|  | @ -118,7 +118,7 @@ int main(int argc, char ** argv) { | |||
|     std::vector<seq_draft> drafts(n_seq_dft); | ||||
| 
 | ||||
|     params.grammar.clear();             // the draft samplers will copy the target sampler's grammar
 | ||||
|     params.sampling_params.temp = 1.0f; // the draft samplers use default temperature
 | ||||
|     params.sampling_params.temp = std::max(0.01f, params.sampling_params.temp); | ||||
| 
 | ||||
|     for (int s = 0; s < n_seq_dft; ++s) { | ||||
|         drafts[s].ctx_sampling = llama_sampling_init(params); | ||||
|  | @ -156,7 +156,7 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|             llama_sampling_accept(ctx_sampling, ctx_tgt, id); | ||||
| 
 | ||||
|             //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
 | ||||
|             //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
 | ||||
| 
 | ||||
|             const std::string token_str = llama_token_to_piece(ctx_tgt, id); | ||||
| 
 | ||||
|  | @ -202,7 +202,7 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|             // TODO: simplify
 | ||||
|             { | ||||
|                 LOG("keeping sequence %d\n", s_keep); | ||||
|                 LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); | ||||
| 
 | ||||
|                 llama_kv_cache_seq_keep(ctx_dft, s_keep); | ||||
|                 llama_kv_cache_seq_cp  (ctx_dft, s_keep, 0, -1, -1); | ||||
|  | @ -277,7 +277,7 @@ int main(int argc, char ** argv) { | |||
|                 } | ||||
| 
 | ||||
|                 if (cur_p[0].p < p_accept) { | ||||
|                     LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p); | ||||
|                     LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept); | ||||
|                     drafts[s].drafting = false; | ||||
|                     continue; | ||||
|                 } | ||||
|  | @ -337,16 +337,14 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|                     llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); | ||||
| 
 | ||||
|                     // no need to evaluate the last drafted token, since we won't use the result
 | ||||
|                     if (batch_tgt.n_tokens > n_draft) { | ||||
|                         drafts[s].drafting = false; | ||||
|                         continue; | ||||
|                     } | ||||
| 
 | ||||
|                     // add the token to the batch for batched decoding with the draft model
 | ||||
|                     drafts[s].i_batch_dft = batch_dft.n_tokens; | ||||
| 
 | ||||
|                     llama_batch_add(batch_dft, id, n_past_cur, { s }, true); | ||||
| 
 | ||||
|                     if (batch_tgt.n_tokens > n_draft) { | ||||
|                         drafts[s].drafting = false; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|  | @ -365,11 +363,6 @@ int main(int argc, char ** argv) { | |||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // account for the last drafted token that we didn't evaluate
 | ||||
|         if (batch_tgt.n_tokens > n_draft) { | ||||
|             ++n_drafted; | ||||
|         } | ||||
| 
 | ||||
|         // evaluate the target model on the drafted tokens
 | ||||
|         { | ||||
|             llama_kv_cache_seq_keep(ctx_tgt, 0); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue