diff --git a/CMakeLists.txt b/CMakeLists.txt index ffda74a70..34a897327 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -386,11 +386,6 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES if (MSVC) # TODO: arm msvc? else() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") - # Apple M1, M2, etc. - # Raspberry Pi 3, 4, Zero 2 (64-bit) - add_compile_options(-mcpu=native) - endif() if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") # Raspberry Pi 1, Zero add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access) diff --git a/README.md b/README.md index ee56988c7..e890dc9c2 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,7 @@ as the main playground for developing new features for the [ggml](https://github - [X] [OpenBuddy 🐶 (Multilingual)](https://github.com/OpenBuddy/OpenBuddy) - [X] [Pygmalion 7B / Metharme 7B](#using-pygmalion-7b--metharme-7b) - [X] [WizardLM](https://github.com/nlpxucan/WizardLM) +- [X] [Baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) **Bindings:** diff --git a/convert.py b/convert.py index e340d2273..142692776 100644 --- a/convert.py +++ b/convert.py @@ -136,7 +136,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int: calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult if calc_ff == n_ff: return n_mult - return 1 + raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).") @dataclass class Params: @@ -321,6 +321,10 @@ class Tensor(metaclass=ABCMeta): @abstractmethod def permute(self, n_head: int) -> 'Tensor': ... @abstractmethod + def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ... + @abstractmethod + def part(self, n_part: int) -> 'UnquantizedTensor': ... + @abstractmethod def to_ggml(self) -> 'GGMLCompatibleTensor': ... @@ -345,6 +349,14 @@ class UnquantizedTensor(Tensor): def to_ggml(self) -> 'UnquantizedTensor': return self + def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head)) + + def part(self, n_part: int) -> 'UnquantizedTensor': + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...]) + def permute(self, n_head: int) -> 'UnquantizedTensor': return UnquantizedTensor(permute(self.ndarray, n_head)) @@ -642,6 +654,19 @@ def permute_lazy(lazy_tensor: LazyTensor, n_head: int) -> LazyTensor: return lazy_tensor.load().permute(n_head) return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description) +def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().permute_part(n_part, n_head) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description) + +def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().part(n_part) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel: out: LazyModel = {} @@ -650,11 +675,17 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel: out["output.weight"] = model["lm_head.weight"] for i in itertools.count(): - if f"model.layers.{i}.self_attn.q_proj.weight" not in model: + if f"model.layers.{i}.self_attn.q_proj.weight" in model: + out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head) + out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head) + out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] + elif f"model.layers.{i}.self_attn.W_pack.weight" in model: + out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head) + out[f"layers.{i}.attention.wk.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head) + out[f"layers.{i}.attention.wv.weight"] = part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) + else: break - out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head) - out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head) - out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] + out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"] out[f"layers.{i}.feed_forward.w1.weight"] = model[f"model.layers.{i}.mlp.gate_proj.weight"] diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 37de52ad6..570e273fc 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -210,9 +210,12 @@ llama_token sampling_id(struct MyModel* mymodel) { const char * sampling(struct MyModel * mymodel) { llama_context * ctx = mymodel->ctx; int id = sampling_id(mymodel); - std::string ret; - if (id == llama_token_eos()) ret = ""; - else ret = llama_token_to_str(ctx, id); + static std::string ret; + if (id == llama_token_eos()) { + ret = ""; + } else { + ret = llama_token_to_str(ctx, id); + } eval_id(mymodel, id); return ret.c_str(); } diff --git a/examples/embd-input/embd-input.h b/examples/embd-input/embd-input.h index 4fefabd42..efb5ba5e2 100644 --- a/examples/embd-input/embd-input.h +++ b/examples/embd-input/embd-input.h @@ -5,7 +5,6 @@ #include "llama.h" #include "build-info.h" - extern "C" { typedef struct MyModel { @@ -14,14 +13,13 @@ typedef struct MyModel { int n_past = 0; } MyModel; - struct MyModel* create_mymodel(int argc, char ** argv); bool eval_float(void* model, float* input, int N); bool eval_tokens(void* model, std::vector tokens); bool eval_id(struct MyModel* mymodel, int id); bool eval_string(struct MyModel* mymodel, const char* str); -const char* sampling(struct MyModel* mymodel); +const char * sampling(struct MyModel* mymodel); llama_token sampling_id(struct MyModel* mymodel); void free_mymodel(struct MyModel* mymodel); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 05bfa8016..c50eeb343 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -2671,7 +2671,8 @@ struct train_params { const char * fn_checkpoint_out; const char * fn_model_out; - int seed; + uint32_t seed; + int n_ctx; int n_embd; int n_mult; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4e0d3dbde..50df20edd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -214,6 +214,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); #endif +struct ggml_tensor_extra_gpu { + void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors + cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs +}; + static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -1970,7 +1975,6 @@ inline void ggml_cuda_op_add( } else { GGML_ASSERT(false); } - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2002,7 +2006,6 @@ inline void ggml_cuda_op_mul( // compute mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); } (void) dst; @@ -2023,7 +2026,6 @@ inline void ggml_cuda_op_silu( // compute silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2046,7 +2048,6 @@ inline void ggml_cuda_op_rms_norm( // compute rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2125,7 +2126,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec( GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); #ifdef GGML_CUDA_DMMV_F16 if (src1_convert_f16) { @@ -2202,7 +2202,6 @@ inline void ggml_cuda_op_rope( // compute rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) dst; (void) src0_ddq_i; @@ -2226,7 +2225,6 @@ inline void ggml_cuda_op_diag_mask_inf( // compute diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) dst; (void) src0_ddq_i; @@ -2248,7 +2246,6 @@ inline void ggml_cuda_op_soft_max( // compute soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2344,10 +2341,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0}; - // if multiple GPUs are used they need to wait for the main GPU to finish + // if multiple devices are used they need to wait for the main device + // here an event is recorded that signifies that the main device has finished calculating the input data if (split && g_device_count > 1) { CUDA_CHECK(cudaSetDevice(g_main_device)); - CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device])); } for (int id = 0; id < g_device_count; ++id) { @@ -2373,6 +2371,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm int64_t row_diff = row_high - row_low; cudaSetDevice(id); + cudaStream_t cudaStream_main = g_cudaStreams_main[id]; + + // wait for main GPU data if necessary + if (split && id != g_main_device) { + CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device])); + } if (src0_on_device && src0_is_contiguous) { if (src0_is_f32) { @@ -2448,8 +2452,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } const int64_t i11 = i13*ne12 + i12; - cudaStream_t cudaStream_main = g_cudaStreams_main[id]; - // for split tensors the data begins at i0 == i0_offset_low char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs; float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride; @@ -2509,6 +2511,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm // do the computation op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main); + CUDA_CHECK(cudaGetLastError()); // copy dst to host or other device if necessary if (!dst_on_device) { @@ -2538,6 +2541,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main)); } } + + // signify to main device that other device is done + if (split && g_device_count > 1 && id != g_main_device) { + CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main)); + } } } } @@ -2549,7 +2557,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } CUDA_CHECK(cudaSetDevice(id)); - CUDA_CHECK(cudaDeviceSynchronize()); if (src0_asq[id] > 0) { ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]); @@ -2564,6 +2571,21 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]); } } + + // main device waits for all other devices to be finished + if (split && g_device_count > 1) { + CUDA_CHECK(cudaSetDevice(g_main_device)); + for (int id = 0; id < g_device_count; ++id) { + if (id != g_main_device) { + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id])); + } + } + } + + if (dst->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(cudaDeviceSynchronize()); + } } void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -2803,6 +2825,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); extra->data_device[id] = buf; + + if (backend == GGML_BACKEND_GPU_SPLIT) { + CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming)); + } } tensor->extra = extra; @@ -2816,12 +2842,15 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; for (int id = 0; id < g_device_count; ++id) { - if (extra->data_device[id] == nullptr) { - continue; + if (extra->data_device[id] != nullptr) { + CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(cudaFree(extra->data_device[id])); } - CUDA_CHECK(cudaSetDevice(id)); - CUDA_CHECK(cudaFree(extra->data_device[id])); + if (extra->events[id] != nullptr) { + CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(cudaEventDestroy(extra->events[id])); + } } delete extra; diff --git a/ggml-cuda.h b/ggml-cuda.h index 7a65a3558..3c1e8deb6 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -8,10 +8,6 @@ extern "C" { #define GGML_CUDA_MAX_DEVICES 16 -struct ggml_tensor_extra_gpu { - void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors -}; - void ggml_init_cublas(void); void ggml_cuda_set_tensor_split(const float * tensor_split); diff --git a/ggml-metal.m b/ggml-metal.m index 7551231b9..fd69c41fe 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -202,7 +202,9 @@ struct ggml_metal_context * ggml_metal_init(void) { void ggml_metal_free(struct ggml_metal_context * ctx) { fprintf(stderr, "%s: deallocating\n", __func__); - + for (int i = 0; i < ctx->n_buffers; ++i) { + [ctx->buffers[i].metal release]; + } free(ctx); } diff --git a/ggml.c b/ggml.c index 684caaa37..75cc44baa 100644 --- a/ggml.c +++ b/ggml.c @@ -3846,6 +3846,40 @@ static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); +// WARN: +// Mis-confguration can lead to problem that's hard to reason about: +// * At best it crash or talks nosense. +// * At worst it talks slightly difference but hard to perceive. +// +// An op has to enable INIT or FINALIZE when any of it's branch needs that pass. +// Take care about compile options (e.g., GGML_USE_xxx). +static bool GGML_OP_HAS_INIT [GGML_OP_COUNT] = { 0 }; +static bool GGML_OP_HAS_FINALIZE[GGML_OP_COUNT] = { 0 }; +static void ggml_setup_op_has_task_pass(void) { + { // INIT + bool * I = GGML_OP_HAS_INIT; + + I[GGML_OP_ACC ] = true; + I[GGML_OP_MUL_MAT ] = true; + I[GGML_OP_OUT_PROD ] = true; + I[GGML_OP_SET ] = true; + I[GGML_OP_GET_ROWS_BACK ] = true; + I[GGML_OP_DIAG_MASK_INF ] = true; + I[GGML_OP_DIAG_MASK_ZERO ] = true; + I[GGML_OP_CONV_1D_S1_PH ] = true; + I[GGML_OP_CONV_1D_S2_PH ] = true; + I[GGML_OP_CONV_2D_SK_P0 ] = true; + I[GGML_OP_FLASH_ATTN_BACK ] = true; + I[GGML_OP_CROSS_ENTROPY_LOSS ] = true; + } + + { // FINALIZE + bool * F = GGML_OP_HAS_FINALIZE; + + F[GGML_OP_CROSS_ENTROPY_LOSS ] = true; + } +} + // // ggml context // @@ -4267,6 +4301,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { ggml_cl_init(); #endif + ggml_setup_op_has_task_pass(); + is_first_call = false; } @@ -16791,9 +16827,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (node_n != -1) { /* FINALIZE */ struct ggml_tensor * node = state->shared->cgraph->nodes[node_n]; - params.nth = node->n_tasks; - ggml_compute_forward(¶ms, node); - ggml_graph_compute_perf_stats_node(node, state->shared); + if (GGML_OP_HAS_FINALIZE[node->op]) { + params.nth = node->n_tasks; + ggml_compute_forward(¶ms, node); + ggml_graph_compute_perf_stats_node(node, state->shared); + } } // distribute new work or execute it direct if 1T @@ -16805,10 +16843,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { state->shared->perf_node_start_cycles = ggml_perf_cycles(); state->shared->perf_node_start_time_us = ggml_perf_time_us(); + params.nth = node->n_tasks; + /* INIT */ - params.type = GGML_TASK_INIT; - params.nth = node->n_tasks; - ggml_compute_forward(¶ms, node); + if (GGML_OP_HAS_INIT[node->op]) { + params.type = GGML_TASK_INIT; + ggml_compute_forward(¶ms, node); + } if (node->n_tasks == 1) { // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, @@ -16816,9 +16857,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { params.type = GGML_TASK_COMPUTE; ggml_compute_forward(¶ms, node); - params.type = GGML_TASK_FINALIZE; - ggml_compute_forward(¶ms, node); - ggml_graph_compute_perf_stats_node(node, state->shared); + if (GGML_OP_HAS_FINALIZE[node->op]) { + params.type = GGML_TASK_FINALIZE; + ggml_compute_forward(¶ms, node); + ggml_graph_compute_perf_stats_node(node, state->shared); + } } else { break; } diff --git a/ggml.h b/ggml.h index 459913222..11b51f8bd 100644 --- a/ggml.h +++ b/ggml.h @@ -444,6 +444,9 @@ extern "C" { // compute types + + // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled. + // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995. enum ggml_task_type { GGML_TASK_INIT = 0, GGML_TASK_COMPUTE, diff --git a/llama.cpp b/llama.cpp index 049f73e44..a869bbac8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -66,6 +66,7 @@ enum e_model { MODEL_65B, }; +static const size_t kB = 1024; static const size_t MB = 1024*1024; // computed for n_ctx == 2048 @@ -129,6 +130,34 @@ static const std::map & MEM_REQ_EVAL() return k_sizes; } +// amount of VRAM needed per batch size to hold temporary results +// the values for 3b and 65b are not derived from testing but instead chosen conservatively +static const std::map & VRAM_REQ_SCRATCH_BASE() +{ + static std::map k_sizes = { + { MODEL_3B, 512ull * kB }, + { MODEL_7B, 512ull * kB }, + { MODEL_13B, 640ull * kB }, + { MODEL_30B, 768ull * kB }, + { MODEL_65B, 1536ull * kB }, + }; + return k_sizes; +} + +// amount of VRAM needed per batch size and context to hold temporary results +// the values for 3b and 65b are not derived from testing but instead chosen conservatively +static const std::map & VRAM_REQ_SCRATCH_PER_CONTEXT() +{ + static std::map k_sizes = { + { MODEL_3B, 128ull }, + { MODEL_7B, 128ull }, + { MODEL_13B, 160ull }, + { MODEL_30B, 208ull }, + { MODEL_65B, 416ull }, + }; + return k_sizes; +} + // default hparams (LLaMA 7B) struct llama_hparams { uint32_t n_vocab = 32000; @@ -253,7 +282,13 @@ struct llama_model { struct llama_context { llama_context(const llama_model & model, const llama_vocab & vocab) : model(model), vocab(vocab), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} - +#ifdef GGML_USE_METAL + ~llama_context() { + if (ctx_metal) { + ggml_metal_free(ctx_metal); + } + } +#endif std::mt19937 rng; bool has_evaluated_once = false; @@ -1112,11 +1147,14 @@ static void llama_model_load_internal( fprintf(stderr, "%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); ggml_cuda_set_scratch_size(0); // disable scratch } else { - vram_scratch = n_batch * MB; + const size_t vram_scratch_base = VRAM_REQ_SCRATCH_BASE().at(model.type); + const size_t vram_scratch_per_context = VRAM_REQ_SCRATCH_PER_CONTEXT().at(model.type); + vram_scratch = n_batch * (vram_scratch_base + n_ctx * vram_scratch_per_context); ggml_cuda_set_scratch_size(vram_scratch); if (n_gpu_layers > 0) { - fprintf(stderr, "%s: allocating batch_size x 1 MB = %zd MB VRAM for the scratch buffer\n", - __func__, vram_scratch / MB); + fprintf(stderr, "%s: allocating batch_size x (%zd kB + n_ctx x %zd B) = %zd MB VRAM for the scratch buffer\n", + __func__, vram_scratch_base / kB, vram_scratch_per_context, + (vram_scratch + MB - 1) / MB); // round up } } #endif // GGML_USE_CUBLAS @@ -3219,7 +3257,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { return nread; } -bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(path_session, "rb"); // sanity checks @@ -3273,6 +3311,15 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi return true; } +bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + fprintf(stderr, "error loading session file: %s\n", err.what()); + return false; + } +} + bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { llama_file file(path_session, "wb");