BUG: generates gibberish/repeating tokens after a while

This commit is contained in:
Leon Ericsson 2023-12-15 14:14:17 +01:00
parent 0ec5fdb5ce
commit 1665ad8bf1
2 changed files with 33 additions and 34 deletions

View file

@ -75,10 +75,10 @@ struct gpt_params {
// // sampling parameters
struct llama_sampling_params sparams;
std::string model = "models/7B/ggml-model-q4_0.gguf"; // model path
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_draft = ""; // draft model for speculative decoding
std::string model_alias = "unknown"; // model alias
std::string prompt = "Hello my name is";
std::string prompt = "";
std::string prompt_file = ""; // store the external prompt file name
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
@ -228,4 +228,4 @@ void dump_non_result_info_yaml(
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
// Dump the KV cache view showing individual sequences in each cell (long output).
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);

View file

@ -122,16 +122,7 @@ int main(int argc, char ** argv){
draft.clear();
draft.push_back(id);
// drafts[0].i_batch_tgt.push_back(0);
// llama_batch_clear(batch_dft);
// llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
// llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
// // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
// llama_decode (ctx_dft, batch_dft);
// ++n_past_dft;
inp.push_back(id);
break;
}
@ -142,33 +133,41 @@ int main(int argc, char ** argv){
llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
bool match = false;
// generate n_pred tokens through prompt lookup
for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){
if (match){
break;
}
const auto & prev = ctx_sampling->prev;
int prev_size = prev.size();
const llama_token * ngram = &prev[prev_size - ngram_size];
for (int i = 0; i <= (int) prev_size - (ngram_size * 2); ++i) {
if (prev[i] == ngram[0] && prev[i + 1] == ngram[1] && prev[i + 2] == ngram[2]) {
const int startIdx = i + ngram_size;
const int endIdx = startIdx + n_draft;
if (endIdx < prev_size){
match = true;
for (int j = startIdx; j < endIdx; ++j) {
LOG(" - draft candidate %d: %d\n", j, prev[j]);
draft.push_back(prev[j]);
llama_batch_add(batch_tgt, prev[j], n_past + j + 1, { 1 }, true);
++n_drafted;
auto prompt_lookup = [&]() -> void {
int inp_size = inp.size();
for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){
const llama_token * ngram = &inp[inp_size - ngram_size];
for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
bool match = true;
for (int j = 0; j < ngram_size; ++j) {
if (inp[i + j] != ngram[j]) {
match = false;
break;
}
}
if (match) {
const int startIdx = i + ngram_size;
const int endIdx = startIdx + n_draft;
if (endIdx < inp_size){
for (int j = startIdx; j < endIdx; ++j) {
LOG(" - draft candidate %d: %d\n", j, inp[j]);
draft.push_back(inp[j]);
llama_batch_add(batch_tgt, inp[j], n_past + j + 1, { 0 }, true);
++n_drafted;
}
return;
}
}
}
}
}
return;
};
prompt_lookup();
llama_decode(ctx, batch_tgt);
++n_past;