now working

This commit is contained in:
mike dupont 2023-11-22 12:01:26 -05:00
parent ef4c0f572b
commit 09a1f053e7
4 changed files with 25 additions and 13 deletions

View file

@ -144,7 +144,7 @@ namespace grammar_parser {
while (*pos != '"') {
auto char_pair = parse_char(pos);
pos = char_pair.second;
out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
out_elements.push_back(llama_grammar_element(LLAMA_GRETYPE_CHAR, char_pair.first));
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
@ -162,11 +162,11 @@ namespace grammar_parser {
? LLAMA_GRETYPE_CHAR_ALT
: start_type;
out_elements.push_back({type, char_pair.first});
out_elements.push_back(llama_grammar_element(type, char_pair.first));
if (pos[0] == '-' && pos[1] != ']') {
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
out_elements.push_back(llama_grammar_element(LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first));
}
}
pos = parse_space(pos + 1, is_nested);
@ -175,7 +175,7 @@ namespace grammar_parser {
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = out_elements.size();
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
out_elements.push_back(llama_grammar_element(LLAMA_GRETYPE_RULE_REF, ref_rule_id));
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
@ -183,7 +183,7 @@ namespace grammar_parser {
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
last_sym_start = out_elements.size();
// output reference to synthesized rule
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
out_elements.push_back(llama_grammar_element(LLAMA_GRETYPE_RULE_REF, sub_rule_id));
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
@ -219,7 +219,8 @@ namespace grammar_parser {
// in original rule, replace previous symbol with reference to generated rule
out_elements.resize(last_sym_start);
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
llama_grammar_element(LLAMA_GRETYPE_RULE_REF, sub_rule_id) a;
out_elements.push_back(a);
pos = parse_space(pos + 1, is_nested);
} else {

View file

@ -138,7 +138,7 @@ llama_token llama_sampling_sample(
cur.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
cur.emplace_back(llama_token_data(token_id, logits[token_id], 0.0f));
}
llama_token_data_array cur_p = { cur.data(), cur.size(), false };

View file

@ -6745,7 +6745,7 @@ struct llama_grammar * llama_grammar_init(
for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
vec_rules[i].push_back(*pos);
}
llama_grammar_element ge = {.type=LLAMA_GRETYPE_END, .value=0};
llama_grammar_element ge(LLAMA_GRETYPE_END,0);
vec_rules[i].push_back(ge);
}
@ -7398,11 +7398,11 @@ struct llama_logit_info {
{ }
llama_token_data get_token_data(const llama_token token_id) const {
constexpr auto p = std::numeric_limits<float>::quiet_NaN(); // never used
llama_token_data dd {
.id = token_id,
.logit = logits[token_id],
.p = p
};
llama_token_data dd(
token_id,
logits[token_id],
p
);
return dd;
}
// Return top k token_data by logit.

11
llama.h
View file

@ -115,6 +115,13 @@ extern "C" {
};
typedef struct llama_token_data : refl::attr::usage::type{
llama_token_data( llama_token id,
float logit,
float p):
id( id),
logit(logit),
p(p){
}
llama_token id; // token id
float logit; // log-odds of the token
float p; // probability of the token
@ -244,6 +251,10 @@ extern "C" {
};
typedef struct llama_grammar_element : refl::attr::usage::type{
llama_grammar_element( enum llama_gretype type,
uint32_t value // Unicode code point or rule ID
):type(type), value(value){}
llama_grammar_element( ):type(0), value(0){}
enum llama_gretype type;
uint32_t value; // Unicode code point or rule ID
} llama_grammar_element;