From 339bc36cdda3014a80c45051ca89bf982e76f750 Mon Sep 17 00:00:00 2001 From: niansa Date: Fri, 23 Jun 2023 11:50:30 +0200 Subject: [PATCH] Added more functions from Metal --- ggml-vulkan.cpp | 142 ++++++++++++++++++++++++++++++++++++++++++++++-- ggml-vulkan.h | 26 ++++++++- llama.cpp | 60 ++++++++++++++++++++ 3 files changed, 222 insertions(+), 6 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index c722609a9..b7e70e221 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -5,8 +5,10 @@ #include #include #include -#include +#include #include +#include +#include #include #include @@ -14,8 +16,6 @@ #error Your C implementation is not IEC 559 compliant, which is required for proper Vulkan interop. #endif -typedef ggml_fp16_t half; - #define MULTILINE_QUOTE(...) #__VA_ARGS__ #define STRINGIFY(x) STRINGIFY2(x) #define STRINGIFY2(x) #x @@ -24,6 +24,10 @@ typedef ggml_fp16_t half; #define QR4_0 2 #define QK4_1 32 + +typedef ggml_fp16_t half; +enum class byte : unsigned char {}; + typedef struct { half d; uint8_t qs[QK4_0 / 2]; @@ -35,12 +39,82 @@ typedef struct { uint8_t qs[QK4_1 / 2]; } block_q4_1; +struct ggml_kompute_context { + std::unordered_map> buffers; + std::unordered_map> tensors; +}; + kp::Manager mgr; +ggml_kompute_context *ggml_vk_init() { + return new ggml_kompute_context; +} -std::vector compileSource(const std::string& source) { +void ggml_metal_free(struct ggml_kompute_context * ctx) { + delete ctx; +} + + +bool ggml_vk_add_buffer( + struct ggml_kompute_context * ctx, + const char * name, + void * data, + size_t size, + size_t max_size) { + try { + std::vector vec(max_size); + std::memcpy(vec.data(), data, std::max(size, max_size)); + auto tensor = mgr.tensorT(vec); + ctx->buffers.emplace(name, std::move(tensor)); + } catch (const std::exception & e) { + fprintf(stderr, "ggml_vk: failed to add buffer '%s': %s\n", name, e.what()); + return false; + } + return true; +} + +std::shared_ptr ggml_vk_get_buffer(struct ggml_kompute_context * ctx, const char * name) { + auto res = ctx->buffers.find(name); + if (res == ctx->buffers.end()) return nullptr; + return res->second; +} + + +void ggml_vk_set_tensor(struct ggml_kompute_context * ctx, struct ggml_tensor * t) { + if (t->backend != GGML_BACKEND_GPU) { + return; + } + + auto data = t->data; + auto size = ggml_nbytes(t); + + std::vector vec(size); + memcpy(vec.data(), data, size); + + auto tensor = mgr.tensorT(vec); + mgr.sequence()->eval({tensor}); + ctx->tensors.emplace(t, std::move(tensor)); +} + +void ggml_vk_get_tensor(struct ggml_kompute_context * ctx, struct ggml_tensor * t) { + if (t->backend != GGML_BACKEND_GPU) { + return; + } + + auto data = t->data; + auto size = ggml_nbytes(t); + + auto res = ctx->tensors.find(t); + + auto tensor = res->second; + mgr.sequence()->eval({tensor}); + memcpy(data, tensor->data(), size); +} + + +static std::vector compileSource(const std::string& source) { //FIXME: Terrible solution!!!! std::ofstream fileOut("tmp_kp_shader.comp"); fileOut << source; @@ -53,6 +127,7 @@ std::vector compileSource(const std::string& source) { return {(uint32_t*)buffer.data(), (uint32_t*)(buffer.data() + buffer.size())}; } + template std::vector getVecBlockQ4_0D(T *x, unsigned nb) { std::vector fres(nb); @@ -90,12 +165,12 @@ static const std::string program_source_head = R"( #define QK4_0 32 #define QR4_0 2 #define QK4_1 32 -layout (local_size_x = 1) in; )"; static const std::string program_dequantize_row_q4_0 = program_source_head+'\n'+MULTILINE_QUOTE( +layout(local_size_x = 1, local_size_y = 1) in; layout(binding = 0) buffer tensorBlockQ4_0D { float16_t x_d[]; }; layout(binding = 1) buffer tensorBlockQ4_0QS { uint8_t x_qs[]; }; layout(binding = 2) buffer tensorY { float y[]; }; @@ -143,6 +218,7 @@ void ggml_vk_dequantize_row_q4_0(const void *x_, float *y, int k) { static const std::string program_dequantize_row_q4_1 = program_source_head+'\n'+MULTILINE_QUOTE( +layout(local_size_x = 1, local_size_y = 1) in; layout(binding = 0) buffer tensorBlockQ4_0D { float16_t x_d[]; }; layout(binding = 1) buffer tensorBlockQ4_0M { float16_t x_m[]; }; layout(binding = 2) buffer tensorBlockQ4_0QS { uint8_t x_qs[]; }; @@ -191,6 +267,55 @@ void ggml_vk_dequantize_row_q4_1(const void *x_, float *y, int k) { } +static const std::string program_abmath = + program_source_head+'\n'+MULTILINE_QUOTE( +layout(push_constant) uniform PushConstants { + uint inAOff; + uint inBOff; + uint outOff; +} pcs; + + +layout(local_size_x = 1) in; +layout(binding = 0) buffer tensorInA { float inA[]; }; +layout(binding = 1) buffer tensorInB { float inB[]; }; +layout(binding = 2) buffer tensorout { float out[]; }; + + +void main() { + const int i = int(gl_GlobalInvocationID.x); + + out[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+i]; +} +); + +template +void ggml_vk_abmath(const std::shared_ptr& inA, uint32_t inAOff, + const std::shared_ptr& inB, uint32_t inBOff, + std::shared_ptr& out, uint32_t outOff) { + const static auto spirv = compileSource("#define MATH_OP "+std::string(1, mathOP)+'\n'+program_abmath); + + struct PushConstants { + uint32_t inAOff, inBOff, outOff; + } pushConsts { + inAOff, inBOff, outOff + }; + + mgr.sequence() + ->eval(mgr.algorithm({inA, inB, out}, spirv, {std::min(inA->size(), inB->size())}, {}, {pushConsts})); +} + +template +void ggml_vk_add(Args&&... args) { + return ggml_vk_abmath<'+'>(std::forward(args)...); +} + +template +void ggml_vk_mul(Args&&... args) { + return ggml_vk_abmath<'*'>(std::forward(args)...); +} + + template<> kp::Tensor::TensorDataTypes kp::TensorT::dataType() @@ -204,3 +329,10 @@ kp::TensorT::dataType() { return TensorDataTypes::eUnsignedInt; } + +template<> +kp::Tensor::TensorDataTypes +kp::TensorT::dataType() +{ + return TensorDataTypes::eUnsignedInt; +} diff --git a/ggml-vulkan.h b/ggml-vulkan.h index 34e6d46b3..649c34b53 100644 --- a/ggml-vulkan.h +++ b/ggml-vulkan.h @@ -1,12 +1,36 @@ #pragma once +#include + #ifdef __cplusplus extern "C" { #endif -void ggml_vk_init(void); +struct ggml_kompute_context; + + +ggml_kompute_context * ggml_vk_init(void); +void ggml_metal_free(struct ggml_kompute_context * ctx); + +// creates a mapping between a host memory buffer and a device memory buffer +// - make sure to map all buffers used in the graph before calling ggml_vk_graph_compute +// - the mapping is used during computation to determine the arguments of the compute kernels +// - you don't need to keep the host memory buffer allocated as it is never accessed by Vulkan +// - max_size specifies the maximum size of a tensor and is used to create shared views such +// that it is guaranteed that the tensor will fit in at least one of the views +// +bool ggml_vk_add_buffer( + struct ggml_kompute_context * ctx, + const char * name, + void * data, + size_t size, + size_t max_size); + +void ggml_vk_set_tensor(struct ggml_kompute_context * ctx, struct ggml_tensor * t); +void ggml_vk_get_tensor(struct ggml_kompute_context * ctx, struct ggml_tensor * t); void ggml_vk_dequantize_row_q4_0(const void * x, float * y, int k); +void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * cgraph); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index e597f5048..824ed6121 100644 --- a/llama.cpp +++ b/llama.cpp @@ -14,6 +14,8 @@ #include "ggml-cuda.h" #elif defined(GGML_USE_CLBLAST) #include "ggml-opencl.h" +#elif defined(GGML_USE_KOMPUTE) +#include "ggml-vulkan.h" #endif #ifdef GGML_USE_METAL @@ -280,6 +282,8 @@ struct llama_context { #ifdef GGML_USE_METAL ggml_metal_context * ctx_metal = NULL; +#elif defined(GGML_USE_KOMPUTE) + ggml_kompute_context * ctx_kompute = NULL; #endif int buf_last = 0; @@ -1701,6 +1705,26 @@ static bool llama_eval_internal( ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v); } + ggml_graph_compute(ctx0, &gf); + } +#elif defined(GGML_USE_KOMPUTE) + if (lctx.ctx_kompute && N == 1) { + ggml_vk_graph_compute(lctx.ctx_kompute, &gf); + ggml_vk_get_tensor (lctx.ctx_kompute, cur); + } else { + // IMPORTANT: + // Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla + // ggml_graph_compute(). + // + // When we implement Matrix x Matrix Metal multiplication, we can avoid this branch. + // But for now, we have focused only on Matrix x Vector Metal multiplication. + // + if (lctx.ctx_kompute) { + // We need to sync the GPU KV cache with the CPU KV cache + ggml_vk_get_tensor(lctx.ctx_kompute, kv_self.k); + ggml_vk_get_tensor(lctx.ctx_kompute, kv_self.v); + } + ggml_graph_compute(ctx0, &gf); } #else @@ -2743,6 +2767,42 @@ struct llama_context * llama_init_from_file( LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size, 0)); #undef LLAMA_METAL_CHECK_BUF } +#elif defined(GGML_USE_KOMPUTE) + if (params.n_gpu_layers > 0) { + // this allocates all Metal resources and memory buffers + ctx->ctx_kompute = ggml_vk_init(); + + void * data_ptr = NULL; + size_t data_size = 0; + + if (params.use_mmap) { + data_ptr = ctx->model.mapping->addr; + data_size = ctx->model.mapping->size; + } else { + data_ptr = ggml_get_mem_buffer(ctx->model.ctx); + data_size = ggml_get_mem_size (ctx->model.ctx); + } + + const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); + + printf("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); + +#define LLAMA_METAL_CHECK_BUF(result) \ + if (!(result)) { \ + fprintf(stderr, "%s: failed to add buffer\n", __func__); \ + llama_free(ctx); \ + return NULL; \ + } + + LLAMA_METAL_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "data", data_ptr, data_size, max_size)); + + LLAMA_METAL_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "eval", ctx->buf_compute.addr, ctx->buf_compute.size, 0)); + LLAMA_METAL_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size, 0)); + + LLAMA_METAL_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size, 0)); + LLAMA_METAL_CHECK_BUF(ggml_vk_add_buffer(ctx->ctx_kompute, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size, 0)); +#undef LLAMA_METAL_CHECK_BUF + } #endif return ctx;