From 46f577bfc1e29b4397e6958bc4400b326456c314 Mon Sep 17 00:00:00 2001 From: niansa Date: Fri, 23 Jun 2023 17:10:45 +0200 Subject: [PATCH] h2d tensors during loadup --- ggml-vulkan.cpp | 86 +++++++++++++++++++++++++++++++++++++++++++------ llama.cpp | 12 +++++-- 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index c260c59c2..0f454c899 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -378,6 +378,19 @@ void ggml_vk_scale(kp::Sequence& seq, seq.record(mgr.algorithm({in, out}, spirv, {size}, {}, {pushConsts})); } +void ggml_vk_xxlu(const std::vector& spirv, kp::Sequence& seq, + const std::shared_ptr& in, uint32_t inOff, + const std::shared_ptr& out, uint32_t outOff, + uint32_t size) { + struct PushConstants { + uint32_t inOff, outOff; + } pushConsts { + inOff, outOff + }; + + seq.record(mgr.algorithm({in, out}, spirv, {size}, {}, {pushConsts})); +} + static const std::string program_silu = MULTILINE_QUOTE( @@ -398,19 +411,64 @@ void main() { } ); -void ggml_vk_silu(kp::Sequence& seq, - const std::shared_ptr& in, uint32_t inOff, - const std::shared_ptr& out, uint32_t outOff, - uint32_t size) { +template +void ggml_vk_silu(Args&&... args) { const static auto spirv = compileSource(program_source_head+program_silu); - struct PushConstants { - uint32_t inOff, outOff; - } pushConsts { - inOff, outOff - }; + ggml_vk_xxlu(spirv, std::forward(args)...); +} - seq.record(mgr.algorithm({in, out}, spirv, {size}, {}, {pushConsts})); + +static const std::string program_relu = + MULTILINE_QUOTE( +layout(push_constant) uniform PushConstants { + uint inAOff; + uint inOff; +} pcs; + +layout(local_size_x = 1) in; +layout(binding = 0) buffer tensorInA { float in_[]; }; +layout(binding = 1) buffer tensorOut { float out_[]; }; + +void main() { + const uint i = gl_GlobalInvocationID.x; + + out_[pcs.outOff+i] = max(0.0, in_[pcs.inOff+i]); +} +); + +template +void ggml_vk_relu(Args&&... args) { + const static auto spirv = compileSource(program_source_head+program_relu); + + ggml_vk_xxlu(spirv, std::forward(args)...); +} + + +static const std::string program_gelu = + MULTILINE_QUOTE( +layout(push_constant) uniform PushConstants { + uint inAOff; + uint inOff; +} pcs; + +layout(local_size_x = 1) in; +layout(binding = 0) buffer tensorInA { float in_[]; }; +layout(binding = 1) buffer tensorOut { float out_[]; }; + +void main() { + const uint i = gl_GlobalInvocationID.x; + const float x = in_[pcs.inOff+i]; + + out_[pcs.outOff+i] = 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); +} +); + +template +void ggml_vk_gelu(Args&&... args) { + const static auto spirv = compileSource(program_source_head+program_gelu); + + ggml_vk_xxlu(spirv, std::forward(args)...); } @@ -515,6 +573,14 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph { ggml_vk_silu(seq, id_src0, offs_src0, id_dst, offs_dst, ggml_nelements(dst)); } break; + case GGML_OP_RELU: + { + ggml_vk_relu(seq, id_src0, offs_src0, id_dst, offs_dst, ggml_nelements(dst)); + } break; + case GGML_OP_GELU: + { + ggml_vk_gelu(seq, id_src0, offs_src0, id_dst, offs_dst, ggml_nelements(dst)); + } break; } } }); diff --git a/llama.cpp b/llama.cpp index cbe285afb..be4b5ca68 100644 --- a/llama.cpp +++ b/llama.cpp @@ -753,7 +753,7 @@ struct llama_model_loader { } } - void load_all_data(llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { + void load_all_data(llama_context & lctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { size_t data_size = 0; size_t prefetch_size = 0; size_t lock_size = 0; @@ -810,6 +810,14 @@ struct llama_model_loader { free(lt.data); } break; +#elif defined(GGML_USE_KOMPUTE) + case GGML_BACKEND_GPU: + lt.ggml_tensor->data = lt.data; + ggml_vk_h2d_tensor(lctx.ctx_kompute, lt.ggml_tensor); + if (!use_mmap) { + free(lt.data); + } + break; #endif default: continue; @@ -1315,7 +1323,7 @@ static void llama_model_load_internal( } #endif - ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); + ml->load_all_data(lctx, progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); if (progress_callback) { progress_callback(1.0f, progress_callback_user_data);