context : initial need_reserve logic
ggml-ci
This commit is contained in:
		
							parent
							
								
									c75ba6851e
								
							
						
					
					
						commit
						133ad6a723
					
				
					 3 changed files with 269 additions and 245 deletions
				
			
		|  | @ -576,9 +576,7 @@ ggml_tensor * llama_context::build_lora_mm_id( | |||
|     return res; | ||||
| } | ||||
| 
 | ||||
| bool llama_context::kv_self_update() { | ||||
|     bool need_reserve = false; | ||||
| 
 | ||||
| void llama_context::kv_self_update() { | ||||
|     auto & kv = kv_self; | ||||
| 
 | ||||
|     if (kv.has_shift) { | ||||
|  | @ -655,12 +653,14 @@ bool llama_context::kv_self_update() { | |||
| 
 | ||||
|         ggml_free(ctx0); | ||||
| 
 | ||||
|         need_reserve = true; | ||||
| 
 | ||||
|         kv.do_defrag = false; | ||||
|     } | ||||
| 
 | ||||
|     return need_reserve; | ||||
|         need_reserve = true; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_update(llama_context * ctx) { | ||||
|     ctx->kv_self_update(); | ||||
| } | ||||
| 
 | ||||
| void llama_context::build_attn_inp( | ||||
|  | @ -1824,6 +1824,165 @@ int32_t llama_apply_adapter_cvec( | |||
|     return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end); | ||||
| } | ||||
| 
 | ||||
| //
 | ||||
| // kv cache view
 | ||||
| //
 | ||||
| 
 | ||||
| struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) { | ||||
|     return llama_kv_cache_view_init(ctx->kv_self, n_seq_max); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) { | ||||
|     llama_kv_cache_view_update(view, ctx->kv_self); | ||||
| } | ||||
| 
 | ||||
| //
 | ||||
| // kv cache
 | ||||
| //
 | ||||
| 
 | ||||
| // deprecated
 | ||||
| int32_t llama_get_kv_cache_token_count(const llama_context * ctx) { | ||||
|     return llama_kv_self_n_tokens(ctx); | ||||
| } | ||||
| 
 | ||||
| int32_t llama_kv_self_n_tokens(const llama_context * ctx) { | ||||
|     return llama_kv_cache_n_tokens(&ctx->kv_self); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) { | ||||
|     return llama_kv_self_used_cells(ctx); | ||||
| } | ||||
| 
 | ||||
| int32_t llama_kv_self_used_cells(const llama_context * ctx) { | ||||
|     return llama_kv_cache_used_cells(&ctx->kv_self); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_clear(llama_context * ctx) { | ||||
|     llama_kv_self_clear(ctx); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_clear(llama_context * ctx) { | ||||
|     llama_kv_cache_clear(&ctx->kv_self); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| bool llama_kv_cache_seq_rm( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1) { | ||||
|     return llama_kv_self_seq_rm(ctx, seq_id, p0, p1); | ||||
| } | ||||
| 
 | ||||
| bool llama_kv_self_seq_rm( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1) { | ||||
|     return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_seq_cp( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id_src, | ||||
|          llama_seq_id   seq_id_dst, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1) { | ||||
|     return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_seq_cp( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id_src, | ||||
|          llama_seq_id   seq_id_dst, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1) { | ||||
|     return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_seq_keep( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id) { | ||||
|     return llama_kv_self_seq_keep(ctx, seq_id); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { | ||||
|     return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_seq_add( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1, | ||||
|             llama_pos   delta) { | ||||
|     return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_seq_add( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1, | ||||
|             llama_pos   delta) { | ||||
|     return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_seq_div( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1, | ||||
|                   int   d) { | ||||
|     return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_seq_div( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1, | ||||
|                   int   d) { | ||||
|     return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { | ||||
|     return llama_kv_self_seq_pos_max(ctx, seq_id); | ||||
| } | ||||
| 
 | ||||
| llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { | ||||
|     return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_defrag(llama_context * ctx) { | ||||
|     return llama_kv_self_defrag(ctx); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_defrag(llama_context * ctx) { | ||||
|     return llama_kv_cache_defrag(&ctx->kv_self); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| bool llama_kv_cache_can_shift(const llama_context * ctx) { | ||||
|     return llama_kv_self_can_shift(ctx); | ||||
| } | ||||
| 
 | ||||
| bool llama_kv_self_can_shift(const llama_context * ctx) { | ||||
|     return llama_kv_cache_can_shift(&ctx->kv_self); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_update(llama_context * ctx) { | ||||
|     llama_kv_self_update(ctx); | ||||
| } | ||||
| 
 | ||||
| // llama state API
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -62,6 +62,7 @@ struct llama_context { | |||
|     int32_t n_outputs   = 0; // number of actually-used outputs in the current ubatch or last logical batch
 | ||||
| 
 | ||||
|     bool logits_all = false; | ||||
|     bool need_reserve = false; | ||||
| 
 | ||||
|     // embeddings output (2-dimensional array: [n_outputs][n_embd])
 | ||||
|     // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
 | ||||
|  | @ -87,6 +88,7 @@ struct llama_context { | |||
|     // max token position across all sequences in the current context
 | ||||
|     llama_pos pos_max() const; | ||||
| 
 | ||||
|     // certain implementations could require a padding for the context size
 | ||||
|     uint32_t get_ctx_padding(const llama_cparams & cparams) const; | ||||
| 
 | ||||
|     void reset(); | ||||
|  | @ -140,7 +142,7 @@ struct llama_context { | |||
|     struct ggml_tensor * inp_K_shift;         // I32 [kv_size]
 | ||||
| 
 | ||||
|     // return true if need to reserve new worst-case graph
 | ||||
|     bool kv_self_update(); | ||||
|     void kv_self_update(); | ||||
| 
 | ||||
|     void build_attn_inp( | ||||
|             ggml_context * ctx0, | ||||
|  |  | |||
							
								
								
									
										337
									
								
								src/llama.cpp
									
										
									
									
									
								
							
							
						
						
									
										337
									
								
								src/llama.cpp
									
										
									
									
									
								
							|  | @ -28,57 +28,6 @@ | |||
| #pragma warning(disable: 4244 4267) // possible loss of data
 | ||||
| #endif | ||||
| 
 | ||||
| // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
 | ||||
| static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) { | ||||
|     // loading time will be recalculated after the first eval, so
 | ||||
|     // we take page faults deferred by mmap() into consideration
 | ||||
|     model.t_load_us = 0; | ||||
|     time_meas tm(model.t_load_us); | ||||
| 
 | ||||
|     model.t_start_us = tm.t_start_us; | ||||
| 
 | ||||
|     try { | ||||
|         llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides); | ||||
| 
 | ||||
|         ml.print_info(); | ||||
| 
 | ||||
|         model.hparams.vocab_only = params.vocab_only; | ||||
| 
 | ||||
|         try { | ||||
|             model.load_arch(ml); | ||||
|         } catch(const std::exception & e) { | ||||
|             throw std::runtime_error("error loading model architecture: " + std::string(e.what())); | ||||
|         } | ||||
|         try { | ||||
|             model.load_hparams(ml); | ||||
|         } catch(const std::exception & e) { | ||||
|             throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); | ||||
|         } | ||||
|         try { | ||||
|             model.load_vocab(ml); | ||||
|         } catch(const std::exception & e) { | ||||
|             throw std::runtime_error("error loading model vocabulary: " + std::string(e.what())); | ||||
|         } | ||||
| 
 | ||||
|         model.load_stats(ml); | ||||
|         model.print_info(); | ||||
| 
 | ||||
|         if (params.vocab_only) { | ||||
|             LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); | ||||
|             return 0; | ||||
|         } | ||||
| 
 | ||||
|         if (!model.load_tensors(ml)) { | ||||
|             return -2; | ||||
|         } | ||||
|     } catch (const std::exception & err) { | ||||
|         LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); | ||||
|         return -1; | ||||
|     } | ||||
| 
 | ||||
|     return 0; | ||||
| } | ||||
| 
 | ||||
| //
 | ||||
| // llm_build
 | ||||
| //
 | ||||
|  | @ -7951,6 +7900,30 @@ static int llama_decode_impl( | |||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // reserve a worst case graph if needed
 | ||||
|         // TODO: extract to a function
 | ||||
|         if (lctx.need_reserve) { | ||||
|             const auto & cparams = lctx.cparams; | ||||
|             const auto & model   = lctx.model; | ||||
| 
 | ||||
|             // build worst-case graph
 | ||||
|             uint32_t n_seqs = 1; // TODO: worst-case number of sequences
 | ||||
|             uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); | ||||
| 
 | ||||
|             llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 | ||||
|             llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; | ||||
| 
 | ||||
|             ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); | ||||
| 
 | ||||
|             // initialize scheduler with the worst-case graph
 | ||||
|             ggml_backend_sched_reset(lctx.sched.get()); | ||||
|             if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) { | ||||
|                 LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); | ||||
|             } | ||||
| 
 | ||||
|             lctx.need_reserve = false; | ||||
|         } | ||||
| 
 | ||||
|         //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 | ||||
| 
 | ||||
|         ggml_backend_sched_reset(lctx.sched.get()); | ||||
|  | @ -8206,6 +8179,31 @@ static int llama_encode_impl( | |||
| 
 | ||||
|     lctx.prepare_decode(ubatch); | ||||
| 
 | ||||
|     // reserve a worst case graph if needed
 | ||||
|     // TODO: extract to a function
 | ||||
|     if (lctx.need_reserve) { | ||||
|         // TODO: extract to a function
 | ||||
|         const auto & cparams = lctx.cparams; | ||||
|         const auto & model   = lctx.model; | ||||
| 
 | ||||
|         // build worst-case graph
 | ||||
|         uint32_t n_seqs = 1; // TODO: worst-case number of sequences
 | ||||
|         uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); | ||||
| 
 | ||||
|         llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 | ||||
|         llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; | ||||
| 
 | ||||
|         ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); | ||||
| 
 | ||||
|         // initialize scheduler with the worst-case graph
 | ||||
|         ggml_backend_sched_reset(lctx.sched.get()); | ||||
|         if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) { | ||||
|             LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); | ||||
|         } | ||||
| 
 | ||||
|         lctx.need_reserve = false; | ||||
|     } | ||||
| 
 | ||||
|     ggml_backend_sched_reset(lctx.sched.get()); | ||||
|     ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); | ||||
| 
 | ||||
|  | @ -8419,6 +8417,57 @@ int64_t llama_time_us(void) { | |||
|     return ggml_time_us(); | ||||
| } | ||||
| 
 | ||||
| // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
 | ||||
| static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) { | ||||
|     // loading time will be recalculated after the first eval, so
 | ||||
|     // we take page faults deferred by mmap() into consideration
 | ||||
|     model.t_load_us = 0; | ||||
|     time_meas tm(model.t_load_us); | ||||
| 
 | ||||
|     model.t_start_us = tm.t_start_us; | ||||
| 
 | ||||
|     try { | ||||
|         llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides); | ||||
| 
 | ||||
|         ml.print_info(); | ||||
| 
 | ||||
|         model.hparams.vocab_only = params.vocab_only; | ||||
| 
 | ||||
|         try { | ||||
|             model.load_arch(ml); | ||||
|         } catch(const std::exception & e) { | ||||
|             throw std::runtime_error("error loading model architecture: " + std::string(e.what())); | ||||
|         } | ||||
|         try { | ||||
|             model.load_hparams(ml); | ||||
|         } catch(const std::exception & e) { | ||||
|             throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); | ||||
|         } | ||||
|         try { | ||||
|             model.load_vocab(ml); | ||||
|         } catch(const std::exception & e) { | ||||
|             throw std::runtime_error("error loading model vocabulary: " + std::string(e.what())); | ||||
|         } | ||||
| 
 | ||||
|         model.load_stats(ml); | ||||
|         model.print_info(); | ||||
| 
 | ||||
|         if (params.vocab_only) { | ||||
|             LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); | ||||
|             return 0; | ||||
|         } | ||||
| 
 | ||||
|         if (!model.load_tensors(ml)) { | ||||
|             return -2; | ||||
|         } | ||||
|     } catch (const std::exception & err) { | ||||
|         LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); | ||||
|         return -1; | ||||
|     } | ||||
| 
 | ||||
|     return 0; | ||||
| } | ||||
| 
 | ||||
| static struct llama_model * llama_model_load_from_file_impl( | ||||
|         const std::string & path_model, | ||||
|         std::vector<std::string> & splits, | ||||
|  | @ -8889,192 +8938,6 @@ struct llama_context * llama_new_context_with_model( | |||
|     return llama_init_from_model(model, params); | ||||
| } | ||||
| 
 | ||||
| //
 | ||||
| // kv cache view
 | ||||
| //
 | ||||
| 
 | ||||
| struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) { | ||||
|     return llama_kv_cache_view_init(ctx->kv_self, n_seq_max); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) { | ||||
|     llama_kv_cache_view_update(view, ctx->kv_self); | ||||
| } | ||||
| 
 | ||||
| //
 | ||||
| // kv cache
 | ||||
| //
 | ||||
| 
 | ||||
| // deprecated
 | ||||
| int32_t llama_get_kv_cache_token_count(const llama_context * ctx) { | ||||
|     return llama_kv_self_n_tokens(ctx); | ||||
| } | ||||
| 
 | ||||
| int32_t llama_kv_self_n_tokens(const llama_context * ctx) { | ||||
|     return llama_kv_cache_n_tokens(&ctx->kv_self); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) { | ||||
|     return llama_kv_self_used_cells(ctx); | ||||
| } | ||||
| 
 | ||||
| int32_t llama_kv_self_used_cells(const llama_context * ctx) { | ||||
|     return llama_kv_cache_used_cells(&ctx->kv_self); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_clear(llama_context * ctx) { | ||||
|     llama_kv_self_clear(ctx); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_clear(llama_context * ctx) { | ||||
|     llama_kv_cache_clear(&ctx->kv_self); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| bool llama_kv_cache_seq_rm( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1) { | ||||
|     return llama_kv_self_seq_rm(ctx, seq_id, p0, p1); | ||||
| } | ||||
| 
 | ||||
| bool llama_kv_self_seq_rm( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1) { | ||||
|     return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_seq_cp( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id_src, | ||||
|          llama_seq_id   seq_id_dst, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1) { | ||||
|     return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_seq_cp( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id_src, | ||||
|          llama_seq_id   seq_id_dst, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1) { | ||||
|     return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_seq_keep( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id) { | ||||
|     return llama_kv_self_seq_keep(ctx, seq_id); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { | ||||
|     return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_seq_add( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1, | ||||
|             llama_pos   delta) { | ||||
|     return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_seq_add( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1, | ||||
|             llama_pos   delta) { | ||||
|     return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_seq_div( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1, | ||||
|                   int   d) { | ||||
|     return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_seq_div( | ||||
|         llama_context * ctx, | ||||
|          llama_seq_id   seq_id, | ||||
|             llama_pos   p0, | ||||
|             llama_pos   p1, | ||||
|                   int   d) { | ||||
|     return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { | ||||
|     return llama_kv_self_seq_pos_max(ctx, seq_id); | ||||
| } | ||||
| 
 | ||||
| llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { | ||||
|     return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_defrag(llama_context * ctx) { | ||||
|     return llama_kv_self_defrag(ctx); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_defrag(llama_context * ctx) { | ||||
|     return llama_kv_cache_defrag(&ctx->kv_self); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| bool llama_kv_cache_can_shift(const llama_context * ctx) { | ||||
|     return llama_kv_self_can_shift(ctx); | ||||
| } | ||||
| 
 | ||||
| bool llama_kv_self_can_shift(const llama_context * ctx) { | ||||
|     return llama_kv_cache_can_shift(&ctx->kv_self); | ||||
| } | ||||
| 
 | ||||
| // deprecated
 | ||||
| void llama_kv_cache_update(llama_context * ctx) { | ||||
|     llama_kv_self_update(ctx); | ||||
| } | ||||
| 
 | ||||
| void llama_kv_self_update(llama_context * ctx) { | ||||
|     const bool need_reserve = ctx->kv_self_update(); | ||||
| 
 | ||||
|     // reserve a worst case graph again
 | ||||
|     if (need_reserve) { | ||||
|         // TODO: extract to a function
 | ||||
|         const auto & cparams = ctx->cparams; | ||||
|         const auto & model   = ctx->model; | ||||
| 
 | ||||
|         // build worst-case graph
 | ||||
|         uint32_t n_seqs = 1; // TODO: worst-case number of sequences
 | ||||
|         uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); | ||||
| 
 | ||||
|         llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 | ||||
|         llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; | ||||
| 
 | ||||
|         ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true); | ||||
| 
 | ||||
|         // initialize scheduler with the worst-case graph
 | ||||
|         ggml_backend_sched_reset(ctx->sched.get()); | ||||
|         if (!ggml_backend_sched_reserve(ctx->sched.get(), gf)) { | ||||
|             LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| ///
 | ||||
| 
 | ||||
| int32_t llama_encode( | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue