From 6633689fa5cd972bfa3de3c06477996fb554f79b Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Sat, 16 Mar 2024 16:49:44 +0100 Subject: [PATCH] llama_load_model_from_url: cleanup code --- common/common.cpp | 134 +++++++++++++++++++++++++++------------------- 1 file changed, 79 insertions(+), 55 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 8b256e7fb..89b5ee501 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -53,6 +53,19 @@ #define GGML_USE_CUBLAS_SYCL_VULKAN #endif +#ifdef LLAMA_USE_CURL +#ifdef __linux__ +#include +#elif defined(_WIN32) +#include +#define PATH_MAX MAX_PATH +#else +#include +#endif +#define LLAMA_CURL_MAX_PATH_LENGTH PATH_MAX +#define LLAMA_CURL_MAX_HEADER_LENGTH 256 +#endif // LLAMA_USE_CURL + int32_t get_num_physical_cores() { #ifdef __linux__ // enumerate the set of thread siblings, num entries is num cores @@ -1389,11 +1402,17 @@ void llama_batch_add( #ifdef LLAMA_USE_CURL -struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, - struct llama_model_params params) { +struct llama_model *llama_load_model_from_url(const char *model_url, const char *path_model, + struct llama_model_params params) { + // Basic validation of the model_url + if (!model_url || strlen(model_url) == 0) { + fprintf(stderr, "%s: invalid model_url\n", __func__); + return NULL; + } + // Initialize libcurl globally curl_global_init(CURL_GLOBAL_DEFAULT); - CURL *curl = curl_easy_init(); + auto curl = curl_easy_init(); if (!curl) { curl_global_cleanup(); @@ -1408,73 +1427,77 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha // Check if the file already exists locally struct stat buffer; - int file_exists = (stat(path_model, &buffer) == 0); + auto file_exists = (stat(path_model, &buffer) == 0); - // If the file exists, check for an ETag file or a lastModified file - char etag[256] = {0}; - char etag_path[256] = {0}; - strcpy(etag_path, path_model); - strcat(etag_path, ".etag"); + // If the file exists, check for ${model_path}.etag or ${model_path}.lastModified files + char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; + char etag_path[LLAMA_CURL_MAX_PATH_LENGTH] = {0}; + strncpy(etag_path, path_model, LLAMA_CURL_MAX_PATH_LENGTH - 6); // 6 is the length of ".etag\0" + strncat(etag_path, ".etag", 6); - char last_modified[256] = {0}; - char last_modified_path[256] = {0}; - strcpy(last_modified_path, path_model); - strcat(last_modified_path, ".lastModified"); + char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; + char last_modified_path[LLAMA_CURL_MAX_PATH_LENGTH] = {0}; + strncpy(last_modified_path, path_model, LLAMA_CURL_MAX_PATH_LENGTH - 15); // 15 is the length of ".lastModified\0" + strncat(last_modified_path, ".lastModified", 15); if (file_exists) { - FILE *f_etag = fopen(etag_path, "r"); + auto *f_etag = fopen(etag_path, "r"); if (f_etag) { fgets(etag, sizeof(etag), f_etag); fclose(f_etag); fprintf(stderr, "%s: previous model .etag file found %s: %s\n", __func__, path_model, etag); } - FILE *f_last_modified = fopen(last_modified_path, "r"); + auto *f_last_modified = fopen(last_modified_path, "r"); if (f_last_modified) { fgets(last_modified, sizeof(last_modified), f_last_modified); - fclose(f_etag); - fprintf(stderr, "%s: previous model .lastModified file found %s: %s\n", __func__, last_modified_path, last_modified); + fclose(f_last_modified); + fprintf(stderr, "%s: previous model .lastModified file found %s: %s\n", __func__, last_modified_path, + last_modified); } } - // Send a HEAD request to retrieve the ETag and Last-Modified headers + // Send a HEAD request to retrieve the etag and last-modified headers struct llama_load_model_from_url_headers { - char etag[256] = {0}; - char last_modified[256] = {0}; + char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; + char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; }; - typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); - auto header_callback = [](char *buffer, size_t /*size*/, size_t n_items, void *userdata) -> size_t { - llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers*) userdata; - - const char *etag_prefix = "etag: "; - if (strncmp(buffer, etag_prefix, strlen(etag_prefix)) == 0) { - strncpy(headers->etag, buffer + strlen(etag_prefix), n_items - strlen(etag_prefix)- 2); // Remove LRLF - } - - const char *last_modified_prefix = "last-modified: "; - if (strncmp(buffer, last_modified_prefix, strlen(last_modified_prefix)) == 0) { - strncpy(headers->last_modified, buffer + strlen(last_modified_prefix), n_items - strlen(last_modified_prefix) - 2); // Remove LRLF - } - return n_items; - }; - - curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); llama_load_model_from_url_headers headers; - curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, static_cast(header_callback)); - curl_easy_setopt(curl, CURLOPT_HEADERDATA, &headers); + { + typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); + auto header_callback = [](char *buffer, size_t /*size*/, size_t n_items, void *userdata) -> size_t { + llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata; - CURLcode res = curl_easy_perform(curl); - if (res != CURLE_OK) { - curl_easy_cleanup(curl); - curl_global_cleanup(); - fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); - return NULL; + const char *etag_prefix = "etag: "; + if (strncmp(buffer, etag_prefix, strlen(etag_prefix)) == 0) { + strncpy(headers->etag, buffer + strlen(etag_prefix), n_items - strlen(etag_prefix) - 2); // Remove LRLF + } + + const char *last_modified_prefix = "last-modified: "; + if (strncmp(buffer, last_modified_prefix, strlen(last_modified_prefix)) == 0) { + strncpy(headers->last_modified, buffer + strlen(last_modified_prefix), + n_items - strlen(last_modified_prefix) - 2); // Remove LRLF + } + return n_items; + }; + + curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); + curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, static_cast(header_callback)); + curl_easy_setopt(curl, CURLOPT_HEADERDATA, &headers); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + curl_easy_cleanup(curl); + curl_global_cleanup(); + fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); + return NULL; + } } // If only the ETag or the Last-Modified header are different, trigger a new download if (strcmp(etag, headers.etag) != 0 || strcmp(last_modified, headers.last_modified) != 0) { // Set the output file - FILE *outfile = fopen(path_model, "wb"); + auto *outfile = fopen(path_model, "wb"); if (!outfile) { curl_easy_cleanup(curl); curl_global_cleanup(); @@ -1490,7 +1513,7 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha // start the download fprintf(stderr, "%s: downloading model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, model_url, path_model, headers.etag, headers.last_modified); - res = curl_easy_perform(curl); + auto res = curl_easy_perform(curl); if (res != CURLE_OK) { fclose(outfile); curl_easy_cleanup(curl); @@ -1513,22 +1536,23 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha fclose(outfile); // Write the new ETag to the .etag file - if (strlen( headers.etag) > 0) { - FILE *etag_file = fopen(etag_path, "w"); + if (strlen(headers.etag) > 0) { + auto *etag_file = fopen(etag_path, "w"); if (etag_file) { - fputs( headers.etag, etag_file); + fputs(headers.etag, etag_file); fclose(etag_file); - fprintf(stderr, "%s: model etag saved %s:%s\n", __func__, etag_path, etag); + fprintf(stderr, "%s: model etag saved %s:%s\n", __func__, etag_path, headers.etag); } } // Write the new lastModified to the .etag file - if (strlen( headers.last_modified) > 0) { - FILE *last_modified_file = fopen(last_modified_path, "w"); + if (strlen(headers.last_modified) > 0) { + auto *last_modified_file = fopen(last_modified_path, "w"); if (last_modified_file) { fputs(headers.last_modified, last_modified_file); fclose(last_modified_file); - fprintf(stderr, "%s: model last modified saved %s:%s\n", __func__, last_modified_path, headers.last_modified); + fprintf(stderr, "%s: model last modified saved %s:%s\n", __func__, last_modified_path, + headers.last_modified); } } } @@ -1547,7 +1571,7 @@ struct llama_model *llama_load_model_from_url(const char * /*model_url*/, const return nullptr; } -#endif +#endif // LLAMA_USE_CURL std::tuple llama_init_from_gpt_params(gpt_params & params) { auto mparams = llama_model_params_from_gpt_params(params);