cleanup sampling code
This commit is contained in:
parent
3c8f404243
commit
bd4fe936f5
4 changed files with 50 additions and 62 deletions
|
@ -14,7 +14,7 @@
|
|||
|
||||
#include "ggml.h"
|
||||
|
||||
#define CL_DMMV_BLOCK_SIZE 32;
|
||||
#define CL_DMMV_BLOCK_SIZE 64;
|
||||
|
||||
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
||||
static std::string program_source = MULTILINE_QUOTE(
|
||||
|
|
|
@ -857,18 +857,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
gpt_vocab::id id = 0;
|
||||
// predict
|
||||
unsigned int embdsize = embd.size();
|
||||
//print progress
|
||||
if (!startedsampling)
|
||||
{
|
||||
printf("\rProcessing Prompt%s (%d / %d tokens)", (blasmode ? " [BLAS]" : ""), input_consumed, embd_inp.size());
|
||||
}
|
||||
fflush(stdout);
|
||||
|
||||
if (embdsize > 0)
|
||||
{
|
||||
//print progress
|
||||
if (!startedsampling)
|
||||
{
|
||||
printf("\rProcessing Prompt%s (%d / %d tokens)", (blasmode ? " [BLAS]" : ""), input_consumed, embd_inp.size());
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\rGenerating (%d / %d tokens)", (1 + params.n_predict - remaining_tokens), params.n_predict);
|
||||
}
|
||||
fflush(stdout);
|
||||
|
||||
bool evalres = false;
|
||||
|
||||
|
@ -954,40 +951,35 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
printf("\n");
|
||||
}
|
||||
|
||||
unsigned int eosID = 0;
|
||||
float * logitsPtr;
|
||||
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3)
|
||||
{
|
||||
float * logits;
|
||||
if(file_format == FileFormat::GGJT_3)
|
||||
{
|
||||
logits = llama_get_logits(llama_ctx_v3);
|
||||
logitsPtr = llama_get_logits(llama_ctx_v3);
|
||||
}
|
||||
else
|
||||
{
|
||||
logits = llama_v2_get_logits(llama_ctx_v2);
|
||||
logitsPtr = llama_v2_get_logits(llama_ctx_v2);
|
||||
}
|
||||
|
||||
eosID = llama_token_eos();
|
||||
|
||||
if (!unbanTokens)
|
||||
{
|
||||
// set the logit of the eos token (2) to zero to avoid sampling it
|
||||
logits[llama_token_eos()] = 0;
|
||||
//set logits of opening square bracket to zero. (disabled as obsolete)
|
||||
// logits[518] = 0;
|
||||
// logits[29961] = 0;
|
||||
logitsPtr[eosID] = 0;
|
||||
}
|
||||
|
||||
|
||||
id = SampleLogits(logits, nctx, n_vocab, last_n_size, repeat_penalty,
|
||||
top_k, top_p, typical_p, tfs_z, temp, rng,
|
||||
params.mirostat,params.mirostat_tau,params.mirostat_eta);
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
logitsPtr = logits.data();
|
||||
if (!unbanTokens)
|
||||
{
|
||||
//gpt2 uses negative logits, so we cant zero it
|
||||
// set the logit of the eos token to minimum to avoid sampling it
|
||||
if ((file_format == FileFormat::GPT2_1 ||
|
||||
if (file_format == FileFormat::GPT2_1 ||
|
||||
file_format == FileFormat::GPT2_2 ||
|
||||
file_format == FileFormat::GPT2_3 ||
|
||||
file_format == FileFormat::GPT2_4 ||
|
||||
|
@ -995,11 +987,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
file_format == FileFormat::GPTJ_2 ||
|
||||
file_format == FileFormat::GPTJ_3 ||
|
||||
file_format == FileFormat::GPTJ_4 ||
|
||||
file_format == FileFormat::GPTJ_5) &&
|
||||
logits.size() > 50256)
|
||||
{
|
||||
int topid = std::min_element(logits.begin(),logits.end())-logits.begin();
|
||||
logits[50256] = (logits[topid] < 0 ? logits[topid] : 0);
|
||||
file_format == FileFormat::GPTJ_5)
|
||||
{
|
||||
eosID = 50256;
|
||||
if(logits.size() > eosID)
|
||||
{
|
||||
int topid = std::min_element(logits.begin(),logits.end())-logits.begin();
|
||||
logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0);
|
||||
}
|
||||
}
|
||||
|
||||
// set the logit of the eos token (0) to minimum to avoid sampling it
|
||||
|
@ -1011,16 +1006,18 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
file_format == FileFormat::NEOX_6 ||
|
||||
file_format == FileFormat::NEOX_7)
|
||||
{
|
||||
eosID = 0;
|
||||
int topid = std::min_element(logits.begin(),logits.end())-logits.begin();
|
||||
logits[0] = (logits[topid] < 0 ? logits[topid] : 0);
|
||||
logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0);
|
||||
}
|
||||
}
|
||||
|
||||
id = SampleLogits(logits.data(), nctx, n_vocab, last_n_size, repeat_penalty,
|
||||
top_k, top_p, typical_p, tfs_z, temp, rng,
|
||||
params.mirostat,params.mirostat_tau,params.mirostat_eta);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
|
||||
top_k, top_p, typical_p, tfs_z, temp, rng,
|
||||
params.mirostat,params.mirostat_tau,params.mirostat_eta);
|
||||
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(id);
|
||||
current_context_tokens.push_back(id);
|
||||
|
@ -1031,31 +1028,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
// decrement remaining sampling budget
|
||||
--remaining_tokens;
|
||||
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2|| file_format == FileFormat::GGJT_3)
|
||||
for (auto id : embd)
|
||||
{
|
||||
if(file_format == FileFormat::GGJT_3)
|
||||
{
|
||||
concat_output += llama_token_to_str(llama_ctx_v3, id);
|
||||
}
|
||||
else
|
||||
{
|
||||
concat_output += llama_v2_token_to_str(llama_ctx_v2, id);
|
||||
}
|
||||
|
||||
if(unbanTokens && id==llama_token_eos())
|
||||
{
|
||||
printf("\n(EOS token triggered!)");
|
||||
remaining_tokens = 0;
|
||||
}
|
||||
concat_output += FileFormatTokenizeID(id,file_format);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto id : embd)
|
||||
{
|
||||
concat_output += vocab.id_to_token[id].c_str();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (startedsampling)
|
||||
{
|
||||
printf("\rGenerating (%d / %d tokens)", (params.n_predict - remaining_tokens), params.n_predict);
|
||||
}
|
||||
if(debugmode && top_picks.size()>0)
|
||||
{
|
||||
printf(" [");
|
||||
|
@ -1074,6 +1055,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
printf("]\n");
|
||||
}
|
||||
|
||||
if(unbanTokens && id==eosID)
|
||||
{
|
||||
printf("\n(EOS token triggered!)");
|
||||
remaining_tokens = 0;
|
||||
}
|
||||
|
||||
for (const auto &matched : stop_sequence)
|
||||
{
|
||||
if (concat_output.find(matched) != std::string::npos)
|
||||
|
@ -1084,6 +1071,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
break;
|
||||
}
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
|
@ -208,7 +208,7 @@ maxctx = 2048
|
|||
maxlen = 128
|
||||
modelbusy = False
|
||||
defaultport = 5001
|
||||
KcppVersion = "1.25.1"
|
||||
KcppVersion = "1.26"
|
||||
|
||||
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
sys_version = ""
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
#include "ggml_v2.h"
|
||||
|
||||
#define CL_DMMV_BLOCK_SIZE 32;
|
||||
#define CL_DMMV_BLOCK_SIZE 64;
|
||||
|
||||
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
||||
static std::string program_source = MULTILINE_QUOTE(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue