Code cleanup

This commit is contained in:
Igor Pissolati 2023-06-19 14:52:57 -03:00
parent 61a98bc30a
commit 0c14627438
2 changed files with 14 additions and 24 deletions

View file

@ -281,7 +281,6 @@ struct llama_vocab {
llama_trie special_token_trie; llama_trie special_token_trie;
std::unordered_map<token, id> special_token_to_id; std::unordered_map<token, id> special_token_to_id;
std::vector<id> special_tokens;
size_t max_special_token_length; size_t max_special_token_length;
}; };
@ -580,14 +579,13 @@ struct llama_file_loader {
for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) { for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) {
uint32_t token_id = file.read_u32(); uint32_t token_id = file.read_u32();
const auto & token = vocab.id_to_token[token_id].tok; const auto & word = vocab.id_to_token[token_id].tok;
vocab.special_token_trie.add(token); vocab.special_token_trie.add(word);
vocab.special_tokens.push_back(token_id); vocab.special_token_to_id[word] = token_id;
vocab.special_token_to_id[token] = token_id;
if (vocab.max_special_token_length < token.size()) { if (vocab.max_special_token_length < word.size()) {
vocab.max_special_token_length = token.size(); vocab.max_special_token_length = word.size();
} }
} }
} }
@ -674,9 +672,8 @@ struct llama_file_saver {
file.write_raw(token_score.tok.data(), token_score.tok.size()); file.write_raw(token_score.tok.data(), token_score.tok.size());
file.write_raw(&token_score.score, sizeof(token_score.score)); file.write_raw(&token_score.score, sizeof(token_score.score));
} }
uint32_t n_vocab_sp = any_file_loader->hparams.n_vocab_sp; for (const auto & pair : any_file_loader->vocab.special_token_to_id) {
for (uint32_t i = 0; i < n_vocab; i++) { file.write_u32(pair.second);
file.write_u32(any_file_loader->vocab.special_tokens[i]);
} }
} }
void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) { void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) {
@ -2111,24 +2108,23 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
return output; return output;
} }
auto offsets = vocab.special_token_trie.split(text); std::vector<int> offsets = vocab.special_token_trie.split(text);
int start = 0; int start = 0;
for (int end : offsets) { for (int end : offsets) {
if (start >= end) { if (start >= end) {
continue; continue;
} }
size_t part_length = end - start; const char *part = text.c_str() + start;
//printf("\"%.*s\"\n", (int) part_length, text.c_str() + start); size_t part_len = end - start;
if (vocab.max_special_token_length < part_len) {
if (vocab.max_special_token_length < part_length) { tokenizer.tokenize(part, part_len, output);
tokenizer.tokenize(text.c_str() + start, part_length, output);
} else { } else {
auto token_it = vocab.special_token_to_id.find(std::string(text.c_str() + start, part_length)); auto token_it = vocab.special_token_to_id.find(std::string(part, part_len));
if (token_it != vocab.special_token_to_id.end()) { if (token_it != vocab.special_token_to_id.end()) {
output.push_back(token_it->second); output.push_back(token_it->second);
} else { } else {
tokenizer.tokenize(text.c_str() + start, part_length, output); tokenizer.tokenize(part, part_len, output);
} }
} }
start = end; start = end;
@ -4270,10 +4266,6 @@ llama_token llama_token_nl() {
return 13; return 13;
} }
bool llama_is_special_token(const struct llama_context *ctx, llama_token token) {
return std::find(ctx->vocab.special_tokens.begin(), ctx->vocab.special_tokens.end(), token) != ctx->vocab.special_tokens.end();
}
struct llama_timings llama_get_timings(struct llama_context * ctx) { struct llama_timings llama_get_timings(struct llama_context * ctx) {
struct llama_timings result = { struct llama_timings result = {
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us, /*.t_start_ms =*/ 1e-3 * ctx->t_start_us,

View file

@ -373,8 +373,6 @@ extern "C" {
LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_eos(); // end-of-sentence
LLAMA_API llama_token llama_token_nl(); // next-line LLAMA_API llama_token llama_token_nl(); // next-line
LLAMA_API bool llama_is_special_token(const struct llama_context * ctx, llama_token token);
// Grammar // Grammar
// //
LLAMA_API struct llama_grammar * llama_grammar_init( LLAMA_API struct llama_grammar * llama_grammar_init(