diff --git a/common/common.cpp b/common/common.cpp index 59095a126..d8ecbf5dc 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -68,6 +68,8 @@ #endif #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 #define LLAMA_CURL_MAX_HEADER_LENGTH 256 +#define LLAMA_PROGRESS_UPDATE_INTERVAL 1 +#define LLAMA_PROGRESS_PERCENTAGE_WIDTH 10 #endif // LLAMA_USE_CURL using json = nlohmann::ordered_json; @@ -1866,6 +1868,72 @@ void llama_batch_add( #ifdef LLAMA_USE_CURL +struct shard_file_progress { + std::string filename; + double total_bytes; + double received_bytes; +}; + +std::map progress_table; +std::mutex progress_mutex; +std::stringstream download_done_buffer; + +static int shard_progress_callback(void* clientp, double dltotal, double dlnow, double ultotal, double ulnow) { + // upload not needed for downloading + (void) ultotal; + (void) ulnow; + char* url = static_cast(clientp); + + std::lock_guard lock(progress_mutex); + + shard_file_progress& progress = progress_table[url]; + progress.total_bytes = static_cast(dltotal); + progress.received_bytes = static_cast(dlnow); + + std::string url_string = static_cast(url); + progress.filename = url_string.substr(url_string.find_last_of('/') + 1); + + return 0; +} + +static void print_shard_progress_table(bool first_progress) { + if (first_progress) { + fprintf(stderr, "=========================\n"); + } else { + // use updating output + { + std::lock_guard lock(progress_mutex); + for (unsigned int i = 0; i < progress_table.size(); i++) { + fprintf(stderr, "\033[1A\033[K\033[1A\033[K"); + } + fprintf(stderr, "\r"); + } + } + + struct winsize ws; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &ws); + int progress_bar_width = ws.ws_col - LLAMA_PROGRESS_PERCENTAGE_WIDTH; + + // Print the progress information for each downloading file + { + std::lock_guard lock(progress_mutex); + for (const auto& entry : progress_table) { + shard_file_progress progress = entry.second; + int progress_width = static_cast((progress.received_bytes / progress.total_bytes) * progress_bar_width); + + fprintf(stderr, "%s\n", progress.filename.c_str()); + fprintf(stderr, "["); + for (int i = 0; i < progress_width; ++i) { + fprintf(stderr, "="); + } + for (int i = progress_width; i < progress_bar_width; ++i) { + fprintf(stderr, " "); + } + fprintf(stderr, "] %d%%\n", static_cast((progress.received_bytes / progress.total_bytes) * 100)); + } + } +} + static bool llama_download_file(CURL * curl, const char * url, const char * path, bool is_shard) { bool force_download = false; @@ -1999,11 +2067,13 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, static_cast(write_callback)); curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile); - // display download progress if not sharded + // display download progress + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + + // custom progress callback on sharded download if (is_shard) { - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); - } else { - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(curl, CURLOPT_PROGRESSFUNCTION, shard_progress_callback); + curl_easy_setopt(curl, CURLOPT_PROGRESSDATA, url); } // helper function to hide password in URL @@ -2050,7 +2120,11 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path if (etag_file) { fputs(headers.etag, etag_file); fclose(etag_file); - fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag); + if (is_shard) { + download_done_buffer << __func__ << ": file etag saved " << etag_path << ": " << headers.etag << "\n"; + } else { + fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag); + } } } @@ -2060,8 +2134,11 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path if (last_modified_file) { fputs(headers.last_modified, last_modified_file); fclose(last_modified_file); - fprintf(stderr, "%s: file last modified saved %s: %s\n", __func__, last_modified_path, - headers.last_modified); + if (is_shard) { + download_done_buffer << __func__ << ": unable to rename file: " << path_temporary << " to " << path << "\n"; + } else { + fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary, path); + } } } @@ -2158,6 +2235,33 @@ struct llama_model * llama_load_model_from_url( return res; }, idx)); } + + bool first_progress = true; + while (true) { + // Print the progress table periodically + std::this_thread::sleep_for(std::chrono::seconds(LLAMA_PROGRESS_UPDATE_INTERVAL)); + // Print the progress table header + print_shard_progress_table(first_progress); + first_progress = false; + + // Check if all downloads are complete + bool all_complete = true; + { + std::lock_guard lock(progress_mutex); + for (const auto& entry : progress_table) { + const shard_file_progress& progress = entry.second; + if (progress.received_bytes < progress.total_bytes) { + all_complete = false; + break; + } + } + } + + if (all_complete) { + fprintf(stderr, "%s", download_done_buffer.str().c_str()); + break; + } + } // Wait for all downloads to complete for (auto & f : futures_download) {