added support for Authorization Bearer tokens

This commit is contained in:
Derrick T. Woolworth 2024-07-04 13:48:32 -05:00
parent d7fd29fff1
commit 6fa99b4fad
2 changed files with 64 additions and 9 deletions

View file

@ -190,6 +190,18 @@ int32_t cpu_get_num_math() {
// CLI argument parsing // CLI argument parsing
// //
void gpt_params_handle_auth_token(gpt_params & params) {
if (params.hf_token.empty() && params.auth_token.empty()) {
hf_get_token_from_env(params);
}
if (!params.hf_token.empty() && !params.auth_token.empty()) {
throw std::invalid_argument("error: --hf-token and --bearer-token are mutually exclusive\n");
}
if (!params.hf_token.empty()) {
params.auth_token = params.hf_token;
}
}
void gpt_params_handle_model_default(gpt_params & params) { void gpt_params_handle_model_default(gpt_params & params) {
if (!params.hf_repo.empty()) { if (!params.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model // short-hand to avoid specifying --hf-file -> default it to --model
@ -237,6 +249,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
gpt_params_handle_model_default(params); gpt_params_handle_model_default(params);
gpt_params_handle_auth_token(params);
if (params.escape) { if (params.escape) {
string_process_escapes(params.prompt); string_process_escapes(params.prompt);
string_process_escapes(params.input_prefix); string_process_escapes(params.input_prefix);
@ -644,6 +658,22 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.model_url = argv[i]; params.model_url = argv[i];
return true; return true;
} }
if (arg == "-bt" || arg == "--bearer-token") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.auth_token = argv[i];
return true;
}
if (arg == "-hft" || arg == "--hr-token") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.hf_token = argv[i];
return true;
}
if (arg == "-hfr" || arg == "--hf-repo") { if (arg == "-hfr" || arg == "--hf-repo") {
CHECK_ARG CHECK_ARG
params.hf_repo = argv[i]; params.hf_repo = argv[i];
@ -1559,8 +1589,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"or --model-url if set, otherwise %s)", DEFAULT_MODEL_PATH }); "or --model-url if set, otherwise %s)", DEFAULT_MODEL_PATH });
options.push_back({ "*", "-md, --model-draft FNAME", "draft model for speculative decoding (default: unused)" }); options.push_back({ "*", "-md, --model-draft FNAME", "draft model for speculative decoding (default: unused)" });
options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" }); options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" });
options.push_back({ "*", "-bt, --bearer-token TOKEN", "model download bearer token (default: unused)" });
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" }); options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" }); options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: unused)" });
options.push_back({ "retrieval" }); options.push_back({ "retrieval" });
options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" }); options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" });
@ -1989,6 +2021,12 @@ std::string fs_get_cache_file(const std::string & filename) {
return cache_directory + filename; return cache_directory + filename;
} }
void hf_get_token_from_env(gpt_params & params) {
if (std::getenv("HF_TOKEN")) {
params.hf_token = std::getenv("HF_TOKEN");
}
}
// //
// Model utils // Model utils
@ -2000,9 +2038,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
llama_model * model = nullptr; llama_model * model = nullptr;
if (!params.hf_repo.empty() && !params.hf_file.empty()) { if (!params.hf_repo.empty() && !params.hf_file.empty()) {
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams); model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.auth_token.c_str(), mparams);
} else if (!params.model_url.empty()) { } else if (!params.model_url.empty()) {
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.auth_token.c_str(), mparams);
} else { } else {
model = llama_load_model_from_file(params.model.c_str(), mparams); model = llama_load_model_from_file(params.model.c_str(), mparams);
} }
@ -2189,7 +2227,7 @@ static bool starts_with(const std::string & str, const std::string & prefix) {
return str.rfind(prefix, 0) == 0; return str.rfind(prefix, 0) == 0;
} }
static bool llama_download_file(const std::string & url, const std::string & path) { static bool llama_download_file(const std::string & url, const std::string & path, const std::string & auth_token) {
// Initialize libcurl // Initialize libcurl
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup); std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
@ -2204,6 +2242,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
// Check if hf-token or bearer-token was specified
if (!auth_token.empty()) {
std::string auth_header = "Authorization: Bearer ";
auth_header += auth_token.c_str();
struct curl_slist *http_headers = NULL;
http_headers = curl_slist_append(http_headers, auth_header.c_str());
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
}
#if defined(_WIN32) #if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows. // operating system. Currently implemented under MS-Windows.
@ -2399,6 +2446,7 @@ static bool llama_download_file(const std::string & url, const std::string & pat
struct llama_model * llama_load_model_from_url( struct llama_model * llama_load_model_from_url(
const char * model_url, const char * model_url,
const char * path_model, const char * path_model,
const char * auth_token,
const struct llama_model_params & params) { const struct llama_model_params & params) {
// Basic validation of the model_url // Basic validation of the model_url
if (!model_url || strlen(model_url) == 0) { if (!model_url || strlen(model_url) == 0) {
@ -2406,7 +2454,7 @@ struct llama_model * llama_load_model_from_url(
return NULL; return NULL;
} }
if (!llama_download_file(model_url, path_model)) { if (!llama_download_file(model_url, path_model, auth_token)) {
return NULL; return NULL;
} }
@ -2454,14 +2502,14 @@ struct llama_model * llama_load_model_from_url(
// Prepare download in parallel // Prepare download in parallel
std::vector<std::future<bool>> futures_download; std::vector<std::future<bool>> futures_download;
for (int idx = 1; idx < n_split; idx++) { for (int idx = 1; idx < n_split; idx++) {
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool { futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, auth_token](int download_idx) -> bool {
char split_path[PATH_MAX] = {0}; char split_path[PATH_MAX] = {0};
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split); llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split); llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
return llama_download_file(split_url, split_path); return llama_download_file(split_url, split_path, auth_token);
}, idx)); }, idx));
} }
@ -2480,6 +2528,7 @@ struct llama_model * llama_load_model_from_hf(
const char * repo, const char * repo,
const char * model, const char * model,
const char * path_model, const char * path_model,
const char * auth_token,
const struct llama_model_params & params) { const struct llama_model_params & params) {
// construct hugging face model url: // construct hugging face model url:
// //
@ -2495,7 +2544,7 @@ struct llama_model * llama_load_model_from_hf(
model_url += "/resolve/main/"; model_url += "/resolve/main/";
model_url += model; model_url += model;
return llama_load_model_from_url(model_url.c_str(), path_model, params); return llama_load_model_from_url(model_url.c_str(), path_model, auth_token, params);
} }
#else #else
@ -2503,6 +2552,7 @@ struct llama_model * llama_load_model_from_hf(
struct llama_model * llama_load_model_from_url( struct llama_model * llama_load_model_from_url(
const char * /*model_url*/, const char * /*model_url*/,
const char * /*path_model*/, const char * /*path_model*/,
const char * /*auth_token*/
const struct llama_model_params & /*params*/) { const struct llama_model_params & /*params*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return nullptr; return nullptr;
@ -2512,6 +2562,7 @@ struct llama_model * llama_load_model_from_hf(
const char * /*repo*/, const char * /*repo*/,
const char * /*model*/, const char * /*model*/,
const char * /*path_model*/, const char * /*path_model*/,
const char * /*auth_token*/,
const struct llama_model_params & /*params*/) { const struct llama_model_params & /*params*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return nullptr; return nullptr;

View file

@ -107,6 +107,8 @@ struct gpt_params {
std::string model_draft = ""; // draft model for speculative decoding std::string model_draft = ""; // draft model for speculative decoding
std::string model_alias = "unknown"; // model alias std::string model_alias = "unknown"; // model alias
std::string model_url = ""; // model url to download std::string model_url = ""; // model url to download
std::string auth_token = ""; // auth bearer token
std::string hf_token = ""; // HF token
std::string hf_repo = ""; // HF repo std::string hf_repo = ""; // HF repo
std::string hf_file = ""; // HF file std::string hf_file = ""; // HF file
std::string prompt = ""; std::string prompt = "";
@ -255,6 +257,7 @@ struct gpt_params {
bool spm_infill = false; // suffix/prefix/middle pattern for infill bool spm_infill = false; // suffix/prefix/middle pattern for infill
}; };
void gpt_params_handle_auth_token(gpt_params & params);
void gpt_params_handle_model_default(gpt_params & params); void gpt_params_handle_model_default(gpt_params & params);
bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params); bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params);
@ -297,6 +300,7 @@ void string_process_escapes(std::string & input);
bool fs_validate_filename(const std::string & filename); bool fs_validate_filename(const std::string & filename);
bool fs_create_directory_with_parents(const std::string & path); bool fs_create_directory_with_parents(const std::string & path);
void hf_get_token_from_env(gpt_params & params);
std::string fs_get_cache_directory(); std::string fs_get_cache_directory();
std::string fs_get_cache_file(const std::string & filename); std::string fs_get_cache_file(const std::string & filename);
@ -310,8 +314,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params); struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * auth_token, const struct llama_model_params & params);
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params); struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * auth_token, const struct llama_model_params & params);
// Batch utils // Batch utils