llama : second attempt to refactor vision API
This commit is contained in:
parent
2a458d1a9d
commit
0a81051ae2
20 changed files with 695 additions and 145 deletions
|
@ -31,6 +31,7 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
|||
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
||||
/*seq_id =*/ ubatch_seq_id.data(),
|
||||
/*output =*/ ubatch_output.data(),
|
||||
/*embd_tensor =*/ nullptr,
|
||||
};
|
||||
return ubatch;
|
||||
}
|
||||
|
@ -55,7 +56,9 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
|
|||
} else {
|
||||
ubatch.token = nullptr;
|
||||
}
|
||||
if (batch->embd) {
|
||||
if (batch->embd_tensor) {
|
||||
ubatch.embd_tensor = batch->embd_tensor;
|
||||
} else if (batch->embd) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
memcpy(
|
||||
|
@ -139,7 +142,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
|
|||
|
||||
llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr);
|
||||
ubatch.equal_seqs = false;
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
|
@ -152,7 +155,7 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
|
|||
|
||||
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr);
|
||||
if (!seq.empty()) {
|
||||
size_t length = 0;
|
||||
size_t n_tokens_in_ubatch = 0;
|
||||
|
@ -179,7 +182,7 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
|
|||
|
||||
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr);
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[seq.size() - 1];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
|
@ -320,6 +323,7 @@ struct llama_batch llama_batch_get_one(
|
|||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*embd_tensor =*/ nullptr,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -332,6 +336,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
|||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*embd_tensor =*/ nullptr,
|
||||
};
|
||||
|
||||
if (embd) {
|
||||
|
@ -353,6 +358,35 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
|||
return batch;
|
||||
}
|
||||
|
||||
struct llama_batch llama_batch_get_one_from_tensor(struct ggml_tensor * tensor, int32_t p0, int32_t seq_id) {
|
||||
GGML_ASSERT(tensor->ne[2] == 1 && tensor->ne[3] == 1);
|
||||
int32_t n_tokens = tensor->ne[1];
|
||||
llama_batch batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*embd_tensor =*/ tensor,
|
||||
};
|
||||
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
|
||||
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
|
||||
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens + 1));
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
batch.pos [i] = p0 + i;
|
||||
batch.seq_id [i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
|
||||
batch.seq_id [i][0] = seq_id;
|
||||
batch.n_seq_id[i] = 1;
|
||||
}
|
||||
batch.seq_id[n_tokens] = nullptr;
|
||||
|
||||
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||
|
||||
return batch;
|
||||
}
|
||||
|
||||
void llama_batch_free(struct llama_batch batch) {
|
||||
if (batch.token) free(batch.token);
|
||||
if (batch.embd) free(batch.embd);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue