From 09a1f053e7f97c7b455305ffc7b124fcc178b8dc Mon Sep 17 00:00:00 2001 From: mike dupont Date: Wed, 22 Nov 2023 12:01:26 -0500 Subject: [PATCH] now working --- common/grammar-parser.cpp | 13 +++++++------ common/sampling.cpp | 2 +- llama.cpp | 12 ++++++------ llama.h | 11 +++++++++++ 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index ff51cc803..59bcf0d73 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -144,7 +144,7 @@ namespace grammar_parser { while (*pos != '"') { auto char_pair = parse_char(pos); pos = char_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + out_elements.push_back(llama_grammar_element(LLAMA_GRETYPE_CHAR, char_pair.first)); } pos = parse_space(pos + 1, is_nested); } else if (*pos == '[') { // char range(s) @@ -162,11 +162,11 @@ namespace grammar_parser { ? LLAMA_GRETYPE_CHAR_ALT : start_type; - out_elements.push_back({type, char_pair.first}); + out_elements.push_back(llama_grammar_element(type, char_pair.first)); if (pos[0] == '-' && pos[1] != ']') { auto endchar_pair = parse_char(pos + 1); pos = endchar_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + out_elements.push_back(llama_grammar_element(LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first)); } } pos = parse_space(pos + 1, is_nested); @@ -175,7 +175,7 @@ namespace grammar_parser { uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); pos = parse_space(name_end, is_nested); last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + out_elements.push_back(llama_grammar_element(LLAMA_GRETYPE_RULE_REF, ref_rule_id)); } else if (*pos == '(') { // grouping // parse nested alternates into synthesized rule pos = parse_space(pos + 1, true); @@ -183,7 +183,7 @@ namespace grammar_parser { pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); last_sym_start = out_elements.size(); // output reference to synthesized rule - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + out_elements.push_back(llama_grammar_element(LLAMA_GRETYPE_RULE_REF, sub_rule_id)); if (*pos != ')') { throw std::runtime_error(std::string("expecting ')' at ") + pos); } @@ -219,7 +219,8 @@ namespace grammar_parser { // in original rule, replace previous symbol with reference to generated rule out_elements.resize(last_sym_start); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + llama_grammar_element(LLAMA_GRETYPE_RULE_REF, sub_rule_id) a; + out_elements.push_back(a); pos = parse_space(pos + 1, is_nested); } else { diff --git a/common/sampling.cpp b/common/sampling.cpp index 1317024c2..d83cfc57c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -138,7 +138,7 @@ llama_token llama_sampling_sample( cur.clear(); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + cur.emplace_back(llama_token_data(token_id, logits[token_id], 0.0f)); } llama_token_data_array cur_p = { cur.data(), cur.size(), false }; diff --git a/llama.cpp b/llama.cpp index 3f203fe33..01e675019 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6745,7 +6745,7 @@ struct llama_grammar * llama_grammar_init( for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { vec_rules[i].push_back(*pos); } - llama_grammar_element ge = {.type=LLAMA_GRETYPE_END, .value=0}; + llama_grammar_element ge(LLAMA_GRETYPE_END,0); vec_rules[i].push_back(ge); } @@ -7398,11 +7398,11 @@ struct llama_logit_info { { } llama_token_data get_token_data(const llama_token token_id) const { constexpr auto p = std::numeric_limits::quiet_NaN(); // never used - llama_token_data dd { - .id = token_id, - .logit = logits[token_id], - .p = p - }; + llama_token_data dd( + token_id, + logits[token_id], + p + ); return dd; } // Return top k token_data by logit. diff --git a/llama.h b/llama.h index d2da1abea..32d5dff8d 100644 --- a/llama.h +++ b/llama.h @@ -115,6 +115,13 @@ extern "C" { }; typedef struct llama_token_data : refl::attr::usage::type{ + llama_token_data( llama_token id, + float logit, + float p): + id( id), + logit(logit), + p(p){ + } llama_token id; // token id float logit; // log-odds of the token float p; // probability of the token @@ -244,6 +251,10 @@ extern "C" { }; typedef struct llama_grammar_element : refl::attr::usage::type{ + llama_grammar_element( enum llama_gretype type, + uint32_t value // Unicode code point or rule ID + ):type(type), value(value){} + llama_grammar_element( ):type(0), value(0){} enum llama_gretype type; uint32_t value; // Unicode code point or rule ID } llama_grammar_element;