This commit is contained in:
goerch 2023-08-15 08:24:56 +02:00
commit c545d85f83
11 changed files with 711 additions and 373 deletions

View file

@ -36,19 +36,15 @@ KEY_ROPE_DIMENSION_COUNT = "{llm}.rope.dimension_count"
KEY_ROPE_SCALE = "{llm}.rope.scale" KEY_ROPE_SCALE = "{llm}.rope.scale"
# tokenization # tokenization
KEY_TOKENIZER_MODEL = "tokenizer.ggml.model" KEY_TOKENIZER_MODEL = "tokenizer.ggml.model"
KEY_TOKENIZER_LIST = "tokenizer.ggml.tokens" KEY_TOKENIZER_LIST = "tokenizer.ggml.tokens"
KEY_TOKENIZER_SCORES = "tokenizer.ggml.scores" KEY_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type"
KEY_TOKENIZER_MERGES = "tokenizer.ggml.merges" KEY_TOKENIZER_SCORES = "tokenizer.ggml.scores"
KEY_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id" KEY_TOKENIZER_MERGES = "tokenizer.ggml.merges"
KEY_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id" KEY_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id"
KEY_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id" KEY_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id"
KEY_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id" KEY_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id"
KEY_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id" KEY_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id"
KEY_TOKENIZER_HF_JSON = "tokenizer.huggingface.json" KEY_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id"
KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world" KEY_TOKENIZER_HF_JSON = "tokenizer.huggingface.json"
KEY_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id" KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world"
KEY_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id"
KEY_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id"
KEY_TOKENIZER_SEP_ID = "tokenizer.ggml.separator_token_id"
KEY_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id"

View file

@ -110,8 +110,9 @@ gguf_writer.add_layer_norm_rms_eps(llm_arch, hparams["rms_norm_eps"])
print("gguf: get tokenizer metadata") print("gguf: get tokenizer metadata")
tokens: List[str] = [] tokens: List[bytes] = []
scores: List[float] = [] scores: List[float] = []
toktypes: List[int] = []
if Path(dir_model + "/tokenizer.model").is_file(): if Path(dir_model + "/tokenizer.model").is_file():
# vocab type sentencepiece # vocab type sentencepiece
@ -121,26 +122,31 @@ if Path(dir_model + "/tokenizer.model").is_file():
for i in range(tokenizer.vocab_size()): for i in range(tokenizer.vocab_size()):
text: bytes text: bytes
if tokenizer.is_unknown(i): score: float
text = " \u2047 ".encode("utf-8")
elif tokenizer.is_control(i): piece = tokenizer.id_to_piece(i)
text = b"" text = piece.encode("utf-8")
if tokenizer.is_byte(i): score = tokenizer.get_score(i)
piece = tokenizer.id_to_piece(i)
if len(piece) != 6: toktype = 1 # defualt to normal token type
raise Exception(f"Invalid token: {piece}") if tokenizer.is_unknown(i): toktype = 2
byte_value = int(piece[3:-1], 16) if tokenizer.is_control(i): toktype = 3
text = struct.pack("B", byte_value)
else: # TODO: How to determinate if a token is user defined?
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
score: float = tokenizer.get_score(i) # if tokenizer.is_user_defined(i): toktype = 4
if tokenizer.is_unused(i): toktype = 5
if tokenizer.is_byte(i): toktype = 6
tokens.append(text) tokens.append(text)
scores.append(score) scores.append(score)
toktypes.append(toktype)
gguf_writer.add_tokenizer_model("llama") gguf_writer.add_tokenizer_model("llama")
gguf_writer.add_token_list(tokens) gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
if Path(dir_model + "/tokenizer.json").is_file(): if Path(dir_model + "/tokenizer.json").is_file():
with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f: with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:

View file

