now working
This commit is contained in:
parent
ef4c0f572b
commit
09a1f053e7
4 changed files with 25 additions and 13 deletions
|
@ -144,7 +144,7 @@ namespace grammar_parser {
|
||||||
while (*pos != '"') {
|
while (*pos != '"') {
|
||||||
auto char_pair = parse_char(pos);
|
auto char_pair = parse_char(pos);
|
||||||
pos = char_pair.second;
|
pos = char_pair.second;
|
||||||
out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
out_elements.push_back(llama_grammar_element(LLAMA_GRETYPE_CHAR, char_pair.first));
|
||||||
}
|
}
|
||||||
pos = parse_space(pos + 1, is_nested);
|
pos = parse_space(pos + 1, is_nested);
|
||||||
} else if (*pos == '[') { // char range(s)
|
} else if (*pos == '[') { // char range(s)
|
||||||
|
@ -162,11 +162,11 @@ namespace grammar_parser {
|
||||||
? LLAMA_GRETYPE_CHAR_ALT
|
? LLAMA_GRETYPE_CHAR_ALT
|
||||||
: start_type;
|
: start_type;
|
||||||
|
|
||||||
out_elements.push_back({type, char_pair.first});
|
out_elements.push_back(llama_grammar_element(type, char_pair.first));
|
||||||
if (pos[0] == '-' && pos[1] != ']') {
|
if (pos[0] == '-' && pos[1] != ']') {
|
||||||
auto endchar_pair = parse_char(pos + 1);
|
auto endchar_pair = parse_char(pos + 1);
|
||||||
pos = endchar_pair.second;
|
pos = endchar_pair.second;
|
||||||
out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
out_elements.push_back(llama_grammar_element(LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pos = parse_space(pos + 1, is_nested);
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
@ -175,7 +175,7 @@ namespace grammar_parser {
|
||||||
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
|
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
|
||||||
pos = parse_space(name_end, is_nested);
|
pos = parse_space(name_end, is_nested);
|
||||||
last_sym_start = out_elements.size();
|
last_sym_start = out_elements.size();
|
||||||
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
out_elements.push_back(llama_grammar_element(LLAMA_GRETYPE_RULE_REF, ref_rule_id));
|
||||||
} else if (*pos == '(') { // grouping
|
} else if (*pos == '(') { // grouping
|
||||||
// parse nested alternates into synthesized rule
|
// parse nested alternates into synthesized rule
|
||||||
pos = parse_space(pos + 1, true);
|
pos = parse_space(pos + 1, true);
|
||||||
|
@ -183,7 +183,7 @@ namespace grammar_parser {
|
||||||
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
|
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
|
||||||
last_sym_start = out_elements.size();
|
last_sym_start = out_elements.size();
|
||||||
// output reference to synthesized rule
|
// output reference to synthesized rule
|
||||||
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
out_elements.push_back(llama_grammar_element(LLAMA_GRETYPE_RULE_REF, sub_rule_id));
|
||||||
if (*pos != ')') {
|
if (*pos != ')') {
|
||||||
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
||||||
}
|
}
|
||||||
|
@ -219,7 +219,8 @@ namespace grammar_parser {
|
||||||
|
|
||||||
// in original rule, replace previous symbol with reference to generated rule
|
// in original rule, replace previous symbol with reference to generated rule
|
||||||
out_elements.resize(last_sym_start);
|
out_elements.resize(last_sym_start);
|
||||||
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
llama_grammar_element(LLAMA_GRETYPE_RULE_REF, sub_rule_id) a;
|
||||||
|
out_elements.push_back(a);
|
||||||
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
pos = parse_space(pos + 1, is_nested);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -138,7 +138,7 @@ llama_token llama_sampling_sample(
|
||||||
cur.clear();
|
cur.clear();
|
||||||
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
cur.emplace_back(llama_token_data(token_id, logits[token_id], 0.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||||
|
|
12
llama.cpp
12
llama.cpp
|
@ -6745,7 +6745,7 @@ struct llama_grammar * llama_grammar_init(
|
||||||
for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
|
for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
|
||||||
vec_rules[i].push_back(*pos);
|
vec_rules[i].push_back(*pos);
|
||||||
}
|
}
|
||||||
llama_grammar_element ge = {.type=LLAMA_GRETYPE_END, .value=0};
|
llama_grammar_element ge(LLAMA_GRETYPE_END,0);
|
||||||
vec_rules[i].push_back(ge);
|
vec_rules[i].push_back(ge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7398,11 +7398,11 @@ struct llama_logit_info {
|
||||||
{ }
|
{ }
|
||||||
llama_token_data get_token_data(const llama_token token_id) const {
|
llama_token_data get_token_data(const llama_token token_id) const {
|
||||||
constexpr auto p = std::numeric_limits<float>::quiet_NaN(); // never used
|
constexpr auto p = std::numeric_limits<float>::quiet_NaN(); // never used
|
||||||
llama_token_data dd {
|
llama_token_data dd(
|
||||||
.id = token_id,
|
token_id,
|
||||||
.logit = logits[token_id],
|
logits[token_id],
|
||||||
.p = p
|
p
|
||||||
};
|
);
|
||||||
return dd;
|
return dd;
|
||||||
}
|
}
|
||||||
// Return top k token_data by logit.
|
// Return top k token_data by logit.
|
||||||
|
|
11
llama.h
11
llama.h
|
@ -115,6 +115,13 @@ extern "C" {
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef struct llama_token_data : refl::attr::usage::type{
|
typedef struct llama_token_data : refl::attr::usage::type{
|
||||||
|
llama_token_data( llama_token id,
|
||||||
|
float logit,
|
||||||
|
float p):
|
||||||
|
id( id),
|
||||||
|
logit(logit),
|
||||||
|
p(p){
|
||||||
|
}
|
||||||
llama_token id; // token id
|
llama_token id; // token id
|
||||||
float logit; // log-odds of the token
|
float logit; // log-odds of the token
|
||||||
float p; // probability of the token
|
float p; // probability of the token
|
||||||
|
@ -244,6 +251,10 @@ extern "C" {
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef struct llama_grammar_element : refl::attr::usage::type{
|
typedef struct llama_grammar_element : refl::attr::usage::type{
|
||||||
|
llama_grammar_element( enum llama_gretype type,
|
||||||
|
uint32_t value // Unicode code point or rule ID
|
||||||
|
):type(type), value(value){}
|
||||||
|
llama_grammar_element( ):type(0), value(0){}
|
||||||
enum llama_gretype type;
|
enum llama_gretype type;
|
||||||
uint32_t value; // Unicode code point or rule ID
|
uint32_t value; // Unicode code point or rule ID
|
||||||
} llama_grammar_element;
|
} llama_grammar_element;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue