diff --git a/llama.cpp b/llama.cpp index 7d8b9a0ac..bd777ac04 100644 --- a/llama.cpp +++ b/llama.cpp @@ -954,6 +954,12 @@ struct llama_vocab { id linefeed_id = 13; + // codellama FIM special tokens + id special_prefix_id = 32007; + id special_middle_id = 32009; + id special_suffix_id = 32008; + id special_eot_id = 32010; + int find_bpe_rank(std::string token_left, std::string token_right) const { replace_all(token_left, " ", "Ġ"); replace_all(token_left, "\n", "Ċ"); @@ -6080,6 +6086,22 @@ llama_token llama_token_nl(const struct llama_context * ctx) { return ctx->model.vocab.linefeed_id; } +llama_token llama_token_prefix(const struct llama_context * ctx) { + return ctx->model.vocab.special_prefix_id; +} + +llama_token llama_token_middle(const struct llama_context * ctx) { + return ctx->model.vocab.special_middle_id; +} + +llama_token llama_token_suffix(const struct llama_context * ctx) { + return ctx->model.vocab.special_suffix_id; +} + +llama_token llama_token_eot(const struct llama_context * ctx) { + return ctx->model.vocab.special_eot_id; +} + int llama_tokenize( struct llama_context * ctx, const char * text, diff --git a/llama.h b/llama.h index b77dd7735..f15f63e0e 100644 --- a/llama.h +++ b/llama.h @@ -359,6 +359,12 @@ extern "C" { LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line + // codellama FIM tokens + LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of FIM prefix + LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of FIM middle + LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of FIM suffix + LLAMA_API llama_token llama_token_eot (const struct llama_context * ctx); // End of FIM middle + // // Tokenization //