llama : assert input batch with pooling enabled

This commit is contained in:
Georgi Gerganov 2024-03-04 19:56:40 +02:00
parent c23c554744
commit 1af2d06139
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -8097,13 +8097,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
float * data = (float *) lctx.inp_mean->data;
float * data = (float *) lctx.inp_mean->data;
memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
std::vector<uint64_t> sum(n_tokens, 0);
for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
sum[seq_id] += 1;
}
@ -8127,10 +8130,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
const llama_pos pos = batch.pos[i];
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
if (pos == 0) {
data[seq_id] = i;
}