From b1ca8f36a9cdbcee5f5c425df717611a1040a897 Mon Sep 17 00:00:00 2001 From: Qingyou Meng Date: Sat, 1 Jul 2023 23:42:43 +0800 Subject: [PATCH 01/10] ggml : disable GGML_TASK_INIT and GGML_TASK_FINALIZE by default (#1995) Will not be scheduled unless explicitly enabled. --- ggml.c | 61 +++++++++++++++++++++++++++++++++++++++++++++++++--------- ggml.h | 3 +++ 2 files changed, 55 insertions(+), 9 deletions(-) diff --git a/ggml.c b/ggml.c index 684caaa37..75cc44baa 100644 --- a/ggml.c +++ b/ggml.c @@ -3846,6 +3846,40 @@ static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); +// WARN: +// Mis-confguration can lead to problem that's hard to reason about: +// * At best it crash or talks nosense. +// * At worst it talks slightly difference but hard to perceive. +// +// An op has to enable INIT or FINALIZE when any of it's branch needs that pass. +// Take care about compile options (e.g., GGML_USE_xxx). +static bool GGML_OP_HAS_INIT [GGML_OP_COUNT] = { 0 }; +static bool GGML_OP_HAS_FINALIZE[GGML_OP_COUNT] = { 0 }; +static void ggml_setup_op_has_task_pass(void) { + { // INIT + bool * I = GGML_OP_HAS_INIT; + + I[GGML_OP_ACC ] = true; + I[GGML_OP_MUL_MAT ] = true; + I[GGML_OP_OUT_PROD ] = true; + I[GGML_OP_SET ] = true; + I[GGML_OP_GET_ROWS_BACK ] = true; + I[GGML_OP_DIAG_MASK_INF ] = true; + I[GGML_OP_DIAG_MASK_ZERO ] = true; + I[GGML_OP_CONV_1D_S1_PH ] = true; + I[GGML_OP_CONV_1D_S2_PH ] = true; + I[GGML_OP_CONV_2D_SK_P0 ] = true; + I[GGML_OP_FLASH_ATTN_BACK ] = true; + I[GGML_OP_CROSS_ENTROPY_LOSS ] = true; + } + + { // FINALIZE + bool * F = GGML_OP_HAS_FINALIZE; + + F[GGML_OP_CROSS_ENTROPY_LOSS ] = true; + } +} + // // ggml context // @@ -4267,6 +4301,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { ggml_cl_init(); #endif + ggml_setup_op_has_task_pass(); + is_first_call = false; } @@ -16791,9 +16827,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (node_n != -1) { /* FINALIZE */ struct ggml_tensor * node = state->shared->cgraph->nodes[node_n]; - params.nth = node->n_tasks; - ggml_compute_forward(¶ms, node); - ggml_graph_compute_perf_stats_node(node, state->shared); + if (GGML_OP_HAS_FINALIZE[node->op]) { + params.nth = node->n_tasks; + ggml_compute_forward(¶ms, node); + ggml_graph_compute_perf_stats_node(node, state->shared); + } } // distribute new work or execute it direct if 1T @@ -16805,10 +16843,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { state->shared->perf_node_start_cycles = ggml_perf_cycles(); state->shared->perf_node_start_time_us = ggml_perf_time_us(); + params.nth = node->n_tasks; + /* INIT */ - params.type = GGML_TASK_INIT; - params.nth = node->n_tasks; - ggml_compute_forward(¶ms, node); + if (GGML_OP_HAS_INIT[node->op]) { + params.type = GGML_TASK_INIT; + ggml_compute_forward(¶ms, node); + } if (node->n_tasks == 1) { // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, @@ -16816,9 +16857,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { params.type = GGML_TASK_COMPUTE; ggml_compute_forward(¶ms, node); - params.type = GGML_TASK_FINALIZE; - ggml_compute_forward(¶ms, node); - ggml_graph_compute_perf_stats_node(node, state->shared); + if (GGML_OP_HAS_FINALIZE[node->op]) { + params.type = GGML_TASK_FINALIZE; + ggml_compute_forward(¶ms, node); + ggml_graph_compute_perf_stats_node(node, state->shared); + } } else { break; } diff --git a/ggml.h b/ggml.h index 459913222..11b51f8bd 100644 --- a/ggml.h +++ b/ggml.h @@ -444,6 +444,9 @@ extern "C" { // compute types + + // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled. + // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995. enum ggml_task_type { GGML_TASK_INIT = 0, GGML_TASK_COMPUTE, From 04606a159947566b27810508433e6ca5dbc684ba Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 1 Jul 2023 18:45:44 +0300 Subject: [PATCH 02/10] train : fix compile warning --- examples/train-text-from-scratch/train-text-from-scratch.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 05bfa8016..c50eeb343 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -2671,7 +2671,8 @@ struct train_params { const char * fn_checkpoint_out; const char * fn_model_out; - int seed; + uint32_t seed; + int n_ctx; int n_embd; int n_mult; From 79f634a19d1c32a6cfb1befc21551ee684fced6b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 1 Jul 2023 18:46:00 +0300 Subject: [PATCH 03/10] embd-input : fix returning ptr to temporary --- examples/embd-input/embd-input-lib.cpp | 9 ++++++--- examples/embd-input/embd-input.h | 4 +--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 37de52ad6..570e273fc 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -210,9 +210,12 @@ llama_token sampling_id(struct MyModel* mymodel) { const char * sampling(struct MyModel * mymodel) { llama_context * ctx = mymodel->ctx; int id = sampling_id(mymodel); - std::string ret; - if (id == llama_token_eos()) ret = ""; - else ret = llama_token_to_str(ctx, id); + static std::string ret; + if (id == llama_token_eos()) { + ret = ""; + } else { + ret = llama_token_to_str(ctx, id); + } eval_id(mymodel, id); return ret.c_str(); } diff --git a/examples/embd-input/embd-input.h b/examples/embd-input/embd-input.h index 4fefabd42..efb5ba5e2 100644 --- a/examples/embd-input/embd-input.h +++ b/examples/embd-input/embd-input.h @@ -5,7 +5,6 @@ #include "llama.h" #include "build-info.h" - extern "C" { typedef struct MyModel { @@ -14,14 +13,13 @@ typedef struct MyModel { int n_past = 0; } MyModel; - struct MyModel* create_mymodel(int argc, char ** argv); bool eval_float(void* model, float* input, int N); bool eval_tokens(void* model, std::vector tokens); bool eval_id(struct MyModel* mymodel, int id); bool eval_string(struct MyModel* mymodel, const char* str); -const char* sampling(struct MyModel* mymodel); +const char * sampling(struct MyModel* mymodel); llama_token sampling_id(struct MyModel* mymodel); void free_mymodel(struct MyModel* mymodel); From cb44dbc7de287b3d17772cfb1aa49d55e082ce5b Mon Sep 17 00:00:00 2001 From: Rand Xie Date: Sun, 2 Jul 2023 00:02:58 +0800 Subject: [PATCH 04/10] llama : catch llama_load_session_file_internal exceptions (#2022) * convert checks in llama_load_session_file to throw and handle them * make llama_load_session_file_internal static * address feedbacks to avoid using exceptions --- llama.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 049f73e44..3a7a0d5da 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3219,7 +3219,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { return nread; } -bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(path_session, "rb"); // sanity checks @@ -3269,8 +3269,15 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi llama_set_state_data(ctx, state_data.data()); } +} - return true; +bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + fprintf(stderr, "error loading session file: %s\n", err.what()); + return false; + } } bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { From 463f2f4c4f8dd5ca9594b7d65849f346f0effe05 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 1 Jul 2023 19:05:09 +0300 Subject: [PATCH 05/10] llama : fix return value of llama_load_session_file_internal (#2022) --- llama.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llama.cpp b/llama.cpp index 3a7a0d5da..69c2ab01b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3269,6 +3269,8 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c llama_set_state_data(ctx, state_data.data()); } + + return true; } bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { From 471aab6e4cb89d8ef6d043f1bc93acb6eb78ab67 Mon Sep 17 00:00:00 2001 From: Judd Date: Sun, 2 Jul 2023 01:00:25 +0800 Subject: [PATCH 06/10] convert : add support of baichuan-7b (#2055) Co-authored-by: Judd --- README.md | 1 + convert.py | 41 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ee56988c7..e890dc9c2 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,7 @@ as the main playground for developing new features for the [ggml](https://github - [X] [OpenBuddy 🐶 (Multilingual)](https://github.com/OpenBuddy/OpenBuddy) - [X] [Pygmalion 7B / Metharme 7B](#using-pygmalion-7b--metharme-7b) - [X] [WizardLM](https://github.com/nlpxucan/WizardLM) +- [X] [Baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) **Bindings:** diff --git a/convert.py b/convert.py index e340d2273..142692776 100644 --- a/convert.py +++ b/convert.py @@ -136,7 +136,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int: calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult if calc_ff == n_ff: return n_mult - return 1 + raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).") @dataclass class Params: @@ -321,6 +321,10 @@ class Tensor(metaclass=ABCMeta): @abstractmethod def permute(self, n_head: int) -> 'Tensor': ... @abstractmethod + def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ... + @abstractmethod + def part(self, n_part: int) -> 'UnquantizedTensor': ... + @abstractmethod def to_ggml(self) -> 'GGMLCompatibleTensor': ... @@ -345,6 +349,14 @@ class UnquantizedTensor(Tensor): def to_ggml(self) -> 'UnquantizedTensor': return self + def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head)) + + def part(self, n_part: int) -> 'UnquantizedTensor': + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...]) + def permute(self, n_head: int) -> 'UnquantizedTensor': return UnquantizedTensor(permute(self.ndarray, n_head)) @@ -642,6 +654,19 @@ def permute_lazy(lazy_tensor: LazyTensor, n_head: int) -> LazyTensor: return lazy_tensor.load().permute(n_head) return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description) +def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().permute_part(n_part, n_head) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description) + +def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().part(n_part) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel: out: LazyModel = {} @@ -650,11 +675,17 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel: out["output.weight"] = model["lm_head.weight"] for i in itertools.count(): - if f"model.layers.{i}.self_attn.q_proj.weight" not in model: + if f"model.layers.{i}.self_attn.q_proj.weight" in model: + out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head) + out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head) + out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] + elif f"model.layers.{i}.self_attn.W_pack.weight" in model: + out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head) + out[f"layers.{i}.attention.wk.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head) + out[f"layers.{i}.attention.wv.weight"] = part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) + else: break - out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head) - out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head) - out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] + out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"] out[f"layers.{i}.feed_forward.w1.weight"] = model[f"model.layers.{i}.mlp.gate_proj.weight"] From 2f8cd979ecd1fa582852e7136e92ff8990b98fd8 Mon Sep 17 00:00:00 2001 From: Aaron Miller Date: Sat, 1 Jul 2023 11:14:59 -0700 Subject: [PATCH 07/10] metal : release buffers when freeing metal context (#2062) --- ggml-metal.m | 4 +++- llama.cpp | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 7551231b9..fd69c41fe 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -202,7 +202,9 @@ struct ggml_metal_context * ggml_metal_init(void) { void ggml_metal_free(struct ggml_metal_context * ctx) { fprintf(stderr, "%s: deallocating\n", __func__); - + for (int i = 0; i < ctx->n_buffers; ++i) { + [ctx->buffers[i].metal release]; + } free(ctx); } diff --git a/llama.cpp b/llama.cpp index 69c2ab01b..561accf88 100644 --- a/llama.cpp +++ b/llama.cpp @@ -253,7 +253,13 @@ struct llama_model { struct llama_context { llama_context(const llama_model & model, const llama_vocab & vocab) : model(model), vocab(vocab), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} - +#ifdef GGML_USE_METAL + ~llama_context() { + if (ctx_metal) { + ggml_metal_free(ctx_metal); + } + } +#endif std::mt19937 rng; bool has_evaluated_once = false; From b2132270678c473f7cd9ba871b03d694126bc33a Mon Sep 17 00:00:00 2001 From: Daniel Drake Date: Sat, 1 Jul 2023 20:31:44 +0200 Subject: [PATCH 08/10] cmake : don't force -mcpu=native on aarch64 (#2063) It's currently not possible to cross-compile llama.cpp for aarch64 because CMakeLists.txt forces -mcpu=native for that target. -mcpu=native doesn't make sense if your build host is not the target architecture, and clang rejects it for that reason, aborting the build. This can be easily reproduced using the current Android NDK to build for aarch64 on an x86_64 host. If there is not a specific CPU-tuning target for aarch64 then -mcpu should be omitted completely. I think that makes sense, there is not enough variance in the aarch64 instruction set to warrant a fixed -mcpu optimization at this point. And if someone is building natively and wishes to enable any possible optimizations for the host device, then there is already the LLAMA_NATIVE option available. Fixes #495. --- CMakeLists.txt | 5 ----- 1 file changed, 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ffda74a70..34a897327 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -386,11 +386,6 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES if (MSVC) # TODO: arm msvc? else() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") - # Apple M1, M2, etc. - # Raspberry Pi 3, 4, Zero 2 (64-bit) - add_compile_options(-mcpu=native) - endif() if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") # Raspberry Pi 1, Zero add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access) From befb3a35627432473f143c90994557d78ff5bc67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 1 Jul 2023 21:47:26 +0200 Subject: [PATCH 09/10] Test-based VRAM scratch size + context adjustment (#2056) --- llama.cpp | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 561accf88..a869bbac8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -66,6 +66,7 @@ enum e_model { MODEL_65B, }; +static const size_t kB = 1024; static const size_t MB = 1024*1024; // computed for n_ctx == 2048 @@ -129,6 +130,34 @@ static const std::map & MEM_REQ_EVAL() return k_sizes; } +// amount of VRAM needed per batch size to hold temporary results +// the values for 3b and 65b are not derived from testing but instead chosen conservatively +static const std::map & VRAM_REQ_SCRATCH_BASE() +{ + static std::map k_sizes = { + { MODEL_3B, 512ull * kB }, + { MODEL_7B, 512ull * kB }, + { MODEL_13B, 640ull * kB }, + { MODEL_30B, 768ull * kB }, + { MODEL_65B, 1536ull * kB }, + }; + return k_sizes; +} + +// amount of VRAM needed per batch size and context to hold temporary results +// the values for 3b and 65b are not derived from testing but instead chosen conservatively +static const std::map & VRAM_REQ_SCRATCH_PER_CONTEXT() +{ + static std::map k_sizes = { + { MODEL_3B, 128ull }, + { MODEL_7B, 128ull }, + { MODEL_13B, 160ull }, + { MODEL_30B, 208ull }, + { MODEL_65B, 416ull }, + }; + return k_sizes; +} + // default hparams (LLaMA 7B) struct llama_hparams { uint32_t n_vocab = 32000; @@ -1118,11 +1147,14 @@ static void llama_model_load_internal( fprintf(stderr, "%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); ggml_cuda_set_scratch_size(0); // disable scratch } else { - vram_scratch = n_batch * MB; + const size_t vram_scratch_base = VRAM_REQ_SCRATCH_BASE().at(model.type); + const size_t vram_scratch_per_context = VRAM_REQ_SCRATCH_PER_CONTEXT().at(model.type); + vram_scratch = n_batch * (vram_scratch_base + n_ctx * vram_scratch_per_context); ggml_cuda_set_scratch_size(vram_scratch); if (n_gpu_layers > 0) { - fprintf(stderr, "%s: allocating batch_size x 1 MB = %zd MB VRAM for the scratch buffer\n", - __func__, vram_scratch / MB); + fprintf(stderr, "%s: allocating batch_size x (%zd kB + n_ctx x %zd B) = %zd MB VRAM for the scratch buffer\n", + __func__, vram_scratch_base / kB, vram_scratch_per_context, + (vram_scratch + MB - 1) / MB); // round up } } #endif // GGML_USE_CUBLAS From 0bc2cdfc875fa7877d8e01c8bb17066f99c08f21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 1 Jul 2023 21:49:44 +0200 Subject: [PATCH 10/10] Better CUDA synchronization logic (#2057) --- ggml-cuda.cu | 63 ++++++++++++++++++++++++++++++++++++++-------------- ggml-cuda.h | 4 ---- 2 files changed, 46 insertions(+), 21 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4e0d3dbde..50df20edd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -214,6 +214,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); #endif +struct ggml_tensor_extra_gpu { + void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors + cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs +}; + static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -1970,7 +1975,6 @@ inline void ggml_cuda_op_add( } else { GGML_ASSERT(false); } - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2002,7 +2006,6 @@ inline void ggml_cuda_op_mul( // compute mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); } (void) dst; @@ -2023,7 +2026,6 @@ inline void ggml_cuda_op_silu( // compute silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2046,7 +2048,6 @@ inline void ggml_cuda_op_rms_norm( // compute rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2125,7 +2126,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec( GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); #ifdef GGML_CUDA_DMMV_F16 if (src1_convert_f16) { @@ -2202,7 +2202,6 @@ inline void ggml_cuda_op_rope( // compute rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) dst; (void) src0_ddq_i; @@ -2226,7 +2225,6 @@ inline void ggml_cuda_op_diag_mask_inf( // compute diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) dst; (void) src0_ddq_i; @@ -2248,7 +2246,6 @@ inline void ggml_cuda_op_soft_max( // compute soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2344,10 +2341,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0}; - // if multiple GPUs are used they need to wait for the main GPU to finish + // if multiple devices are used they need to wait for the main device + // here an event is recorded that signifies that the main device has finished calculating the input data if (split && g_device_count > 1) { CUDA_CHECK(cudaSetDevice(g_main_device)); - CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device])); } for (int id = 0; id < g_device_count; ++id) { @@ -2373,6 +2371,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm int64_t row_diff = row_high - row_low; cudaSetDevice(id); + cudaStream_t cudaStream_main = g_cudaStreams_main[id]; + + // wait for main GPU data if necessary + if (split && id != g_main_device) { + CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device])); + } if (src0_on_device && src0_is_contiguous) { if (src0_is_f32) { @@ -2448,8 +2452,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } const int64_t i11 = i13*ne12 + i12; - cudaStream_t cudaStream_main = g_cudaStreams_main[id]; - // for split tensors the data begins at i0 == i0_offset_low char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs; float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride; @@ -2509,6 +2511,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm // do the computation op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main); + CUDA_CHECK(cudaGetLastError()); // copy dst to host or other device if necessary if (!dst_on_device) { @@ -2538,6 +2541,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main)); } } + + // signify to main device that other device is done + if (split && g_device_count > 1 && id != g_main_device) { + CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main)); + } } } } @@ -2549,7 +2557,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } CUDA_CHECK(cudaSetDevice(id)); - CUDA_CHECK(cudaDeviceSynchronize()); if (src0_asq[id] > 0) { ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]); @@ -2564,6 +2571,21 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]); } } + + // main device waits for all other devices to be finished + if (split && g_device_count > 1) { + CUDA_CHECK(cudaSetDevice(g_main_device)); + for (int id = 0; id < g_device_count; ++id) { + if (id != g_main_device) { + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id])); + } + } + } + + if (dst->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(cudaDeviceSynchronize()); + } } void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -2803,6 +2825,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); extra->data_device[id] = buf; + + if (backend == GGML_BACKEND_GPU_SPLIT) { + CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming)); + } } tensor->extra = extra; @@ -2816,12 +2842,15 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; for (int id = 0; id < g_device_count; ++id) { - if (extra->data_device[id] == nullptr) { - continue; + if (extra->data_device[id] != nullptr) { + CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(cudaFree(extra->data_device[id])); } - CUDA_CHECK(cudaSetDevice(id)); - CUDA_CHECK(cudaFree(extra->data_device[id])); + if (extra->events[id] != nullptr) { + CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(cudaEventDestroy(extra->events[id])); + } } delete extra; diff --git a/ggml-cuda.h b/ggml-cuda.h index 7a65a3558..3c1e8deb6 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -8,10 +8,6 @@ extern "C" { #define GGML_CUDA_MAX_DEVICES 16 -struct ggml_tensor_extra_gpu { - void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors -}; - void ggml_init_cublas(void); void ggml_cuda_set_tensor_split(const float * tensor_split);