@ -1,4 +1,4 @@
# HF llama --> gguf conversion, GQA/70b not supported # HF llama --> gguf conversion
import gguf import gguf
import gguf_namemap as tmap import gguf_namemap as tmap
@ -10,7 +10,7 @@ import json
import numpy as np import numpy as np
import torch import torch
from typing import Any, List from typing import Any, List, Optional
from pathlib import Path from pathlib import Path
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
@ -18,11 +18,11 @@ from sentencepiece import SentencePieceProcessor
# compatible with python < 3.9 # compatible with python < 3.9
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
def permute(weights: NDArray, n_head: int) -> NDArray: if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2) .swapaxes(1, 2)
.reshape(weights.shape)) .reshape(weights.shape))
def count_model_parts(dir_model: str) -> int: def count_model_parts(dir_model: str) -> int:
num_parts = 0 num_parts = 0
@ -95,7 +95,7 @@ else:
gguf_writer.add_architecture(llm_arch) gguf_writer.add_architecture(llm_arch)
gguf_writer.add_name(last_dir) gguf_writer.add_name(last_dir)
gguf_writer.add_file_type( "All tensors F32" if ftype == 0 else "Most tensors F16, some F32") gguf_writer.add_file_type("All tensors F32" if ftype == 0 else "Most tensors F16, some F32")
gguf_writer.add_source_hf_repo(hf_repo) gguf_writer.add_source_hf_repo(hf_repo)
gguf_writer.add_context_length(llm_arch, hparams["max_position_embeddings"]) gguf_writer.add_context_length(llm_arch, hparams["max_position_embeddings"])
gguf_writer.add_embedding_length(llm_arch, hparams["hidden_size"]) gguf_writer.add_embedding_length(llm_arch, hparams["hidden_size"])
@ -111,37 +111,43 @@ gguf_writer.add_layer_norm_rms_eps(llm_arch, hparams["rms_norm_eps"])
print("gguf: get tokenizer metadata") print("gguf: get tokenizer metadata")
tokens: List[str] = [] tokens: List[bytes] = []
scores: List[float] = [] scores: List[float] = []
toktypes: List[int] = []
if Path(dir_model + "/tokenizer.model").is_file(): if Path(dir_model + "/tokenizer.model").is_file():
# vocab type sentencepiece # vocab type sentencepiece
print("gguf: get sentencepiece tokenizer vocab and scores") print("gguf: get sentencepiece tokenizer vocab, scores and token types")
tokenizer = SentencePieceProcessor(dir_model + "/tokenizer.model") tokenizer = SentencePieceProcessor(dir_model + "/tokenizer.model")
for i in range(tokenizer.vocab_size()): for i in range(tokenizer.vocab_size()):
text: bytes text: bytes
if tokenizer.is_unknown(i): score: float
text = " \u2047 ".encode("utf-8")
elif tokenizer.is_control(i): piece = tokenizer.id_to_piece(i)
text = b"" text = piece.encode("utf-8")
if tokenizer.is_byte(i): score = tokenizer.get_score(i)
piece = tokenizer.id_to_piece(i)
if len(piece) != 6: toktype = 1 # defualt to normal token type
raise Exception(f"Invalid token: {piece}") if tokenizer.is_unknown(i): toktype = 2
byte_value = int(piece[3:-1], 16) if tokenizer.is_control(i): toktype = 3
text = struct.pack("B", byte_value)
else: # TODO: How to determinate if a token is user defined?
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
score: float = tokenizer.get_score(i) # if tokenizer.is_user_defined(i): toktype = 4
if tokenizer.is_unused(i): toktype = 5
if tokenizer.is_byte(i): toktype = 6
tokens.append(text) tokens.append(text)
scores.append(score) scores.append(score)
toktypes.append(toktype)
gguf_writer.add_tokenizer_model("llama") gguf_writer.add_tokenizer_model("llama")
gguf_writer.add_token_list(tokens) gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
if Path(dir_model + "/tokenizer.json").is_file(): if Path(dir_model + "/tokenizer.json").is_file():
with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f: with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:
@ -214,7 +220,7 @@ for part_name in part_names:
# permute these # permute these
if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"): if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
data = permute(data,head_count) data = permute(data, head_count, head_count_kv)
# map tensor names # map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map: if name.endswith(".weight") and name[:-7] in tensor_map:
@ -283,7 +289,7 @@ for part_name in part_names:
# permute these # permute these
if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"): if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
data = permute(data, head_count) data = permute(data, head_count, head_count_kv)
# map tensor names # map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map: if name.endswith(".weight") and name[:-7] in tensor_map:

View file

@ -54,7 +54,7 @@ int main(int argc, char ** argv) {
const int max_context_size = llama_n_ctx(ctx); const int max_context_size = llama_n_ctx(ctx);
const int max_tokens_list_size = max_context_size - 4; const int max_tokens_list_size = max_context_size - 4;
if ((int)tokens_list.size() > max_tokens_list_size) { if ((int) tokens_list.size() > max_tokens_list_size) {
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size); fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size);
return 1; return 1;
} }
@ -62,7 +62,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n\n"); fprintf(stderr, "\n\n");
for (auto id : tokens_list) { for (auto id : tokens_list) {
fprintf(stderr, "%s", llama_token_to_str(ctx, id)); fprintf(stderr, "%s", llama_token_to_str(ctx, id).c_str());
} }
fflush(stderr); fflush(stderr);
@ -109,7 +109,7 @@ int main(int argc, char ** argv) {
} }
// print the new token : // print the new token :
printf("%s", llama_token_to_str(ctx, new_token_id)); printf("%s", llama_token_to_str(ctx, new_token_id).c_str());
fflush(stdout); fflush(stdout);
// push this new token for next evaluation // push this new token for next evaluation

View file

