Make sampling not throw exception
This commit is contained in:
parent
aa3094c91d
commit
a94895217c
9 changed files with 36 additions and 4 deletions
|
@ -189,13 +189,16 @@ static llama_token llama_sampling_sample_impl(
|
||||||
|
|
||||||
std::vector<float> original_logits;
|
std::vector<float> original_logits;
|
||||||
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
|
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
|
||||||
|
if (cur_p.data == NULL) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
||||||
GGML_ASSERT(!original_logits.empty());
|
GGML_ASSERT(!original_logits.empty());
|
||||||
}
|
}
|
||||||
llama_token id = 0;
|
llama_token id = 0;
|
||||||
// Get a pointer to the logits
|
// Get a pointer to the logits
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
GGML_ASSERT(logits); // already checked in llama_sampling_prepare
|
GGML_ASSERT(logits); // already checked in llama_sampling_prepare
|
||||||
|
|
||||||
if (temp < 0.0) {
|
if (temp < 0.0) {
|
||||||
// greedy sampling, with probs
|
// greedy sampling, with probs
|
||||||
|
@ -286,7 +289,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
// Get a pointer to the logits
|
// Get a pointer to the logits
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
if (!logits) {
|
if (!logits) {
|
||||||
throw std::runtime_error("llama_get_logits_ith failed");
|
return {NULL, 0, false};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx_sampling->grammar != NULL && !apply_grammar) {
|
if (ctx_sampling->grammar != NULL && !apply_grammar) {
|
||||||
|
@ -303,7 +306,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
if (ctx_cfg) {
|
if (ctx_cfg) {
|
||||||
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
||||||
if (!logits_guidance) {
|
if (!logits_guidance) {
|
||||||
throw std::runtime_error("llama_get_logits_ith failed");
|
return {NULL, 0, false};
|
||||||
}
|
}
|
||||||
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
||||||
}
|
}
|
||||||
|
|
|
@ -530,6 +530,9 @@ int main(int argc, char ** argv) {
|
||||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||||
|
if (id == -1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_llama,
|
struct llama_context * ctx_llama,
|
||||||
int * n_past) {
|
int * n_past) {
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
|
||||||
|
GGML_ASSERT(id != -1);
|
||||||
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
|
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
|
||||||
static std::string ret;
|
static std::string ret;
|
||||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||||
|
|
|
@ -159,6 +159,9 @@ int main(int argc, char ** argv) {
|
||||||
// sample first token
|
// sample first token
|
||||||
{
|
{
|
||||||
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
|
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
|
||||||
|
if (id == -1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
@ -284,6 +287,9 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// sample the next token
|
// sample the next token
|
||||||
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
|
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
|
||||||
|
if (id == -1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
@ -361,6 +367,9 @@ int main(int argc, char ** argv) {
|
||||||
// sample from the last level
|
// sample from the last level
|
||||||
for (int i = 0; i < W; i++) {
|
for (int i = 0; i < W; i++) {
|
||||||
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
||||||
|
if (tokens_j[N - 2][i] == -1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < W; i++) {
|
for (int i = 0; i < W; i++) {
|
||||||
|
|
|
@ -131,6 +131,7 @@ int main(int argc, char ** argv){
|
||||||
while (true) {
|
while (true) {
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
|
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
|
||||||
|
GGML_ASSERT(id != -1);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
|
|
@ -706,6 +706,9 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||||
|
if (id == -1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
|
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
|
||||||
|
|
||||||
|
|
|
@ -341,6 +341,7 @@ int main(int argc, char ** argv) {
|
||||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
||||||
|
GGML_ASSERT(id != -1);
|
||||||
|
|
||||||
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
|
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
|
|
@ -2257,6 +2257,9 @@ struct server_context {
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
|
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
|
||||||
|
if (id == -1) {
|
||||||
|
continue; // keep going, don't crash, already logged
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
|
|
@ -229,6 +229,9 @@ int main(int argc, char ** argv) {
|
||||||
// stochastic verification
|
// stochastic verification
|
||||||
|
|
||||||
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
|
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
|
||||||
|
if (dist_tgt.data == NULL) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
llama_sample_softmax(ctx_tgt, &dist_tgt);
|
llama_sample_softmax(ctx_tgt, &dist_tgt);
|
||||||
float p_tgt = 0, p_dft = 0;
|
float p_tgt = 0, p_dft = 0;
|
||||||
|
|
||||||
|
@ -337,6 +340,9 @@ int main(int argc, char ** argv) {
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
if (token_id == -1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||||
|
|
||||||
|
@ -457,7 +463,9 @@ int main(int argc, char ** argv) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
|
if (llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft) == -1) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
const auto & cur_p = drafts[s].ctx_sampling->cur;
|
const auto & cur_p = drafts[s].ctx_sampling->cur;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue