llama : minor sampling refactor (2) (#9386)

This commit is contained in:
slaren 2024-09-09 17:10:46 +02:00 committed by GitHub
parent 38ca6f644b
commit 5fb5e24811
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 115 additions and 113 deletions

View file

@ -140,8 +140,6 @@ while n_cur <= n_len {
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
llama_sampler_accept(smpl, new_token_id)
// is it an end of stream? -> mark the stream as finished
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
i_batch[i] = -1

View file

@ -172,8 +172,6 @@ int main(int argc, char ** argv) {
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
llama_sampler_accept(smpl, new_token_id);
// is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
i_batch[i] = -1;

View file

@ -121,7 +121,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_decode(ctx, bat);
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
llama_sampler_accept(smpl, token);
if (token == eos_token) {
break;

View file

@ -414,8 +414,6 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
// sample the most likely token
const auto new_token_id = llama_sampler_sample(sampler, context, -1);
llama_sampler_accept(sampler, new_token_id);
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
return nullptr;

View file

@ -152,8 +152,6 @@ actor LlamaContext {
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
llama_sampler_accept(sampling, new_token_id)
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
print("\n")
is_done = true

View file

@ -220,8 +220,6 @@ int main(int argc, char ** argv) {
{
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
llama_sampler_accept(smpl, new_token_id);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
LOG_TEE("\n");

View file

@ -74,8 +74,6 @@ int main(int argc, char ** argv) {
auto next_token = llama_sampler_sample(smpl, ctx, -1);
auto next_token_str = llama_token_to_piece(ctx, next_token);
llama_sampler_accept(smpl, next_token);
printf("%s", next_token_str.c_str());
result0 += next_token_str;
@ -132,8 +130,6 @@ int main(int argc, char ** argv) {
auto next_token = llama_sampler_sample(smpl2, ctx2, -1);
auto next_token_str = llama_token_to_piece(ctx2, next_token);
llama_sampler_accept(smpl2, next_token);
printf("%s", next_token_str.c_str());
result1 += next_token_str;
@ -222,8 +218,6 @@ int main(int argc, char ** argv) {
auto next_token = llama_sampler_sample(smpl3, ctx3, -1);
auto next_token_str = llama_token_to_piece(ctx3, next_token);
llama_sampler_accept(smpl3, next_token);
printf("%s", next_token_str.c_str());
result2 += next_token_str;

View file

@ -613,7 +613,7 @@ struct server_context {
gpt_params params;
llama_batch batch;
llama_batch batch = {};
bool clean_kv_cache = true;
bool add_bos_token = true;

View file

@ -118,8 +118,6 @@ int main(int argc, char ** argv) {
{
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
llama_sampler_accept(smpl, new_token_id);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
LOG_TEE("\n");