first try to intergrate sentencepiece
This commit is contained in:
parent
460c482540
commit
307dba3dd2
4 changed files with 41 additions and 30 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -21,3 +21,4 @@ models/*
|
|||
|
||||
arm_neon.h
|
||||
compile_commands.json
|
||||
*.dSYM/
|
||||
|
|
6
Makefile
6
Makefile
|
@ -30,9 +30,9 @@ endif
|
|||
# Compile flags
|
||||
#
|
||||
|
||||
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
||||
LDFLAGS =
|
||||
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC -g -I/opt/homebrew/include
|
||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -g -I/opt/homebrew/include
|
||||
LDFLAGS = -L/opt/homebrew/lib -lsentencepiece
|
||||
|
||||
# OS specific
|
||||
# TODO: support Windows
|
||||
|
|
20
main.cpp
20
main.cpp
|
@ -10,6 +10,12 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sentencepiece_processor.h>
|
||||
#include <stdexcept>
|
||||
#include <iostream>
|
||||
#include <bitset>
|
||||
#include <sstream>
|
||||
#include <regex>
|
||||
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
|
@ -82,7 +88,7 @@ struct llama_model {
|
|||
};
|
||||
|
||||
// load the model's weights from a file
|
||||
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
|
||||
bool llama_model_load(const std::string & fname, llama_model & model, sentencepiece::SentencePieceProcessor & sp, gpt_vocab & vocab, int n_ctx) {
|
||||
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
|
@ -144,6 +150,8 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
|
|||
return false;
|
||||
}
|
||||
|
||||
printf("total pieces: %d", sp.GetPieceSize());
|
||||
|
||||
std::string word;
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
|
@ -152,8 +160,9 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
|
|||
word.resize(len);
|
||||
fin.read((char *) word.data(), len);
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
std::string wordx = sp.IdToPiece(i);
|
||||
vocab.token_to_id[wordx] = i;
|
||||
vocab.id_to_token[i] = wordx;
|
||||
|
||||
//if (i < 30000) {
|
||||
// printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
|
||||
|
@ -764,6 +773,9 @@ int main(int argc, char ** argv) {
|
|||
gpt_params params;
|
||||
params.model = "models/llama-7B/ggml-model.bin";
|
||||
|
||||
sentencepiece::SentencePieceProcessor sp;
|
||||
sp.Load("./models/tokenizer.model");
|
||||
|
||||
if (gpt_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
@ -791,7 +803,7 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!llama_model_load(params.model, model, vocab, 512)) { // TODO: set context from user input ??
|
||||
if (!llama_model_load(params.model, model, sp, vocab, 512)) { // TODO: set context from user input ??
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
|
44
utils.cpp
44
utils.cpp
|
@ -4,6 +4,7 @@
|
|||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
#include <sentencepiece_processor.h>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
@ -281,33 +282,30 @@ std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::st
|
|||
|
||||
std::vector<gpt_vocab::id> res;
|
||||
|
||||
if (bos) {
|
||||
res.push_back(1); // TODO: replace with vocab.bos
|
||||
// if (bos) {
|
||||
// res.push_back(1); // TODO: replace with vocab.bos
|
||||
// }
|
||||
|
||||
sentencepiece::SentencePieceProcessor sp;
|
||||
sp.Load("./models/tokenizer.model");
|
||||
|
||||
std::vector<std::string> pieces;
|
||||
return sp.EncodeAsIds(text);
|
||||
/*
|
||||
for (const auto & piece : pieces) {
|
||||
printf("piece: %s\n", piece.c_str());
|
||||
if (vocab.token_to_id.count(piece) > 0) {
|
||||
res.push_back(vocab.token_to_id.at(piece));
|
||||
} else {
|
||||
// handle unknown token
|
||||
}
|
||||
}
|
||||
|
||||
//find the longest token that matches the text
|
||||
int pos = 0;
|
||||
while (true) {
|
||||
int l = 0;
|
||||
int t = 0;
|
||||
for (const auto & kv : vocab.id_to_token) {
|
||||
if (kv.second.size() < l) continue;
|
||||
if (kv.second.size() > text.size() - pos) continue;
|
||||
if (text.substr(pos, kv.second.size()) == kv.second) {
|
||||
l = kv.second.size();
|
||||
t = kv.first;
|
||||
}
|
||||
}
|
||||
|
||||
if (l == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
res.push_back(t);
|
||||
pos += l;
|
||||
for (const auto& id : res) {
|
||||
printf("%d\n", id);
|
||||
}
|
||||
|
||||
return res;
|
||||
return res;*/
|
||||
}
|
||||
|
||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue