Resolve merge conflict with grammar stuff.

This commit is contained in:
goerch 2023-07-25 18:14:38 +02:00
parent 3bdf106e06
commit b4a5461ff8

View file

@ -2624,6 +2624,25 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
}
}
// TODO: reorder functions?
std::string llama_token_to_str(
const struct llama_context * ctx,
llama_token token) {
std::string result;
int length = 8;
result.resize(length);
length = llama_token_to_str(ctx, token, (char *)result.data(), result.length());
if (length < 0) {
result.resize(-length);
int check = llama_token_to_str(ctx, token, (char *)result.data(), result.length());
assert(check == -length);
GGML_UNUSED(check);
} else {
result.resize(length);
}
return result;
}
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
assert(ctx);
const int64_t t_start_sample_us = ggml_time_us();
@ -2643,15 +2662,15 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const char * str = llama_token_to_str(ctx, id);
std::string str = llama_token_to_str(ctx, id);
if (id == eos) {
if (!allow_eos) {
candidates->data[i].logit = -INFINITY;
}
} else if (*str == 0) {
} else if (str.empty()) {
candidates->data[i].logit = -INFINITY;
} else {
candidates_decoded.push_back(decode_utf8(str));
candidates_decoded.push_back(decode_utf8(str.c_str()));
candidates_grammar.push_back({ i, candidates_decoded.back().data() });
}
}
@ -2852,9 +2871,9 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
LLAMA_ASSERT(false);
}
const char * str = llama_token_to_str(ctx, token);
std::string str = llama_token_to_str(ctx, token);
// Note terminating 0 in decoded string
auto code_points = decode_utf8(str);
auto code_points = decode_utf8(str.c_str());
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
}
@ -4182,24 +4201,6 @@ int llama_token_to_str(const struct llama_context * ctx, llama_token token, char
return llama_token_to_str_with_model(&ctx->model, token, str, length);
}
std::string llama_token_to_str(
const struct llama_context * ctx,
llama_token token) {
std::string result;
int length = 8;
result.resize(length);
length = llama_token_to_str(ctx, token, (char *)result.data(), result.length());
if (length < 0) {
result.resize(-length);
int check = llama_token_to_str(ctx, token, (char *)result.data(), result.length());
assert(check == -length);
GGML_UNUSED(check);
} else {
result.resize(length);
}
return result;
}
int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * str, int length) {
if (0 <= token && token < llama_n_vocab_from_model(&ctx->model)) {
std::string result = ctx->model.vocab.id_to_token[token].tok;