Fix C linkage for llama_token_to_str

This commit is contained in:
goerch 2023-07-24 08:05:16 +02:00
parent dba8369a39
commit b97a505c5d
6 changed files with 80 additions and 20 deletions

View file

@ -3,6 +3,7 @@
#pragma once #pragma once
#include "llama.h" #include "llama.h"
#include "llama.cpp.h"
#include <string> #include <string>
#include <vector> #include <vector>

View file

@ -520,7 +520,7 @@ struct llama_file_loader {
vocab.token_to_id[word] = i; vocab.token_to_id[word] = i;
auto & tok_score = vocab.id_to_token[i]; auto & tok_score = vocab.id_to_token[i];
tok_score.tok = word; tok_score.tok = std::move(word);
tok_score.score = score; tok_score.score = score;
} }
} }
@ -3725,24 +3725,32 @@ float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embedding.data(); return ctx->embedding.data();
} }
std::string llama_token_to_str_with_model(const struct llama_model * model, llama_token token) { int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * str, int length) {
if (token >= llama_n_vocab_from_model(model)) { if (0 <= token && token < llama_n_vocab_from_model(model)) {
return nullptr; std::string result = llama_unescape_whitespace(model->vocab.id_to_token[token].tok);
if(result.length() > length) {
return - result.length();
}
strcpy(str, result.c_str());
return result.length();
} }
return 0;
return llama_unescape_whitespace(model->vocab.id_to_token[token].tok);
} }
std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) { int llama_token_to_str(const struct llama_context * ctx, llama_token token, char * str, int length) {
return llama_token_to_str_with_model(&ctx->model, token); return llama_token_to_str_with_model(&ctx->model, token, str, length);
} }
std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) { int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * str, int length) {
if (token >= llama_n_vocab_from_model(&ctx->model)) { if (0 <= token && token < llama_n_vocab_from_model(&ctx->model)) {
return nullptr; std::string result = ctx->model.vocab.id_to_token[token].tok;
if (result.length() > length) {
return -result.length();
}
strcpy(str, result.c_str());
return result.length();
} }
return 0;
return ctx->model.vocab.id_to_token[token].tok;
} }
llama_token llama_token_bos() { llama_token llama_token_bos() {

44
llama.cpp.h Normal file
View file

@ -0,0 +1,44 @@
#ifndef LLAMA_CPP_H
#define LLAMA_CPP_H
#include "llama.h"
#include <cassert>
static std::string llama_token_to_str(
const struct llama_context * ctx,
llama_token token) {
std::string result;
int length = 8;
result.resize(length);
length = llama_token_to_str(ctx, token, (char *)result.data(), result.length());
if (length < 0) {
result.resize(-length);
int check = llama_token_to_str(ctx, token, (char *)result.data(), result.length());
assert(check == -length);
GGML_UNUSED(check);
} else {
result.resize(length);
}
return result;
}
static std::string llama_token_to_str_bpe(
const struct llama_context * ctx,
llama_token token) {
std::string result;
int length = 8;
result.resize(length);
length = llama_token_to_str_bpe(ctx, token, (char*)result.data(), result.length());
if (length < 0) {
result.resize(-length);
int check = llama_token_to_str_bpe(ctx, token, (char*)result.data(), result.length());
assert(check == -length);
GGML_UNUSED(check);
} else {
result.resize(length);
}
return result;
}
#endif

19
llama.h
View file

@ -327,18 +327,23 @@ extern "C" {
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context // Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API std::string llama_token_to_str( LLAMA_API int llama_token_to_str(
const struct llama_context * ctx, const struct llama_context * ctx,
llama_token token); llama_token token,
char * str,
int length);
LLAMA_API std::string llama_token_to_str_bpe( LLAMA_API int llama_token_to_str_bpe(
const struct llama_context * ctx, const struct llama_context * ctx,
llama_token token); llama_token token,
char * str,
int length);
LLAMA_API std::string llama_token_to_str_with_model( LLAMA_API int llama_token_to_str_with_model(
const struct llama_model * model, const struct llama_model * model,
llama_token token); llama_token token,
char * str,
int length);
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_eos(); // end-of-sentence

View file

@ -1,4 +1,5 @@
#include "llama.h" #include "llama.h"
#include "llama.cpp.h"
#include <cstdio> #include <cstdio>
#include <string> #include <string>

View file

@ -1,4 +1,5 @@
#include "llama.h" #include "llama.h"
#include "llama.cpp.h"
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>