@ -7,6 +7,7 @@
#endif #endif
#include "gguf-util.h" #include "gguf-util.h"
#define LLAMA_API_CPP // TODO: eliminate me
#include "gguf-llama.h" #include "gguf-llama.h"
#include "ggml.h" #include "ggml.h"
@ -76,7 +77,7 @@ static std::string to_string(const T & val) {
#define LLAMA_MAX_SCRATCH_BUFFERS 16 #define LLAMA_MAX_SCRATCH_BUFFERS 16
#endif #endif
typedef void (*offload_func_t)(struct ggml_tensor * tensor); #define UNUSED GGML_UNUSED
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
#define llama_host_malloc(n) ggml_cuda_host_malloc(n) #define llama_host_malloc(n) ggml_cuda_host_malloc(n)
@ -125,6 +126,8 @@ struct llama_buffer {
} }
}; };
typedef void (*offload_func_t)(struct ggml_tensor * tensor);
void llama_nop(struct ggml_tensor * tensor) { // don't offload by default void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
(void) tensor; (void) tensor;
} }
@ -623,12 +626,13 @@ struct gguf_file_loader {
hparams.n_embd = read_u32("llama.embedding_length"); hparams.n_embd = read_u32("llama.embedding_length");
hparams.n_ff = read_u32("llama.feed_forward_length"); hparams.n_ff = read_u32("llama.feed_forward_length");
hparams.n_head = read_u32("llama.attention.head_count"); hparams.n_head = read_u32("llama.attention.head_count");
hparams.n_layer = read_u32("llama.layer_count"); hparams.n_layer = read_u32("llama.block_count");
hparams.n_rot = read_u32("llama.rope.dimension_count"); hparams.n_rot = read_u32("llama.rope.dimension_count");
hparams.f_rms_norm_eps = read_f32("llama.attention.layer_norm_rms_epsilon"); hparams.f_rms_norm_eps = read_f32("llama.attention.layer_norm_rms_epsilon");
// LLaMAv2 // n_head_kv default to n_head
// hparams.n_head_kv = read_u32("llama.attention.head_count_kv"); hparams.n_head_kv = gguf_find_key(gguf_ctx, "llama.attention.head_count_kv") == -1 ? hparams.n_head : read_u32("llama.attention.head_count_kv");
} }
void read_vocab() { void read_vocab() {
@ -1081,7 +1085,7 @@ static bool kv_cache_init(
cache.ctx = ggml_init(params); cache.ctx = ggml_init(params);
if (!cache.ctx) { if (!cache.ctx) {
fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); LLAMA_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
return false; return false;
} }
@ -1370,7 +1374,7 @@ static void llama_model_load_internal(
ml->ggml_ctx = ctx; ml->ggml_ctx = ctx;
model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU); model.tok_embeddings = ml->get_tensor("token_embd.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU);
// "output" tensor // "output" tensor
{ {
@ -1391,8 +1395,8 @@ static void llama_model_load_internal(
backend_output = GGML_BACKEND_CPU; backend_output = GGML_BACKEND_CPU;
} }
model.norm = ml->get_tensor("norm.weight", {n_embd}, backend_norm); model.norm = ml->get_tensor("output_norm.weight", {n_embd}, backend_norm);
model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output);
if (backend_norm == GGML_BACKEND_GPU) { if (backend_norm == GGML_BACKEND_GPU) {
vram_weights += ggml_nbytes(model.norm); vram_weights += ggml_nbytes(model.norm);
} }
@ -1410,20 +1414,20 @@ static void llama_model_load_internal(
auto & layer = model.layers[i]; auto & layer = model.layers[i];
std::string layers_i = "layers." + std::to_string(i); std::string layers_i = "blk." + std::to_string(i);
layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend); layer.attention_norm = ml->get_tensor(layers_i + ".attn_norm.weight", {n_embd}, backend);
layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split); layer.wq = ml->get_tensor(layers_i + ".attn_q.weight", {n_embd, n_embd}, backend_split);
layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd_gqa}, backend_split); layer.wk = ml->get_tensor(layers_i + ".attn_k.weight", {n_embd, n_embd_gqa}, backend_split);
layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd_gqa}, backend_split); layer.wv = ml->get_tensor(layers_i + ".attn_v.weight", {n_embd, n_embd_gqa}, backend_split);
layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split); layer.wo = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd, n_embd}, backend_split);
layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split); layer.w1 = ml->get_tensor(layers_i + ".ffn_gate.weight", {n_embd, n_ff}, backend_split);
layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split); layer.w2 = ml->get_tensor(layers_i + ".ffn_down.weight", { n_ff, n_embd}, backend_split);
layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split); layer.w3 = ml->get_tensor(layers_i + ".ffn_up.weight", {n_embd, n_ff}, backend_split);
if (backend == GGML_BACKEND_GPU) { if (backend == GGML_BACKEND_GPU) {
vram_weights += vram_weights +=
@ -2109,6 +2113,122 @@ static bool llama_eval_internal(
// tokenizer // tokenizer
// //
static std::string llama_vocab_type(const llama_vocab & vocab) {
return vocab.token_to_id.size() == 32000 ? "spm": "bpe";
}
static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") {
return token >= 259;
}
if (llama_vocab_type(vocab) == "bpe") {
return token >= 95;
}
return false;
}
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") {
return token == 0;
}
// TODO: improve?
return false;
}
static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") {
return token == 1 || token == 2;
}
// TODO: improve?
return false;
}
static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") {
return token == 1;
}
// TODO: improve?
return false;
}
static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") {
return token == 2;
}
// TODO: improve?
return false;
}
static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) {
UNUSED(vocab);
UNUSED(token);
// TODO: improve?
return false;
}
static bool llama_is_unused_token(const llama_vocab & vocab, llama_token token) {
UNUSED(vocab);
UNUSED(token);
// TODO: improve?
return false;
}
static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") {
return 3 <= token && token < 259;
}
if (llama_vocab_type(vocab) == "bpe") {
return 1 <= token && token < 95;
}
return false;
}
static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) {
if (llama_vocab_type(vocab) == "spm") {
return byte + 3;
}
if (llama_vocab_type(vocab) == "bpe") {
return byte + 32;
}
return false;
}
static std::string llama_escape_whitespace(const std::string& text) {
std::string result;
bool escaping = false;
result += "\xe2\x96\x81";
for (size_t offs = 0; offs < text.length(); ++offs) {
if (text[offs] == ' ') {
if (!escaping) {
result += "\xe2\x96\x81";
escaping = true;
}
}
else {
escaping = false;
result += text[offs];
}
}
return result;
}
static std::string llama_unescape_whitespace(const std::string& word) {
if (word.length() >= 3 && word.substr(0, 3) == "\xe2\x96\x81") {
return std::string(" ") + word.substr(3);
}
return word;
}
static size_t utf8_len(char src) { static size_t utf8_len(char src) {
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
uint8_t highbits = static_cast<uint8_t>(src) >> 4; uint8_t highbits = static_cast<uint8_t>(src) >> 4;
@ -2150,10 +2270,11 @@ struct llama_tokenizer {
size_t offs = 0; size_t offs = 0;
while (offs < text.size()) { while (offs < text.size()) {
llama_sp_symbol sym; llama_sp_symbol sym;
size_t char_len = std::min(text.size() - offs, utf8_len(text[offs])); size_t len = utf8_len(text[offs]);
GGML_ASSERT(offs + len <= text.size());
sym.text = text.c_str() + offs; sym.text = text.c_str() + offs;
sym.n = char_len; sym.n = len;
offs += char_len; offs += len;
sym.prev = index - 1; sym.prev = index - 1;
sym.next = offs == text.size() ? -1 : index + 1; sym.next = offs == text.size() ? -1 : index + 1;
index++; index++;
@ -2198,23 +2319,36 @@ struct llama_tokenizer {
for (int i = 0; i != -1; i = symbols_[i].next) { for (int i = 0; i != -1; i = symbols_[i].next) {
auto & symbol = symbols_[i]; auto & symbol = symbols_[i];
auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n)); resegment(symbol, output);
if (token == vocab_.token_to_id.end()) {
// output any symbols that did not form tokens as bytes.
for (int j = 0; j < (int) symbol.n; ++j) {
// NOTE: old version, before #2420 - not sure what are the implications of this
//llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
llama_vocab::id token_id = vocab_.token_to_id.at(std::string(1, symbol.text[j]));
output.push_back(token_id);
}
} else {
output.push_back((*token).second);
}
} }
} }
private: private:
void resegment(llama_sp_symbol &symbol, std::vector<llama_vocab::id> &output) {
auto text = std::string(symbol.text, symbol.n);
auto token = vocab_.token_to_id.find(text);
// Do we need to support is_unused?
if (token != vocab_.token_to_id.end()) {
output.push_back((*token).second);
return;
}
const auto p = rev_merge.find(text);
if (p == rev_merge.end()) {
// output any symbols that did not form tokens as bytes.
for (int j = 0; j < (int)symbol.n; ++j) {
llama_vocab::id token_id = llama_byte_to_char(vocab_, symbol.text[j]);
output.push_back(token_id);
}
return;
}
resegment(symbols_[p->second.first], output);
resegment(symbols_[p->second.second], output);
}
void try_add_bigram(int left, int right) { void try_add_bigram(int left, int right) {
if (left == -1 || right == -1) { if (left == -1 || right == -1) {
return; return;
@ -2239,18 +2373,22 @@ private:
bigram.score = tok_score.score; bigram.score = tok_score.score;
bigram.size = text.size(); bigram.size = text.size();
work_queue_.push(bigram); work_queue_.push(bigram);
// Do we need to support is_unused?
rev_merge[text] = std::make_pair(left, right);
} }
const llama_vocab & vocab_; const llama_vocab & vocab_;
std::vector<llama_sp_symbol> symbols_; std::vector<llama_sp_symbol> symbols_;
llama_sp_bigram::queue work_queue_; llama_sp_bigram::queue work_queue_;
std::map<std::string, std::pair<int, int> > rev_merge;
}; };
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) { static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) {
llama_tokenizer tokenizer(vocab); llama_tokenizer tokenizer(vocab);
std::vector<llama_vocab::id> output; std::vector<llama_vocab::id> output;
if (text.empty()) { if (raw_text.empty()) {
return output; return output;
} }
@ -2258,6 +2396,13 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
output.push_back(llama_token_bos()); output.push_back(llama_token_bos());
} }
std::string text;
if (escape) {
text = llama_escape_whitespace(raw_text);
} else {
text = raw_text;
}
tokenizer.tokenize(text, output); tokenizer.tokenize(text, output);
return output; return output;
} }
@ -2839,15 +2984,15 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id; const llama_token id = candidates->data[i].id;
const char * str = llama_token_to_str(ctx, id); std::string str = llama_token_to_str(ctx, id);
if (id == eos) { if (id == eos) {
if (!allow_eos) { if (!allow_eos) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
} }
} else if (*str == 0) { } else if (str.empty()) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
} else { } else {
candidates_decoded.push_back(decode_utf8(str)); candidates_decoded.push_back(decode_utf8(str.c_str()));
candidates_grammar.push_back({ i, candidates_decoded.back().data() }); candidates_grammar.push_back({ i, candidates_decoded.back().data() });
} }
} }
@ -3048,9 +3193,9 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false); GGML_ASSERT(false);
} }
const char * str = llama_token_to_str(ctx, token); std::string str = llama_token_to_str(ctx, token);
// Note terminating 0 in decoded string // Note terminating 0 in decoded string
auto code_points = decode_utf8(str); auto code_points = decode_utf8(str.c_str());
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
} }
@ -3507,6 +3652,12 @@ struct llama_context * llama_new_context_with_model(
// this allocates all Metal resources and memory buffers // this allocates all Metal resources and memory buffers
ctx->ctx_metal = ggml_metal_init(1); ctx->ctx_metal = ggml_metal_init(1);
if (!ctx->ctx_metal) {
LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__);
llama_free(ctx);
return NULL;
}
void * data_ptr = NULL; void * data_ptr = NULL;
size_t data_size = 0; size_t data_size = 0;
@ -4281,7 +4432,8 @@ int llama_tokenize_with_model(
llama_token * tokens, llama_token * tokens,
int n_max_tokens, int n_max_tokens,
bool add_bos) { bool add_bos) {
auto res = llama_tokenize(model->vocab, text, add_bos); auto escape = llama_vocab_type(model->vocab) == "spm";
auto res = llama_tokenize(model->vocab, text, add_bos, escape);
if (n_max_tokens < (int) res.size()) { if (n_max_tokens < (int) res.size()) {
LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
@ -4304,6 +4456,62 @@ int llama_tokenize(
return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos); return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos);
} }
std::vector<llama_token> llama_tokenize(
struct llama_context * ctx,
const std::string & text,
bool add_bos) {
int length = text.length() + add_bos;
std::vector<llama_token> result(length);
length = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
if (length < 0) {
result.resize(-length);
int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
assert(check == -length);
GGML_UNUSED(check);
} else {
result.resize(length);
}
return result;
}
int llama_tokenize_bpe(
struct llama_context * ctx,
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos) {
auto res = llama_tokenize(ctx->model.vocab, text, add_bos, false);
if (n_max_tokens < (int) res.size()) {
LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
return -((int) res.size());
}
for (size_t i = 0; i < res.size(); i++) {
tokens[i] = res[i];
}
return res.size();
}
std::vector<llama_token> llama_tokenize_bpe(
struct llama_context * ctx,
const std::string & text,
bool add_bos) {
int length = text.length() + add_bos;
std::vector<llama_token> result(length);
length = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
if (length < 0) {
result.resize(-length);
int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
assert(check == -length);
GGML_UNUSED(check);
} else {
result.resize(length);
}
return result;
}
int llama_n_vocab_from_model(const struct llama_model * model) { int llama_n_vocab_from_model(const struct llama_model * model) {
return model->vocab.id_to_token.size(); return model->vocab.id_to_token.size();
} }
@ -4357,16 +4565,80 @@ float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embedding.data(); return ctx->embedding.data();
} }
const char * llama_token_to_str_with_model(const struct llama_model * model, llama_token token) { int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * str, int length) {
if (token >= llama_n_vocab_from_model(model)) { if (0 <= token && token < llama_n_vocab_from_model(model)) {
return nullptr; if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].tok;
if(llama_vocab_type(model->vocab) == "spm") {
result = llama_unescape_whitespace(result);
}
if (length < (int) result.length()) {
return -result.length();
}
strncpy(str, result.c_str(), result.length());
return result.length();
} else if (llama_is_unknown_token(model->vocab, token)) {
if (length < 3) {
return -3;
}
strncpy(str, "\xe2\x96\x85", 3);
return 3;
} else if (llama_is_control_token(model->vocab, token)) {
;
} else if (llama_is_byte_token(model->vocab, token)) {
if (length < 1) {
return -1;
}
str[0] = llama_byte_to_char(model->vocab, token);
str[1] = 0x00;
return 1;
}
} }
return 0;
return model->vocab.id_to_token[token].tok.c_str();
} }
const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) { int llama_token_to_str(const struct llama_context * ctx, llama_token token, char * str, int length) {
return llama_token_to_str_with_model(&ctx->model, token); return llama_token_to_str_with_model(&ctx->model, token, str, length);
}
std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
const int length = llama_token_to_str(ctx, token, result.data(), result.size());
if (length < 0) {
result.resize(-length);
int check = llama_token_to_str(ctx, token, result.data(), result.size());
GGML_ASSERT(check == -length);
} else {
result.resize(length);
}
return std::string(result.data(), result.size());
}
int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * str, int length) {
if (0 <= token && token < llama_n_vocab_from_model(&ctx->model)) {
std::string result = ctx->model.vocab.id_to_token[token].tok;
if (length < (int) result.length()) {
return -result.length();
}
strncpy(str, result.c_str(), result.length());
return result.length();
}
return 0;
}
std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
const int length = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
if (length < 0) {
result.resize(-length);
const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
GGML_ASSERT(check == -length);
} else {
result.resize(length);
}
return std::string(result.data(), result.size());
} }
llama_token llama_token_bos() { llama_token llama_token_bos() {

View file

@ -311,6 +311,13 @@ extern "C" {
int n_max_tokens, int n_max_tokens,
bool add_bos); bool add_bos);
LLAMA_API int llama_tokenize_bpe(
struct llama_context * ctx,
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
LLAMA_API int llama_tokenize_with_model( LLAMA_API int llama_tokenize_with_model(
const struct llama_model * model, const struct llama_model * model,
const char * text, const char * text,
@ -352,14 +359,23 @@ extern "C" {
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context // Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str( LLAMA_API int llama_token_to_str(
const struct llama_context * ctx, const struct llama_context * ctx,
llama_token token); llama_token token,
char * str,
int length);
LLAMA_API const char * llama_token_to_str_with_model( LLAMA_API int llama_token_to_str_bpe(
const struct llama_context * ctx,
llama_token token,
char * str,
int length);
LLAMA_API int llama_token_to_str_with_model(
const struct llama_model * model, const struct llama_model * model,
llama_token token); llama_token token,
char * str,
int length);
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_eos(); // end-of-sentence
@ -447,15 +463,43 @@ extern "C" {
} }
#endif #endif
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only // C++ API, will be moving to common.h soon (TM)
#ifdef LLAMA_API_INTERNAL #ifdef LLAMA_API_CPP
#include <vector> #include <vector>
#include <string> #include <string>
//
// Vocab utils
//
std::vector<llama_token> llama_tokenize(
struct llama_context * ctx,
const std::string & text,
bool add_bos);
std::vector<llama_token> llama_tokenize_bpe(
struct llama_context * ctx,
const std::string & text,
bool add_bos);
std::string llama_token_to_str(
const struct llama_context * ctx,
llama_token token);
std::string llama_token_to_str_bpe(
const struct llama_context * ctx,
llama_token token);
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
#ifdef LLAMA_API_INTERNAL
struct ggml_tensor; struct ggml_tensor;
const std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx); const std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
#endif #endif // LLAMA_API_CPP
#endif // LLAMA_API_INTERNAL
#endif // LLAMA_H #endif // LLAMA_H

View file

@ -297,6 +297,9 @@ class GGUFWriter:
def add_token_merges(self, merges: List): def add_token_merges(self, merges: List):
self.add_array(constants.KEY_TOKENIZER_MERGES, merges) self.add_array(constants.KEY_TOKENIZER_MERGES, merges)
def add_token_types(self, types: List[int]):
self.add_array(constants.KEY_TOKENIZER_TOKEN_TYPE, types)
def add_token_scores(self, scores: List[float]): def add_token_scores(self, scores: List[float]):
self.add_array(constants.KEY_TOKENIZER_SCORES, scores) self.add_array(constants.KEY_TOKENIZER_SCORES, scores)

View file

@ -4,92 +4,92 @@ def get_tensor_namemap( n_blocks : int):
tensor_map = {} tensor_map = {}
# Token embeddings # Token embeddings
mapped_to = "token_embd" mapped_to = "token_embd"
tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox
tensor_map["transformer.wte"] = mapped_to # gpt2 mpt tensor_map["transformer.wte"] = mapped_to # gpt2 mpt
tensor_map["transformer.word_embeddings"] = mapped_to # falcon tensor_map["transformer.word_embeddings"] = mapped_to # falcon
tensor_map["model.embed_tokens"] = mapped_to # llama-hf tensor_map["model.embed_tokens"] = mapped_to # llama-hf
tensor_map["tok_embeddings"] = mapped_to # llama-pth tensor_map["tok_embeddings"] = mapped_to # llama-pth
# Position embeddings # Position embeddings
mapped_to = "pos_embd" mapped_to = "pos_embd"
tensor_map["transformer.wpe"] = mapped_to # gpt2 tensor_map["transformer.wpe"] = mapped_to # gpt2
# Output norm # Output norm
mapped_to = "output_norm" mapped_to = "output_norm"
tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox
tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon
tensor_map["transformer.norm_f"] = mapped_to # mpt tensor_map["transformer.norm_f"] = mapped_to # mpt
tensor_map["model.norm"] = mapped_to # llama-hf tensor_map["model.norm"] = mapped_to # llama-hf
tensor_map["norm"] = mapped_to # llama-pth tensor_map["norm"] = mapped_to # llama-pth
# Output # Output
mapped_to = "output" mapped_to = "output"
tensor_map["embed_out"] = mapped_to # gptneox tensor_map["embed_out"] = mapped_to # gptneox
tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf
tensor_map["output"] = mapped_to # llama-pth tensor_map["output"] = mapped_to # llama-pth
# Attention and fee-forward layer blocks # Attention and fee-forward layer blocks
for i in range(0,n_blocks): for i in range(0,n_blocks):
# Attention norm # Attention norm
mapped_to = "blk."+str(i)+".attn_norm" mapped_to = "blk."+str(i)+".attn_norm"
tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b
tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b
tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth
# Attention norm 2 # Attention norm 2
mapped_to = "blk."+str(i)+".attn_norm_2" mapped_to = "blk."+str(i)+".attn_norm_2"
tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b
# Attention query-key-value # Attention query-key-value
mapped_to = "blk."+str(i)+".attn_qkv" mapped_to = "blk."+str(i)+".attn_qkv"
tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon
# Attention query # Attention query
mapped_to = "blk."+str(i)+".attn_q" mapped_to = "blk."+str(i)+".attn_q"
tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth
# Attention key # Attention key
mapped_to = "blk."+str(i)+".attn_k" mapped_to = "blk."+str(i)+".attn_k"
tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth
# Attention value # Attention value
mapped_to = "blk."+str(i)+".attn_v" mapped_to = "blk."+str(i)+".attn_v"
tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth
# Attention output # Attention output
mapped_to = "blk."+str(i)+".attn_output" mapped_to = "blk."+str(i)+".attn_output"
tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth
# Feed-forward norm # Feed-forward norm
mapped_to = "blk."+str(i)+".ffn_norm" mapped_to = "blk."+str(i)+".ffn_norm"
tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt
tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth
# Feed-forward up # Feed-forward up
mapped_to = "blk."+str(i)+".ffn_up" mapped_to = "blk."+str(i)+".ffn_up"
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth
# Feed-forward gate # Feed-forward gate
mapped_to = "blk."+str(i)+".ffn_gate" mapped_to = "blk."+str(i)+".ffn_gate"
tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth
# Feed-forward down # Feed-forward down
mapped_to = "blk."+str(i)+".ffn_down" mapped_to = "blk."+str(i)+".ffn_down"
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth
return tensor_map return tensor_map

150
llama.cpp
View file

@ -72,6 +72,7 @@ static void llama_log_callback_default(llama_log_level level, const char * text,
#define LLAMA_MAX_SCRATCH_BUFFERS 16 #define LLAMA_MAX_SCRATCH_BUFFERS 16
#endif #endif
#define UNUSED GGML_UNUSED
// available llama models // available llama models
enum e_model { enum e_model {
@ -1943,77 +1944,94 @@ static bool llama_eval_internal(
// tokenizer // tokenizer
// //
static std::string llama_vocab_type(const llama_vocab& vocab) { static std::string llama_vocab_type(const llama_vocab & vocab) {
return vocab.token_to_id.size() == 32000 ? "spm": "bpe"; return vocab.token_to_id.size() == 32000 ? "spm": "bpe";
} }
static bool llama_is_normal_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token >= 259; return token >= 259;
else if(llama_vocab_type(vocab) == "bpe") }
if (llama_vocab_type(vocab) == "bpe") {
return token >= 95; return token >= 95;
else }
return false;
return false;
} }
static bool llama_is_unknown_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token == 0; return token == 0;
else }
// TODO: improve?
return false; // TODO: improve?
return false;
} }
static bool llama_is_control_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token == 1 || token == 2; return token == 1 || token == 2;
else }
// TODO: improve?
return false; // TODO: improve?
return false;
} }
static bool llama_is_bos_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token == 1; return token == 1;
else }
// TODO: improve?
return false; // TODO: improve?
return false;
} }
static bool llama_is_eos_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token == 2; return token == 2;
else }
// TODO: improve?
return false;
}
static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token token) {
// TODO: improve? // TODO: improve?
return false; return false;
} }
static bool llama_is_unused_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) {
UNUSED(vocab);
UNUSED(token);
// TODO: improve? // TODO: improve?
return false; return false;
} }
static bool llama_is_byte_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_unused_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") UNUSED(vocab);
UNUSED(token);
// TODO: improve?
return false;
}
static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") {
return 3 <= token && token < 259; return 3 <= token && token < 259;
else if(llama_vocab_type(vocab) == "bpe") }
if (llama_vocab_type(vocab) == "bpe") {
return 1 <= token && token < 95; return 1 <= token && token < 95;
else }
return false;
return false;
} }
static uint8_t llama_byte_to_char(const llama_vocab& vocab, uint8_t byte) { static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return byte + 3; return byte + 3;
else if(llama_vocab_type(vocab) == "bpe") }
if (llama_vocab_type(vocab) == "bpe") {
return byte + 32; return byte + 32;
else }
return false;
return false;
} }
static std::string llama_escape_whitespace(const std::string& text) { static std::string llama_escape_whitespace(const std::string& text) {
@ -4399,21 +4417,21 @@ int llama_token_to_str_with_model(const struct llama_model * model, llama_token
if(llama_vocab_type(model->vocab) == "spm") { if(llama_vocab_type(model->vocab) == "spm") {
result = llama_unescape_whitespace(result); result = llama_unescape_whitespace(result);
} }
if(result.length() > length) { if (length < (int) result.length()) {
return - result.length(); return -result.length();
} }
strcpy(str, result.c_str()); strncpy(str, result.c_str(), result.length());
return result.length(); return result.length();
} else if (llama_is_unknown_token(model->vocab, token)) { } else if (llama_is_unknown_token(model->vocab, token)) {
if(3 > length) { if (length < 3) {
return -3; return -3;
} }
strcpy(str, "\xe2\x96\x85"); strncpy(str, "\xe2\x96\x85", 3);
return 3; return 3;
} else if (llama_is_control_token(model->vocab, token)) { } else if (llama_is_control_token(model->vocab, token)) {
; ;
} else if (llama_is_byte_token(model->vocab, token)) { } else if (llama_is_byte_token(model->vocab, token)) {
if(1 > length) { if (length < 1) {
return -1; return -1;
} }
str[0] = llama_byte_to_char(model->vocab, token); str[0] = llama_byte_to_char(model->vocab, token);
@ -4428,52 +4446,44 @@ int llama_token_to_str(const struct llama_context * ctx, llama_token token, char
return llama_token_to_str_with_model(&ctx->model, token, str, length); return llama_token_to_str_with_model(&ctx->model, token, str, length);
} }
std::string llama_token_to_str( std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
const struct llama_context * ctx, std::vector<char> result(8, 0);
llama_token token) { const int length = llama_token_to_str(ctx, token, result.data(), result.size());
std::string result;
int length = 8;
result.resize(length);
length = llama_token_to_str(ctx, token, (char *)result.data(), result.length());
if (length < 0) { if (length < 0) {
result.resize(-length); result.resize(-length);
int check = llama_token_to_str(ctx, token, (char *)result.data(), result.length()); int check = llama_token_to_str(ctx, token, result.data(), result.size());
assert(check == -length); GGML_ASSERT(check == -length);
GGML_UNUSED(check);
} else { } else {
result.resize(length); result.resize(length);
} }
return result;
return std::string(result.data(), result.size());
} }
int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * str, int length) { int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * str, int length) {
if (0 <= token && token < llama_n_vocab_from_model(&ctx->model)) { if (0 <= token && token < llama_n_vocab_from_model(&ctx->model)) {
std::string result = ctx->model.vocab.id_to_token[token].tok; std::string result = ctx->model.vocab.id_to_token[token].tok;
if (result.length() > length) { if (length < (int) result.length()) {
return - result.length(); return -result.length();
} }
strcpy(str, result.c_str()); strncpy(str, result.c_str(), result.length());
return result.length(); return result.length();
} }
return 0; return 0;
} }
std::string llama_token_to_str_bpe( std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) {
const struct llama_context * ctx, std::vector<char> result(8, 0);
llama_token token) { const int length = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
std::string result;
int length = 8;
result.resize(length);
length = llama_token_to_str_bpe(ctx, token, (char*)result.data(), result.length());
if (length < 0) { if (length < 0) {
result.resize(-length); result.resize(-length);
int check = llama_token_to_str_bpe(ctx, token, (char*)result.data(), result.length()); const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
assert(check == -length); GGML_ASSERT(check == -length);
GGML_UNUSED(check);
} else { } else {
result.resize(length); result.resize(length);
} }
return result;
return std::string(result.data(), result.size());
} }
llama_token llama_token_bos() { llama_token llama_token_bos() {

View file

@ -8,14 +8,13 @@
static std::string unescape_whitespace(llama_context* ctx, const std::vector<llama_token>& tokens) { static std::string unescape_whitespace(llama_context* ctx, const std::vector<llama_token>& tokens) {
std::string result; std::string result;
for (int i = 0; i < tokens.size(); ++i) { for (size_t i = 0; i < tokens.size(); ++i) {
result += llama_token_to_str(ctx, tokens[i]); result += llama_token_to_str(ctx, tokens[i]);
} }
return result; return result;
} }
static const std::map<std::string, std::vector<llama_token>> & k_tests() static const std::map<std::string, std::vector<llama_token>> & k_tests() {
{
static std::map<std::string, std::vector<llama_token>> _k_tests = { static std::map<std::string, std::vector<llama_token>> _k_tests = {
{ " ", {1, 259, }, }, { " ", {1, 259, }, },
{ "\t", { 1, 29871, 12, }, }, { "\t", { 1, 29871, 12, }, },
@ -30,16 +29,17 @@ static const std::map<std::string, std::vector<llama_token>> & k_tests()
{ "w048 7tuijk dsdfhu", { 1, 281, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, }, { "w048 7tuijk dsdfhu", { 1, 281, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, },
{ "нещо на Български", { 1, 1538, 4851, 665, 1386, 29713, 1305, }, }, { "нещо на Български", { 1, 1538, 4851, 665, 1386, 29713, 1305, }, },
{ "កាន់តែពិសេសអាចខលចេញ", { 1, 29871, 31849, 31324, 31934, 228, 162, 142, 228, 161, { "កាន់តែពិសេសអាចខលចេញ", { 1, 29871, 31849, 31324, 31934, 228, 162, 142, 228, 161,
146, 228, 162, 133, 228, 161, 153, 228, 161, 186, 146, 228, 162, 133, 228, 161, 153, 228, 161, 186,
31708, 228, 162, 132, 31708, 228, 161, 165, 31324, 228, 31708, 228, 162, 132, 31708, 228, 161, 165, 31324, 228,
161, 136, 228, 161, 132, 228, 161, 158, 228, 161, 161, 136, 228, 161, 132, 228, 161, 158, 228, 161,
136, 228, 162, 132, 228, 161, 140, }, }, 136, 228, 162, 132, 228, 161, 140, }, },
{ "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", { "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)",
{ 1, 29871, 243, 162, 157, 131, 313, 8945, 29897, 29871, { 1, 29871, 243, 162, 157, 131, 313, 8945, 29897, 29871,
243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598, 243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598,
313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681, 313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681,
313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, }, 313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, },
}; };
return _k_tests; return _k_tests;
}; };
@ -90,7 +90,7 @@ int main(int argc, char **argv) {
} }
for (const auto & test_kv : k_tests()) { for (const auto & test_kv : k_tests()) {
std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first.c_str(), true); std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, true);
fprintf(stderr, "%s : '%s' tokenized to '%s'\n", fprintf(stderr, "%s : '%s' tokenized to '%s'\n",
__func__, test_kv.first.c_str(), unescape_whitespace(ctx, res).c_str()); __func__, test_kv.first.c_str(), unescape_whitespace(ctx, res).c_str());

View file

@ -8,8 +8,9 @@
#include <codecvt> #include <codecvt>
#include <map> #include <map>
#include <vector> #include <vector>
#include <locale>
static std::string vocab_type(llama_context* ctx) { static std::string vocab_type(llama_context * ctx) {
return llama_n_vocab(ctx) == 32000 ? "spm": "bpe"; return llama_n_vocab(ctx) == 32000 ? "spm": "bpe";
} }
@ -32,9 +33,9 @@ static std::string escape_whitespace(const std::string& text) {
return result; return result;
} }
static std::string unescape_whitespace(llama_context* ctx, const std::vector<llama_token>& tokens) { static std::string unescape_whitespace(llama_context * ctx, const std::vector<llama_token> & tokens) {
std::string result; std::string result;
for (int i = 0; i < tokens.size(); ++i) { for (size_t i = 0; i < tokens.size(); ++i) {
result += llama_token_to_str(ctx, tokens[i]); result += llama_token_to_str(ctx, tokens[i]);
} }
return result; return result;