fix build
This commit is contained in:
parent
6ab116ac5a
commit
540c3016d8
1 changed files with 52 additions and 50 deletions
102
src/llama.cpp
102
src/llama.cpp
|
@ -5190,6 +5190,56 @@ struct llama_model_loader {
|
|||
}
|
||||
};
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
static const llama_seq_id batch_default_seq_id = 0;
|
||||
struct llama_batch_allocr {
|
||||
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<int8_t> logits;
|
||||
struct llama_batch batch;
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
|
||||
batch = in_batch;
|
||||
if (!batch.pos) {
|
||||
// determine the last position in KV cache
|
||||
llama_pos last_pos = -1;
|
||||
for (const auto & cell : ctx.kv_self.cells) {
|
||||
if (cell.has_seq_id(batch_default_seq_id)) {
|
||||
last_pos = std::max(last_pos, cell.pos);
|
||||
}
|
||||
}
|
||||
last_pos++; // next position
|
||||
pos.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
pos[i] = i+last_pos;
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
}
|
||||
if (!batch.n_seq_id) {
|
||||
n_seq_id.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
n_seq_id[i] = seq_id_0.size();
|
||||
}
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
}
|
||||
if (!batch.seq_id) {
|
||||
seq_id.resize(batch.n_tokens + 1);
|
||||
seq_id[batch.n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
seq_id[i] = seq_id_0.data();
|
||||
}
|
||||
batch.seq_id = seq_id.data();
|
||||
}
|
||||
if (!batch.logits) {
|
||||
logits.resize(batch.n_tokens);
|
||||
logits[logits.size() - 1] = true;
|
||||
batch.logits = logits.data();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
|
||||
uint32_t tmp;
|
||||
|
@ -17117,6 +17167,7 @@ static int llama_decode_internal(
|
|||
return -1;
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
llama_batch_allocr batch_allocr(lctx, inp_batch);
|
||||
llama_batch batch = batch_allocr.batch;
|
||||
const uint32_t n_tokens_all = batch.n_tokens;
|
||||
|
@ -17434,6 +17485,7 @@ static int llama_encode_internal(
|
|||
return -1;
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
llama_batch_allocr batch_allocr(lctx, inp_batch);
|
||||
llama_batch batch = batch_allocr.batch;
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
|
@ -21132,56 +21184,6 @@ void llama_batch_free(struct llama_batch batch) {
|
|||
if (batch.logits) free(batch.logits);
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
static const llama_seq_id batch_default_seq_id = 0;
|
||||
struct llama_batch_allocr {
|
||||
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<int8_t> logits;
|
||||
struct llama_batch batch;
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
|
||||
batch = in_batch;
|
||||
if (!batch.pos) {
|
||||
// determine the last position in KV cache
|
||||
llama_pos last_pos = -1;
|
||||
for (const auto & cell : ctx.kv_self.cells) {
|
||||
if (cell.has_seq_id(batch_default_seq_id)) {
|
||||
last_pos = std::max(last_pos, cell.pos);
|
||||
}
|
||||
}
|
||||
last_pos++; // next position
|
||||
pos.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
pos[i] = i+last_pos;
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
}
|
||||
if (!batch.n_seq_id) {
|
||||
n_seq_id.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
n_seq_id[i] = seq_id_0.size();
|
||||
}
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
}
|
||||
if (!batch.seq_id) {
|
||||
seq_id.resize(batch.n_tokens + 1);
|
||||
seq_id[batch.n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
seq_id[i] = seq_id_0.data();
|
||||
}
|
||||
batch.seq_id = seq_id.data();
|
||||
}
|
||||
if (!batch.logits) {
|
||||
logits.resize(batch.n_tokens);
|
||||
logits[logits.size() - 1] = true;
|
||||
batch.logits = logits.data();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue