main : add self-extend support (#4815)
* examples : add passkey test * passkey : better prints * passkey : select pass key pos from CLI * passkey : simplify n_past logic * llama : "self-extend"-like context extension * passkey : add comment * main : add Self-Extend support * llama : add comment about llama_kv_cache_seq_div
This commit is contained in:
		
							parent
							
								
									b0034d93ce
								
							
						
					
					
						commit
						52531fdff8
					
				
					 4 changed files with 87 additions and 24 deletions
				
			
		|  | @ -220,6 +220,20 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { | ||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|             params.n_ctx = std::stoi(argv[i]); |             params.n_ctx = std::stoi(argv[i]); | ||||||
|  |         } else if (arg == "--grp-attn-n" || arg == "-gan") { | ||||||
|  |             if (++i >= argc) { | ||||||
|  |                 invalid_param = true; | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             params.grp_attn_n = std::stoi(argv[i]); | ||||||
|  |         } else if (arg == "--grp-attn-w" || arg == "-gaw") { | ||||||
|  |             if (++i >= argc) { | ||||||
|  |                 invalid_param = true; | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             params.grp_attn_w = std::stoi(argv[i]); | ||||||
|         } else if (arg == "--rope-freq-base") { |         } else if (arg == "--rope-freq-base") { | ||||||
|             if (++i >= argc) { |             if (++i >= argc) { | ||||||
|                 invalid_param = true; |                 invalid_param = true; | ||||||
|  | @ -904,6 +918,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | ||||||
|     printf("                        Not recommended since this is both slower and uses more VRAM.\n"); |     printf("                        Not recommended since this is both slower and uses more VRAM.\n"); | ||||||
| #endif // GGML_USE_CUBLAS
 | #endif // GGML_USE_CUBLAS
 | ||||||
| #endif | #endif | ||||||
|  |     printf("  -gan N, --grp-attn-n N\n"); | ||||||
|  |     printf("                        group-attention factor (default: %d)\n", params.grp_attn_n); | ||||||
|  |     printf("  -gat N, --grp-attn-w N\n"); | ||||||
|  |     printf("                        group-attention width (default: %.1f)\n", (double)params.grp_attn_w); | ||||||
|     printf("  --verbose-prompt      print prompt before generation\n"); |     printf("  --verbose-prompt      print prompt before generation\n"); | ||||||
|     printf("  -dkvc, --dump-kv-cache\n"); |     printf("  -dkvc, --dump-kv-cache\n"); | ||||||
|     printf("                        verbose print of the KV cache\n"); |     printf("                        verbose print of the KV cache\n"); | ||||||
|  |  | ||||||
|  | @ -62,6 +62,8 @@ struct gpt_params { | ||||||
|     int32_t main_gpu                        = 0;     // the GPU that is used for scratch and small tensors
 |     int32_t main_gpu                        = 0;     // the GPU that is used for scratch and small tensors
 | ||||||
|     float   tensor_split[LLAMA_MAX_DEVICES] = {0};   // how split tensors should be distributed across GPUs
 |     float   tensor_split[LLAMA_MAX_DEVICES] = {0};   // how split tensors should be distributed across GPUs
 | ||||||
|     int32_t n_beams                         = 0;     // if non-zero then use beam search of given width.
 |     int32_t n_beams                         = 0;     // if non-zero then use beam search of given width.
 | ||||||
|  |     int32_t grp_attn_n                      = 1;     // group-attention factor
 | ||||||
|  |     int32_t grp_attn_w                      = 512;   // group-attention width
 | ||||||
|     float   rope_freq_base                  = 0.0f;  // RoPE base frequency
 |     float   rope_freq_base                  = 0.0f;  // RoPE base frequency
 | ||||||
|     float   rope_freq_scale                 = 0.0f;  // RoPE frequency scaling factor
 |     float   rope_freq_scale                 = 0.0f;  // RoPE frequency scaling factor
 | ||||||
|     float   yarn_ext_factor                 = -1.0f; // YaRN extrapolation mix factor
 |     float   yarn_ext_factor                 = -1.0f; // YaRN extrapolation mix factor
 | ||||||
|  |  | ||||||
|  | @ -439,6 +439,21 @@ int main(int argc, char ** argv) { | ||||||
|     LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); |     LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); | ||||||
|     LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str()); |     LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str()); | ||||||
|     LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); |     LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); | ||||||
|  | 
 | ||||||
|  |     // group-attention state
 | ||||||
|  |     // number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
 | ||||||
|  |     int ga_i = 0; | ||||||
|  | 
 | ||||||
|  |     const int ga_n = params.grp_attn_n; | ||||||
|  |     const int ga_w = params.grp_attn_w; | ||||||
|  | 
 | ||||||
|  |     if (ga_n != 1) { | ||||||
|  |         GGML_ASSERT(ga_n > 0                    && "grp_attn_n must be positive");                     // NOLINT
 | ||||||
|  |         GGML_ASSERT(ga_w % ga_n == 0            && "grp_attn_w must be a multiple of grp_attn_n");     // NOLINT
 | ||||||
|  |       //GGML_ASSERT(n_ctx_train % ga_w == 0     && "n_ctx_train must be a multiple of grp_attn_w");    // NOLINT
 | ||||||
|  |       //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
 | ||||||
|  |         LOG_TEE("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w); | ||||||
|  |     } | ||||||
|     LOG_TEE("\n\n"); |     LOG_TEE("\n\n"); | ||||||
| 
 | 
 | ||||||
|     if (params.interactive) { |     if (params.interactive) { | ||||||
|  | @ -500,37 +515,61 @@ int main(int argc, char ** argv) { | ||||||
|                 fflush(stdout); |                 fflush(stdout); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             // infinite text generation via context swapping
 |             if (ga_n == 1) { | ||||||
|             // if we run out of context:
 |                 // infinite text generation via context shifting
 | ||||||
|             // - take the n_keep first tokens from the original prompt (via n_past)
 |                 // if we run out of context:
 | ||||||
|             // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
 |                 // - take the n_keep first tokens from the original prompt (via n_past)
 | ||||||
|             if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) { |                 // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
 | ||||||
|                 if (params.n_predict == -2) { |                 if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) { | ||||||
|                     LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); |                     if (params.n_predict == -2) { | ||||||
|                     break; |                         LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); | ||||||
|  |                         break; | ||||||
|  |                     } | ||||||
|  | 
 | ||||||
|  |                     const int n_left    = n_past - params.n_keep - 1; | ||||||
|  |                     const int n_discard = n_left/2; | ||||||
|  | 
 | ||||||
|  |                     LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", | ||||||
|  |                             n_past, n_left, n_ctx, params.n_keep, n_discard); | ||||||
|  | 
 | ||||||
|  |                     llama_kv_cache_seq_rm   (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1); | ||||||
|  |                     llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); | ||||||
|  | 
 | ||||||
|  |                     n_past -= n_discard; | ||||||
|  | 
 | ||||||
|  |                     if (ctx_guidance) { | ||||||
|  |                         n_past_guidance -= n_discard; | ||||||
|  |                     } | ||||||
|  | 
 | ||||||
|  |                     LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); | ||||||
|  | 
 | ||||||
|  |                     LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); | ||||||
|  | 
 | ||||||
|  |                     LOG("clear session path\n"); | ||||||
|  |                     path_session.clear(); | ||||||
|                 } |                 } | ||||||
|  |             } else { | ||||||
|  |                 // context extension via Self-Extend
 | ||||||
|  |                 while (n_past >= ga_i + ga_w) { | ||||||
|  |                     const int ib = (ga_n*ga_i)/ga_w; | ||||||
|  |                     const int bd = (ga_w/ga_n)*(ga_n - 1); | ||||||
|  |                     const int dd = (ga_w/ga_n) - ib*bd - ga_w; | ||||||
| 
 | 
 | ||||||
|                 const int n_left    = n_past - params.n_keep - 1; |                     LOG("\n"); | ||||||
|                 const int n_discard = n_left/2; |                     LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd); | ||||||
|  |                     LOG("div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); | ||||||
|  |                     LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); | ||||||
| 
 | 
 | ||||||
|                 LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", |                     llama_kv_cache_seq_shift(ctx, 0, ga_i,                n_past,              ib*bd); | ||||||
|                     n_past, n_left, n_ctx, params.n_keep, n_discard); |                     llama_kv_cache_seq_div  (ctx, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n); | ||||||
|  |                     llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd); | ||||||
| 
 | 
 | ||||||
|                 llama_kv_cache_seq_rm   (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1); |                     n_past -= bd; | ||||||
|                 llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); |  | ||||||
| 
 | 
 | ||||||
|                 n_past -= n_discard; |                     ga_i += ga_w/ga_n; | ||||||
| 
 | 
 | ||||||
|                 if (ctx_guidance) { |                     LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i); | ||||||
|                     n_past_guidance -= n_discard; |  | ||||||
|                 } |                 } | ||||||
| 
 |  | ||||||
|                 LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); |  | ||||||
| 
 |  | ||||||
|                 LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); |  | ||||||
| 
 |  | ||||||
|                 LOG("clear session path\n"); |  | ||||||
|                 path_session.clear(); |  | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
 |             // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
 | ||||||
|  |  | ||||||
							
								
								
									
										4
									
								
								llama.h
									
										
									
									
									
								
							
							
						
						
									
										4
									
								
								llama.h
									
										
									
									
									
								
							|  | @ -484,6 +484,10 @@ extern "C" { | ||||||
|                        llama_pos   p1, |                        llama_pos   p1, | ||||||
|                        llama_pos   delta); |                        llama_pos   delta); | ||||||
| 
 | 
 | ||||||
|  |     // Integer division of the positions by factor of `d > 1`
 | ||||||
|  |     // If the KV cache is RoPEd, the KV data is updated accordingly
 | ||||||
|  |     // p0 < 0 : [0,  p1]
 | ||||||
|  |     // p1 < 0 : [p0, inf)
 | ||||||
|     LLAMA_API void llama_kv_cache_seq_div( |     LLAMA_API void llama_kv_cache_seq_div( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|                     llama_seq_id   seq_id, |                     llama_seq_id   seq_id, | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue