llama : rename batch.logits to batch.output
This commit renames the `logits` field of the `llama_batch` struct to `output`. The motivation for this change (apart from the TODO comment) is that the `logits` field is actually used to specify that output should be generated. For example, in the case of generating embeddings, setting logits to true can be confusing since the logits are not used when generating embeddings.
This commit is contained in:
parent
9f4cc8f8d3
commit
291a785587
19 changed files with 52 additions and 53 deletions
|
@ -607,7 +607,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
|
|||
<< ", pos " << std::to_string(batch.pos[i])
|
||||
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
|
||||
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
|
||||
<< ", logits " << std::to_string(batch.logits[i]);
|
||||
<< ", output " << std::to_string(batch.output[i]);
|
||||
}
|
||||
|
||||
buf << " ]";
|
||||
|
@ -1617,7 +1617,7 @@ void common_batch_add(
|
|||
llama_token id,
|
||||
llama_pos pos,
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits) {
|
||||
bool output) {
|
||||
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
|
||||
|
||||
batch.token [batch.n_tokens] = id;
|
||||
|
@ -1626,7 +1626,7 @@ void common_batch_add(
|
|||
for (size_t i = 0; i < seq_ids.size(); ++i) {
|
||||
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
||||
}
|
||||
batch.logits [batch.n_tokens] = logits;
|
||||
batch.output [batch.n_tokens] = output;
|
||||
|
||||
batch.n_tokens++;
|
||||
}
|
||||
|
|
|
@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
|
|||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
batch.output + i,
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
|
@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
|
|||
common_batch_add(batch, 0, i, { j }, false);
|
||||
}
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
batch.output[batch.n_tokens - 1] = true;
|
||||
|
||||
const auto t_pp_start = ggml_time_us();
|
||||
|
||||
|
|
|
@ -104,11 +104,11 @@ for (i, token) in tokens.enumerated() {
|
|||
if let seq_id = batch.seq_id[i] {
|
||||
seq_id[0] = 0
|
||||
}
|
||||
batch.logits[i] = 0
|
||||
batch.output[i] = 0
|
||||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[Int(batch.n_tokens) - 1] = 1
|
||||
batch.output[Int(batch.n_tokens) - 1] = 1
|
||||
|
||||
if llama_decode(context, batch) != 0 {
|
||||
print("llama_decode() failed")
|
||||
|
@ -171,7 +171,7 @@ while n_cur <= n_len {
|
|||
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
|
||||
seq_id[0] = Int32(i)
|
||||
}
|
||||
batch.logits[Int(batch.n_tokens)] = 1
|
||||
batch.output[Int(batch.n_tokens)] = 1
|
||||
|
||||
i_batch[i] = batch.n_tokens
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
batch.output[batch.n_tokens - 1] = true;
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
|
|
|
@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||
}
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch.logits[i]) {
|
||||
if (!batch.output[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
|||
common_batch_add(*batch, 0, i, { 0 }, false);
|
||||
}
|
||||
|
||||
batch->logits[batch->n_tokens - 1] = true;
|
||||
batch->output[batch->n_tokens - 1] = true;
|
||||
llama_kv_cache_clear(context);
|
||||
|
||||
const auto t_pp_start = ggml_time_us();
|
||||
|
@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
|
|||
for (int i = 0; i < n_tokens; ++i) {
|
||||
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
||||
}
|
||||
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||
batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||
|
||||
return reinterpret_cast<jlong>(batch);
|
||||
}
|
||||
|
@ -381,7 +381,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
|||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch->logits[batch->n_tokens - 1] = true;
|
||||
batch->output[batch->n_tokens - 1] = true;
|
||||
|
||||
if (llama_decode(context, *batch) != 0) {
|
||||
LOGe("llama_decode() failed");
|
||||
|
|
|
@ -9,14 +9,14 @@ func llama_batch_clear(_ batch: inout llama_batch) {
|
|||
batch.n_tokens = 0
|
||||
}
|
||||
|
||||
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) {
|
||||
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ outputs: Bool) {
|
||||
batch.token [Int(batch.n_tokens)] = id
|
||||
batch.pos [Int(batch.n_tokens)] = pos
|
||||
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
|
||||
for i in 0..<seq_ids.count {
|
||||
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
|
||||
}
|
||||
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
|
||||
batch.outputs [Int(batch.n_tokens)] = outputs ? 1 : 0
|
||||
|
||||
batch.n_tokens += 1
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ actor LlamaContext {
|
|||
let i = Int(i1)
|
||||
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
|
||||
}
|
||||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
|
||||
batch.output[Int(batch.n_tokens) - 1] = 1 // true
|
||||
|
||||
if llama_decode(context, batch) != 0 {
|
||||
print("llama_decode() failed")
|
||||
|
@ -208,7 +208,7 @@ actor LlamaContext {
|
|||
for i in 0..<n_tokens {
|
||||
llama_batch_add(&batch, 0, Int32(i), [0], false)
|
||||
}
|
||||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
|
||||
batch.output[Int(batch.n_tokens) - 1] = 1 // true
|
||||
|
||||
llama_kv_cache_clear(context)
|
||||
|
||||
|
|
|
@ -441,13 +441,13 @@ struct llava_embd_batch {
|
|||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id> seq_id_0;
|
||||
std::vector<llama_seq_id *> seq_ids;
|
||||
std::vector<int8_t> logits;
|
||||
std::vector<int8_t> outputs;
|
||||
llama_batch batch;
|
||||
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||
pos .resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids .resize(n_tokens + 1);
|
||||
logits .resize(n_tokens);
|
||||
outputs .resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
|
@ -458,13 +458,13 @@ struct llava_embd_batch {
|
|||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
/*output =*/ outputs.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = pos_0 + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = false;
|
||||
batch.output [i] = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -266,7 +266,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// extract the logits only for the last token
|
||||
if (batch.n_tokens > 0) {
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
batch.output[batch.n_tokens - 1] = true;
|
||||
}
|
||||
|
||||
client.n_prompt = tokens_prompt.size();
|
||||
|
@ -309,7 +309,7 @@ int main(int argc, char ** argv) {
|
|||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
batch.output + i,
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
|
|
|
@ -146,7 +146,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
if (i + n_batch >= n_tokens_all) {
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
batch.output[batch.n_tokens - 1] = true;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
|
@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
if (i + n_batch >= n_tokens_all) {
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
batch.output[batch.n_tokens - 1] = true;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
|
|
|
@ -572,9 +572,9 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||
batch.pos [idx] = j*n_batch + k;
|
||||
batch.n_seq_id[idx] = 1;
|
||||
batch.seq_id [idx][0] = seq;
|
||||
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
|
||||
batch.output [idx] = batch.pos[idx] >= first ? 1 : 0;
|
||||
|
||||
n_outputs += batch.logits[idx] != 0;
|
||||
n_outputs += batch.output[idx] != 0;
|
||||
}
|
||||
batch.n_tokens += batch_size;
|
||||
|
||||
|
@ -669,7 +669,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
|||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
batch.output + i,
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
|
@ -680,7 +680,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
|||
|
||||
int n_outputs = 0;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
n_outputs += batch_view.logits[i] != 0;
|
||||
n_outputs += batch_view.output[i] != 0;
|
||||
}
|
||||
|
||||
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
|
||||
|
@ -896,7 +896,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|||
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
||||
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||
n_logits += 1;
|
||||
|
||||
for (int s = 0; s < 4; ++s) {
|
||||
|
@ -1177,7 +1177,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
|||
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
||||
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
batch.output[batch.n_tokens - 1] = true;
|
||||
n_logits += 1;
|
||||
|
||||
for (int s = 0; s < 2; ++s) {
|
||||
|
@ -1545,7 +1545,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
|||
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
||||
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||
n_logits += 1;
|
||||
|
||||
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
||||
|
|
|
@ -92,7 +92,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||
}
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch.logits[i]) {
|
||||
if (!batch.output[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ int main(int argc, char ** argv) {
|
|||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
common_batch_add(batch, tokens[i], i, {0}, false);
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true; // generate next token
|
||||
batch.output[batch.n_tokens - 1] = true; // generate next token
|
||||
|
||||
// evaluate prompt
|
||||
llama_decode(ctx, batch);
|
||||
|
|
|
@ -2413,7 +2413,7 @@ struct server_context {
|
|||
std::vector<float> embd_res(n_embd, 0.0f);
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||
if (!batch.output[i] || batch.seq_id[i][0] != slot.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -2451,7 +2451,7 @@ struct server_context {
|
|||
res->n_tokens = slot.n_prompt_tokens;
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||
if (!batch.output[i] || batch.seq_id[i][0] != slot.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -3109,7 +3109,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
batch.output[batch.n_tokens - 1] = true;
|
||||
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
|
@ -3149,7 +3149,7 @@ struct server_context {
|
|||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
batch.output + i,
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
|
|
|
@ -722,7 +722,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|||
GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
batch.output[batch.n_tokens - 1] = true;
|
||||
|
||||
if (llama_decode(ctx_ttc, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
|
|
|
@ -252,7 +252,7 @@ extern "C" {
|
|||
llama_pos * pos;
|
||||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
int8_t * output;
|
||||
} llama_batch;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
|
|
|
@ -102,17 +102,17 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
|
|||
ubatch.output[ubatch.n_tokens + i] = 1;
|
||||
out_ids.push_back(ids[seq.offset + i]);
|
||||
}
|
||||
} else if (batch->logits) {
|
||||
} else if (batch->output) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
size_t id = ids[seq.offset + i];
|
||||
int8_t is_output = batch->logits[id];
|
||||
int8_t is_output = batch->output[id];
|
||||
ubatch.output[ubatch.n_tokens + i] = is_output;
|
||||
if (is_output) { out_ids.push_back(id); }
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.output = batch->logits + seq.offset;
|
||||
ubatch.output = batch->output + seq.offset;
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
||||
}
|
||||
|
@ -298,10 +298,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
|
|||
}
|
||||
batch.seq_id = seq_id.data();
|
||||
}
|
||||
if (!batch.logits) {
|
||||
logits.resize(batch.n_tokens);
|
||||
logits[logits.size() - 1] = true;
|
||||
batch.logits = logits.data();
|
||||
if (!batch.output) {
|
||||
outputs.resize(batch.n_tokens);
|
||||
outputs[outputs.size() - 1] = true;
|
||||
batch.output = outputs.data();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -348,7 +348,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
|||
}
|
||||
batch.seq_id[n_tokens_alloc] = nullptr;
|
||||
|
||||
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
|
||||
batch.output = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
|
||||
|
||||
return batch;
|
||||
}
|
||||
|
@ -364,5 +364,5 @@ void llama_batch_free(struct llama_batch batch) {
|
|||
}
|
||||
free(batch.seq_id);
|
||||
}
|
||||
if (batch.logits) free(batch.logits);
|
||||
if (batch.output) free(batch.output);
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ struct llama_batch_allocr {
|
|||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<int8_t> logits;
|
||||
std::vector<int8_t> outputs;
|
||||
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
|
||||
|
|
|
@ -8473,9 +8473,9 @@ static int llama_prepare_sbatch(
|
|||
lctx.embd_seq.clear();
|
||||
|
||||
// count outputs
|
||||
if (batch.logits && !embd_pooled) {
|
||||
if (batch.output && !embd_pooled) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs += batch.logits[i] != 0;
|
||||
n_outputs += batch.output[i] != 0;
|
||||
}
|
||||
} else if (lctx.logits_all || embd_pooled) {
|
||||
n_outputs = n_tokens_all;
|
||||
|
@ -9972,7 +9972,6 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) {
|
|||
return llama_kv_cache_can_shift(ctx->kv_self);
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue