llama : assert input batch with pooling enabled
This commit is contained in:
parent
c23c554744
commit
1af2d06139
1 changed files with 8 additions and 1 deletions
|
@ -8097,13 +8097,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
|
||||
float * data = (float *) lctx.inp_mean->data;
|
||||
|
||||
float * data = (float *) lctx.inp_mean->data;
|
||||
memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
|
||||
|
||||
std::vector<uint64_t> sum(n_tokens, 0);
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
|
||||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
|
||||
|
||||
sum[seq_id] += 1;
|
||||
}
|
||||
|
||||
|
@ -8127,10 +8130,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
||||
|
||||
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
||||
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
const llama_pos pos = batch.pos[i];
|
||||
|
||||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
|
||||
|
||||
if (pos == 0) {
|
||||
data[seq_id] = i;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue