GPU weights not in RAM, direct loading with cuFile
This commit is contained in:
parent
2365a2a970
commit
fa1a29f36f
5 changed files with 162 additions and 67 deletions
9
.gitignore
vendored
9
.gitignore
vendored
|
@ -48,3 +48,12 @@ qnt-*.txt
|
||||||
perf-*.txt
|
perf-*.txt
|
||||||
|
|
||||||
examples/jeopardy/results.txt
|
examples/jeopardy/results.txt
|
||||||
|
|
||||||
|
/prompts
|
||||||
|
*.sh
|
||||||
|
*.log
|
||||||
|
*.py
|
||||||
|
*.txt
|
||||||
|
/wikitext-2-raw/
|
||||||
|
*.org
|
||||||
|
/libllama.so
|
||||||
|
|
2
Makefile
2
Makefile
|
@ -125,7 +125,7 @@ endif
|
||||||
ifdef LLAMA_CUBLAS
|
ifdef LLAMA_CUBLAS
|
||||||
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
||||||
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
||||||
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
|
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lcufile -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
|
||||||
OBJS += ggml-cuda.o
|
OBJS += ggml-cuda.o
|
||||||
NVCC = nvcc
|
NVCC = nvcc
|
||||||
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
|
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
|
||||||
|
|
49
ggml-cuda.cu
49
ggml-cuda.cu
|
@ -2,11 +2,13 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
#include <fcntl.h>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <cublas_v2.h>
|
#include <cublas_v2.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
#include <cufile.h>
|
||||||
|
|
||||||
#include "ggml-cuda.h"
|
#include "ggml-cuda.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
@ -32,6 +34,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
#define CUFILE_CHECK(status) \
|
||||||
|
do { \
|
||||||
|
CUfileError_t status_ = (status); \
|
||||||
|
if (status_.err != CU_FILE_SUCCESS) { \
|
||||||
|
fprintf(stderr, "cuFile error %d at %s:%d\n", status_.err, __FILE__, __LINE__); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
|
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
|
||||||
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||||
typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||||
|
@ -372,7 +383,7 @@ struct cuda_buffer {
|
||||||
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
|
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
|
||||||
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
||||||
|
|
||||||
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||||
|
|
||||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||||
|
@ -391,7 +402,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||||
|
|
||||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||||
|
@ -431,6 +442,9 @@ void ggml_init_cublas() {
|
||||||
|
|
||||||
// configure logging to stdout
|
// configure logging to stdout
|
||||||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
|
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
|
||||||
|
|
||||||
|
// initialize cuFile for loading model parameters directly to VRAM
|
||||||
|
CUFILE_CHECK(cuFileDriverOpen());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -893,3 +907,34 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
|
||||||
tensor->data = dst;
|
tensor->data = dst;
|
||||||
tensor->backend = GGML_BACKEND_CUDA;
|
tensor->backend = GGML_BACKEND_CUDA;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensors, const int num_tensors, const size_t * offsets) {
|
||||||
|
CUfileDescr_t cf_descr;
|
||||||
|
memset((void *)&cf_descr, 0, sizeof(CUfileDescr_t));
|
||||||
|
const int fd_cf = open(fname, O_RDONLY|O_DIRECT, 0644);
|
||||||
|
cf_descr.handle.fd = fd_cf;
|
||||||
|
cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
|
||||||
|
|
||||||
|
CUfileHandle_t cf_handle;
|
||||||
|
CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr);
|
||||||
|
|
||||||
|
if (status.err == CU_FILE_INTERNAL_ERROR) {
|
||||||
|
fprintf(stderr, "WARNING: cuFile experienced an internal error while loading weights from \"%s\". Using a workaround (slower). "
|
||||||
|
"This happens with weight files on Btrfs partitions. ext4 and NTFS are confirmed to work.\n", fname);
|
||||||
|
}
|
||||||
|
if (status.err != CU_FILE_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < num_tensors; ++i) {
|
||||||
|
ggml_tensor * tensor = tensors[i];
|
||||||
|
const size_t size = ggml_nbytes(tensor);
|
||||||
|
const size_t offset = offsets[i];
|
||||||
|
|
||||||
|
size_t actual_size;
|
||||||
|
void * buf = ggml_cuda_pool_malloc(size, &actual_size);
|
||||||
|
cuFileRead(cf_handle, buf, size, offset, 0);
|
||||||
|
tensor->data = buf;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
|
@ -14,8 +14,11 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
|
||||||
// TODO: export these with GGML_API
|
// TODO: export these with GGML_API
|
||||||
void * ggml_cuda_host_malloc(size_t size);
|
void * ggml_cuda_host_malloc(size_t size);
|
||||||
void ggml_cuda_host_free(void * ptr);
|
void ggml_cuda_host_free(void * ptr);
|
||||||
|
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
|
||||||
|
void ggml_cuda_pool_free(void * ptr, size_t size);
|
||||||
|
|
||||||
void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
|
void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
|
||||||
|
bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensors, int num_tensors, const size_t * offsets);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
166
llama.cpp
166
llama.cpp
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
#include <cuda_runtime.h>
|
||||||
#include "ggml-cuda.h"
|
#include "ggml-cuda.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -641,7 +642,7 @@ struct llama_model_loader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * get_tensor(const std::string & name, const std::vector<uint32_t> & ne) {
|
struct ggml_tensor * get_tensor(const std::string & name, const std::vector<uint32_t> & ne, ggml_backend backend) {
|
||||||
auto it = tensors_map.name_to_idx.find(name);
|
auto it = tensors_map.name_to_idx.find(name);
|
||||||
if (it == tensors_map.name_to_idx.end()) {
|
if (it == tensors_map.name_to_idx.end()) {
|
||||||
throw format("llama.cpp: tensor '%s' is missing from model", name.c_str());
|
throw format("llama.cpp: tensor '%s' is missing from model", name.c_str());
|
||||||
|
@ -652,10 +653,10 @@ struct llama_model_loader {
|
||||||
name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str());
|
name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
return get_tensor_for(lt);
|
return get_tensor_for(lt, backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * get_tensor_for(llama_load_tensor & lt) {
|
struct ggml_tensor * get_tensor_for(llama_load_tensor & lt, ggml_backend backend) {
|
||||||
struct ggml_tensor * tensor;
|
struct ggml_tensor * tensor;
|
||||||
if (lt.ne.size() == 2) {
|
if (lt.ne.size() == 2) {
|
||||||
tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1));
|
tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1));
|
||||||
|
@ -665,6 +666,7 @@ struct llama_model_loader {
|
||||||
}
|
}
|
||||||
ggml_set_name(tensor, lt.name.c_str());
|
ggml_set_name(tensor, lt.name.c_str());
|
||||||
LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor
|
LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor
|
||||||
|
tensor->backend = backend;
|
||||||
lt.ggml_tensor = tensor;
|
lt.ggml_tensor = tensor;
|
||||||
num_ggml_tensors_created++;
|
num_ggml_tensors_created++;
|
||||||
return tensor;
|
return tensor;
|
||||||
|
@ -683,7 +685,7 @@ struct llama_model_loader {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_mmap) {
|
if (use_mmap) {
|
||||||
mapping.reset(new llama_mmap(&file_loaders.at(0)->file));
|
mapping.reset(new llama_mmap(&file_loaders.at(0)->file, false));
|
||||||
if (!lmlock) {
|
if (!lmlock) {
|
||||||
// Don't call the callback since the actual loading will be lazy
|
// Don't call the callback since the actual loading will be lazy
|
||||||
// and we can't measure it.
|
// and we can't measure it.
|
||||||
|
@ -696,6 +698,9 @@ struct llama_model_loader {
|
||||||
|
|
||||||
size_t done_size = 0;
|
size_t done_size = 0;
|
||||||
for (llama_load_tensor & lt : tensors_map.tensors) {
|
for (llama_load_tensor & lt : tensors_map.tensors) {
|
||||||
|
if (lt.ggml_tensor->backend != GGML_BACKEND_CPU) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (progress_callback) {
|
if (progress_callback) {
|
||||||
progress_callback((float) done_size / data_size, progress_callback_user_data);
|
progress_callback((float) done_size / data_size, progress_callback_user_data);
|
||||||
}
|
}
|
||||||
|
@ -944,26 +949,6 @@ static void llama_model_load_internal(
|
||||||
ml->calc_sizes(&ctx_size, &mmapped_size);
|
ml->calc_sizes(&ctx_size, &mmapped_size);
|
||||||
fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0);
|
fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0);
|
||||||
|
|
||||||
// print memory requirements
|
|
||||||
{
|
|
||||||
const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
|
|
||||||
|
|
||||||
// this is the total memory required to run the inference
|
|
||||||
const size_t mem_required =
|
|
||||||
ctx_size +
|
|
||||||
mmapped_size +
|
|
||||||
MEM_REQ_SCRATCH0().at(model.type) +
|
|
||||||
MEM_REQ_SCRATCH1().at(model.type) +
|
|
||||||
MEM_REQ_EVAL().at(model.type);
|
|
||||||
|
|
||||||
// this is the memory required by one llama_state
|
|
||||||
const size_t mem_required_state =
|
|
||||||
scale*MEM_REQ_KV_SELF().at(model.type);
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
|
|
||||||
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// create the ggml context
|
// create the ggml context
|
||||||
{
|
{
|
||||||
lctx.model.buf.resize(ctx_size);
|
lctx.model.buf.resize(ctx_size);
|
||||||
|
@ -985,6 +970,7 @@ static void llama_model_load_internal(
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare memory for the weights
|
// prepare memory for the weights
|
||||||
|
size_t vram_total = 0;
|
||||||
{
|
{
|
||||||
const uint32_t n_embd = hparams.n_embd;
|
const uint32_t n_embd = hparams.n_embd;
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
|
@ -992,33 +978,78 @@ static void llama_model_load_internal(
|
||||||
|
|
||||||
ml->ggml_ctx = ctx;
|
ml->ggml_ctx = ctx;
|
||||||
|
|
||||||
model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab});
|
model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||||
model.norm = ml->get_tensor("norm.weight", {n_embd});
|
model.norm = ml->get_tensor("norm.weight", {n_embd}, GGML_BACKEND_CPU);
|
||||||
model.output = ml->get_tensor("output.weight", {n_embd, n_vocab});
|
ggml_backend backend_output;
|
||||||
|
if (n_gpu_layers > int(n_layer)) {
|
||||||
|
backend_output = GGML_BACKEND_CUDA;
|
||||||
|
} else {
|
||||||
|
backend_output = GGML_BACKEND_CPU;
|
||||||
|
}
|
||||||
|
model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output);
|
||||||
|
|
||||||
model.layers.resize(n_layer);
|
model.layers.resize(n_layer);
|
||||||
|
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
auto & layer = model.layers[i];
|
auto & layer = model.layers[i];
|
||||||
|
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : GGML_BACKEND_CUDA;
|
||||||
|
|
||||||
std::string layers_i = "layers." + std::to_string(i);
|
std::string layers_i = "layers." + std::to_string(i);
|
||||||
|
|
||||||
layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd});
|
layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend);
|
||||||
|
|
||||||
layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd});
|
layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend);
|
||||||
layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd});
|
layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend);
|
||||||
layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd});
|
layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend);
|
||||||
layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd});
|
layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend);
|
||||||
|
|
||||||
layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd});
|
layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
|
||||||
|
|
||||||
layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff});
|
layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend);
|
||||||
layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd});
|
layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend);
|
||||||
layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff});
|
layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend);
|
||||||
|
if (backend == GGML_BACKEND_CUDA) {
|
||||||
|
vram_total += ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk)
|
||||||
|
+ ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm)
|
||||||
|
+ ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ml->done_getting_tensors();
|
ml->done_getting_tensors();
|
||||||
|
|
||||||
|
// print memory requirements
|
||||||
|
{
|
||||||
|
const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
|
||||||
|
|
||||||
|
// this is the total memory required to run the inference
|
||||||
|
const size_t mem_required =
|
||||||
|
ctx_size +
|
||||||
|
mmapped_size - vram_total + // weights in VRAM not in memory
|
||||||
|
MEM_REQ_SCRATCH0().at(model.type) +
|
||||||
|
MEM_REQ_SCRATCH1().at(model.type) +
|
||||||
|
MEM_REQ_EVAL().at(model.type);
|
||||||
|
|
||||||
|
// this is the memory required by one llama_state
|
||||||
|
const size_t mem_required_state =
|
||||||
|
scale*MEM_REQ_KV_SELF().at(model.type);
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
|
||||||
|
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu);
|
||||||
|
if (n_gpu_layers > (int) hparams.n_layer) {
|
||||||
|
fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
|
||||||
|
#else
|
||||||
|
(void) n_gpu_layers;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// populate `tensors_by_name`
|
// populate `tensors_by_name`
|
||||||
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
|
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
|
||||||
model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor);
|
model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor);
|
||||||
|
@ -1026,38 +1057,44 @@ static void llama_model_load_internal(
|
||||||
|
|
||||||
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
|
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
|
||||||
|
|
||||||
model.mapping = std::move(ml->mapping);
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
{
|
{
|
||||||
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
std::vector<struct ggml_tensor *> tensors;
|
||||||
|
std::vector<size_t> offsets;
|
||||||
fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu);
|
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
|
||||||
|
if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) {
|
||||||
size_t vram_total = 0;
|
continue;
|
||||||
|
}
|
||||||
for (int i = 0; i < n_gpu; ++i) {
|
tensors.emplace_back(lt.ggml_tensor);
|
||||||
const auto & layer = model.layers[i];
|
LLAMA_ASSERT(lt.shards.size() == 1);
|
||||||
|
offsets.emplace_back(lt.shards.at(0).file_off);
|
||||||
ggml_cuda_transform_tensor(layer.attention_norm); vram_total += ggml_nbytes(layer.attention_norm);
|
|
||||||
ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq);
|
|
||||||
ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk);
|
|
||||||
ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv);
|
|
||||||
ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo);
|
|
||||||
ggml_cuda_transform_tensor(layer.ffn_norm); vram_total += ggml_nbytes(layer.ffn_norm);
|
|
||||||
ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1);
|
|
||||||
ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2);
|
|
||||||
ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3);
|
|
||||||
}
|
|
||||||
if (n_gpu_layers > (int) hparams.n_layer) {
|
|
||||||
fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__);
|
|
||||||
ggml_cuda_transform_tensor(model.output); vram_total += ggml_nbytes(model.output);
|
|
||||||
}
|
}
|
||||||
|
bool cufile_success = ggml_cuda_load_data_cufile(fname.c_str(), tensors.data(), tensors.size(), offsets.data());
|
||||||
|
|
||||||
fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
|
if (!cufile_success) {
|
||||||
|
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
|
||||||
|
if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
size_t actual_size;
|
||||||
|
void * buf = ggml_cuda_pool_malloc(lt.size, &actual_size);
|
||||||
|
void * buf_host = ggml_cuda_host_malloc(lt.size);
|
||||||
|
|
||||||
|
llama_file & file = ml->file_loaders.at(lt.shards.at(0).file_idx)->file;
|
||||||
|
file.seek(lt.shards.at(0).file_off, SEEK_SET);
|
||||||
|
file.read_raw(buf_host, lt.size);
|
||||||
|
|
||||||
|
cudaMemcpy(buf, buf_host, lt.size, cudaMemcpyHostToDevice);
|
||||||
|
cudaDeviceSynchronize();
|
||||||
|
|
||||||
|
lt.ggml_tensor->data = buf;
|
||||||
|
ggml_cuda_host_free(buf_host);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#else
|
#endif // GGML_USE_CUBLAS
|
||||||
(void) n_gpu_layers;
|
|
||||||
#endif
|
model.mapping = std::move(ml->mapping);
|
||||||
|
|
||||||
// loading time will be recalculate after the first eval, so
|
// loading time will be recalculate after the first eval, so
|
||||||
// we take page faults deferred by mmap() into consideration
|
// we take page faults deferred by mmap() into consideration
|
||||||
|
@ -2395,7 +2432,8 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
|
||||||
}
|
}
|
||||||
size_t idx = model_loader->tensors_map.name_to_idx[base_name];
|
size_t idx = model_loader->tensors_map.name_to_idx[base_name];
|
||||||
llama_load_tensor & lt = model_loader->tensors_map.tensors[idx];
|
llama_load_tensor & lt = model_loader->tensors_map.tensors[idx];
|
||||||
base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] });
|
base_t = model_loader->get_tensor(
|
||||||
|
base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU);
|
||||||
lt.data = (uint8_t *) lt.ggml_tensor->data;
|
lt.data = (uint8_t *) lt.ggml_tensor->data;
|
||||||
model_loader->load_data_for(lt);
|
model_loader->load_data_for(lt);
|
||||||
lt.ggml_tensor->data = lt.data;
|
lt.ggml_tensor->data = lt.data;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue