llama : rename batch.logits to batch.output

This commit renames the `logits` field of the `llama_batch` struct to
`output`.

The motivation for this change (apart from the TODO comment) is that
the `logits` field is actually used to specify that output should be
generated. For example, in the case of generating embeddings, setting
logits to true can be confusing since the logits are not used when
generating embeddings.
This commit is contained in:
Daniel Bevenius 2024-10-22 15:03:00 +02:00
parent 9f4cc8f8d3
commit 291a785587
19 changed files with 52 additions and 53 deletions

View file

@ -607,7 +607,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
<< ", pos " << std::to_string(batch.pos[i])
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ", logits " << std::to_string(batch.logits[i]);
<< ", output " << std::to_string(batch.output[i]);
}
buf << " ]";
@ -1617,7 +1617,7 @@ void common_batch_add(
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
bool output) {
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
batch.token [batch.n_tokens] = id;
@ -1626,7 +1626,7 @@ void common_batch_add(
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.output [batch.n_tokens] = output;
batch.n_tokens++;
}

View file

@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
};
const int ret = llama_decode(ctx, batch_view);
@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
common_batch_add(batch, 0, i, { j }, false);
}
}
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
const auto t_pp_start = ggml_time_us();

View file

@ -104,11 +104,11 @@ for (i, token) in tokens.enumerated() {
if let seq_id = batch.seq_id[i] {
seq_id[0] = 0
}
batch.logits[i] = 0
batch.output[i] = 0
}
// llama_decode will output logits only for the last token of the prompt
batch.logits[Int(batch.n_tokens) - 1] = 1
batch.output[Int(batch.n_tokens) - 1] = 1
if llama_decode(context, batch) != 0 {
print("llama_decode() failed")
@ -171,7 +171,7 @@ while n_cur <= n_len {
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
seq_id[0] = Int32(i)
}
batch.logits[Int(batch.n_tokens)] = 1
batch.output[Int(batch.n_tokens)] = 1
i_batch[i] = batch.n_tokens

View file

@ -131,7 +131,7 @@ int main(int argc, char ** argv) {
}
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
if (llama_decode(ctx, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);

View file

@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
if (!batch.output[i]) {
continue;
}

View file

@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
common_batch_add(*batch, 0, i, { 0 }, false);
}
batch->logits[batch->n_tokens - 1] = true;
batch->output[batch->n_tokens - 1] = true;
llama_kv_cache_clear(context);
const auto t_pp_start = ggml_time_us();
@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
for (int i = 0; i < n_tokens; ++i) {
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return reinterpret_cast<jlong>(batch);
}
@ -381,7 +381,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
}
// llama_decode will output logits only for the last token of the prompt
batch->logits[batch->n_tokens - 1] = true;
batch->output[batch->n_tokens - 1] = true;
if (llama_decode(context, *batch) != 0) {
LOGe("llama_decode() failed");

View file

@ -9,14 +9,14 @@ func llama_batch_clear(_ batch: inout llama_batch) {
batch.n_tokens = 0
}
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) {
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ outputs: Bool) {
batch.token [Int(batch.n_tokens)] = id
batch.pos [Int(batch.n_tokens)] = pos
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
for i in 0..<seq_ids.count {
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
}
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
batch.outputs [Int(batch.n_tokens)] = outputs ? 1 : 0
batch.n_tokens += 1
}
@ -139,7 +139,7 @@ actor LlamaContext {
let i = Int(i1)
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
batch.output[Int(batch.n_tokens) - 1] = 1 // true
if llama_decode(context, batch) != 0 {
print("llama_decode() failed")
@ -208,7 +208,7 @@ actor LlamaContext {
for i in 0..<n_tokens {
llama_batch_add(&batch, 0, Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
batch.output[Int(batch.n_tokens) - 1] = 1 // true
llama_kv_cache_clear(context)

View file

@ -441,13 +441,13 @@ struct llava_embd_batch {
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
std::vector<int8_t> outputs;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
outputs .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
@ -458,13 +458,13 @@ struct llava_embd_batch {
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
/*output =*/ outputs.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
batch.output [i] = false;
}
}
};

View file

@ -266,7 +266,7 @@ int main(int argc, char ** argv) {
// extract the logits only for the last token
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
}
client.n_prompt = tokens_prompt.size();
@ -309,7 +309,7 @@ int main(int argc, char ** argv) {
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
};
const int ret = llama_decode(ctx, batch_view);

View file

@ -146,7 +146,7 @@ int main(int argc, char ** argv) {
}
if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
}
if (llama_decode(ctx, batch) != 0) {
@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
}
if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
}
if (llama_decode(ctx, batch) != 0) {

View file

@ -572,9 +572,9 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
batch.pos [idx] = j*n_batch + k;
batch.n_seq_id[idx] = 1;
batch.seq_id [idx][0] = seq;
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
batch.output [idx] = batch.pos[idx] >= first ? 1 : 0;
n_outputs += batch.logits[idx] != 0;
n_outputs += batch.output[idx] != 0;
}
batch.n_tokens += batch_size;
@ -669,7 +669,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
};
const int ret = llama_decode(ctx, batch_view);
@ -680,7 +680,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
int n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
n_outputs += batch_view.logits[i] != 0;
n_outputs += batch_view.output[i] != 0;
}
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
@ -896,7 +896,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
n_logits += 1;
for (int s = 0; s < 4; ++s) {
@ -1177,7 +1177,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
n_logits += 1;
for (int s = 0; s < 2; ++s) {
@ -1545,7 +1545,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
n_logits += 1;
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {

View file

@ -92,7 +92,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
if (!batch.output[i]) {
continue;
}

View file

@ -52,7 +52,7 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], i, {0}, false);
}
batch.logits[batch.n_tokens - 1] = true; // generate next token
batch.output[batch.n_tokens - 1] = true; // generate next token
// evaluate prompt
llama_decode(ctx, batch);

View file

@ -2413,7 +2413,7 @@ struct server_context {
std::vector<float> embd_res(n_embd, 0.0f);
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
if (!batch.output[i] || batch.seq_id[i][0] != slot.id) {
continue;
}
@ -2451,7 +2451,7 @@ struct server_context {
res->n_tokens = slot.n_prompt_tokens;
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
if (!batch.output[i] || batch.seq_id[i][0] != slot.id) {
continue;
}
@ -3109,7 +3109,7 @@ struct server_context {
}
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
@ -3149,7 +3149,7 @@ struct server_context {
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
};
const int ret = llama_decode(ctx, batch_view);

View file

@ -722,7 +722,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
if (llama_decode(ctx_ttc, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);

View file

@ -252,7 +252,7 @@ extern "C" {
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
int8_t * output;
} llama_batch;
enum llama_model_kv_override_type {

View file

@ -102,17 +102,17 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
ubatch.output[ubatch.n_tokens + i] = 1;
out_ids.push_back(ids[seq.offset + i]);
}
} else if (batch->logits) {
} else if (batch->output) {
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
size_t id = ids[seq.offset + i];
int8_t is_output = batch->logits[id];
int8_t is_output = batch->output[id];
ubatch.output[ubatch.n_tokens + i] = is_output;
if (is_output) { out_ids.push_back(id); }
}
} else {
// simple split
ubatch.output = batch->logits + seq.offset;
ubatch.output = batch->output + seq.offset;
for (size_t i = 0; i < length; ++i) {
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
}
@ -298,10 +298,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
}
batch.seq_id = seq_id.data();
}
if (!batch.logits) {
logits.resize(batch.n_tokens);
logits[logits.size() - 1] = true;
batch.logits = logits.data();
if (!batch.output) {
outputs.resize(batch.n_tokens);
outputs[outputs.size() - 1] = true;
batch.output = outputs.data();
}
}
@ -348,7 +348,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
}
batch.seq_id[n_tokens_alloc] = nullptr;
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
batch.output = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
return batch;
}
@ -364,5 +364,5 @@ void llama_batch_free(struct llama_batch batch) {
}
free(batch.seq_id);
}
if (batch.logits) free(batch.logits);
if (batch.output) free(batch.output);
}

View file

@ -81,7 +81,7 @@ struct llama_batch_allocr {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> logits;
std::vector<int8_t> outputs;
// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);

View file

@ -8473,9 +8473,9 @@ static int llama_prepare_sbatch(
lctx.embd_seq.clear();
// count outputs
if (batch.logits && !embd_pooled) {
if (batch.output && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs += batch.logits[i] != 0;
n_outputs += batch.output[i] != 0;
}
} else if (lctx.logits_all || embd_pooled) {
n_outputs = n_tokens_all;
@ -9972,7 +9972,6 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) {
return llama_kv_cache_can_shift(ctx->kv_self);
}
///
int32_t llama_encode(
struct llama_context * ctx,