llama : minor

This commit is contained in:
Georgi Gerganov 2024-04-21 20:06:30 +03:00
parent 4bd26644bf
commit 5cf8ccb191
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 34 additions and 43 deletions

View file

@ -235,7 +235,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} }
bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) { bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
const char* sep = strchr(data, '='); const char * sep = strchr(data, '=');
if (sep == nullptr || sep - data >= 128) { if (sep == nullptr || sep - data >= 128) {
fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data);
return false; return false;
@ -247,18 +247,18 @@ bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> &
if (strncmp(sep, "int:", 4) == 0) { if (strncmp(sep, "int:", 4) == 0) {
sep += 4; sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.int_value = std::atol(sep); kvo.val_i64 = std::atol(sep);
} else if (strncmp(sep, "float:", 6) == 0) { } else if (strncmp(sep, "float:", 6) == 0) {
sep += 6; sep += 6;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
kvo.float_value = std::atof(sep); kvo.val_f64 = std::atof(sep);
} else if (strncmp(sep, "bool:", 5) == 0) { } else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5; sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
if (std::strcmp(sep, "true") == 0) { if (std::strcmp(sep, "true") == 0) {
kvo.bool_value = true; kvo.val_bool = true;
} else if (std::strcmp(sep, "false") == 0) { } else if (std::strcmp(sep, "false") == 0) {
kvo.bool_value = false; kvo.val_bool = false;
} else { } else {
fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data); fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data);
return false; return false;
@ -266,7 +266,7 @@ bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> &
} else if (strncmp(sep, "str:", 4) == 0) { } else if (strncmp(sep, "str:", 4) == 0) {
sep += 4; sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
strncpy(kvo.str_value, sep, 128); strncpy(kvo.val_str, sep, 128);
} else { } else {
fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data);
return false; return false;
@ -276,7 +276,7 @@ bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> &
} }
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
llama_sampling_params& sparams = params.sparams; llama_sampling_params & sparams = params.sparams;
if (arg == "-s" || arg == "--seed") { if (arg == "-s" || arg == "--seed") {
if (++i >= argc) { if (++i >= argc) {

View file

@ -305,14 +305,14 @@ int main(int argc, char ** argv) {
llama_model_kv_override kvo; llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE); std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE);
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
strncpy(kvo.str_value, imatrix_file.c_str(), 128); strncpy(kvo.val_str, imatrix_file.c_str(), 128);
kv_overrides.emplace_back(std::move(kvo)); kv_overrides.emplace_back(std::move(kvo));
} }
if (!imatrix_dataset.empty()) { if (!imatrix_dataset.empty()) {
llama_model_kv_override kvo; llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_DATASET); std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_DATASET);
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
strncpy(kvo.str_value, imatrix_dataset.c_str(), 128); strncpy(kvo.val_str, imatrix_dataset.c_str(), 128);
kv_overrides.emplace_back(std::move(kvo)); kv_overrides.emplace_back(std::move(kvo));
} }
@ -320,7 +320,7 @@ int main(int argc, char ** argv) {
llama_model_kv_override kvo; llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES); std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES);
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.int_value = imatrix_data.size(); kvo.val_i64 = imatrix_data.size();
kv_overrides.emplace_back(std::move(kvo)); kv_overrides.emplace_back(std::move(kvo));
} }
@ -328,7 +328,7 @@ int main(int argc, char ** argv) {
llama_model_kv_override kvo; llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS); std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS);
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.int_value = m_last_call; kvo.val_i64 = m_last_call;
kv_overrides.emplace_back(std::move(kvo)); kv_overrides.emplace_back(std::move(kvo));
} }
} }

View file

@ -2875,16 +2875,16 @@ namespace GGUFMeta {
__func__, override_type_to_str(ovrd->tag), ovrd->key); __func__, override_type_to_str(ovrd->tag), ovrd->key);
switch (ovrd->tag) { switch (ovrd->tag) {
case LLAMA_KV_OVERRIDE_TYPE_BOOL: { case LLAMA_KV_OVERRIDE_TYPE_BOOL: {
LLAMA_LOG_INFO("%s\n", ovrd->bool_value ? "true" : "false"); LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false");
} break; } break;
case LLAMA_KV_OVERRIDE_TYPE_INT: { case LLAMA_KV_OVERRIDE_TYPE_INT: {
LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->int_value); LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64);
} break; } break;
case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { case LLAMA_KV_OVERRIDE_TYPE_FLOAT: {
LLAMA_LOG_INFO("%.6f\n", ovrd->float_value); LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64);
} break; } break;
case LLAMA_KV_OVERRIDE_TYPE_STR: { case LLAMA_KV_OVERRIDE_TYPE_STR: {
LLAMA_LOG_INFO("%s\n", ovrd->str_value); LLAMA_LOG_INFO("%s\n", ovrd->val_str);
} break; } break;
default: default:
// Shouldn't be possible to end up here, but just in case... // Shouldn't be possible to end up here, but just in case...
@ -2903,7 +2903,7 @@ namespace GGUFMeta {
static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type
try_override(OT & target, const struct llama_model_kv_override * ovrd) { try_override(OT & target, const struct llama_model_kv_override * ovrd) {
if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) {
target = ovrd->bool_value; target = ovrd->val_bool;
return true; return true;
} }
return false; return false;
@ -2913,7 +2913,7 @@ namespace GGUFMeta {
static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type
try_override(OT & target, const struct llama_model_kv_override * ovrd) { try_override(OT & target, const struct llama_model_kv_override * ovrd) {
if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) {
target = ovrd->int_value; target = ovrd->val_i64;
return true; return true;
} }
return false; return false;
@ -2923,17 +2923,7 @@ namespace GGUFMeta {
static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type
try_override(T & target, const struct llama_model_kv_override * ovrd) { try_override(T & target, const struct llama_model_kv_override * ovrd) {
if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) {
target = ovrd->float_value; target = ovrd->val_f64;
return true;
}
return false;
}
template<typename OT>
static typename std::enable_if<std::is_same<OT, char *>::value, char *>::type
try_override(T & target, const struct llama_model_kv_override * ovrd) {
if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) {
target = ovrd->str_value;
return true; return true;
} }
return false; return false;
@ -2942,12 +2932,11 @@ namespace GGUFMeta {
template<typename OT> template<typename OT>
static typename std::enable_if<std::is_same<OT, std::string>::value, bool>::type static typename std::enable_if<std::is_same<OT, std::string>::value, bool>::type
try_override(T & target, const struct llama_model_kv_override * ovrd) { try_override(T & target, const struct llama_model_kv_override * ovrd) {
(void)target; if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) {
(void)ovrd; target = ovrd->val_str;
if (!ovrd) { return false; } return true;
// Currently, we should never end up here so it would be a bug if we do. }
throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n", return false;
ovrd ? ovrd->key : "NULL"));
} }
static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) { static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
@ -14276,13 +14265,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
for (auto & o : overrides) { for (auto & o : overrides) {
if (o.key[0] == 0) break; if (o.key[0] == 0) break;
if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
gguf_set_val_f32(ctx_out, o.key, o.float_value); gguf_set_val_f32(ctx_out, o.key, o.val_f64);
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
gguf_set_val_i32(ctx_out, o.key, o.int_value); gguf_set_val_i32(ctx_out, o.key, o.val_i64);
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
gguf_set_val_bool(ctx_out, o.key, o.bool_value); gguf_set_val_bool(ctx_out, o.key, o.val_bool);
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
gguf_set_val_str(ctx_out, o.key, o.str_value); gguf_set_val_str(ctx_out, o.key, o.val_str);
} else { } else {
LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key); LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
} }

12
llama.h
View file

@ -199,13 +199,15 @@ extern "C" {
}; };
struct llama_model_kv_override { struct llama_model_kv_override {
char key[128];
enum llama_model_kv_override_type tag; enum llama_model_kv_override_type tag;
char str_value[128];
char key[128];
union { union {
int64_t int_value; int64_t val_i64;
double float_value; double val_f64;
bool bool_value; bool val_bool;
char val_str[128];
}; };
}; };