example : fix build + fix speculative

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-04 22:16:30 +03:00
parent 9b950671f4
commit b2b36e9e95
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 45 additions and 22 deletions

View file

@ -158,7 +158,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_da
return llama_sampler_sample(gsmpl->smpl, cur_p);
}
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) {
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
auto & bias = gsmpl->bias;
auto & pnlt = gsmpl->pnlt;
auto & grmr = gsmpl->grmr;
@ -173,10 +173,18 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
llama_constraint_apply(bias, cur_p);
llama_constraint_apply(pnlt, cur_p);
if (grammar_first) {
llama_constraint_apply(grmr, cur_p);
}
llama_sampler_apply(smpl, cur_p);
const llama_token id = llama_sampler_sample(smpl, cur_p);
if (grammar_first) {
return id;
}
// check if it the sampled token fits the grammar
{
llama_token_data single_token_data = { id, 1.0f, 0.0f };

View file

@ -92,7 +92,10 @@ void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl);
// - check if the token fits the grammar (if any)
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
//
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx);
// if grammar_first is true, the grammar is applied before the constraints (slower)
// useful in cases where all the resulting candidates must fit the grammar
//
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
// helpers

View file

@ -141,9 +141,7 @@ while n_cur <= n_len {
llama_sampler_set_logits(smpl, logits)
let new_token_id = llama_sampler_sample_dist(smpl, nil)
// const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nil, false);
let new_token_id = llama_sampler_sample(smpl, nil)
// is it an end of stream? -> mark the stream as finished
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {

View file

@ -399,7 +399,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
llama_sampler_set_logits(sampling, logits);
// sample the most likely token
const auto new_token_id = llama_sampler_sample_greedy(sampling, nullptr, false);
const auto new_token_id = llama_sampler_sample(sampling, nullptr);
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {

View file

@ -43,7 +43,9 @@ actor LlamaContext {
self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1)
self.temporary_invalid_cchars = []
self.sampling = llama_sampler_init(context, llama_sampler_default_params())
var sparams = llama_sampler_default_params()
sparams.type = LLAMA_SAMPLER_TYPE_GREEDY
self.sampling = llama_sampler_init(context, sparams)
}
deinit {
@ -151,7 +153,7 @@ actor LlamaContext {
llama_sampler_set_logits(sampling, logits);
new_token_id = llama_sampler_sample_greedy(sampling, nil, false)
new_token_id = llama_sampler_sample(sampling, nil)
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
print("\n")

View file

@ -228,20 +228,13 @@ int main(int argc, char ** argv) {
bool accept = false;
if (params.sparams.temp > 0) {
// stochastic verification
const float * logits = llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
gpt_sampler_set_logits(smpl, logits);
gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
auto & dist_tgt = *gpt_sampler_get_candidates(smpl);
gpt_sampler_apply_grammar(smpl, &dist_tgt);
llama_constraint_apply(softmax, &dist_tgt);
float p_tgt = 0.0f;
float p_dft = 0.0f;
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
while (active_seqs.size() > 0) {
// randomly select a sequence to verify from active sequences
std::uniform_int_distribution<unsigned int> u_int_dist(0, active_seqs.size() - 1);
@ -259,9 +252,13 @@ int main(int argc, char ** argv) {
}
continue;
}
LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
float r = u_dist(rng);
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
//GGML_ASSERT(dist_tgt.size <= dist_dft.size);
// acquire the token probabilities assigned by the draft and target models
for (size_t i = 0; i < dist_tgt.size; i++) {
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
@ -291,7 +288,6 @@ int main(int argc, char ** argv) {
// calculate residual probability
GGML_ASSERT(dist_tgt.sorted);
GGML_ASSERT(dist_dft.sorted);
float sum_probs = 0.0f;
// sort dist by id
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
@ -301,10 +297,18 @@ int main(int argc, char ** argv) {
return a.id < b.id;
});
float sum_probs = 0.0f;
for (size_t i = 0; i < dist_tgt.size; i++) {
if (i < dist_dft.size) {
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
} else {
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p);
}
sum_probs += dist_tgt.data[i].p;
}
for (size_t i = 0; i < dist_tgt.size; i++) {
dist_tgt.data[i].p /= sum_probs;
}
@ -334,7 +338,16 @@ int main(int argc, char ** argv) {
// all drafted tokens were rejected
// sample from the target model
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
token_id = gpt_sampler_sample(smpl, &dist_tgt);
std::vector<float> probs(dist_tgt.size);
for (size_t i = 0; i < dist_tgt.size; ++i) {
probs[i] = dist_tgt.data[i].p;
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
const int idx = dist(rng);
token_id = dist_tgt.data[idx].id;
gpt_sampler_accept(smpl, token_id, true);
token_str = llama_token_to_piece(ctx_tgt, token_id);
}
@ -467,7 +480,7 @@ int main(int argc, char ** argv) {
continue;
}
gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl);
@ -512,7 +525,6 @@ int main(int argc, char ** argv) {
}
drafts[n_seq_cur].smpl = gpt_sampler_cp(drafts[s].smpl);
sa.push_back(n_seq_cur);
n_seq_cur++;