sampling : hide prev behind API and apply #3661

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-10-20 18:26:20 +03:00
parent 7e2b5fb1dd
commit 56ba00b923
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
9 changed files with 119 additions and 105 deletions

View file

@ -4,5 +4,5 @@ install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()

View file

@ -523,7 +523,7 @@ int main(int argc, char ** argv) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
llama_sampling_accept(ctx_sampling, ctx, id);
llama_sampling_accept(ctx_sampling, ctx, id, true);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
@ -541,8 +541,11 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
@ -574,7 +577,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode
if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
@ -591,7 +594,7 @@ int main(int argc, char ** argv) {
buffer += line;
} while (another_line);
// check if we got an empty line, if so we use the old input
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
params.input_prefix = buffer;
}
buffer.clear();
@ -601,7 +604,7 @@ int main(int argc, char ** argv) {
buffer += line;
} while (another_line);
// check if we got an empty line
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
params.input_suffix = buffer;
}
buffer.clear();
@ -614,7 +617,7 @@ int main(int argc, char ** argv) {
process_escapes(params.input_suffix);
}
suff_rm_leading_spc = params.escape;
if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
@ -641,7 +644,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of text token in interactive mode
else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
LOG("found EOS token\n");
if (params.interactive) {