integrated token probability viewer in debugmode
This commit is contained in:
parent
8b8f2f4cf5
commit
3c8f404243
1 changed files with 95 additions and 47 deletions
|
@ -65,6 +65,7 @@ static size_t mem_per_token = 0;
|
||||||
static std::vector<float> logits;
|
static std::vector<float> logits;
|
||||||
static std::vector<int> smartcontext;
|
static std::vector<int> smartcontext;
|
||||||
static std::vector<std::string> stop_sequence;
|
static std::vector<std::string> stop_sequence;
|
||||||
|
static std::vector<llama_token_data> top_picks;
|
||||||
|
|
||||||
inline bool IsNanCheck(float f)
|
inline bool IsNanCheck(float f)
|
||||||
{
|
{
|
||||||
|
@ -96,11 +97,26 @@ llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng
|
||||||
llama_sample_softmax(nullptr, candidates);
|
llama_sample_softmax(nullptr, candidates);
|
||||||
std::vector<float> probs;
|
std::vector<float> probs;
|
||||||
probs.reserve(candidates->size);
|
probs.reserve(candidates->size);
|
||||||
|
top_picks.clear();
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
probs.push_back(candidates->data[i].p);
|
probs.push_back(candidates->data[i].p);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||||
int idx = dist(rng);
|
int idx = dist(rng);
|
||||||
|
|
||||||
|
if(debugmode)
|
||||||
|
{
|
||||||
|
top_picks.push_back(candidates->data[idx]);
|
||||||
|
for (size_t i = 0; (i < candidates->size && i<4); ++i)
|
||||||
|
{
|
||||||
|
if(i!=idx)
|
||||||
|
{
|
||||||
|
top_picks.push_back(candidates->data[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
llama_token result = candidates->data[idx].id;
|
llama_token result = candidates->data[idx].id;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -216,6 +232,22 @@ int mirostat, float mirostat_tau, float mirostat_eta)
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string FileFormatTokenizeID(int id, FileFormat file_format)
|
||||||
|
{
|
||||||
|
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
|
||||||
|
{
|
||||||
|
return std::string(llama_v2_token_to_str(llama_ctx_v2, id));
|
||||||
|
}
|
||||||
|
else if (file_format == FileFormat::GGJT_3)
|
||||||
|
{
|
||||||
|
return std::string(llama_token_to_str(llama_ctx_v3, id));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return vocab.id_to_token[id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format)
|
ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format)
|
||||||
{
|
{
|
||||||
ggml_time_init();
|
ggml_time_init();
|
||||||
|
@ -593,7 +625,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output)
|
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output)
|
||||||
{
|
{
|
||||||
stop_sequence.clear();
|
stop_sequence.clear();
|
||||||
|
@ -628,7 +659,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
}
|
}
|
||||||
if (params.top_k < 1)
|
if (params.top_k < 1)
|
||||||
{
|
{
|
||||||
params.top_k = 300; //to disable top_k we actually need to increase this value to a very high number
|
params.top_k = 200; //to disable top_k we actually need to increase this value to a very high number
|
||||||
}
|
}
|
||||||
if (params.seed <= 0)
|
if (params.seed <= 0)
|
||||||
{
|
{
|
||||||
|
@ -795,49 +826,32 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
if(debugmode)
|
if (debugmode)
|
||||||
{
|
{
|
||||||
printf("\n[Debug: Dump Input Tokens, format: %d]\n",file_format);
|
std::string outstr = "";
|
||||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
|
printf("\n[Debug: Dump Input Tokens, format: %d]\n", file_format);
|
||||||
{
|
|
||||||
for (auto id : embd_inp)
|
|
||||||
{
|
|
||||||
printf("'%s (%d)', ",llama_v2_token_to_str(llama_ctx_v2, id),id);
|
|
||||||
}
|
|
||||||
|
|
||||||
printf("\n\n[Debug: Context Size = %d]\n",current_context_tokens.size());
|
std::string tmp = "";
|
||||||
for (auto id : current_context_tokens)
|
for (auto id : embd_inp)
|
||||||
{
|
|
||||||
printf("'%s (%d)', ",llama_v2_token_to_str(llama_ctx_v2, id),id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (file_format == FileFormat::GGJT_3)
|
|
||||||
{
|
{
|
||||||
for (auto id : embd_inp)
|
tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', ";
|
||||||
{
|
|
||||||
printf("'%s (%d)', ",llama_token_to_str(llama_ctx_v3, id),id);
|
|
||||||
}
|
|
||||||
printf("\n\n[Debug: Context Size = %d]\n",current_context_tokens.size());
|
|
||||||
for (auto id : current_context_tokens)
|
|
||||||
{
|
|
||||||
printf("'%s (%d)', ",llama_token_to_str(llama_ctx_v3, id),id);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else
|
::utreplace(tmp, "\n", "\\n");
|
||||||
|
outstr += tmp;
|
||||||
|
|
||||||
|
outstr += "\n\n[Debug: Context Size = " + std::to_string(current_context_tokens.size()) + "]\n";
|
||||||
|
tmp = "";
|
||||||
|
for (auto id : current_context_tokens)
|
||||||
{
|
{
|
||||||
for (auto id : embd_inp)
|
tmp += "'" + FileFormatTokenizeID(id, file_format) + " (" + std::to_string(id) + ")', ";
|
||||||
{
|
|
||||||
printf("'%s (%d)', ",vocab.id_to_token[id].c_str(),id);
|
|
||||||
}
|
|
||||||
printf("\n\n[Debug: Context Size = %d]\n",current_context_tokens.size());
|
|
||||||
for (auto id : current_context_tokens)
|
|
||||||
{
|
|
||||||
printf("'%s (%d)', ",vocab.id_to_token[id].c_str(),id);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
::utreplace(tmp, "\n", "\\n");
|
||||||
|
outstr += tmp;
|
||||||
|
printf(outstr.c_str());
|
||||||
|
|
||||||
printf("\n\n");
|
printf("\n\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
while (remaining_tokens > 0)
|
while (remaining_tokens > 0)
|
||||||
{
|
{
|
||||||
gpt_vocab::id id = 0;
|
gpt_vocab::id id = 0;
|
||||||
|
@ -851,7 +865,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
printf("\rProcessing Prompt%s (%d / %d tokens)", (blasmode ? " [BLAS]" : ""), input_consumed, embd_inp.size());
|
printf("\rProcessing Prompt%s (%d / %d tokens)", (blasmode ? " [BLAS]" : ""), input_consumed, embd_inp.size());
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
printf("\rGenerating (%d / %d tokens)", (1 + params.n_predict - remaining_tokens), params.n_predict);
|
printf("\rGenerating (%d / %d tokens)", (1 + params.n_predict - remaining_tokens), params.n_predict);
|
||||||
}
|
}
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
@ -956,11 +970,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
{
|
{
|
||||||
// set the logit of the eos token (2) to zero to avoid sampling it
|
// set the logit of the eos token (2) to zero to avoid sampling it
|
||||||
logits[llama_token_eos()] = 0;
|
logits[llama_token_eos()] = 0;
|
||||||
//set logits of opening square bracket to zero.
|
//set logits of opening square bracket to zero. (disabled as obsolete)
|
||||||
logits[518] = 0;
|
// logits[518] = 0;
|
||||||
logits[29961] = 0;
|
// logits[29961] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
id = SampleLogits(logits, nctx, n_vocab, last_n_size, repeat_penalty,
|
id = SampleLogits(logits, nctx, n_vocab, last_n_size, repeat_penalty,
|
||||||
top_k, top_p, typical_p, tfs_z, temp, rng,
|
top_k, top_p, typical_p, tfs_z, temp, rng,
|
||||||
params.mirostat,params.mirostat_tau,params.mirostat_eta);
|
params.mirostat,params.mirostat_tau,params.mirostat_eta);
|
||||||
|
@ -970,7 +985,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
{
|
{
|
||||||
if (!unbanTokens)
|
if (!unbanTokens)
|
||||||
{
|
{
|
||||||
// set the logit of the eos token (2) to zero to avoid sampling it
|
//gpt2 uses negative logits, so we cant zero it
|
||||||
|
// set the logit of the eos token to minimum to avoid sampling it
|
||||||
if ((file_format == FileFormat::GPT2_1 ||
|
if ((file_format == FileFormat::GPT2_1 ||
|
||||||
file_format == FileFormat::GPT2_2 ||
|
file_format == FileFormat::GPT2_2 ||
|
||||||
file_format == FileFormat::GPT2_3 ||
|
file_format == FileFormat::GPT2_3 ||
|
||||||
|
@ -981,11 +997,24 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
file_format == FileFormat::GPTJ_4 ||
|
file_format == FileFormat::GPTJ_4 ||
|
||||||
file_format == FileFormat::GPTJ_5) &&
|
file_format == FileFormat::GPTJ_5) &&
|
||||||
logits.size() > 50256)
|
logits.size() > 50256)
|
||||||
{
|
{
|
||||||
logits[50256] = (logits[50256] < 0 ? logits[50256] : 0);
|
int topid = std::min_element(logits.begin(),logits.end())-logits.begin();
|
||||||
|
logits[50256] = (logits[topid] < 0 ? logits[topid] : 0);
|
||||||
}
|
}
|
||||||
//gpt2 uses negative logits, so we cant zero it
|
|
||||||
}
|
// set the logit of the eos token (0) to minimum to avoid sampling it
|
||||||
|
if (file_format == FileFormat::NEOX_1 ||
|
||||||
|
file_format == FileFormat::NEOX_2 ||
|
||||||
|
file_format == FileFormat::NEOX_3 ||
|
||||||
|
file_format == FileFormat::NEOX_4 ||
|
||||||
|
file_format == FileFormat::NEOX_5 ||
|
||||||
|
file_format == FileFormat::NEOX_6 ||
|
||||||
|
file_format == FileFormat::NEOX_7)
|
||||||
|
{
|
||||||
|
int topid = std::min_element(logits.begin(),logits.end())-logits.begin();
|
||||||
|
logits[0] = (logits[topid] < 0 ? logits[topid] : 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
id = SampleLogits(logits.data(), nctx, n_vocab, last_n_size, repeat_penalty,
|
id = SampleLogits(logits.data(), nctx, n_vocab, last_n_size, repeat_penalty,
|
||||||
top_k, top_p, typical_p, tfs_z, temp, rng,
|
top_k, top_p, typical_p, tfs_z, temp, rng,
|
||||||
|
@ -1026,6 +1055,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
concat_output += vocab.id_to_token[id].c_str();
|
concat_output += vocab.id_to_token[id].c_str();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(debugmode && top_picks.size()>0)
|
||||||
|
{
|
||||||
|
printf(" [");
|
||||||
|
bool firstloop = true;
|
||||||
|
for (auto & pick : top_picks)
|
||||||
|
{
|
||||||
|
if (!firstloop)
|
||||||
|
{
|
||||||
|
printf(" ");
|
||||||
|
}
|
||||||
|
firstloop = false;
|
||||||
|
std::string tokenizedstr = FileFormatTokenizeID(pick.id, file_format);
|
||||||
|
::utreplace(tokenizedstr, "\n", "\\n");
|
||||||
|
printf("(%s %.2f%%)", tokenizedstr.c_str(), pick.p*100);
|
||||||
|
}
|
||||||
|
printf("]\n");
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto &matched : stop_sequence)
|
for (const auto &matched : stop_sequence)
|
||||||
{
|
{
|
||||||
if (concat_output.find(matched) != std::string::npos)
|
if (concat_output.find(matched) != std::string::npos)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue