Use a global progress callback instead of completing removing progress

This commit is contained in:
TevinWang 2024-04-28 17:54:39 -04:00
parent 309a918ed7
commit 763460ba71

View file

@ -68,6 +68,8 @@
#endif #endif
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
#define LLAMA_CURL_MAX_HEADER_LENGTH 256 #define LLAMA_CURL_MAX_HEADER_LENGTH 256
#define LLAMA_PROGRESS_UPDATE_INTERVAL 1
#define LLAMA_PROGRESS_PERCENTAGE_WIDTH 10
#endif // LLAMA_USE_CURL #endif // LLAMA_USE_CURL
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
@ -1866,6 +1868,72 @@ void llama_batch_add(
#ifdef LLAMA_USE_CURL #ifdef LLAMA_USE_CURL
struct shard_file_progress {
std::string filename;
double total_bytes;
double received_bytes;
};
std::map<std::string, shard_file_progress> progress_table;
std::mutex progress_mutex;
std::stringstream download_done_buffer;
static int shard_progress_callback(void* clientp, double dltotal, double dlnow, double ultotal, double ulnow) {
// upload not needed for downloading
(void) ultotal;
(void) ulnow;
char* url = static_cast<char*>(clientp);
std::lock_guard<std::mutex> lock(progress_mutex);
shard_file_progress& progress = progress_table[url];
progress.total_bytes = static_cast<double>(dltotal);
progress.received_bytes = static_cast<double>(dlnow);
std::string url_string = static_cast<std::string>(url);
progress.filename = url_string.substr(url_string.find_last_of('/') + 1);
return 0;
}
static void print_shard_progress_table(bool first_progress) {
if (first_progress) {
fprintf(stderr, "=========================\n");
} else {
// use updating output
{
std::lock_guard<std::mutex> lock(progress_mutex);
for (unsigned int i = 0; i < progress_table.size(); i++) {
fprintf(stderr, "\033[1A\033[K\033[1A\033[K");
}
fprintf(stderr, "\r");
}
}
struct winsize ws;
ioctl(STDOUT_FILENO, TIOCGWINSZ, &ws);
int progress_bar_width = ws.ws_col - LLAMA_PROGRESS_PERCENTAGE_WIDTH;
// Print the progress information for each downloading file
{
std::lock_guard<std::mutex> lock(progress_mutex);
for (const auto& entry : progress_table) {
shard_file_progress progress = entry.second;
int progress_width = static_cast<int>((progress.received_bytes / progress.total_bytes) * progress_bar_width);
fprintf(stderr, "%s\n", progress.filename.c_str());
fprintf(stderr, "[");
for (int i = 0; i < progress_width; ++i) {
fprintf(stderr, "=");
}
for (int i = progress_width; i < progress_bar_width; ++i) {
fprintf(stderr, " ");
}
fprintf(stderr, "] %d%%\n", static_cast<int>((progress.received_bytes / progress.total_bytes) * 100));
}
}
}
static bool llama_download_file(CURL * curl, const char * url, const char * path, bool is_shard) { static bool llama_download_file(CURL * curl, const char * url, const char * path, bool is_shard) {
bool force_download = false; bool force_download = false;
@ -1999,11 +2067,13 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback)); curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile); curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile);
// display download progress if not sharded // display download progress
if (is_shard) {
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L);
} else {
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
// custom progress callback on sharded download
if (is_shard) {
curl_easy_setopt(curl, CURLOPT_PROGRESSFUNCTION, shard_progress_callback);
curl_easy_setopt(curl, CURLOPT_PROGRESSDATA, url);
} }
// helper function to hide password in URL // helper function to hide password in URL
@ -2050,9 +2120,13 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
if (etag_file) { if (etag_file) {
fputs(headers.etag, etag_file); fputs(headers.etag, etag_file);
fclose(etag_file); fclose(etag_file);
if (is_shard) {
download_done_buffer << __func__ << ": file etag saved " << etag_path << ": " << headers.etag << "\n";
} else {
fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag); fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag);
} }
} }
}
// Write the new lastModified to the .etag file // Write the new lastModified to the .etag file
if (strlen(headers.last_modified) > 0) { if (strlen(headers.last_modified) > 0) {
@ -2060,8 +2134,11 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
if (last_modified_file) { if (last_modified_file) {
fputs(headers.last_modified, last_modified_file); fputs(headers.last_modified, last_modified_file);
fclose(last_modified_file); fclose(last_modified_file);
fprintf(stderr, "%s: file last modified saved %s: %s\n", __func__, last_modified_path, if (is_shard) {
headers.last_modified); download_done_buffer << __func__ << ": unable to rename file: " << path_temporary << " to " << path << "\n";
} else {
fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary, path);
}
} }
} }
@ -2159,6 +2236,33 @@ struct llama_model * llama_load_model_from_url(
}, idx)); }, idx));
} }
bool first_progress = true;
while (true) {
// Print the progress table periodically
std::this_thread::sleep_for(std::chrono::seconds(LLAMA_PROGRESS_UPDATE_INTERVAL));
// Print the progress table header
print_shard_progress_table(first_progress);
first_progress = false;
// Check if all downloads are complete
bool all_complete = true;
{
std::lock_guard<std::mutex> lock(progress_mutex);
for (const auto& entry : progress_table) {
const shard_file_progress& progress = entry.second;
if (progress.received_bytes < progress.total_bytes) {
all_complete = false;
break;
}
}
}
if (all_complete) {
fprintf(stderr, "%s", download_done_buffer.str().c_str());
break;
}
}
// Wait for all downloads to complete // Wait for all downloads to complete
for (auto & f : futures_download) { for (auto & f : futures_download) {
if (!f.get()) { if (!f.get()) {