From 69f2fafebcfffb123f67ec4233b0aa6aa85453e6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Sep 2023 15:25:53 +0300 Subject: [PATCH 1/9] speculative : add grammar support --- examples/speculative/speculative.cpp | 81 +++++++++++++++++++++++++++- llama.cpp | 19 +++++++ llama.h | 2 + 3 files changed, 100 insertions(+), 2 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index f0400c13f..594d4f5d6 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "llama.h" +#include "grammar-parser.h" #include #include @@ -109,6 +110,41 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; + // grammar stuff + struct llama_grammar * grammar_dft = NULL; + struct llama_grammar * grammar_tgt = NULL; + + grammar_parser::parse_state parsed_grammar_dft; + grammar_parser::parse_state parsed_grammar_tgt; + + std::vector grammar_mem(n_draft, NULL); + + if (!params.grammar.empty()) { + // dft + { + parsed_grammar_dft = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar_dft.rules.empty()) { + return 1; + } + + std::vector grammar_rules(parsed_grammar_dft.c_rules()); + grammar_dft = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar_dft.symbol_ids.at("root")); + } + + // tgt + { + parsed_grammar_tgt = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar_tgt.rules.empty()) { + return 1; + } + + std::vector grammar_rules(parsed_grammar_tgt.c_rules()); + grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar_tgt.symbol_ids.at("root")); + } + } + const auto t_dec_start = ggml_time_us(); while (true) { @@ -117,7 +153,7 @@ int main(int argc, char ** argv) { // sample from the drafted tokens if any int i_dft = 0; while (true) { - const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft); + const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); @@ -144,6 +180,24 @@ int main(int argc, char ** argv) { continue; } + if (i_dft < (int) drafted.size()) { + LOG("drafted token %d rejected\n", id); + + if (grammar_mem[i_dft]) { + grammar_dft = llama_grammar_copy(grammar_mem[i_dft]); + LOG("restored grammar %d\n", i_dft); + } + } + + for (auto & g : grammar_mem) { + if (g) { + llama_grammar_free(g); + g = NULL; + } + } + + LOG("i_dft = %d, drafted.size() = %d\n", i_dft, (int) drafted.size()); + // the drafted token was rejected or we are out of drafted tokens llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); ++n_past_dft; @@ -151,6 +205,10 @@ int main(int argc, char ** argv) { drafted.clear(); drafted.push_back(id); + if (grammar_dft != NULL) { + llama_grammar_accept_token(ctx_dft, grammar_dft, id); + } + break; } @@ -161,6 +219,11 @@ int main(int argc, char ** argv) { // sample n_draft tokens from the draft model picking the best token int n_past_cur = n_past_dft; for (int i = 0; i < n_draft; ++i) { + // remember the grammar state + if (grammar_dft != NULL) { + grammar_mem[i] = llama_grammar_copy(grammar_dft); + } + float * logits = llama_get_logits(ctx_dft); candidates.clear(); @@ -170,6 +233,10 @@ int main(int argc, char ** argv) { llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; + if (grammar_dft != NULL) { + llama_sample_grammar(ctx_dft, &cur_p, grammar_dft); + } + // computes softmax and sorts the candidates llama_sample_softmax(ctx_dft, &cur_p); @@ -182,7 +249,13 @@ int main(int argc, char ** argv) { break; } - drafted.push_back(cur_p.data[0].id); + const llama_token id = cur_p.data[0].id; + + if (grammar_dft != NULL) { + llama_grammar_accept_token(ctx_dft, grammar_dft, id); + } + + drafted.push_back(id); ++n_drafted; if (i < n_draft - 1) { @@ -226,6 +299,10 @@ int main(int argc, char ** argv) { llama_free(ctx_dft); llama_free_model(model_dft); + if (grammar_dft != NULL) { + llama_grammar_free(grammar_dft); + llama_grammar_free(grammar_tgt); + } llama_backend_free(); fprintf(stderr, "\n\n"); diff --git a/llama.cpp b/llama.cpp index c97c1462f..cbf255115 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3850,6 +3850,25 @@ void llama_grammar_free(struct llama_grammar * grammar) { delete grammar; } +struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { + llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; + + // redirect elements in stacks to point to new rules + for (size_t is = 0; is < result->stacks.size(); is++) { + for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { + for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { + if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { + result->stacks[is][ie] = &result->rules[ir0][ir1]; + } + } + } + } + } + + return result; +} + // // sampling // diff --git a/llama.h b/llama.h index 422f28527..5b95aaa87 100644 --- a/llama.h +++ b/llama.h @@ -410,6 +410,8 @@ extern "C" { LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); + // // Sampling functions // From e0a8658e7c1b27c761c6b46e9aaf4449a93634f4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Sep 2023 17:52:49 +0300 Subject: [PATCH 2/9] grammars : add json_arr.gbnf --- grammars/json_arr.gbnf | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 grammars/json_arr.gbnf diff --git a/grammars/json_arr.gbnf b/grammars/json_arr.gbnf new file mode 100644 index 000000000..0106c2384 --- /dev/null +++ b/grammars/json_arr.gbnf @@ -0,0 +1,31 @@ +root ::= arr +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +arr ::= + "[\n" ws ( + value + (",\n" ws value)* + )? "]" + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? From 2d89da4f774906eb9204924e6d354d7bde9f3f37 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Sep 2023 18:47:38 +0300 Subject: [PATCH 3/9] grammar : add comments to new grammar file --- grammars/json_arr.gbnf | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/grammars/json_arr.gbnf b/grammars/json_arr.gbnf index 0106c2384..653119d88 100644 --- a/grammars/json_arr.gbnf +++ b/grammars/json_arr.gbnf @@ -1,6 +1,15 @@ +# This is the same as json.gbnf but we restrict whitespaces at the end of the root array +# Useful for generating JSON arrays + root ::= arr value ::= object | array | string | number | ("true" | "false" | "null") ws +arr ::= + "{\n\t[\n" ws ( + value + (",\n" ws value)* + )? "\t]\n}" + object ::= "{" ws ( string ":" ws value @@ -13,12 +22,6 @@ array ::= ("," ws value)* )? "]" ws -arr ::= - "[\n" ws ( - value - (",\n" ws value)* - )? "]" - string ::= "\"" ( [^"\\] | From 013457885a23a5006bbb47be1ce4e371a12df3de Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Sep 2023 19:10:43 +0300 Subject: [PATCH 4/9] grammar : remove one nested level --- grammars/json_arr.gbnf | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grammars/json_arr.gbnf b/grammars/json_arr.gbnf index 653119d88..ef53e77a0 100644 --- a/grammars/json_arr.gbnf +++ b/grammars/json_arr.gbnf @@ -5,10 +5,10 @@ root ::= arr value ::= object | array | string | number | ("true" | "false" | "null") ws arr ::= - "{\n\t[\n" ws ( + "[\n" ws ( value (",\n" ws value)* - )? "\t]\n}" + )? "]" object ::= "{" ws ( From ebe41d49a69a46735d65e0e7fc8ae6e4b0a0fbda Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Sep 2023 21:07:01 +0300 Subject: [PATCH 5/9] common : warm-up with 2 tokens - seems to work better --- common/common.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index 313821375..f72f1b84d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -770,7 +770,7 @@ std::tuple llama_init_from_gpt_par { LOG("warming up the model with an empty run\n"); - const std::vector tmp = { llama_token_bos(lctx), }; + const std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; llama_eval(lctx, tmp.data(), tmp.size(), 0, params.n_threads); llama_reset_timings(lctx); } From 6c150d763ee3ed2ae4df4595ee74693c7e4304f6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Sep 2023 12:54:38 +0300 Subject: [PATCH 6/9] speculative : print draft token pieces --- examples/speculative/speculative.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 594d4f5d6..78681dd25 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -241,7 +241,7 @@ int main(int argc, char ** argv) { llama_sample_softmax(ctx_dft, &cur_p); for (int i = 0; i < 3; ++i) { - LOG(" - draft candidate %d: %d (%.3f)\n", i, cur_p.data[i].id, cur_p.data[i].p); + LOG(" - draft candidate %3d: %6d (%8.3f) '%s'\n", i, cur_p.data[i].id, cur_p.data[i].p, llama_token_to_piece(ctx_dft, cur_p.data[i].id).c_str()); } // too low probability, stop drafting From e7dc5b08acedb066129dee08babb7ac5ba089044 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Sep 2023 15:18:38 +0300 Subject: [PATCH 7/9] speculative : reuse grammar parser + better logs and comments --- examples/speculative/speculative.cpp | 84 ++++++++++++++-------------- 1 file changed, 41 insertions(+), 43 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 78681dd25..9fab8266d 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -114,35 +114,21 @@ int main(int argc, char ** argv) { struct llama_grammar * grammar_dft = NULL; struct llama_grammar * grammar_tgt = NULL; - grammar_parser::parse_state parsed_grammar_dft; - grammar_parser::parse_state parsed_grammar_tgt; + grammar_parser::parse_state parsed_grammar; std::vector grammar_mem(n_draft, NULL); + // if requested - load the grammar, error checking is omitted for brevity if (!params.grammar.empty()) { - // dft - { - parsed_grammar_dft = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar_dft.rules.empty()) { - return 1; - } - - std::vector grammar_rules(parsed_grammar_dft.c_rules()); - grammar_dft = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar_dft.symbol_ids.at("root")); + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + return 1; } - // tgt - { - parsed_grammar_tgt = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar_tgt.rules.empty()) { - return 1; - } - - std::vector grammar_rules(parsed_grammar_tgt.c_rules()); - grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar_tgt.symbol_ids.at("root")); - } + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar_dft = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } const auto t_dec_start = ggml_time_us(); @@ -150,11 +136,12 @@ int main(int argc, char ** argv) { while (true) { LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted)); - // sample from the drafted tokens if any int i_dft = 0; while (true) { + // sample from the target model const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); + // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); @@ -170,8 +157,9 @@ int main(int argc, char ** argv) { ++n_predict; + // check if the draft matches the target if (i_dft < (int) drafted.size() && id == drafted[i_dft]) { - LOG("drafted token %d accepted\n", id); + LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str()); ++n_accept; ++n_past_tgt; ++n_past_dft; @@ -180,25 +168,20 @@ int main(int argc, char ** argv) { continue; } + // the drafted token was rejected or we are out of drafted tokens + if (i_dft < (int) drafted.size()) { - LOG("drafted token %d rejected\n", id); + LOG("the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n", + i_dft, drafted[i_dft], llama_token_to_piece(ctx_dft, drafted[i_dft]).c_str(), id, token_str.c_str()); if (grammar_mem[i_dft]) { grammar_dft = llama_grammar_copy(grammar_mem[i_dft]); - LOG("restored grammar %d\n", i_dft); + LOG("restored draft grammar state %d\n", i_dft); } + } else { + LOG("out of drafted tokens\n"); } - for (auto & g : grammar_mem) { - if (g) { - llama_grammar_free(g); - g = NULL; - } - } - - LOG("i_dft = %d, drafted.size() = %d\n", i_dft, (int) drafted.size()); - - // the drafted token was rejected or we are out of drafted tokens llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); ++n_past_dft; @@ -212,11 +195,20 @@ int main(int argc, char ** argv) { break; } + for (int i = 0; i < (int) grammar_mem.size(); ++i) { + auto & g = grammar_mem[i]; + if (g) { + LOG("freeing grammar state %d\n", i); + llama_grammar_free(g); + g = NULL; + } + } + if (n_predict > params.n_predict || has_eos) { break; } - // sample n_draft tokens from the draft model picking the best token + // sample n_draft tokens from the draft model using greedy decoding int n_past_cur = n_past_dft; for (int i = 0; i < n_draft; ++i) { // remember the grammar state @@ -244,11 +236,13 @@ int main(int argc, char ** argv) { LOG(" - draft candidate %3d: %6d (%8.3f) '%s'\n", i, cur_p.data[i].id, cur_p.data[i].p, llama_token_to_piece(ctx_dft, cur_p.data[i].id).c_str()); } - // too low probability, stop drafting + // TODO: better logic? if (cur_p.data[0].p < 2*cur_p.data[1].p) { + LOG("stopping drafting, probability too low: %8.f < 2*%8.f\n", cur_p.data[0].p, cur_p.data[1].p); break; } + // drafted token const llama_token id = cur_p.data[0].id; if (grammar_dft != NULL) { @@ -258,17 +252,21 @@ int main(int argc, char ** argv) { drafted.push_back(id); ++n_drafted; - if (i < n_draft - 1) { - // evaluate the drafted token on the draft model - llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); - ++n_past_cur; + // no need to evaluate the last drafted token, since we won't use the result + if (i == n_draft - 1) { + break; } + + // evaluate the drafted token on the draft model + llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); + ++n_past_cur; } // evaluate the target model on the drafted tokens llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads); ++n_past_tgt; + // the first token is always proposed by the traget model before the speculation loop drafted.erase(drafted.begin()); } From 2db2471c13db3b0e60fe02341faa1b7e036fd9d5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Sep 2023 15:42:54 +0300 Subject: [PATCH 8/9] speculative : avoid grammar_mem --- examples/speculative/speculative.cpp | 45 +++++++++------------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 9fab8266d..c6211ac79 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -116,8 +116,6 @@ int main(int argc, char ** argv) { grammar_parser::parse_state parsed_grammar; - std::vector grammar_mem(n_draft, NULL); - // if requested - load the grammar, error checking is omitted for brevity if (!params.grammar.empty()) { parsed_grammar = grammar_parser::parse(params.grammar.c_str()); @@ -127,7 +125,6 @@ int main(int argc, char ** argv) { } std::vector grammar_rules(parsed_grammar.c_rules()); - grammar_dft = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } @@ -173,11 +170,6 @@ int main(int argc, char ** argv) { if (i_dft < (int) drafted.size()) { LOG("the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n", i_dft, drafted[i_dft], llama_token_to_piece(ctx_dft, drafted[i_dft]).c_str(), id, token_str.c_str()); - - if (grammar_mem[i_dft]) { - grammar_dft = llama_grammar_copy(grammar_mem[i_dft]); - LOG("restored draft grammar state %d\n", i_dft); - } } else { LOG("out of drafted tokens\n"); } @@ -188,34 +180,25 @@ int main(int argc, char ** argv) { drafted.clear(); drafted.push_back(id); - if (grammar_dft != NULL) { - llama_grammar_accept_token(ctx_dft, grammar_dft, id); - } - break; } - for (int i = 0; i < (int) grammar_mem.size(); ++i) { - auto & g = grammar_mem[i]; - if (g) { - LOG("freeing grammar state %d\n", i); - llama_grammar_free(g); - g = NULL; - } - } - if (n_predict > params.n_predict || has_eos) { break; } + if (grammar_tgt) { + if (grammar_dft) { + llama_grammar_free(grammar_dft); + } + grammar_dft = llama_grammar_copy(grammar_tgt); + + LOG("copied target grammar to draft grammar\n"); + } + // sample n_draft tokens from the draft model using greedy decoding int n_past_cur = n_past_dft; for (int i = 0; i < n_draft; ++i) { - // remember the grammar state - if (grammar_dft != NULL) { - grammar_mem[i] = llama_grammar_copy(grammar_dft); - } - float * logits = llama_get_logits(ctx_dft); candidates.clear(); @@ -238,17 +221,13 @@ int main(int argc, char ** argv) { // TODO: better logic? if (cur_p.data[0].p < 2*cur_p.data[1].p) { - LOG("stopping drafting, probability too low: %8.f < 2*%8.f\n", cur_p.data[0].p, cur_p.data[1].p); + LOG("stopping drafting, probability too low: %.3f < 2*%.3f\n", cur_p.data[0].p, cur_p.data[1].p); break; } // drafted token const llama_token id = cur_p.data[0].id; - if (grammar_dft != NULL) { - llama_grammar_accept_token(ctx_dft, grammar_dft, id); - } - drafted.push_back(id); ++n_drafted; @@ -260,6 +239,10 @@ int main(int argc, char ** argv) { // evaluate the drafted token on the draft model llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); ++n_past_cur; + + if (grammar_dft != NULL) { + llama_grammar_accept_token(ctx_dft, grammar_dft, id); + } } // evaluate the target model on the drafted tokens From c79d130f74cdfb18cb47fe7ed7a0b56e2774c6f1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Sep 2023 15:50:04 +0300 Subject: [PATCH 9/9] make : fix speculative build --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 9ff2f9e95..71c6fdde4 100644 --- a/Makefile +++ b/Makefile @@ -477,7 +477,7 @@ baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o $(OBJS) beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o common.o $(OBJS) +speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) ifneq '' '$(or $(filter clean,$(MAKECMDGOALS)),$(LLAMA_METAL))'