cuda backend can be used though ggml-backend with LLAMA_GGML_BACKEND_CUDA_TEST

access all tensor data with ggml_backend_tensor_get/set
This commit is contained in:
slaren 2023-12-19 17:55:37 +01:00
parent 94507911bb
commit 0c5ee7c417

108
llama.cpp
View file

@ -1,4 +1,5 @@
#define LLAMA_API_INTERNAL #define LLAMA_API_INTERNAL
#define LLAMA_GGML_BACKEND_CUDA_TEST // for testing only - disables partial offloading
#include "llama.h" #include "llama.h"
#include "unicode.h" #include "unicode.h"
@ -1289,7 +1290,7 @@ struct llama_kv_cache {
ggml_backend_buffer_t buf = NULL; ggml_backend_buffer_t buf = NULL;
~llama_kv_cache() { ~llama_kv_cache() {
#ifdef GGML_USE_CUBLAS #if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST)
if (ggml_cublas_loaded()) { if (ggml_cublas_loaded()) {
for (size_t i = 0; i < k_l.size(); ++i) { for (size_t i = 0; i < k_l.size(); ++i) {
ggml_cuda_free_data(k_l[i]); ggml_cuda_free_data(k_l[i]);
@ -1403,7 +1404,7 @@ struct llama_model {
int64_t t_start_us = 0; int64_t t_start_us = 0;
~llama_model() { ~llama_model() {
#ifdef GGML_USE_CUBLAS #if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST)
if (ggml_cublas_loaded()) { if (ggml_cublas_loaded()) {
for (size_t i = 0; i < tensors_by_name.size(); ++i) { for (size_t i = 0; i < tensors_by_name.size(); ++i) {
ggml_cuda_free_data(tensors_by_name[i].second); ggml_cuda_free_data(tensors_by_name[i].second);
@ -1472,6 +1473,9 @@ struct llama_context {
ggml_backend_buffer_t buf_alloc = NULL; ggml_backend_buffer_t buf_alloc = NULL;
ggml_allocr * alloc = NULL; ggml_allocr * alloc = NULL;
// temporary buffer for copying data to/from the backend
std::vector<no_init<uint8_t>> buf_copy;
#ifdef GGML_USE_MPI #ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = NULL; ggml_mpi_context * ctx_mpi = NULL;
#endif #endif
@ -1549,7 +1553,7 @@ static bool llama_kv_cache_init(
if (cache.buf) { if (cache.buf) {
// TODO: ggml_backend_buffer_memset // TODO: ggml_backend_buffer_memset
// this is only valid with CPU buffers! // this is only valid with CPU buffers!
memset(ggml_backend_buffer_get_base(cache.buf), 0, ggml_backend_buffer_get_size(cache.buf)); //memset(ggml_backend_buffer_get_base(cache.buf), 0, ggml_backend_buffer_get_size(cache.buf));
} }
if (vram_kv_cache > 0) { if (vram_kv_cache > 0) {
@ -2249,7 +2253,7 @@ struct llama_model_loader {
} }
} }
void load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, ggml_backend_buffer_t mmap_buf, llama_mlock * lmlock) { void load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, ggml_backend_buffer_t buf_mmap, llama_mlock * lmlock) {
size_t size_lock = 0; size_t size_lock = 0;
size_t size_data = 0; size_t size_data = 0;
@ -2276,12 +2280,13 @@ struct llama_model_loader {
switch (cur->backend) { switch (cur->backend) {
case GGML_BACKEND_CPU: case GGML_BACKEND_CPU:
if (use_mmap) { if (use_mmap) {
if (mmap_buf) { if (buf_mmap) {
ggml_backend_tensor_alloc(mmap_buf, cur, (uint8_t *)mapping->addr + offs); ggml_backend_tensor_alloc(buf_mmap, cur, (uint8_t *)mapping->addr + offs);
} else { } else {
ggml_backend_tensor_set(cur, (uint8_t *)mapping->addr + offs, 0, ggml_nbytes(cur)); ggml_backend_tensor_set(cur, (uint8_t *)mapping->addr + offs, 0, ggml_nbytes(cur));
} }
} else { } else {
// FIXME: use read_buf for device buffers without unified memory
file.seek(offs, SEEK_SET); file.seek(offs, SEEK_SET);
file.read_raw(cur->data, ggml_nbytes(cur)); file.read_raw(cur->data, ggml_nbytes(cur));
} }
@ -2905,15 +2910,16 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
} }
// TODO: metal should be disabled with ngl=0 -> cpu_buffer_type
static ggml_backend_buffer_type_t llama_default_buffer_type(int n_gpu_layers) { static ggml_backend_buffer_type_t llama_default_buffer_type(int n_gpu_layers) {
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (n_gpu_layers > 0) { if (n_gpu_layers > 0) {
return ggml_backend_metal_buffer_type(); return ggml_backend_metal_buffer_type();
} }
#elif GGML_USE_CUBLAS #elif defined(GGML_USE_CUBLAS) && defined(LLAMA_GGML_BACKEND_CUDA_TEST)
return ggml_backend_cuda_buffer_type(0);
#elif defined(GGML_USE_CUBLAS)
return ggml_backend_cuda_host_buffer_type(); return ggml_backend_cuda_host_buffer_type();
#elif GGML_USE_CPU_HBM #elif defined(GGML_USE_CPU_HBM)
return ggml_backend_cpu_hbm_buffer_type(); return ggml_backend_cpu_hbm_buffer_type();
#endif #endif
@ -3520,20 +3526,26 @@ static void llm_load_tensors(
// create backend buffer // create backend buffer
bool sys_mem_buf = false; bool sys_mem_buf = false;
ggml_backend_buffer_t buf_mmap = nullptr;
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
// todo: disable with 0 gpu layers // todo: disable with 0 gpu layers
if (ml.use_mmap) { if (ml.use_mmap) {
const size_t max_size = ggml_get_max_tensor_size(ctx); const size_t max_size = ggml_get_max_tensor_size(ctx);
model.buf = ggml_backend_metal_buffer_from_ptr(ml.mapping->addr, ml.mapping->size, max_size); model.buf = ggml_backend_metal_buffer_from_ptr(ml.mapping->addr, ml.mapping->size, max_size);
buf_mmap = model.buf;
} else { } else {
model.buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_metal_buffer_type()); model.buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_metal_buffer_type());
sys_mem_buf = true; sys_mem_buf = true;
} }
#elif defined(GGML_USE_CUBLAS) && defined(LLAMA_GGML_BACKEND_CUDA_TEST)
// for testing only
model.buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cuda_buffer_type(0));
#else #else
// CPU backend, and indirectly CUDA and OpenCL // CPU backend, and indirectly CUDA and OpenCL
if (ml.use_mmap) { if (ml.use_mmap) {
model.buf = ggml_backend_cpu_buffer_from_ptr(ml.mapping->addr, ml.mapping->size); model.buf = ggml_backend_cpu_buffer_from_ptr(ml.mapping->addr, ml.mapping->size);
buf_mmap = model.buf;
} else { } else {
// allocate only CPU tensors // allocate only CPU tensors
model.buf = ggml_backend_buft_alloc_buffer(buft, buf_size); model.buf = ggml_backend_buft_alloc_buffer(buft, buf_size);
@ -3549,7 +3561,6 @@ static void llm_load_tensors(
#endif #endif
if (use_mlock && sys_mem_buf) { if (use_mlock && sys_mem_buf) {
// TODO: CPU/metal only
model.mlock_buf.init (ggml_backend_buffer_get_base(model.buf)); model.mlock_buf.init (ggml_backend_buffer_get_base(model.buf));
model.mlock_buf.grow_to(ggml_backend_buffer_get_size(model.buf)); model.mlock_buf.grow_to(ggml_backend_buffer_get_size(model.buf));
} }
@ -3591,7 +3602,7 @@ static void llm_load_tensors(
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
// TODO: only pass buf if it is a mmap buffer // TODO: only pass buf if it is a mmap buffer
ml.load_all_data(ctx, progress_callback, progress_callback_user_data, model.buf, use_mlock ? &model.mlock_mmap : NULL); ml.load_all_data(ctx, progress_callback, progress_callback_user_data, buf_mmap, use_mlock ? &model.mlock_mmap : NULL);
if (progress_callback) { if (progress_callback) {
progress_callback(1.0f, progress_callback_user_data); progress_callback(1.0f, progress_callback_user_data);
@ -5690,7 +5701,7 @@ static struct ggml_cgraph * llama_build_graph(
if (!ggml_allocr_is_measure(lctx.alloc) && batch.token) { if (!ggml_allocr_is_measure(lctx.alloc) && batch.token) {
const int64_t n_tokens = cur->ne[0]; const int64_t n_tokens = cur->ne[0];
memcpy(cur->data, batch.token, n_tokens*ggml_element_size(cur)); ggml_backend_tensor_set(cur, batch.token, 0, n_tokens*ggml_element_size(cur));
} }
alloc_inp_tokens = true; alloc_inp_tokens = true;
@ -5703,7 +5714,7 @@ static struct ggml_cgraph * llama_build_graph(
const int64_t n_embd = cur->ne[0]; const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1]; const int64_t n_tokens = cur->ne[1];
memcpy(cur->data, batch.embd, n_tokens*n_embd*ggml_element_size(cur)); ggml_backend_tensor_set(cur, batch.embd, 0, n_tokens*n_embd*ggml_element_size(cur));
} }
alloc_inp_embd = true; alloc_inp_embd = true;
@ -5715,11 +5726,8 @@ static struct ggml_cgraph * llama_build_graph(
if (!ggml_allocr_is_measure(lctx.alloc) && batch.pos) { if (!ggml_allocr_is_measure(lctx.alloc) && batch.pos) {
const int64_t n_tokens = cur->ne[0]; const int64_t n_tokens = cur->ne[0];
int32_t * data = (int32_t *) cur->data; static_assert(std::is_same<llama_pos, int32_t>::value, "llama_pos must be int32_t");
ggml_backend_tensor_set(cur, batch.pos, 0, n_tokens*ggml_element_size(cur));
for (int i = 0; i < n_tokens; ++i) {
data[i] = batch.pos[i];
}
} }
alloc_inp_pos = true; alloc_inp_pos = true;
@ -5730,7 +5738,8 @@ static struct ggml_cgraph * llama_build_graph(
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_embd_head = model.hparams.n_embd_head(); const int64_t n_embd_head = model.hparams.n_embd_head();
ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head))); float f = 1.0f/sqrtf(float(n_embd_head));
ggml_backend_tensor_set(cur, &f, 0, sizeof(f));
} }
alloc_inp_Q_scale = true; alloc_inp_Q_scale = true;
@ -5741,13 +5750,15 @@ static struct ggml_cgraph * llama_build_graph(
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_embd_head = model.hparams.n_embd_head(); const int64_t n_embd_head = model.hparams.n_embd_head();
float f;
if (model.arch == LLM_ARCH_PHI2) { if (model.arch == LLM_ARCH_PHI2) {
// with phi2, we scale the Q to avoid precision issues // with phi2, we scale the Q to avoid precision issues
// ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
ggml_set_f32(cur, 1.0f); f = 1.0f;
} else { } else {
ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head))); f = 1.0f/sqrtf(float(n_embd_head));
} }
ggml_backend_tensor_set(cur, &f, 0, sizeof(f));
} }
alloc_inp_KQ_scale = true; alloc_inp_KQ_scale = true;
@ -5760,8 +5771,13 @@ static struct ggml_cgraph * llama_build_graph(
const int64_t n_kv = cur->ne[0]; const int64_t n_kv = cur->ne[0];
const int64_t n_tokens = cur->ne[1]; const int64_t n_tokens = cur->ne[1];
float * data = (float *) cur->data; float * data;
memset(data, 0, ggml_nbytes(cur)); if (/*is_sys_mem_buf(cur->buffer)*/false) { // TODO
data = (float *) cur->data;
} else {
lctx.buf_copy.resize(ggml_nbytes(cur));
data = (float *) lctx.buf_copy.data();
}
for (int h = 0; h < 1; ++h) { for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) { for (int j = 0; j < n_tokens; ++j) {
@ -5769,11 +5785,19 @@ static struct ggml_cgraph * llama_build_graph(
const llama_seq_id seq_id = batch.seq_id[j][0]; const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; f = -INFINITY;
} else {
f = 0;
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
} }
} }
} }
if (data != cur->data) {
ggml_backend_tensor_set(cur, data, 0, ggml_nbytes(cur));
} }
} }
@ -5786,11 +5810,21 @@ static struct ggml_cgraph * llama_build_graph(
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_ctx = cur->ne[0]; const int64_t n_ctx = cur->ne[0];
int32_t * data = (int32_t *) cur->data; int32_t * data;
if (/*is_sys_mem_buf(cur->buffer)*/false) { // TODO
data = (int32_t *) cur->data;
} else {
lctx.buf_copy.resize(ggml_nbytes(cur));
data = (int32_t *) lctx.buf_copy.data();
}
for (int i = 0; i < n_ctx; ++i) { for (int i = 0; i < n_ctx; ++i) {
data[i] = lctx.kv_self.cells[i].delta; data[i] = lctx.kv_self.cells[i].delta;
} }
if (data != cur->data) {
ggml_backend_tensor_set(cur, data, 0, ggml_nbytes(cur));
}
} }
alloc_inp_K_shift = true; alloc_inp_K_shift = true;
@ -6122,7 +6156,7 @@ static int llama_decode_internal(
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
} }
#ifdef GGML_USE_CUBLAS #if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST)
char * buf_alloc_base = (char *)ggml_backend_buffer_get_base(lctx.buf_alloc); char * buf_alloc_base = (char *)ggml_backend_buffer_get_base(lctx.buf_alloc);
for (int i = 0; i < gf->n_leafs; i++) { for (int i = 0; i < gf->n_leafs; i++) {
ggml_tensor * node = gf->leafs[i]; ggml_tensor * node = gf->leafs[i];
@ -6162,7 +6196,7 @@ static int llama_decode_internal(
n_threads = 1; n_threads = 1;
} }
#if GGML_USE_MPI #ifdef GGML_USE_MPI
const int64_t n_layer = hparams.n_layer; const int64_t n_layer = hparams.n_layer;
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
#endif #endif
@ -6178,7 +6212,7 @@ static int llama_decode_internal(
} }
ggml_backend_graph_compute(lctx.backend, gf); ggml_backend_graph_compute(lctx.backend, gf);
#if GGML_USE_MPI #ifdef GGML_USE_MPI
ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer);
#endif #endif
@ -6230,20 +6264,20 @@ static int llama_decode_internal(
if (batch.logits[i] == 0) { if (batch.logits[i] == 0) {
continue; continue;
} }
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); ggml_backend_tensor_get(res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
#ifndef NDEBUG #ifndef NDEBUG
logits_valid[i] = true; logits_valid[i] = true;
#endif #endif
} }
} else if (lctx.logits_all) { } else if (lctx.logits_all) {
logits_out.resize(n_vocab * n_tokens); logits_out.resize(n_vocab * n_tokens);
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); ggml_backend_tensor_get(res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
#ifndef NDEBUG #ifndef NDEBUG
std::fill(logits_valid.begin(), logits_valid.end(), true); std::fill(logits_valid.begin(), logits_valid.end(), true);
#endif #endif
} else { } else {
logits_out.resize(n_vocab); logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); ggml_backend_tensor_get(res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float));
#ifndef NDEBUG #ifndef NDEBUG
logits_valid[0] = true; logits_valid[0] = true;
#endif #endif
@ -6255,7 +6289,7 @@ static int llama_decode_internal(
auto & embedding_out = lctx.embedding; auto & embedding_out = lctx.embedding;
embedding_out.resize(n_embd); embedding_out.resize(n_embd);
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(n_tokens - 1)), sizeof(float)*n_embd); ggml_backend_tensor_get(embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float));
} }
// measure the performance only for the single-token evals // measure the performance only for the single-token evals
@ -9187,8 +9221,16 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_ERROR("%s: failed to initialize Metal backend\n", __func__); LLAMA_LOG_ERROR("%s: failed to initialize Metal backend\n", __func__);
} }
} }
#endif #elif defined(GGML_USE_CUBLAS) && defined(LLAMA_GGML_BACKEND_CUDA_TEST)
// for testing only
ctx->backend = ggml_backend_cuda_init(0);
if (ctx->backend == nullptr) { if (ctx->backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize CUDA backend\n", __func__);
}
#endif
if (ctx->backend == nullptr) {
// FIXME: this may fail if the model buffer is not compatible with the CPU backend
ctx->backend = ggml_backend_cpu_init(); ctx->backend = ggml_backend_cpu_init();
} }