Fix alignment bug in llama.com

This commit is contained in:
Justine Tunney 2023-05-10 06:15:32 -07:00
parent ca990ef091
commit 6cb9553706
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
2 changed files with 10 additions and 7 deletions

View file

@ -27,6 +27,7 @@
*/ */
#include "third_party/ggml/llama.h" #include "third_party/ggml/llama.h"
#include "libc/assert.h"
#include "libc/intrin/bits.h" #include "libc/intrin/bits.h"
#include "third_party/ggml/ggml.h" #include "third_party/ggml/ggml.h"
#include "third_party/ggml/llama_util.h" #include "third_party/ggml/llama_util.h"
@ -2540,8 +2541,9 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
if (kv_size) { if (kv_size) {
const size_t elt_size = ggml_element_size(kv_self.k); const size_t elt_size = ggml_element_size(kv_self.k);
char buffer[4096]; llama_buffer buffer;
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true }); buffer.resize(4096);
ggml_context * cpy_ctx = ggml_init({ buffer.size, buffer.addr, /* no_alloc */ true });
ggml_cgraph gf{}; ggml_cgraph gf{};
gf.n_threads = 1; gf.n_threads = 1;
@ -2644,8 +2646,9 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
LLAMA_ASSERT(kv_self.buf.size == kv_size); LLAMA_ASSERT(kv_self.buf.size == kv_size);
const size_t elt_size = ggml_element_size(kv_self.k); const size_t elt_size = ggml_element_size(kv_self.k);
char buffer[4096]; llama_buffer buffer;
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true }); buffer.resize(4096);
ggml_context * cpy_ctx = ggml_init({ buffer.size, buffer.addr, /* no_alloc */ true });
ggml_cgraph gf{}; ggml_cgraph gf{};
gf.n_threads = 1; gf.n_threads = 1;

View file

@ -377,13 +377,13 @@ struct llama_buffer {
size_t size = 0; size_t size = 0;
void resize(size_t size) { void resize(size_t size) {
delete[] addr; free(addr);
addr = new uint8_t[size]; addr = (uint8_t *)memalign(32, size);
this->size = size; this->size = size;
} }
~llama_buffer() { ~llama_buffer() {
delete[] addr; free(addr);
} }
}; };
#endif #endif