Merge branch 'master' into gg/flash-attn

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-29 17:19:25 +03:00
commit a1616e9f72
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
82 changed files with 3896 additions and 1063 deletions

View file

@ -90,7 +90,8 @@ export default function () {
"model": model,
"stream": true,
"seed": 42,
"max_tokens": max_tokens
"max_tokens": max_tokens,
"stop": ["<|im_end|>"] // This is temporary for phi-2 base (i.e. not instructed) since the server expects that the model always to emit BOS
}
const params = {method: 'POST', body: JSON.stringify(payload)};

View file

@ -1207,6 +1207,27 @@ struct server_context {
LOG_VERBOSE("eos token found", {});
}
auto n_ctx_train = llama_n_ctx_train(model);
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
LOG_WARNING("n_predict is not set and self-context extend is disabled."
" Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
{ "id_slot", slot.id },
{ "params.n_predict", slot.params.n_predict },
{ "slot.n_prompt_tokens", slot.n_prompt_tokens },
{ "slot.n_decoded", slot.n_decoded },
{ "slot.n_predict", slot.n_predict },
{ "n_slots", params.n_parallel },
{ "slot.n_ctx", slot.n_ctx },
{ "n_ctx", n_ctx },
{ "n_ctx_train", n_ctx_train },
{ "ga_n", slot.ga_n },
});
slot.truncated = true;
slot.stopped_limit = true;
slot.has_next_token = false; // stop prediction
}
LOG_VERBOSE("next token", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
@ -2141,7 +2162,7 @@ struct server_context {
});
// process the created batch of tokens
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
for (auto & slot : slots) {
@ -2372,7 +2393,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`\n");
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`\n");
printf(" --chat-template JINJA_TEMPLATE\n");
@ -2805,43 +2826,11 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
invalid_param = true;
break;
}
char * sep = strchr(argv[i], '=');
if (sep == nullptr || sep - argv[i] >= 128) {
fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]);
invalid_param = true;
break;
}
struct llama_model_kv_override kvo;
std::strncpy(kvo.key, argv[i], sep - argv[i]);
kvo.key[sep - argv[i]] = 0;
sep++;
if (strncmp(sep, "int:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.int_value = std::atol(sep);
} else if (strncmp(sep, "float:", 6) == 0) {
sep += 6;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
kvo.float_value = std::atof(sep);
} else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
if (std::strcmp(sep, "true") == 0) {
kvo.bool_value = true;
} else if (std::strcmp(sep, "false") == 0) {
kvo.bool_value = false;
} else {
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
invalid_param = true;
break;
}
} else {
if (!parse_kv_override(argv[i], params.kv_overrides)) {
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
invalid_param = true;
break;
}
params.kv_overrides.push_back(kvo);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
server_print_usage(argv[0], default_params, default_sparams);