llama : minor sampling refactor (2) (#9386)
This commit is contained in:
parent
38ca6f644b
commit
5fb5e24811
12 changed files with 115 additions and 113 deletions
|
@ -140,8 +140,6 @@ while n_cur <= n_len {
|
|||
|
||||
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
|
||||
|
||||
llama_sampler_accept(smpl, new_token_id)
|
||||
|
||||
// is it an end of stream? -> mark the stream as finished
|
||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||
i_batch[i] = -1
|
||||
|
|
|
@ -172,8 +172,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
|
||||
|
||||
llama_sampler_accept(smpl, new_token_id);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||
i_batch[i] = -1;
|
||||
|
|
|
@ -121,7 +121,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
|
|||
llama_decode(ctx, bat);
|
||||
|
||||
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
|
||||
llama_sampler_accept(smpl, token);
|
||||
|
||||
if (token == eos_token) {
|
||||
break;
|
||||
|
|
|
@ -414,8 +414,6 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||
// sample the most likely token
|
||||
const auto new_token_id = llama_sampler_sample(sampler, context, -1);
|
||||
|
||||
llama_sampler_accept(sampler, new_token_id);
|
||||
|
||||
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||
return nullptr;
|
||||
|
|
|
@ -152,8 +152,6 @@ actor LlamaContext {
|
|||
|
||||
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
|
||||
|
||||
llama_sampler_accept(sampling, new_token_id)
|
||||
|
||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||
print("\n")
|
||||
is_done = true
|
||||
|
|
|
@ -220,8 +220,6 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
|
||||
|
||||
llama_sampler_accept(smpl, new_token_id);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||
LOG_TEE("\n");
|
||||
|
|
|
@ -74,8 +74,6 @@ int main(int argc, char ** argv) {
|
|||
auto next_token = llama_sampler_sample(smpl, ctx, -1);
|
||||
auto next_token_str = llama_token_to_piece(ctx, next_token);
|
||||
|
||||
llama_sampler_accept(smpl, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result0 += next_token_str;
|
||||
|
||||
|
@ -132,8 +130,6 @@ int main(int argc, char ** argv) {
|
|||
auto next_token = llama_sampler_sample(smpl2, ctx2, -1);
|
||||
auto next_token_str = llama_token_to_piece(ctx2, next_token);
|
||||
|
||||
llama_sampler_accept(smpl2, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result1 += next_token_str;
|
||||
|
||||
|
@ -222,8 +218,6 @@ int main(int argc, char ** argv) {
|
|||
auto next_token = llama_sampler_sample(smpl3, ctx3, -1);
|
||||
auto next_token_str = llama_token_to_piece(ctx3, next_token);
|
||||
|
||||
llama_sampler_accept(smpl3, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result2 += next_token_str;
|
||||
|
||||
|
|
|
@ -613,7 +613,7 @@ struct server_context {
|
|||
|
||||
gpt_params params;
|
||||
|
||||
llama_batch batch;
|
||||
llama_batch batch = {};
|
||||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
|
|
|
@ -118,8 +118,6 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
|
||||
|
||||
llama_sampler_accept(smpl, new_token_id);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||
LOG_TEE("\n");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue