context : minor

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-15 10:54:21 +02:00
parent 17b363afd3
commit a19f671fe0
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 37 additions and 47 deletions

View file

@ -8,30 +8,6 @@
#include <cstring>
#include <stdexcept>
void llama_set_k_shift(struct llama_context & lctx) {
const int64_t kv_size = lctx.kv_self.size;
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
for (int i = 0; i < kv_size; ++i) {
data[i] = lctx.kv_self.cells[i].delta;
}
}
void llama_set_s_copy(struct llama_context & lctx) {
const int64_t kv_size = lctx.kv_self.size;
assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
for (int i = 0; i < kv_size; ++i) {
data[i] = lctx.kv_self.cells[i].src;
}
}
// llama input
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
@ -58,6 +34,16 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
return relative_bucket;
}
void llama_context::set_k_shift(llama_kv_cache & kv) {
assert(ggml_backend_buffer_is_host(inp_K_shift->buffer));
int32_t * data = (int32_t *) inp_K_shift->data;
for (uint32_t i = 0; i < kv.size; ++i) {
data[i] = kv.cells[i].delta;
}
}
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
//
// set input data
@ -134,7 +120,6 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
float * data = nullptr;
float * data_swa = nullptr;
@ -599,6 +584,7 @@ uint32_t llama_n_ubatch(const struct llama_context * ctx) {
}
uint32_t llama_n_seq_max(const struct llama_context * ctx) {
// TODO: add notion of n_seq_max to llama_kv_cache and use it here
return ctx->kv_self.size;
}