llama : assert input batch with pooling enabled
This commit is contained in:
parent
c23c554744
commit
1af2d06139
1 changed files with 8 additions and 1 deletions
|
@ -8097,13 +8097,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
|
||||||
float * data = (float *) lctx.inp_mean->data;
|
|
||||||
|
|
||||||
|
float * data = (float *) lctx.inp_mean->data;
|
||||||
memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
|
memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
|
||||||
|
|
||||||
std::vector<uint64_t> sum(n_tokens, 0);
|
std::vector<uint64_t> sum(n_tokens, 0);
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||||
|
|
||||||
|
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
|
||||||
|
|
||||||
sum[seq_id] += 1;
|
sum[seq_id] += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8127,10 +8130,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
||||||
|
|
||||||
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
||||||
|
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
|
||||||
|
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||||
const llama_pos pos = batch.pos[i];
|
const llama_pos pos = batch.pos[i];
|
||||||
|
|
||||||
|
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
|
||||||
|
|
||||||
if (pos == 0) {
|
if (pos == 0) {
|
||||||
data[seq_id] = i;
|
data[seq_id] = i;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue