From 2d7baaf50f3277e65cf71071f61ea34823d14c30 Mon Sep 17 00:00:00 2001 From: AustinMroz Date: Tue, 8 Aug 2023 06:44:48 -0500 Subject: [PATCH 1/5] vim : streaming and more (#2495) * Update Vim plugin * Remove getbufoneline usage, Add input bind example. getbufoneline() appears to be a recently added function and has been replaced with getbufline for compatibility. An additional example that explains how to add a keybind that works in insert mode was added. --- examples/llama.vim | 132 +++++++++++++++++++++++++++++++++++++++++++++ examples/llm.vim | 23 -------- 2 files changed, 132 insertions(+), 23 deletions(-) create mode 100644 examples/llama.vim delete mode 100644 examples/llm.vim diff --git a/examples/llama.vim b/examples/llama.vim new file mode 100644 index 000000000..f03fadfb7 --- /dev/null +++ b/examples/llama.vim @@ -0,0 +1,132 @@ +" Requires an already running llama.cpp server +" To install either copy or symlink to ~/.vim/autoload/llama.vim +" Then start with either :call llama#doLlamaGen(), +" or add a keybind to your vimrc such as +" nnoremap Z :call llama#doLlamaGen() +" Similarly, you could add an insert mode keybind with +" inoremap call llama#doLlamaGen() +" +" g:llama_api_url and g:llama_overrides can be configured in your .vimrc +" let g:llama_api_url = "192.168.1.10:8080" +" llama_overrides can also be set through buffer/window scopes. For instance +" autocmd filetype python let b:llama_overrides = {"temp": 0.2} +" Could be added to your .vimrc to automatically set a lower temperature when +" editing a python script +" Additionally, an override dict can be stored at the top of a file +" !*{"stop": ["User:"]} +" Could be added to the start of your chatlog.txt to set the stopping token +" These parameter dicts are merged together from lowest to highest priority: +" server default -> g:llama_overrides -> w:llama_overrides -> +" b:llama_overrides -> in file (!*) overrides +" +" Sublists (like logit_bias and stop) are overridden, not merged +" Example override: +" !*{"logit_bias": [[13, -5], [2, false]], "temperature": 1, "top_k": 5, "top_p": 0.5, "n_predict": 256, "repeat_last_n": 256, "repeat_penalty": 1.17647} +if !exists("g:llama_api_url") + let g:llama_api_url= "127.0.0.1:8080" +endif +if !exists("g:llama_overrides") + let g:llama_overrides = {} +endif +const s:querydata = {"n_predict": 256, "stop": [ "\n" ], "stream": v:true } +const s:curlcommand = ['curl','--data-raw', "{\"prompt\":\"### System:\"}", '--silent', '--no-buffer', '--request', 'POST', '--url', g:llama_api_url .. '/completion', '--header', "Content-Type: application/json"] +let s:linedict = {} + +func s:callbackHandler(bufn, channel, msg) + if len(a:msg) < 3 + return + elseif a:msg[0] == "d" + let l:msg = a:msg[6:-1] + else + let l:msg = a:msg + endif + let l:decoded_msg = json_decode(l:msg) + let l:newtext = split(l:decoded_msg['content'], "\n", 1) + if len(l:newtext) > 0 + call setbufline(a:bufn, s:linedict[a:bufn], getbufline(a:bufn, s:linedict[a:bufn])[0] .. newtext[0]) + else + echo "nothing genned" + endif + if len(newtext) > 1 + let l:failed = appendbufline(a:bufn, s:linedict[a:bufn], newtext[1:-1]) + let s:linedict[a:bufn] = s:linedict[a:bufn] + len(newtext)-1 + endif + if has_key(l:decoded_msg, "stop") && l:decoded_msg.stop + echo "Finished generation" + endif +endfunction + +func llama#doLlamaGen() + if exists("b:job") + if job_status(b:job) == "run" + call job_stop(b:job) + return + endif + endif + + let l:cbuffer = bufnr("%") + let s:linedict[l:cbuffer] = line('$') + let l:buflines = getbufline(l:cbuffer, 1, 1000) + let l:querydata = copy(s:querydata) + call extend(l:querydata, g:llama_overrides) + if exists("w:llama_overrides") + call extend(l:querydata, w:llama_overrides) + endif + if exists("b:llama_overrides") + call extend(l:querydata, b:llama_overrides) + endif + if l:buflines[0][0:1] == '!*' + let l:userdata = json_decode(l:buflines[0][2:-1]) + call extend(l:querydata, l:userdata) + let l:buflines = l:buflines[1:-1] + endif + let l:querydata.prompt = join(l:buflines, "\n") + let l:curlcommand = copy(s:curlcommand) + let l:curlcommand[2] = json_encode(l:querydata) + let b:job = job_start(l:curlcommand, {"callback": function("s:callbackHandler", [l:cbuffer])}) +endfunction + +" Echos the tokkenization of the provided string , or cursor to end of word +" Onus is placed on the user to include the preceding space +func llama#tokenizeWord(...) + if (a:0 > 0) + let l:input = a:1 + else + exe "normal \"*ye" + let l:input = @* + endif + let l:querydata = {"content": l:input} + let l:curlcommand = copy(s:curlcommand) + let l:curlcommand[2] = json_encode(l:querydata) + let l:curlcommand[8] = g:llama_api_url .. "/tokenize" + let s:token_job = job_start(l:curlcommand, {"callback": function("s:tokenizeWordCallback", [l:input])}) +endfunction + +func s:tokenizeWordCallback(plaintext, channel, msg) + echo '"' .. a:plaintext ..'" - ' .. string(json_decode(a:msg).tokens) +endfunction + + +" Echos the token count of the entire buffer (or provided string) +" Example usage :echo llama#tokenCount() +func llama#tokenCount(...) + if (a:0 > 0) + let l:buflines = a:1 + else + let l:buflines = getline(1,1000) + if l:buflines[0][0:1] == '!*' + let l:buflines = l:buflines[1:-1] + endif + let l:buflines = join(l:buflines, "\n") + endif + let l:querydata = {"content": l:buflines} + let l:curlcommand = copy(s:curlcommand) + let l:curlcommand[2] = json_encode(l:querydata) + let l:curlcommand[8] = g:llama_api_url .. "/tokenize" + let s:token_job = job_start(l:curlcommand, {"callback": "s:tokenCountCallback"}) +endfunction + +func s:tokenCountCallback(channel, msg) + let resp = json_decode(a:msg) + echo len(resp.tokens) +endfunction diff --git a/examples/llm.vim b/examples/llm.vim deleted file mode 100644 index efecad0cd..000000000 --- a/examples/llm.vim +++ /dev/null @@ -1,23 +0,0 @@ -function! Llm() - - let url = "http://127.0.0.1:8080/completion" - - " Get the content of the current buffer - let buffer_content = join(getline(1, '$'), "\n") - - " Create the JSON payload - let json_payload = {"temp":0.72,"top_k":100,"top_p":0.73,"repeat_penalty":1.100000023841858,"n_predict":10,"stream": v:false} - let json_payload.prompt = buffer_content - - " Define the curl command - let curl_command = 'curl -k -s -X POST -H "Content-Type: application/json" -d @- ' . url - let response = system(curl_command, json_encode(json_payload)) - - " Extract the content field from the response - let content = json_decode(response).content - - " Insert the content at the cursor position - call setline(line('.'), getline('.') . content) -endfunction - -command! Llm call Llm() From e7f94d6fdc83b41ba449b4b8c80821673dd12ffc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 8 Aug 2023 15:05:30 +0300 Subject: [PATCH 2/5] vim : bring back simple llm.vim example --- examples/llm.vim | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 examples/llm.vim diff --git a/examples/llm.vim b/examples/llm.vim new file mode 100644 index 000000000..473e0077a --- /dev/null +++ b/examples/llm.vim @@ -0,0 +1,25 @@ +" Basic plugin example + +function! Llm() + + let url = "http://127.0.0.1:8080/completion" + + " Get the content of the current buffer + let buffer_content = join(getline(1, '$'), "\n") + + " Create the JSON payload + let json_payload = {"temp":0.72,"top_k":100,"top_p":0.73,"repeat_penalty":1.100000023841858,"n_predict":10,"stream": v:false} + let json_payload.prompt = buffer_content + + " Define the curl command + let curl_command = 'curl -k -s -X POST -H "Content-Type: application/json" -d @- ' . url + let response = system(curl_command, json_encode(json_payload)) + + " Extract the content field from the response + let content = json_decode(response).content + + " Insert the content at the cursor position + call setline(line('.'), getline('.') . content) +endfunction + +command! Llm call Llm() From 7ed8d1fe7f8cbe6a6763e6b46759795ac8d21e12 Mon Sep 17 00:00:00 2001 From: chaihahaha Date: Tue, 8 Aug 2023 20:07:02 +0800 Subject: [PATCH 3/5] llm.vim : multiline autocompletion, get rid of "^@" (#2543) --- examples/llm.vim | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/llm.vim b/examples/llm.vim index 473e0077a..594a28549 100644 --- a/examples/llm.vim +++ b/examples/llm.vim @@ -18,8 +18,10 @@ function! Llm() " Extract the content field from the response let content = json_decode(response).content + let split_newlines = split(content, '\n', 1) + " Insert the content at the cursor position - call setline(line('.'), getline('.') . content) + call setline(line('.'), [ getline('.') . split_newlines[0] ] + split_newlines[1:]) endfunction command! Llm call Llm() From acfc5478ff3446ca3b54553967a3dea09b7c771a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 8 Aug 2023 14:38:16 +0200 Subject: [PATCH 4/5] CUDA: tighter VRAM scratch size for 65b/70b (#2551) --- llama.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/llama.cpp b/llama.cpp index 39aefd499..71061aab9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -149,7 +149,7 @@ static const std::map & MEM_REQ_EVAL() } // amount of VRAM needed per batch size to hold temporary results -// the values for 3b and 65b are not derived from testing but instead chosen conservatively +// the values for 3b are not derived from testing but instead chosen conservatively static const std::map & VRAM_REQ_SCRATCH_BASE() { static std::map k_sizes = { @@ -157,14 +157,14 @@ static const std::map & VRAM_REQ_SCRATCH_BASE() { MODEL_7B, 512ull * kB }, { MODEL_13B, 640ull * kB }, { MODEL_30B, 768ull * kB }, - { MODEL_65B, 1536ull * kB }, - { MODEL_70B, 1536ull * kB }, // TODO (likely can be reduced) + { MODEL_65B, 1280ull * kB }, + { MODEL_70B, 1280ull * kB }, }; return k_sizes; } // amount of VRAM needed per batch size and context to hold temporary results -// the values for 3b and 65b are not derived from testing but instead chosen conservatively +// the values for 3b are not derived from testing but instead chosen conservatively static const std::map & VRAM_REQ_SCRATCH_PER_CONTEXT() { static std::map k_sizes = { @@ -172,8 +172,8 @@ static const std::map & VRAM_REQ_SCRATCH_PER_CONTEXT() { MODEL_7B, 128ull }, { MODEL_13B, 160ull }, { MODEL_30B, 208ull }, - { MODEL_65B, 416ull }, - { MODEL_70B, 416ull }, // TODO (likely can be reduced) + { MODEL_65B, 256ull }, + { MODEL_70B, 256ull }, }; return k_sizes; } From f5bfea0580e417f99850d5456ca541d871a3e48c Mon Sep 17 00:00:00 2001 From: Martin Krasser Date: Tue, 8 Aug 2023 15:29:19 +0200 Subject: [PATCH 5/5] Allow passing grammar to completion endpoint (#2532) * Allow passing grammar to completion endpoint --- Makefile | 2 +- examples/server/README.md | 2 ++ examples/server/server.cpp | 60 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 61 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 897c5cb9a..32598edfe 100644 --- a/Makefile +++ b/Makefile @@ -380,7 +380,7 @@ embedding: examples/embedding/embedding.cpp build-info.h ggml. save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o $(OBJS) +server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) $(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS) diff --git a/examples/server/README.md b/examples/server/README.md index aee31ae42..e56ca063a 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -151,6 +151,8 @@ node . `mirostat_eta`: Set the Mirostat learning rate, parameter eta (default: 0.1). + `grammar`: Set grammar for grammar-based sampling (default: no grammar) + `seed`: Set the random number generator (RNG) seed (default: -1, -1 = random seed). `ignore_eos`: Ignore end of stream token and continue generating (default: false). diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6f7a66da1..10ae264f5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,6 +1,7 @@ #include "common.h" #include "llama.h" #include "build-info.h" +#include "grammar-parser.h" #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error @@ -195,6 +196,8 @@ struct llama_server_context llama_context *ctx = nullptr; gpt_params params; + llama_grammar *grammar = nullptr; + bool truncated = false; bool stopped_eos = false; bool stopped_word = false; @@ -226,6 +229,7 @@ struct llama_server_context void rewind() { params.antiprompt.clear(); + params.grammar.clear(); num_prompt_tokens = 0; num_tokens_predicted = 0; generated_text = ""; @@ -237,6 +241,7 @@ struct llama_server_context stopped_limit = false; stopping_word = ""; multibyte_pending = 0; + grammar = nullptr; n_remain = 0; n_past = 0; @@ -257,6 +262,33 @@ struct llama_server_context return true; } + bool loadGrammar() + { + if (!params.grammar.empty()) { + grammar_parser::parse_state parsed_grammar; + + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + LOG_ERROR("grammar parse error", {{"grammar", params.grammar}}); + return false; + } + grammar_parser::print_grammar(stderr, parsed_grammar); + + { + auto it = params.logit_bias.find(llama_token_eos()); + if (it != params.logit_bias.end() && it->second == -INFINITY) { + LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); + } + } + + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + return true; + } + void loadPrompt() { params.prompt.insert(0, 1, ' '); // always add a first space @@ -420,6 +452,10 @@ struct llama_server_context logits[llama_token_nl()] = nl_logit; } + if (grammar != nullptr) { + llama_sample_grammar(ctx, &candidates_p, grammar); + } + if (temp <= 0) { // Greedy sampling @@ -457,10 +493,15 @@ struct llama_server_context } } + if (grammar != nullptr) { + llama_grammar_accept_token(ctx, grammar, result.tok); + } + for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) { result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); } + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(result.tok); num_tokens_predicted++; @@ -947,6 +988,7 @@ static json format_generation_settings(llama_server_context &llama) {"stream", llama.stream}, {"logit_bias", llama.params.logit_bias}, {"n_probs", llama.params.n_probs}, + {"grammar", llama.params.grammar}, }; } @@ -1048,6 +1090,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla llama.params.n_keep = body.value("n_keep", default_params.n_keep); llama.params.seed = body.value("seed", default_params.seed); llama.params.prompt = body.value("prompt", default_params.prompt); + llama.params.grammar = body.value("grammar", default_params.grammar); llama.params.n_probs = body.value("n_probs", default_params.n_probs); llama.params.logit_bias.clear(); @@ -1179,6 +1222,12 @@ int main(int argc, char **argv) parse_options_completion(json::parse(req.body), llama); + if (!llama.loadGrammar()) + { + res.status = 400; + return; + } + llama.loadPrompt(); llama.beginCompletion(); @@ -1334,8 +1383,12 @@ int main(int argc, char **argv) svr.set_error_handler([](const Request &, Response &res) { - res.set_content("File Not Found", "text/plain"); - res.status = 404; }); + if (res.status == 400) { + res.set_content("Invalid request", "text/plain"); + } else { + res.set_content("File Not Found", "text/plain"); + res.status = 404; + } }); // set timeouts and change hostname and port svr.set_read_timeout(sparams.read_timeout); @@ -1363,6 +1416,9 @@ int main(int argc, char **argv) return 1; } + if (llama.grammar != nullptr) { + llama_grammar_free(llama.grammar); + } llama_backend_free(); return 0;