context : minor
ggml-ci
This commit is contained in:
parent
17b363afd3
commit
a19f671fe0
5 changed files with 37 additions and 47 deletions
|
@ -8,30 +8,6 @@
|
|||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
|
||||
void llama_set_k_shift(struct llama_context & lctx) {
|
||||
const int64_t kv_size = lctx.kv_self.size;
|
||||
|
||||
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
|
||||
|
||||
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
|
||||
|
||||
for (int i = 0; i < kv_size; ++i) {
|
||||
data[i] = lctx.kv_self.cells[i].delta;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_set_s_copy(struct llama_context & lctx) {
|
||||
const int64_t kv_size = lctx.kv_self.size;
|
||||
|
||||
assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
|
||||
|
||||
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
|
||||
|
||||
for (int i = 0; i < kv_size; ++i) {
|
||||
data[i] = lctx.kv_self.cells[i].src;
|
||||
}
|
||||
}
|
||||
|
||||
// llama input
|
||||
|
||||
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
|
@ -58,6 +34,16 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
|
|||
return relative_bucket;
|
||||
}
|
||||
|
||||
void llama_context::set_k_shift(llama_kv_cache & kv) {
|
||||
assert(ggml_backend_buffer_is_host(inp_K_shift->buffer));
|
||||
|
||||
int32_t * data = (int32_t *) inp_K_shift->data;
|
||||
|
||||
for (uint32_t i = 0; i < kv.size; ++i) {
|
||||
data[i] = kv.cells[i].delta;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
|
||||
//
|
||||
// set input data
|
||||
|
@ -134,7 +120,6 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
|
|||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
|
||||
float * data = nullptr;
|
||||
float * data_swa = nullptr;
|
||||
|
||||
|
@ -599,6 +584,7 @@ uint32_t llama_n_ubatch(const struct llama_context * ctx) {
|
|||
}
|
||||
|
||||
uint32_t llama_n_seq_max(const struct llama_context * ctx) {
|
||||
// TODO: add notion of n_seq_max to llama_kv_cache and use it here
|
||||
return ctx->kv_self.size;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue