diff --git a/llama.cpp b/llama.cpp index cef984e8f..311bb89fe 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8097,13 +8097,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); - float * data = (float *) lctx.inp_mean->data; + float * data = (float *) lctx.inp_mean->data; memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); std::vector sum(n_tokens, 0); for (int i = 0; i < n_tokens; ++i) { const llama_seq_id seq_id = batch.seq_id[i][0]; + + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); + sum[seq_id] += 1; } @@ -8127,10 +8130,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); uint32_t * data = (uint32_t *) lctx.inp_cls->data; + memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); for (int i = 0; i < n_tokens; ++i) { const llama_seq_id seq_id = batch.seq_id[i][0]; const llama_pos pos = batch.pos[i]; + + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); + if (pos == 0) { data[seq_id] = i; }