llama : more consistent names of count variables (#5994)

* llama : more consistent names of count variables

ggml-ci

* llama : n_parallel -> n_seq_max

* common : fix param name

* examples : fix param name
This commit is contained in:
Georgi Gerganov 2024-03-11 17:49:47 +02:00 committed by GitHub
parent 83796e62bc
commit 05b06210c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 35 additions and 34 deletions

View file

@ -12538,7 +12538,7 @@ struct llama_context_params llama_context_default_params() {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512,
/*.n_batch =*/ 512,
/*.n_parallel =*/ 1,
/*.n_seq_max =*/ 1,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
@ -12700,7 +12700,7 @@ struct llama_context * llama_new_context_with_model(
auto & cparams = ctx->cparams;
cparams.n_batch = params.n_batch;
// TODO: maybe add n_parallel here too
// TODO: maybe add n_seq_max here too
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@ -12767,7 +12767,7 @@ struct llama_context * llama_new_context_with_model(
// Mamba only needs a constant number of KV cache cells per sequence
if (model->arch == LLM_ARCH_MAMBA) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_parallel);
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
@ -13024,7 +13024,7 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
return ctx->cparams.n_batch;
}
uint32_t llama_n_max_seq(const struct llama_context * ctx) {
uint32_t llama_n_seq_max(const struct llama_context * ctx) {
return ctx->kv_self.size;
}
@ -13188,10 +13188,10 @@ int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const
}
}
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
struct llama_kv_cache_view result = {
/*.n_cells = */ 0,
/*.n_max_seq = */ n_max_seq,
/*.n_seq_max = */ n_seq_max,
/*.token_count = */ 0,
/*.used_cells = */ llama_get_kv_cache_used_cells(ctx),
/*.max_contiguous = */ 0,
@ -13219,7 +13219,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
view->cells = (struct llama_kv_cache_view_cell *)p;
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells);
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
view->cells_sequences = (llama_seq_id *)p;
}
@ -13233,7 +13233,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
uint32_t max_contig = 0;
int32_t max_contig_idx = -1;
for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) {
const size_t curr_size = kv_cells[i].seq_id.size();
token_count += curr_size;
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
@ -13250,7 +13250,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
int seq_idx = 0;
for (const llama_seq_id it : kv_cells[i].seq_id) {
if (seq_idx >= view->n_max_seq) {
if (seq_idx >= view->n_seq_max) {
break;
}
cs_curr[seq_idx] = it;
@ -13259,7 +13259,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
if (seq_idx != 0) {
used_cells++;
}
for (; seq_idx < view->n_max_seq; seq_idx++) {
for (; seq_idx < view->n_seq_max; seq_idx++) {
cs_curr[seq_idx] = -1;
}
}
@ -13921,12 +13921,12 @@ int32_t llama_tokenize(
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_max_tokens,
int32_t n_tokens_max,
bool add_bos,
bool special) {
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special);
if (n_max_tokens < (int) res.size()) {
if (n_tokens_max < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
return -((int) res.size());
}