common: introduce llama_load_model_from_url to download model from hf url using libopenssl only

This commit is contained in:
Pierrick HYMBERT 2024-03-16 09:59:05 +01:00
parent d84c48505f
commit 3221ab01ad
4 changed files with 179 additions and 1 deletions

View file

@ -47,6 +47,16 @@ if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif() endif()
# Check for OpenSSL
find_package(OpenSSL QUIET)
if (OPENSSL_FOUND)
add_definitions(-DHAVE_OPENSSL)
include_directories(${OPENSSL_INCLUDE_DIR})
link_libraries(${OPENSSL_LIBRARIES})
else()
message(WARNING "OpenSSL not found. Building without model download support.")
endif ()
set(TARGET common) set(TARGET common)

View file

@ -1376,10 +1376,160 @@ void llama_batch_add(
batch.n_tokens++; batch.n_tokens++;
} }
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
struct llama_model_params params) {
#ifdef HAVE_OPENSSL
// Initialize OpenSSL
SSL_library_init();
SSL_load_error_strings();
OpenSSL_add_all_algorithms();
// Parse the URL to extract host, path, user, and password
char host[256];
char path[256];
char userpass[256];
if (sscanf(model_url, "https://%255[^/]/%255s", host, path) != 2) {
fprintf(stderr, "%s: invalid URL format: %s\n", __func__, model_url);
return nullptr;
}
if (strstr(host, "@")) {
sscanf(host, "%[^@]@%s", userpass, host);
}
// Create an SSL context
auto ctx = SSL_CTX_new(TLS_client_method());
if (!ctx) {
fprintf(stderr, "%s: error creating SSL context\n", __func__);
return nullptr;
}
// Set up certificate verification
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, nullptr);
// Load trusted CA certificates based on platform
const char* ca_cert_path = nullptr;
#ifdef _WIN32
ca_cert_path = "C:\\path\\to\\ca-certificates.crt"; // Windows path (FIXME)
#elif __APPLE__
ca_cert_path = "/etc/ssl/cert.pem"; // macOS path
#else
ca_cert_path = "/etc/ssl/certs/ca-certificates.crt"; // Linux path
#endif
if (!SSL_CTX_load_verify_locations(ctx, ca_cert_path, nullptr)) {
fprintf(stderr, "%s: error loading CA certificates\n", __func__);
SSL_CTX_free(ctx);
return nullptr;
}
// Create an SSL connection
auto bio = BIO_new_ssl_connect(ctx);
if (!bio) {
fprintf(stderr, "%s: error creating SSL connection\n", __func__);
SSL_CTX_free(ctx);
return nullptr;
}
// Set the hostname
if (!BIO_set_conn_hostname(bio, host)) {
fprintf(stderr, "%s: unable to set connection hostname %s\n", __func__, host);
BIO_free_all(bio);
SSL_CTX_free(ctx);
return nullptr;
}
// Construct the HTTP request
char request[1024];
snprintf(request, sizeof(request), "GET /%s HTTP/1.1\r\nHost: %s\r\nAccept: */*\r\nUser-Agent: llama-client\r\nConnection: close\r\n", path, host);
// Add Authorization header if user credentials are available
if (strlen(userpass) > 0) {
char auth_header[256];
snprintf(auth_header, sizeof(auth_header), "Authorization: Basic %s\r\n", userpass);
strcat(request, auth_header);
}
// End of headers
strcat(request, "\r\n");
// Send the request
fprintf(stdout, "%s: downloading model from https://%s/%s to %s ...\n", __func__, host, path, path_model);
if (!BIO_puts(bio, request)) {
fprintf(stderr, "%s: error sending HTTP request https://%s/%s\n", __func__, host, path);
BIO_free_all(bio);
SSL_CTX_free(ctx);
return nullptr;
}
// Read the response status line
char status_line[256];
if (BIO_gets(bio, status_line, sizeof(status_line)) <= 0) {
fprintf(stderr, "%s: error reading response status line\n", __func__);
BIO_free_all(bio);
SSL_CTX_free(ctx);
return nullptr;
}
// Verify HTTP status code
if (strncmp(status_line, "HTTP/1.1 200", 12) != 0) {
fprintf(stderr, "%s: HTTP request failed: %s\n", __func__, status_line);
BIO_free_all(bio);
SSL_CTX_free(ctx);
return nullptr;
}
// Skip response headers
char buffer[4096];
int n_bytes_received;
while ((n_bytes_received = BIO_read(bio, buffer, sizeof(buffer))) > 0) {
// Look for the end of headers (empty line)
if (strstr(buffer, "\r\n\r\n")) {
break;
}
}
// Read and save the file content
FILE* outfile = fopen(path_model, "wb");
if (!outfile) {
fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path_model);
BIO_free_all(bio);
SSL_CTX_free(ctx);
return nullptr;
}
int n_bytes_received_total = 0;
while ((n_bytes_received = BIO_read(bio, buffer, sizeof(buffer))) > 0) {
fwrite(buffer, 1, n_bytes_received, outfile);
n_bytes_received_total += n_bytes_received;
if (n_bytes_received_total % (1024 * 1024) == 0) {
fprintf(stdout, "%s: model downloading %dGi %s ...\n", __func__, n_bytes_received_total / 1024 / 1024, path_model);
}
}
fclose(outfile);
// Clean up
BIO_free_all(bio);
SSL_CTX_free(ctx);
fprintf(stdout, "%s: model downloaded from https://%s/%s to %s.\n", __func__, host, path, path_model);
return llama_load_model_from_file(path_model, params);
#else
LLAMA_LOG_ERROR("llama.cpp built without SSL support, downloading from url not supported.\n", __func__);
return nullptr;
#endif
}
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) { std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params); auto mparams = llama_model_params_from_gpt_params(params);
llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); llama_model * model = nullptr;
if (!params.model_url.empty()) {
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
} else {
model = llama_load_model_from_file(params.model.c_str(), mparams);
}
if (model == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr); return std::make_tuple(nullptr, nullptr);

View file

@ -17,6 +17,12 @@
#include <unordered_map> #include <unordered_map>
#include <tuple> #include <tuple>
#ifdef HAVE_OPENSSL
#include <openssl/ssl.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#endif
#ifdef _WIN32 #ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\' #define DIRECTORY_SEPARATOR '\\'
#else #else
@ -89,6 +95,7 @@ struct gpt_params {
struct llama_sampling_params sparams; struct llama_sampling_params sparams;
std::string model = "models/7B/ggml-model-f16.gguf"; // model path std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_url = ""; // model path
std::string model_draft = ""; // draft model for speculative decoding std::string model_draft = ""; // draft model for speculative decoding
std::string model_alias = "unknown"; // model alias std::string model_alias = "unknown"; // model alias
std::string prompt = ""; std::string prompt = "";
@ -191,6 +198,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
struct llama_model_params params);
// Batch utils // Batch utils
void llama_batch_clear(struct llama_batch & batch); void llama_batch_clear(struct llama_batch & batch);

View file

@ -2195,6 +2195,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
} }
printf(" -m FNAME, --model FNAME\n"); printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str()); printf(" model path (default: %s)\n", params.model.c_str());
printf(" -u MODEL_URL, --url MODEL_URL\n");
printf(" model url (default: %s)\n", params.model_url.c_str());
printf(" -a ALIAS, --alias ALIAS\n"); printf(" -a ALIAS, --alias ALIAS\n");
printf(" set an alias for the model, will be added as `model` field in completion response\n"); printf(" set an alias for the model, will be added as `model` field in completion response\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
@ -2317,6 +2319,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
break; break;
} }
params.model = argv[i]; params.model = argv[i];
} else if (arg == "-u" || arg == "--model-url") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.model_url = argv[i];
} else if (arg == "-a" || arg == "--alias") { } else if (arg == "-a" || arg == "--alias") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;