llama : second attempt to refactor vision API

This commit is contained in:
Xuan Son Nguyen 2025-01-18 20:56:35 +01:00
parent 2a458d1a9d
commit 0a81051ae2
20 changed files with 695 additions and 145 deletions

View file

@ -31,6 +31,7 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
/*n_seq_id =*/ ubatch_n_seq_id.data(),
/*seq_id =*/ ubatch_seq_id.data(),
/*output =*/ ubatch_output.data(),
/*embd_tensor =*/ nullptr,
};
return ubatch;
}
@ -55,7 +56,9 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
} else {
ubatch.token = nullptr;
}
if (batch->embd) {
if (batch->embd_tensor) {
ubatch.embd_tensor = batch->embd_tensor;
} else if (batch->embd) {
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
memcpy(
@ -139,7 +142,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr);
ubatch.equal_seqs = false;
if (!seq.empty()) {
llama_sbatch_seq & s = seq[0];
@ -152,7 +155,7 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr);
if (!seq.empty()) {
size_t length = 0;
size_t n_tokens_in_ubatch = 0;
@ -179,7 +182,7 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr);
if (!seq.empty()) {
llama_sbatch_seq & s = seq[seq.size() - 1];
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
@ -320,6 +323,7 @@ struct llama_batch llama_batch_get_one(
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*embd_tensor =*/ nullptr,
};
}
@ -332,6 +336,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*embd_tensor =*/ nullptr,
};
if (embd) {
@ -353,6 +358,35 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
return batch;
}
struct llama_batch llama_batch_get_one_from_tensor(struct ggml_tensor * tensor, int32_t p0, int32_t seq_id) {
GGML_ASSERT(tensor->ne[2] == 1 && tensor->ne[3] == 1);
int32_t n_tokens = tensor->ne[1];
llama_batch batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*embd_tensor =*/ tensor,
};
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens + 1));
for (int i = 0; i < n_tokens; ++i) {
batch.pos [i] = p0 + i;
batch.seq_id [i] = (llama_seq_id *) malloc(sizeof(llama_seq_id));
batch.seq_id [i][0] = seq_id;
batch.n_seq_id[i] = 1;
}
batch.seq_id[n_tokens] = nullptr;
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return batch;
}
void llama_batch_free(struct llama_batch batch) {
if (batch.token) free(batch.token);
if (batch.embd) free(batch.embd);