Merge 395ae48cb0
into 8ebe8ddebd
This commit is contained in:
commit
00d129f87f
18 changed files with 72 additions and 64 deletions
|
@ -3037,7 +3037,7 @@ void llama_batch_add(
|
||||||
for (size_t i = 0; i < seq_ids.size(); ++i) {
|
for (size_t i = 0; i < seq_ids.size(); ++i) {
|
||||||
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
||||||
}
|
}
|
||||||
batch.logits [batch.n_tokens] = logits;
|
batch.output [batch.n_tokens] = logits;
|
||||||
|
|
||||||
batch.n_tokens++;
|
batch.n_tokens++;
|
||||||
}
|
}
|
||||||
|
|
|
@ -686,7 +686,7 @@ inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
|
||||||
<< ":pos " << std::to_string(batch.pos[i])
|
<< ":pos " << std::to_string(batch.pos[i])
|
||||||
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
|
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
|
||||||
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
|
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
|
||||||
<< ":logits " << std::to_string(batch.logits[i]);
|
<< ":logits " << std::to_string(batch.output[i]);
|
||||||
}
|
}
|
||||||
buf << " ]";
|
buf << " ]";
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ int main(int argc, char ** argv) {
|
||||||
batch.pos + i,
|
batch.pos + i,
|
||||||
batch.n_seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.output + i,
|
||||||
0, 0, 0, // unused
|
0, 0, 0, // unused
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -149,7 +149,7 @@ int main(int argc, char ** argv) {
|
||||||
llama_batch_add(batch, 0, i, { j }, false);
|
llama_batch_add(batch, 0, i, { j }, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
|
|
||||||
|
|
|
@ -86,11 +86,11 @@ for (i, token) in tokens.enumerated() {
|
||||||
if let seq_id = batch.seq_id[i] {
|
if let seq_id = batch.seq_id[i] {
|
||||||
seq_id[0] = 0
|
seq_id[0] = 0
|
||||||
}
|
}
|
||||||
batch.logits[i] = 0
|
batch.output[i] = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
// llama_decode will output logits only for the last token of the prompt
|
||||||
batch.logits[Int(batch.n_tokens) - 1] = 1
|
batch.output[Int(batch.n_tokens) - 1] = 1
|
||||||
|
|
||||||
if llama_decode(context, batch) != 0 {
|
if llama_decode(context, batch) != 0 {
|
||||||
print("llama_decode() failed")
|
print("llama_decode() failed")
|
||||||
|
@ -178,7 +178,7 @@ while n_cur <= n_len {
|
||||||
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
|
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
|
||||||
seq_id[0] = Int32(i)
|
seq_id[0] = Int32(i)
|
||||||
}
|
}
|
||||||
batch.logits[Int(batch.n_tokens)] = 1
|
batch.output[Int(batch.n_tokens)] = 1
|
||||||
|
|
||||||
i_batch[i] = batch.n_tokens
|
i_batch[i] = batch.n_tokens
|
||||||
|
|
||||||
|
|
|
@ -122,7 +122,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
// llama_decode will output logits only for the last token of the prompt
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
|
|
|
@ -52,7 +52,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; i++) {
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
if (!batch.logits[i]) {
|
if (!batch.output[i]) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -102,21 +102,21 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
||||||
llama_set_embeddings(ctx, false);
|
llama_set_embeddings(ctx, false);
|
||||||
llama_set_causal_attn(ctx, true);
|
llama_set_causal_attn(ctx, true);
|
||||||
|
|
||||||
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||||
|
|
||||||
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
|
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
|
||||||
int32_t i_current_token = 0;
|
int32_t i_current_token = 0;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
llama_batch_clear(bat);
|
llama_batch_clear(batch);
|
||||||
auto n_inputs = (int32_t)inputs.size();
|
auto n_inputs = (int32_t)inputs.size();
|
||||||
for (int32_t i = 0; i < n_inputs; i++) {
|
for (int32_t i = 0; i < n_inputs; i++) {
|
||||||
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
|
llama_batch_add(batch, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
|
||||||
}
|
}
|
||||||
inputs.clear();
|
inputs.clear();
|
||||||
|
|
||||||
llama_decode(ctx, bat);
|
llama_decode(ctx, batch);
|
||||||
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
|
auto logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||||
|
|
||||||
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
|
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
|
||||||
auto n_candidates = (int32_t)candidates.size();
|
auto n_candidates = (int32_t)candidates.size();
|
||||||
|
@ -145,7 +145,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
||||||
std::printf("\n");
|
std::printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(bat);
|
llama_batch_free(batch);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -513,7 +513,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
||||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: use batch.logits to save computations instead of relying on logits_all == true
|
// TODO: use batch.output to save computations instead of relying on logits_all == true
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||||
llama_batch_add(*batch, 0, i, { 0 }, false);
|
llama_batch_add(*batch, 0, i, { 0 }, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
batch->output[batch->n_tokens - 1] = true;
|
||||||
llama_kv_cache_clear(context);
|
llama_kv_cache_clear(context);
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
|
@ -306,7 +306,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
||||||
}
|
}
|
||||||
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||||
|
|
||||||
return reinterpret_cast<jlong>(batch);
|
return reinterpret_cast<jlong>(batch);
|
||||||
}
|
}
|
||||||
|
@ -363,7 +363,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
||||||
}
|
}
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
// llama_decode will output logits only for the last token of the prompt
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
batch->output[batch->n_tokens - 1] = true;
|
||||||
|
|
||||||
if (llama_decode(context, *batch) != 0) {
|
if (llama_decode(context, *batch) != 0) {
|
||||||
LOGe("llama_decode() failed");
|
LOGe("llama_decode() failed");
|
||||||
|
|
|
@ -16,7 +16,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
|
||||||
for i in 0..<seq_ids.count {
|
for i in 0..<seq_ids.count {
|
||||||
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
|
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
|
||||||
}
|
}
|
||||||
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
|
batch.output [Int(batch.n_tokens)] = logits ? 1 : 0
|
||||||
|
|
||||||
batch.n_tokens += 1
|
batch.n_tokens += 1
|
||||||
}
|
}
|
||||||
|
@ -132,7 +132,7 @@ actor LlamaContext {
|
||||||
let i = Int(i1)
|
let i = Int(i1)
|
||||||
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
|
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
|
||||||
}
|
}
|
||||||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
|
batch.output[Int(batch.n_tokens) - 1] = 1 // true
|
||||||
|
|
||||||
if llama_decode(context, batch) != 0 {
|
if llama_decode(context, batch) != 0 {
|
||||||
print("llama_decode() failed")
|
print("llama_decode() failed")
|
||||||
|
@ -214,7 +214,7 @@ actor LlamaContext {
|
||||||
for i in 0..<n_tokens {
|
for i in 0..<n_tokens {
|
||||||
llama_batch_add(&batch, 0, Int32(i), [0], false)
|
llama_batch_add(&batch, 0, Int32(i), [0], false)
|
||||||
}
|
}
|
||||||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
|
batch.output[Int(batch.n_tokens) - 1] = 1 // true
|
||||||
|
|
||||||
llama_kv_cache_clear(context)
|
llama_kv_cache_clear(context)
|
||||||
|
|
||||||
|
|
|
@ -265,7 +265,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// extract the logits only for the last token
|
// extract the logits only for the last token
|
||||||
if (batch.n_tokens > 0) {
|
if (batch.n_tokens > 0) {
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
client.n_prompt = tokens_prompt.size();
|
client.n_prompt = tokens_prompt.size();
|
||||||
|
@ -308,7 +308,7 @@ int main(int argc, char ** argv) {
|
||||||
batch.pos + i,
|
batch.pos + i,
|
||||||
batch.n_seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.output + i,
|
||||||
0, 0, 0, // unused
|
0, 0, 0, // unused
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -140,7 +140,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i + n_batch >= n_tokens_all) {
|
if (i + n_batch >= n_tokens_all) {
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
|
@ -174,7 +174,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i + n_batch >= n_tokens_all) {
|
if (i + n_batch >= n_tokens_all) {
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
|
|
|
@ -367,17 +367,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
const int calc_chunk = n_ctx;
|
fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), n_ctx);
|
||||||
|
|
||||||
fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
|
if (int(tokens.size()) <= n_ctx) {
|
||||||
|
|
||||||
if (int(tokens.size()) <= calc_chunk) {
|
|
||||||
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
|
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
|
||||||
tokens.size(), n_ctx, params.ppl_stride);
|
tokens.size(), n_ctx, params.ppl_stride);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
|
const int n_chunk_max = (tokens.size() - n_ctx + params.ppl_stride - 1) / params.ppl_stride;
|
||||||
|
|
||||||
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
@ -386,13 +384,14 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||||
int count = 0;
|
int count = 0;
|
||||||
double nll = 0.0;
|
double nll = 0.0;
|
||||||
|
|
||||||
|
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
||||||
|
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||||
|
|
||||||
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
||||||
|
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
for (int i = 0; i < n_chunk; ++i) {
|
||||||
const int start = i * params.ppl_stride;
|
const int start = i * params.ppl_stride;
|
||||||
const int end = start + calc_chunk;
|
const int end = start + n_ctx;
|
||||||
|
|
||||||
const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
|
|
||||||
//fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
|
//fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
|
||||||
|
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
|
@ -406,9 +405,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
|
||||||
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
|
for (int k = 0; k < batch_size; ++k) {
|
||||||
|
const int idx = batch_start + k;
|
||||||
|
llama_batch_add(batch, tokens[idx], j*n_batch + k, {0}, true);
|
||||||
|
}
|
||||||
|
|
||||||
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||||
// TODO: use llama_batch.logits instead of relying on logits_all == true
|
if (llama_decode(ctx, batch)) {
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
|
||||||
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
@ -465,6 +470,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||||
}
|
}
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
return {tokens, std::exp(nll / count), logit_history, prob_history};
|
return {tokens, std::exp(nll / count), logit_history, prob_history};
|
||||||
|
@ -601,9 +609,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
batch.pos [idx] = j*n_batch + k;
|
batch.pos [idx] = j*n_batch + k;
|
||||||
batch.n_seq_id[idx] = 1;
|
batch.n_seq_id[idx] = 1;
|
||||||
batch.seq_id [idx][0] = seq;
|
batch.seq_id [idx][0] = seq;
|
||||||
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
|
batch.output [idx] = batch.pos[idx] >= first ? 1 : 0;
|
||||||
|
|
||||||
n_outputs += batch.logits[idx] != 0;
|
n_outputs += batch.output[idx] != 0;
|
||||||
}
|
}
|
||||||
batch.n_tokens += batch_size;
|
batch.n_tokens += batch_size;
|
||||||
|
|
||||||
|
@ -697,7 +705,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
||||||
batch.pos + i,
|
batch.pos + i,
|
||||||
batch.n_seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.output + i,
|
||||||
0, 0, 0, // unused
|
0, 0, 0, // unused
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -709,7 +717,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
||||||
|
|
||||||
int n_outputs = 0;
|
int n_outputs = 0;
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
n_outputs += batch_view.logits[i] != 0;
|
n_outputs += batch_view.output[i] != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
|
memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
|
||||||
|
@ -917,7 +925,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
||||||
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||||
n_logits += 1;
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < 4; ++s) {
|
for (int s = 0; s < 4; ++s) {
|
||||||
|
@ -1196,7 +1204,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||||
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
||||||
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
n_logits += 1;
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < 2; ++s) {
|
for (int s = 0; s < 2; ++s) {
|
||||||
|
@ -1565,7 +1573,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
||||||
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
||||||
llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||||
n_logits += 1;
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
||||||
|
@ -1794,7 +1802,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
||||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: use llama_batch.logits instead of relying on logits_all == true
|
// TODO: use llama_batch.output instead of relying on logits_all == true
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -91,7 +91,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; i++) {
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
if (!batch.logits[i]) {
|
if (!batch.output[i]) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1448,7 +1448,7 @@ struct server_context {
|
||||||
std::vector<float> embd_res(n_embd, 0.0f);
|
std::vector<float> embd_res(n_embd, 0.0f);
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
if (!batch.output[i] || batch.seq_id[i][0] != slot.id + 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2260,7 +2260,7 @@ struct server_context {
|
||||||
GGML_ASSERT(batch.n_tokens > 0);
|
GGML_ASSERT(batch.n_tokens > 0);
|
||||||
|
|
||||||
// extract the logits only for the last token
|
// extract the logits only for the last token
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
slot.n_decoded = 0;
|
slot.n_decoded = 0;
|
||||||
slot.i_batch = batch.n_tokens - 1;
|
slot.i_batch = batch.n_tokens - 1;
|
||||||
|
@ -2332,7 +2332,7 @@ struct server_context {
|
||||||
batch.pos + i,
|
batch.pos + i,
|
||||||
batch.n_seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.output + i,
|
||||||
0, 0, 0, // unused
|
0, 0, 0, // unused
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -93,7 +93,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
// llama_decode will output logits only for the last token of the prompt
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
|
|
|
@ -223,7 +223,7 @@ extern "C" {
|
||||||
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||||
// - pos : the positions of the respective token in the sequence
|
// - pos : the positions of the respective token in the sequence
|
||||||
// - seq_id : the sequence to which the respective token belongs
|
// - seq_id : the sequence to which the respective token belongs
|
||||||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
// - output : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||||
//
|
//
|
||||||
typedef struct llama_batch {
|
typedef struct llama_batch {
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
|
@ -233,7 +233,7 @@ extern "C" {
|
||||||
llama_pos * pos;
|
llama_pos * pos;
|
||||||
int32_t * n_seq_id;
|
int32_t * n_seq_id;
|
||||||
llama_seq_id ** seq_id;
|
llama_seq_id ** seq_id;
|
||||||
int8_t * logits; // TODO: rename this to "output"
|
int8_t * output; // Previously named "logits", renamed to "output" now.
|
||||||
|
|
||||||
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
||||||
// for future-proof code, use the above fields instead and ignore everything below
|
// for future-proof code, use the above fields instead and ignore everything below
|
||||||
|
@ -331,7 +331,7 @@ extern "C" {
|
||||||
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
||||||
|
|
||||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.output instead)
|
||||||
bool embeddings; // if true, extract embeddings (together with logits)
|
bool embeddings; // if true, extract embeddings (together with logits)
|
||||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
||||||
|
@ -872,9 +872,9 @@ extern "C" {
|
||||||
LLAMA_API void llama_synchronize(struct llama_context * ctx);
|
LLAMA_API void llama_synchronize(struct llama_context * ctx);
|
||||||
|
|
||||||
// Token logits obtained from the last call to llama_decode()
|
// Token logits obtained from the last call to llama_decode()
|
||||||
// The logits for which llama_batch.logits[i] != 0 are stored contiguously
|
// The logits for which llama_batch.output[i] != 0 are stored contiguously
|
||||||
// in the order they have appeared in the batch.
|
// in the order they have appeared in the batch.
|
||||||
// Rows: number of tokens for which llama_batch.logits[i] != 0
|
// Rows: number of tokens for which llama_batch.output[i] != 0
|
||||||
// Cols: n_vocab
|
// Cols: n_vocab
|
||||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||||
|
|
||||||
|
@ -886,7 +886,7 @@ extern "C" {
|
||||||
|
|
||||||
// Get all output token embeddings.
|
// Get all output token embeddings.
|
||||||
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
|
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
|
||||||
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
|
// the embeddings for which llama_batch.output[i] != 0 are stored contiguously
|
||||||
// in the order they have appeared in the batch.
|
// in the order they have appeared in the batch.
|
||||||
// shape: [n_outputs*n_embd]
|
// shape: [n_outputs*n_embd]
|
||||||
// Otherwise, returns NULL.
|
// Otherwise, returns NULL.
|
||||||
|
|
|
@ -2995,17 +2995,17 @@ struct llama_sbatch {
|
||||||
ubatch.output[ubatch.n_tokens + i] = 1;
|
ubatch.output[ubatch.n_tokens + i] = 1;
|
||||||
out_ids.push_back(ids[seq.offset + i]);
|
out_ids.push_back(ids[seq.offset + i]);
|
||||||
}
|
}
|
||||||
} else if (batch->logits) {
|
} else if (batch->output) {
|
||||||
if (ubatch.equal_seqs) {
|
if (ubatch.equal_seqs) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
size_t id = ids[seq.offset + i];
|
size_t id = ids[seq.offset + i];
|
||||||
int8_t is_output = batch->logits[id];
|
int8_t is_output = batch->output[id];
|
||||||
ubatch.output[ubatch.n_tokens + i] = is_output;
|
ubatch.output[ubatch.n_tokens + i] = is_output;
|
||||||
if (is_output) { out_ids.push_back(id); }
|
if (is_output) { out_ids.push_back(id); }
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// simple split
|
// simple split
|
||||||
ubatch.output = batch->logits + seq.offset;
|
ubatch.output = batch->output + seq.offset;
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
||||||
}
|
}
|
||||||
|
@ -16099,9 +16099,9 @@ static int llama_decode_internal(
|
||||||
lctx.embd_seq.clear();
|
lctx.embd_seq.clear();
|
||||||
|
|
||||||
// count outputs
|
// count outputs
|
||||||
if (batch_all.logits && !embd_pooled) {
|
if (batch_all.output && !embd_pooled) {
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
n_outputs += batch_all.logits[i] != 0;
|
n_outputs += batch_all.output[i] != 0;
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all || embd_pooled) {
|
} else if (lctx.logits_all || embd_pooled) {
|
||||||
n_outputs = n_tokens_all;
|
n_outputs = n_tokens_all;
|
||||||
|
@ -20001,7 +20001,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
||||||
}
|
}
|
||||||
batch.seq_id[n_tokens_alloc] = nullptr;
|
batch.seq_id[n_tokens_alloc] = nullptr;
|
||||||
|
|
||||||
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
|
batch.output = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
|
||||||
|
|
||||||
return batch;
|
return batch;
|
||||||
}
|
}
|
||||||
|
@ -20017,7 +20017,7 @@ void llama_batch_free(struct llama_batch batch) {
|
||||||
}
|
}
|
||||||
free(batch.seq_id);
|
free(batch.seq_id);
|
||||||
}
|
}
|
||||||
if (batch.logits) free(batch.logits);
|
if (batch.output) free(batch.output);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_encode(
|
int32_t llama_encode(
|
||||||
|
@ -20099,7 +20099,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (j < 0) {
|
if (j < 0) {
|
||||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
throw std::runtime_error(format("batch.output[%d] != true", i));
|
||||||
}
|
}
|
||||||
if (j >= ctx->n_outputs) {
|
if (j >= ctx->n_outputs) {
|
||||||
// This should not happen
|
// This should not happen
|
||||||
|
@ -20148,7 +20148,7 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (j < 0) {
|
if (j < 0) {
|
||||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
throw std::runtime_error(format("batch.output[%d] != true", i));
|
||||||
}
|
}
|
||||||
if (j >= ctx->n_outputs) {
|
if (j >= ctx->n_outputs) {
|
||||||
// This should not happen
|
// This should not happen
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue