grammars: mutex-guarded lazy caching of token pieces in llama_sample_grammar
This commit is contained in:
parent
80736c556b
commit
939e143fe2
1 changed files with 16 additions and 12 deletions
28
llama.cpp
28
llama.cpp
|
@ -2329,6 +2329,7 @@ struct llama_context {
|
||||||
struct llama_control_vector cvec;
|
struct llama_control_vector cvec;
|
||||||
|
|
||||||
// caching token pieces & their decoded codepoints.
|
// caching token pieces & their decoded codepoints.
|
||||||
|
std::mutex token_cache_mutex;
|
||||||
std::vector<std::string> token_pieces;
|
std::vector<std::string> token_pieces;
|
||||||
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>>
|
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>>
|
||||||
token_codepoints_without_partial_utf8_prefix;
|
token_codepoints_without_partial_utf8_prefix;
|
||||||
|
@ -13624,6 +13625,21 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// cache tokens & their decoded codepoints (for common case where there's no partial utf8 prefix bytes) for grammar-constrained sampling.
|
||||||
|
std::unique_lock<std::mutex> lock(ctx->token_cache_mutex);
|
||||||
|
if (ctx->token_pieces.empty()) {
|
||||||
|
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
ctx->token_codepoints_without_partial_utf8_prefix.resize(n_vocab);
|
||||||
|
ctx->token_pieces.resize(n_vocab);
|
||||||
|
for (llama_token id = 0; id < n_vocab; ++id) {
|
||||||
|
const std::string piece = llama_token_to_piece(ctx, id, false);
|
||||||
|
ctx->token_pieces[id] = piece;
|
||||||
|
ctx->token_codepoints_without_partial_utf8_prefix[id] = decode_utf8(piece, {0, 0});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Store decoded codepoints when they are not cached (happens when there's a partial utf8 string prefix).
|
// Store decoded codepoints when they are not cached (happens when there's a partial utf8 string prefix).
|
||||||
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
||||||
if (grammar->partial_utf8.n_remain > 0) {
|
if (grammar->partial_utf8.n_remain > 0) {
|
||||||
|
@ -15730,18 +15746,6 @@ struct llama_context * llama_new_context_with_model(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// cache tokens & their decoded codepoints (for common case where there's no partial utf8 prefix bytes) for grammar-constrained sampling.
|
|
||||||
{
|
|
||||||
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
||||||
ctx->token_codepoints_without_partial_utf8_prefix.resize(n_vocab);
|
|
||||||
ctx->token_pieces.resize(n_vocab);
|
|
||||||
for (llama_token id = 0; id < n_vocab; ++id) {
|
|
||||||
const std::string piece = llama_token_to_piece(ctx, id, false);
|
|
||||||
ctx->token_pieces[id] = piece;
|
|
||||||
ctx->token_codepoints_without_partial_utf8_prefix[id] = decode_utf8(piece, {0, 0});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef GGML_USE_MPI
|
#ifdef GGML_USE_MPI
|
||||||
ctx->ctx_mpi = ggml_mpi_init();
|
ctx->ctx_mpi = ggml_mpi_init();
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue