Merge a899673346
into ef6dada60c
This commit is contained in:
commit
448860d5e3
1 changed files with 157 additions and 33 deletions
|
@ -1,6 +1,6 @@
|
|||
#if defined(_WIN32)
|
||||
# include <windows.h>
|
||||
# include <io.h>
|
||||
# include <windows.h>
|
||||
#else
|
||||
# include <sys/file.h>
|
||||
# include <sys/ioctl.h>
|
||||
|
@ -12,12 +12,14 @@
|
|||
#endif
|
||||
|
||||
#include <signal.h>
|
||||
#include <sys/stat.h>
|
||||
|
||||
#include <climits>
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <sstream>
|
||||
|
@ -37,13 +39,14 @@
|
|||
#endif
|
||||
|
||||
GGML_ATTRIBUTE_FORMAT(1, 2)
|
||||
|
||||
static std::string fmt(const char * fmt, ...) {
|
||||
va_list ap;
|
||||
va_list ap2;
|
||||
va_start(ap, fmt);
|
||||
va_copy(ap2, ap);
|
||||
const int size = vsnprintf(NULL, 0, fmt, ap);
|
||||
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
||||
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
||||
std::string buf;
|
||||
buf.resize(size);
|
||||
const int size2 = vsnprintf(const_cast<char *>(buf.data()), buf.size() + 1, fmt, ap2);
|
||||
|
@ -55,6 +58,7 @@ static std::string fmt(const char * fmt, ...) {
|
|||
}
|
||||
|
||||
GGML_ATTRIBUTE_FORMAT(1, 2)
|
||||
|
||||
static int printe(const char * fmt, ...) {
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
|
@ -103,7 +107,8 @@ class Opt {
|
|||
|
||||
llama_context_params ctx_params;
|
||||
llama_model_params model_params;
|
||||
std::string model_;
|
||||
std::string model_;
|
||||
std::string chat_template_;
|
||||
std::string user;
|
||||
int context_size = -1, ngl = -1;
|
||||
float temperature = -1;
|
||||
|
@ -139,7 +144,7 @@ class Opt {
|
|||
}
|
||||
|
||||
int parse(int argc, const char ** argv) {
|
||||
bool options_parsing = true;
|
||||
bool options_parsing = true;
|
||||
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
|
||||
if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) {
|
||||
if (handle_option_with_value(argc, argv, i, context_size) == 1) {
|
||||
|
@ -168,6 +173,11 @@ class Opt {
|
|||
|
||||
++positional_args_i;
|
||||
model_ = argv[i];
|
||||
} else if (options_parsing && strcmp(argv[i], "--chat-template") == 0) {
|
||||
if (i + 1 >= argc) {
|
||||
return 1;
|
||||
}
|
||||
chat_template_ = argv[++i];
|
||||
} else if (positional_args_i == 1) {
|
||||
++positional_args_i;
|
||||
user = argv[i];
|
||||
|
@ -477,7 +487,9 @@ class HttpClient {
|
|||
return (now_downloaded_plus_file_size * 100) / total_to_download;
|
||||
}
|
||||
|
||||
static std::string generate_progress_prefix(curl_off_t percentage) { return fmt("%3ld%% |", static_cast<long int>(percentage)); }
|
||||
static std::string generate_progress_prefix(curl_off_t percentage) {
|
||||
return fmt("%3ld%% |", static_cast<long int>(percentage));
|
||||
}
|
||||
|
||||
static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
|
||||
const auto now = std::chrono::steady_clock::now();
|
||||
|
@ -517,6 +529,7 @@ class HttpClient {
|
|||
printe("\r%*s\r%s%s| %s", get_terminal_width(), " ", progress_prefix.c_str(), progress_bar.c_str(),
|
||||
progress_suffix.c_str());
|
||||
}
|
||||
|
||||
// Function to write data to a file
|
||||
static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
|
||||
FILE * out = static_cast<FILE *>(stream);
|
||||
|
@ -540,6 +553,7 @@ class LlamaData {
|
|||
std::vector<llama_chat_message> messages;
|
||||
std::list<std::string> msg_strs;
|
||||
std::vector<char> fmtted;
|
||||
std::string chat_template;
|
||||
|
||||
int init(Opt & opt) {
|
||||
model = initialize_model(opt);
|
||||
|
@ -547,12 +561,15 @@ class LlamaData {
|
|||
return 1;
|
||||
}
|
||||
|
||||
chat_template = initialize_chat_template(model, opt);
|
||||
|
||||
context = initialize_context(model, opt);
|
||||
if (!context) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
sampler = initialize_sampler(opt);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -575,21 +592,74 @@ class LlamaData {
|
|||
}
|
||||
#endif
|
||||
|
||||
int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
|
||||
int huggingface_dl_tmpl(const std::string & hfr, const std::vector<std::string> headers, const std::string & tn) {
|
||||
if (std::filesystem::exists(tn)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const std::string config_url = "https://huggingface.co/" + hfr + "/resolve/main/tokenizer_config.json";
|
||||
std::string tokenizer_config_str;
|
||||
download(config_url, headers, "", true, &tokenizer_config_str);
|
||||
if (tokenizer_config_str.empty()) {
|
||||
// still return success since tokenizer_config is optional
|
||||
return 0;
|
||||
}
|
||||
|
||||
nlohmann::json config = nlohmann::json::parse(tokenizer_config_str);
|
||||
std::string tmpl = config["chat_template"];
|
||||
|
||||
FILE * tmpl_file = fopen(tn.c_str(), "w");
|
||||
if (tmpl_file == NULL) {
|
||||
return 1;
|
||||
}
|
||||
fprintf(tmpl_file, "%s", tmpl.c_str());
|
||||
fclose(tmpl_file);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn,
|
||||
const std::string & tn) {
|
||||
bool model_exists = std::filesystem::exists(bn);
|
||||
bool chat_tmpl_exists = std::filesystem::exists(tn);
|
||||
if (model_exists && chat_tmpl_exists) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Find the second occurrence of '/' after protocol string
|
||||
size_t pos = model.find('/');
|
||||
pos = model.find('/', pos + 1);
|
||||
if (pos == std::string::npos) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string hfr = model.substr(0, pos);
|
||||
const std::string hff = model.substr(pos + 1);
|
||||
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
|
||||
return download(url, headers, bn, true);
|
||||
|
||||
if (!chat_tmpl_exists) {
|
||||
const int ret = huggingface_dl_tmpl(hfr, headers, tn);
|
||||
if (ret) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (!model_exists) {
|
||||
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
|
||||
const int ret = download(url, headers, bn, true);
|
||||
if (ret) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
|
||||
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn,
|
||||
const std::string & tn) {
|
||||
bool model_exists = std::filesystem::exists(bn);
|
||||
bool chat_tmpl_exists = std::filesystem::exists(tn);
|
||||
if (model_exists && chat_tmpl_exists) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (model.find('/') == std::string::npos) {
|
||||
model = "library/" + model;
|
||||
}
|
||||
|
@ -609,16 +679,34 @@ class LlamaData {
|
|||
}
|
||||
|
||||
nlohmann::json manifest = nlohmann::json::parse(manifest_str);
|
||||
std::string layer;
|
||||
std::string sha_model;
|
||||
std::string sha_template;
|
||||
for (const auto & l : manifest["layers"]) {
|
||||
if (l["mediaType"] == "application/vnd.ollama.image.model") {
|
||||
layer = l["digest"];
|
||||
break;
|
||||
sha_model = l["digest"];
|
||||
}
|
||||
if (l["mediaType"] == "application/vnd.ollama.image.template") {
|
||||
sha_template = l["digest"];
|
||||
}
|
||||
}
|
||||
|
||||
std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
|
||||
return download(blob_url, headers, bn, true);
|
||||
if (!chat_tmpl_exists && !sha_template.empty()) {
|
||||
std::string tmpl_blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + sha_template;
|
||||
const int tmpl_ret = download(tmpl_blob_url, headers, tn, true);
|
||||
if (tmpl_ret) {
|
||||
return tmpl_ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (!model_exists) {
|
||||
std::string model_blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + sha_model;
|
||||
const int model_ret = download(model_blob_url, headers, bn, true);
|
||||
if (model_ret) {
|
||||
return model_ret;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string basename(const std::string & path) {
|
||||
|
@ -630,6 +718,15 @@ class LlamaData {
|
|||
return path.substr(pos + 1);
|
||||
}
|
||||
|
||||
std::string get_proto(const std::string & model_) {
|
||||
const std::string::size_type pos = model_.find("://");
|
||||
if (pos == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
|
||||
return model_.substr(0, pos + 3); // Include "://"
|
||||
}
|
||||
|
||||
int remove_proto(std::string & model_) {
|
||||
const std::string::size_type pos = model_.find("://");
|
||||
if (pos == std::string::npos) {
|
||||
|
@ -640,30 +737,32 @@ class LlamaData {
|
|||
return 0;
|
||||
}
|
||||
|
||||
int resolve_model(std::string & model_) {
|
||||
int ret = 0;
|
||||
if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) {
|
||||
int resolve_model(std::string & model_, std::string & chat_template_) {
|
||||
int ret = 0;
|
||||
if (string_starts_with(model_, "file://")) {
|
||||
remove_proto(model_);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string proto = get_proto(model_);
|
||||
remove_proto(model_);
|
||||
|
||||
const std::string bn = basename(model_);
|
||||
const std::string tn = chat_template_.empty() ? bn + ".template" : chat_template_;
|
||||
const std::vector<std::string> headers = { "--header",
|
||||
"Accept: application/vnd.docker.distribution.manifest.v2+json" };
|
||||
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
|
||||
remove_proto(model_);
|
||||
ret = huggingface_dl(model_, headers, bn);
|
||||
} else if (string_starts_with(model_, "ollama://")) {
|
||||
remove_proto(model_);
|
||||
ret = ollama_dl(model_, headers, bn);
|
||||
} else if (string_starts_with(model_, "https://")) {
|
||||
if (string_starts_with(proto, "hf://") || string_starts_with(proto, "huggingface://")) {
|
||||
ret = huggingface_dl(model_, headers, bn, tn);
|
||||
} else if (string_starts_with(proto, "ollama://")) {
|
||||
ret = ollama_dl(model_, headers, bn, tn);
|
||||
} else if (string_starts_with(proto, "https://")) {
|
||||
download(model_, headers, bn, true);
|
||||
} else {
|
||||
ret = ollama_dl(model_, headers, bn);
|
||||
ret = ollama_dl(model_, headers, bn, tn);
|
||||
}
|
||||
|
||||
model_ = bn;
|
||||
model_ = bn;
|
||||
chat_template_ = tn;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
@ -671,7 +770,7 @@ class LlamaData {
|
|||
// Initializes the model and returns a unique pointer to it
|
||||
llama_model_ptr initialize_model(Opt & opt) {
|
||||
ggml_backend_load_all();
|
||||
resolve_model(opt.model_);
|
||||
resolve_model(opt.model_, opt.chat_template_);
|
||||
printe(
|
||||
"\r%*s"
|
||||
"\rLoading model",
|
||||
|
@ -704,6 +803,31 @@ class LlamaData {
|
|||
|
||||
return sampler;
|
||||
}
|
||||
|
||||
std::string initialize_chat_template(const llama_model_ptr & model, const Opt & opt) {
|
||||
if (!std::filesystem::exists(opt.chat_template_)) {
|
||||
return common_get_builtin_chat_template(model.get());
|
||||
}
|
||||
|
||||
FILE * tmpl_file = ggml_fopen(opt.chat_template_.c_str(), "r");
|
||||
if (!tmpl_file) {
|
||||
std::cerr << "Error opening file '" << opt.chat_template_ << "': " << strerror(errno) << "\n";
|
||||
return "";
|
||||
}
|
||||
|
||||
fseek(tmpl_file, 0, SEEK_END);
|
||||
size_t size = ftell(tmpl_file);
|
||||
fseek(tmpl_file, 0, SEEK_SET);
|
||||
|
||||
std::vector<unsigned char> data(size);
|
||||
size_t read_size = fread(data.data(), 1, size, tmpl_file);
|
||||
fclose(tmpl_file);
|
||||
if (read_size != size) {
|
||||
std::cerr << "Error reading file '" << opt.chat_template_ << "': " << strerror(errno) << "\n";
|
||||
return "";
|
||||
}
|
||||
return std::string(data.begin(), data.end());
|
||||
}
|
||||
};
|
||||
|
||||
// Add a message to `messages` and store its content in `msg_strs`
|
||||
|
@ -715,11 +839,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
|
|||
// Function to apply the chat template and resize `formatted` if needed
|
||||
static int apply_chat_template(LlamaData & llama_data, const bool append) {
|
||||
int result = llama_chat_apply_template(
|
||||
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
|
||||
llama_data.chat_template.c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
|
||||
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
|
||||
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
|
||||
llama_data.fmtted.resize(result);
|
||||
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
|
||||
result = llama_chat_apply_template(llama_data.chat_template.c_str(), llama_data.messages.data(),
|
||||
llama_data.messages.size(), append, llama_data.fmtted.data(),
|
||||
llama_data.fmtted.size());
|
||||
}
|
||||
|
@ -732,8 +856,8 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
|
|||
std::vector<llama_token> & prompt_tokens) {
|
||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
|
||||
prompt_tokens.resize(n_prompt_tokens);
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
|
||||
true) < 0) {
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) <
|
||||
0) {
|
||||
printe("failed to tokenize the prompt\n");
|
||||
return -1;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue