diff --git a/common/sampling.cpp b/common/sampling.cpp index 4b14abe1a..c9285892e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -173,10 +173,10 @@ static llama_token llama_sampling_sample_impl( const float mirostat_tau = params.mirostat_tau; const float mirostat_eta = params.mirostat_eta; - std::vector* original_logits = nullptr; + std::vector original_logits = {}; auto cur_p = llama_sampling_configure_token_candidates(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits); if (!is_resampling) { - GGML_ASSERT(original_logits != nullptr); + GGML_ASSERT(!original_logits.empty()); } llama_token id = 0; // Get a pointer to the logits @@ -236,14 +236,11 @@ static llama_token llama_sampling_sample_impl( LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str()); // Restore logits from the copy - std::copy((*original_logits).begin(), (*original_logits).end(), logits); + std::copy(original_logits.begin(), original_logits.end(), logits); return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling } } - if (original_logits != nullptr) { - delete original_logits; - } return id; } @@ -254,7 +251,7 @@ static llama_token_data_array llama_sampling_configure_token_candidates_impl( struct llama_context * ctx_cfg, const int idx, bool apply_grammar, - std::vector** original_logits) { + std::vector* original_logits) { const llama_sampling_params & params = ctx_sampling->params; const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); @@ -263,7 +260,7 @@ static llama_token_data_array llama_sampling_configure_token_candidates_impl( const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; const float penalty_present = params.penalty_present; - + const bool penalize_nl = params.penalize_nl; auto & prev = ctx_sampling->prev; @@ -272,9 +269,9 @@ static llama_token_data_array llama_sampling_configure_token_candidates_impl( // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); - if (apply_grammar && original_logits != nullptr) { + if (apply_grammar && original_logits != NULL) { // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. - *original_logits = new std::vector(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); + (*original_logits).insert((*original_logits).end(), logits, logits + llama_n_vocab(llama_get_model(ctx_main))); } // apply params.logit_bias map @@ -338,7 +335,7 @@ llama_token_data_array llama_sampling_configure_token_candidates( struct llama_context * ctx_cfg, const int idx, bool apply_grammar, - std::vector** original_logits) { + std::vector* original_logits) { return llama_sampling_configure_token_candidates_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits); } diff --git a/common/sampling.h b/common/sampling.h index 1a0d9e3b6..38557a67d 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -138,7 +138,7 @@ llama_token_data_array llama_sampling_configure_token_candidates( struct llama_context * ctx_cfg, int idx = 0, bool apply_grammar = true, - std::vector** original_logits = nullptr); + std::vector* original_logits = nullptr); void llama_sampling_accept( struct llama_sampling_context * ctx_sampling,