diff --git a/.gitignore b/.gitignore index d231f3ff8..07528b5c6 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,12 @@ qnt-*.txt perf-*.txt examples/jeopardy/results.txt + +/prompts +*.sh +*.log +*.py +*.txt +/wikitext-2-raw/ +*.org +/libllama.so diff --git a/Makefile b/Makefile index f9ec8797a..93c149edc 100644 --- a/Makefile +++ b/Makefile @@ -125,7 +125,7 @@ endif ifdef LLAMA_CUBLAS CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include - LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib + LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lcufile -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b93c3ea8c..893c7d4a9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2,11 +2,13 @@ #include #include #include +#include #include #include #include #include +#include #include "ggml-cuda.h" #include "ggml.h" @@ -32,6 +34,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } \ } while (0) +#define CUFILE_CHECK(status) \ + do { \ + CUfileError_t status_ = (status); \ + if (status_.err != CU_FILE_SUCCESS) { \ + fprintf(stderr, "cuFile error %d at %s:%d\n", status_.err, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream); @@ -372,7 +383,7 @@ struct cuda_buffer { static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS]; static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; -static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { scoped_spin_lock lock(g_cuda_pool_lock); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { @@ -391,7 +402,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { return ptr; } -static void ggml_cuda_pool_free(void * ptr, size_t size) { +void ggml_cuda_pool_free(void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { @@ -431,6 +442,9 @@ void ggml_init_cublas() { // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + + // initialize cuFile for loading model parameters directly to VRAM + CUFILE_CHECK(cuFileDriverOpen()); } } @@ -893,3 +907,34 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) { tensor->data = dst; tensor->backend = GGML_BACKEND_CUDA; } + +bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensors, const int num_tensors, const size_t * offsets) { + CUfileDescr_t cf_descr; + memset((void *)&cf_descr, 0, sizeof(CUfileDescr_t)); + const int fd_cf = open(fname, O_RDONLY|O_DIRECT, 0644); + cf_descr.handle.fd = fd_cf; + cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; + + CUfileHandle_t cf_handle; + CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr); + + if (status.err == CU_FILE_INTERNAL_ERROR) { + fprintf(stderr, "WARNING: cuFile experienced an internal error while loading weights from \"%s\". Using a workaround (slower). " + "This happens with weight files on Btrfs partitions. ext4 and NTFS are confirmed to work.\n", fname); + } + if (status.err != CU_FILE_SUCCESS) { + return false; + } + + for (int i = 0; i < num_tensors; ++i) { + ggml_tensor * tensor = tensors[i]; + const size_t size = ggml_nbytes(tensor); + const size_t offset = offsets[i]; + + size_t actual_size; + void * buf = ggml_cuda_pool_malloc(size, &actual_size); + cuFileRead(cf_handle, buf, size, offset, 0); + tensor->data = buf; + } + return true; +} diff --git a/ggml-cuda.h b/ggml-cuda.h index 682a2ce0a..7485af9f7 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -14,8 +14,11 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens // TODO: export these with GGML_API void * ggml_cuda_host_malloc(size_t size); void ggml_cuda_host_free(void * ptr); +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); +void ggml_cuda_pool_free(void * ptr, size_t size); void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); +bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensors, int num_tensors, const size_t * offsets); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index 0b237a3d9..644b6a57b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10,6 +10,7 @@ #include "ggml.h" #ifdef GGML_USE_CUBLAS +#include #include "ggml-cuda.h" #endif @@ -641,7 +642,7 @@ struct llama_model_loader { } } - struct ggml_tensor * get_tensor(const std::string & name, const std::vector & ne) { + struct ggml_tensor * get_tensor(const std::string & name, const std::vector & ne, ggml_backend backend) { auto it = tensors_map.name_to_idx.find(name); if (it == tensors_map.name_to_idx.end()) { throw format("llama.cpp: tensor '%s' is missing from model", name.c_str()); @@ -652,10 +653,10 @@ struct llama_model_loader { name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str()); } - return get_tensor_for(lt); + return get_tensor_for(lt, backend); } - struct ggml_tensor * get_tensor_for(llama_load_tensor & lt) { + struct ggml_tensor * get_tensor_for(llama_load_tensor & lt, ggml_backend backend) { struct ggml_tensor * tensor; if (lt.ne.size() == 2) { tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1)); @@ -665,6 +666,7 @@ struct llama_model_loader { } ggml_set_name(tensor, lt.name.c_str()); LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor + tensor->backend = backend; lt.ggml_tensor = tensor; num_ggml_tensors_created++; return tensor; @@ -683,7 +685,7 @@ struct llama_model_loader { } if (use_mmap) { - mapping.reset(new llama_mmap(&file_loaders.at(0)->file)); + mapping.reset(new llama_mmap(&file_loaders.at(0)->file, false)); if (!lmlock) { // Don't call the callback since the actual loading will be lazy // and we can't measure it. @@ -696,6 +698,9 @@ struct llama_model_loader { size_t done_size = 0; for (llama_load_tensor & lt : tensors_map.tensors) { + if (lt.ggml_tensor->backend != GGML_BACKEND_CPU) { + continue; + } if (progress_callback) { progress_callback((float) done_size / data_size, progress_callback_user_data); } @@ -944,26 +949,6 @@ static void llama_model_load_internal( ml->calc_sizes(&ctx_size, &mmapped_size); fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0); - // print memory requirements - { - const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; - - // this is the total memory required to run the inference - const size_t mem_required = - ctx_size + - mmapped_size + - MEM_REQ_SCRATCH0().at(model.type) + - MEM_REQ_SCRATCH1().at(model.type) + - MEM_REQ_EVAL().at(model.type); - - // this is the memory required by one llama_state - const size_t mem_required_state = - scale*MEM_REQ_KV_SELF().at(model.type); - - fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, - mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); - } - // create the ggml context { lctx.model.buf.resize(ctx_size); @@ -985,6 +970,7 @@ static void llama_model_load_internal( } // prepare memory for the weights + size_t vram_total = 0; { const uint32_t n_embd = hparams.n_embd; const uint32_t n_layer = hparams.n_layer; @@ -992,33 +978,78 @@ static void llama_model_load_internal( ml->ggml_ctx = ctx; - model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}); - model.norm = ml->get_tensor("norm.weight", {n_embd}); - model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}); + model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.norm = ml->get_tensor("norm.weight", {n_embd}, GGML_BACKEND_CPU); + ggml_backend backend_output; + if (n_gpu_layers > int(n_layer)) { + backend_output = GGML_BACKEND_CUDA; + } else { + backend_output = GGML_BACKEND_CPU; + } + model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); model.layers.resize(n_layer); + const int i_gpu_start = n_layer - n_gpu_layers; for (uint32_t i = 0; i < n_layer; ++i) { auto & layer = model.layers[i]; + const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : GGML_BACKEND_CUDA; std::string layers_i = "layers." + std::to_string(i); - layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}); + layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend); - layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}); - layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}); - layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}); - layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}); + layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend); + layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend); + layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend); + layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend); - layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}); + layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); - layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}); - layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}); - layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}); + layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend); + layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend); + layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend); + if (backend == GGML_BACKEND_CUDA) { + vram_total += ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) + + ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + } } } ml->done_getting_tensors(); + // print memory requirements + { + const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; + + // this is the total memory required to run the inference + const size_t mem_required = + ctx_size + + mmapped_size - vram_total + // weights in VRAM not in memory + MEM_REQ_SCRATCH0().at(model.type) + + MEM_REQ_SCRATCH1().at(model.type) + + MEM_REQ_EVAL().at(model.type); + + // this is the memory required by one llama_state + const size_t mem_required_state = + scale*MEM_REQ_KV_SELF().at(model.type); + + fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, + mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); + +#ifdef GGML_USE_CUBLAS + const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + + fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu); + if (n_gpu_layers > (int) hparams.n_layer) { + fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__); + } + fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); +#else + (void) n_gpu_layers; +#endif + } + // populate `tensors_by_name` for (llama_load_tensor & lt : ml->tensors_map.tensors) { model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor); @@ -1026,38 +1057,44 @@ static void llama_model_load_internal( ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); - model.mapping = std::move(ml->mapping); #ifdef GGML_USE_CUBLAS { - const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); - - fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu); - - size_t vram_total = 0; - - for (int i = 0; i < n_gpu; ++i) { - const auto & layer = model.layers[i]; - - ggml_cuda_transform_tensor(layer.attention_norm); vram_total += ggml_nbytes(layer.attention_norm); - ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq); - ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk); - ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv); - ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo); - ggml_cuda_transform_tensor(layer.ffn_norm); vram_total += ggml_nbytes(layer.ffn_norm); - ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1); - ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2); - ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3); - } - if (n_gpu_layers > (int) hparams.n_layer) { - fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__); - ggml_cuda_transform_tensor(model.output); vram_total += ggml_nbytes(model.output); + std::vector tensors; + std::vector offsets; + for (llama_load_tensor & lt : ml->tensors_map.tensors) { + if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) { + continue; + } + tensors.emplace_back(lt.ggml_tensor); + LLAMA_ASSERT(lt.shards.size() == 1); + offsets.emplace_back(lt.shards.at(0).file_off); } + bool cufile_success = ggml_cuda_load_data_cufile(fname.c_str(), tensors.data(), tensors.size(), offsets.data()); - fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + if (!cufile_success) { + for (llama_load_tensor & lt : ml->tensors_map.tensors) { + if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) { + continue; + } + size_t actual_size; + void * buf = ggml_cuda_pool_malloc(lt.size, &actual_size); + void * buf_host = ggml_cuda_host_malloc(lt.size); + + llama_file & file = ml->file_loaders.at(lt.shards.at(0).file_idx)->file; + file.seek(lt.shards.at(0).file_off, SEEK_SET); + file.read_raw(buf_host, lt.size); + + cudaMemcpy(buf, buf_host, lt.size, cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + lt.ggml_tensor->data = buf; + ggml_cuda_host_free(buf_host); + } + } } -#else - (void) n_gpu_layers; -#endif +#endif // GGML_USE_CUBLAS + + model.mapping = std::move(ml->mapping); // loading time will be recalculate after the first eval, so // we take page faults deferred by mmap() into consideration @@ -2395,7 +2432,8 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * } size_t idx = model_loader->tensors_map.name_to_idx[base_name]; llama_load_tensor & lt = model_loader->tensors_map.tensors[idx]; - base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }); + base_t = model_loader->get_tensor( + base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU); lt.data = (uint8_t *) lt.ggml_tensor->data; model_loader->load_data_for(lt); lt.ggml_tensor->data = lt.data;