From 754ea680a6744fa537bd59d69b07528cb867cc06 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 22 Jul 2023 10:16:18 +0200 Subject: [PATCH] Basic offloading support with mul_f32 and dmmv for q4_0 --- Makefile | 2 + ggml-vulkan.cpp | 283 +++++++++++++++++++++-- ggml-vulkan.h | 5 +- ggml.c | 9 +- llama.cpp | 25 +- llama.h | 2 +- vk_shaders/dequant_mul_mat_vec_q4_0.glsl | 73 ++++++ vk_shaders/mul_f32.glsl | 30 +++ 8 files changed, 399 insertions(+), 30 deletions(-) create mode 100644 vk_shaders/dequant_mul_mat_vec_q4_0.glsl create mode 100644 vk_shaders/mul_f32.glsl diff --git a/Makefile b/Makefile index 3519d52ad..2c80ab24f 100644 --- a/Makefile +++ b/Makefile @@ -239,6 +239,8 @@ ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_split_k_reduce.glsl -o vk_shaders/matmul_split_k_reduce.spv glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f16_to_f32.glsl -o vk_shaders/f16_to_f32.spv glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_q4_0.glsl -o vk_shaders/dequant_q4_0.spv + glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_q4_0.glsl -o vk_shaders/dequant_mul_mat_vec_q4_0.spv + glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/mul_f32.glsl -o vk_shaders/mul_f32.spv endif ifneq ($(filter aarch64%,$(UNAME_M)),) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 75448f6ed..ad64c8f8a 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -113,6 +113,8 @@ vk_queue vk_transfer_queues[VK_TRANSFER_QUEUE_COUNT]; vk_pipeline vk_pipeline_matmul_f32_l, vk_pipeline_matmul_f32_m, vk_pipeline_matmul_f32_s, vk_pipeline_matmul_f16_l, vk_pipeline_matmul_f16_m, vk_pipeline_matmul_f16_s; vk_pipeline vk_pipeline_matmul_f32_aligned_l, vk_pipeline_matmul_f32_aligned_m, vk_pipeline_matmul_f32_aligned_s, vk_pipeline_matmul_f16_aligned_l, vk_pipeline_matmul_f16_aligned_m, vk_pipeline_matmul_f16_aligned_s; vk_pipeline vk_pipeline_matmul_split_k_reduce; +vk_pipeline vk_pipeline_dequant_mul_mat_vec_q4_0; +vk_pipeline vk_pipeline_mul_f32; vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0; bool vk_fp16_support = false; @@ -260,7 +262,7 @@ static vk_sequence ggml_vk_create_sequence_1(vk_queue& q, std::vector& sequences, vk::Fence fence) { #ifdef VK_DEBUG - std::cerr << "ggml_vk_submit(" << q.queue_family_index << ", " << sequences.size() << ")" << std::endl; + std::cerr << "ggml_vk_submit(" << q.queue_family_index << " (" << q.queue << "), " << sequences.size() << ")" << std::endl; #endif if (sequences.empty()) { return; @@ -640,6 +642,10 @@ void ggml_vk_init(void) { vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1); vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1); + vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_q4_0.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + + vk_pipeline_mul_f32 = ggml_vk_create_pipeline("vk_shaders/mul_f32.spv", "main", 3, 8 * sizeof(int), {32, 32, 1}, {}, 1); + // Queues vk_compute_queue = ggml_vk_create_queue(compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader }); for (int i = 0; i < VK_TRANSFER_QUEUE_COUNT; i++) { @@ -656,6 +662,7 @@ void ggml_vk_init(void) { ggml_vk_test_transfer(1024 * 1024 * m); } const std::vector vals { + 4096, 1, 11008, 128, 110, 622, 511, 511, 127, 511, 511, 7, @@ -693,7 +700,7 @@ void ggml_vk_init(void) { #endif } -static vk_pipeline* ggml_get_to_fp32_vk(ggml_type type) { +static vk_pipeline* ggml_vk_get_to_fp32(ggml_type type) { #ifdef VK_DEBUG std::cerr << "ggml_vk_get_to_fp32_vk()" << std::endl; #endif @@ -715,6 +722,35 @@ static vk_pipeline* ggml_get_to_fp32_vk(ggml_type type) { } } +static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return &vk_pipeline_dequant_mul_mat_vec_q4_0; + // case GGML_TYPE_Q4_1: + // return &dequantize_mul_mat_vec_q4_1_cl; + // case GGML_TYPE_Q5_0: + // return &dequantize_mul_mat_vec_q5_0_cl; + // case GGML_TYPE_Q5_1: + // return &dequantize_mul_mat_vec_q5_1_cl; + // case GGML_TYPE_Q8_0: + // return &dequantize_mul_mat_vec_q8_0_cl; + // case GGML_TYPE_F16: + // return &convert_mul_mat_vec_f16_cl; + // case GGML_TYPE_Q2_K: + // return &dequantize_mul_mat_vec_q2_K_cl; + // case GGML_TYPE_Q3_K: + // return &dequantize_mul_mat_vec_q3_K_cl; + // case GGML_TYPE_Q4_K: + // return &dequantize_mul_mat_vec_q4_K_cl; + // case GGML_TYPE_Q5_K: + // return &dequantize_mul_mat_vec_q5_K_cl; + // case GGML_TYPE_Q6_K: + // return &dequantize_mul_mat_vec_q6_K_cl; + default: + return nullptr; + } +} + // buffer pool for vulkan #define MAX_VK_BUFFERS 256 @@ -791,6 +827,16 @@ static void ggml_vk_pool_free(vk_buffer& buffer) { ggml_vk_destroy_buffer(buffer); } +void ggml_vk_free_data(const struct ggml_tensor* tensor) { + if (tensor->backend != GGML_BACKEND_GPU) { + return; + } + + vk_buffer& buf = *(vk_buffer *)tensor->data; + ggml_vk_destroy_buffer(buf); + free(tensor->data); +} + void* ggml_vk_host_malloc(size_t size) { #ifdef VK_DEBUG std::cerr << "ggml_vk_host_malloc(" << size << ")" << std::endl; @@ -848,7 +894,7 @@ static vk_submission ggml_vk_begin_submission(vk_queue& q) { return s; } -static void ggml_vk_dispatch_pipeline(vk_submission& s, vk_pipeline& pipeline, std::vector buffers, size_t push_constant_size, const void* push_constants, std::array elements, vk_queue& q) { +static void ggml_vk_dispatch_pipeline(vk_submission& s, vk_pipeline& pipeline, std::vector buffers, size_t push_constant_size, const void* push_constants, std::array elements) { uint32_t wg0 = CEIL_DIV(elements[0], pipeline.wg_denoms[0]); uint32_t wg1 = CEIL_DIV(elements[1], pipeline.wg_denoms[1]); uint32_t wg2 = CEIL_DIV(elements[2], pipeline.wg_denoms[2]); @@ -1287,17 +1333,17 @@ static vk_sequence ggml_vk_matmul(vk_pipeline& pipeline, vk_buffer& a, vk_buffer ggml_vk_sync_buffers(s.buffer, { d }, q, vk::AccessFlagBits::eMemoryRead, vk::AccessFlagBits::eShaderWrite, false); if (split_k == 1) { const std::vector pc = { m, n, k, stride_a, stride_b, stride_d, k }; - ggml_vk_dispatch_pipeline(s, pipeline, { a, b, d }, pc.size() * sizeof(int), pc.data(), { (uint32_t)m, (uint32_t)n, 1 }, q); + ggml_vk_dispatch_pipeline(s, pipeline, { a, b, d }, pc.size() * sizeof(int), pc.data(), { (uint32_t)m, (uint32_t)n, 1 }); ggml_vk_end_submission(s, std::move(wait_semaphores), std::move(signal_semaphores)); return { s }; } // Synchronize the two submissions const std::vector pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(stride_a, split_k) }; - ggml_vk_dispatch_pipeline(s, pipeline, { a, b, d }, pc1.size() * sizeof(int), pc1.data(), { (uint32_t)m * split_k, (uint32_t)n, 1 }, q); + ggml_vk_dispatch_pipeline(s, pipeline, { a, b, d }, pc1.size() * sizeof(int), pc1.data(), { (uint32_t)m * split_k, (uint32_t)n, 1 }); ggml_vk_sync_buffers(s.buffer, { d }, q, vk::AccessFlagBits::eMemoryWrite, vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite, true); const std::vector pc2 = { m, n, split_k }; - ggml_vk_dispatch_pipeline(s, vk_pipeline_matmul_split_k_reduce, { d }, pc2.size() * sizeof(int), pc2.data(), { (uint32_t)m, (uint32_t)n, 1 }, q); + ggml_vk_dispatch_pipeline(s, vk_pipeline_matmul_split_k_reduce, { d }, pc2.size() * sizeof(int), pc2.data(), { (uint32_t)m, (uint32_t)n, 1 }); ggml_vk_end_submission(s, std::move(wait_semaphores), std::move(signal_semaphores)); return { s }; @@ -1455,6 +1501,7 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(true, ne01, ne11, ne10 == kpad); + // TODO use larger buffers to parallelize execution vk_buffer d_X; vk_buffer d_Y; vk_buffer d_D; @@ -1476,7 +1523,7 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr vk::Semaphore s_it_x; vk::Semaphore s_it_y; - const bool load_x = src1->backend != GGML_BACKEND_GPU; + const bool load_x = src0->backend != GGML_BACKEND_GPU; ggml_fp16_t * fp16_staging = (ggml_fp16_t *) ggml_vk_host_malloc(sizeof(ggml_fp16_t) * (ne11 * ne10) * (ne02 * ne03)); @@ -1589,7 +1636,7 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; const ggml_type type = src0->type; - const bool mul_mat_vec = false; // ne11 == 1; + const bool mul_mat_vec = ne11 == 1; const int x_ne = ne01 * ne00; const int y_ne = ne11 * ne10; @@ -1615,8 +1662,8 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * ggml_vk_pool_malloc(q_sz, &d_Q, {}); } - vk_pipeline* to_fp32_vk = ggml_get_to_fp32_vk(type); - // vk_pipeline* dmmv = ggml_get_dequantize_mul_mat_vec_vk(type); + vk_pipeline* to_fp32_vk = ggml_vk_get_to_fp32(type); + vk_pipeline* dmmv = ggml_vk_get_dequantize_mul_mat_vec(type); GGML_ASSERT(to_fp32_vk != nullptr); std::vector compute_seqs; @@ -1634,10 +1681,9 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * const bool last = i03 == ne03 - 1 && i02 == ne02 - 1; vk::Semaphore s_x; - vk::Semaphore s_y = ggml_vk_create_semaphore(vk_compute_queue); - vk::Semaphore s_q = ggml_vk_create_semaphore(vk_compute_queue); + vk::Semaphore s_y = ggml_vk_create_semaphore(vk_transfer_queues[0]); + vk::Semaphore s_q = ggml_vk_create_semaphore(vk_transfer_queues[0]); - std::vector semaphores = { s_q, s_y }; std::vector q_semaphores; vk::Semaphore s_mm = ggml_vk_create_semaphore(vk_compute_queue); @@ -1669,11 +1715,6 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * } if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel - GGML_ASSERT(false); - // // copy src1 to device - // events.emplace_back(); - // VK_CHECK(ggml_vk_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, events.data() + ev_idx++)); - // // compute // const size_t global = ne01 * VK_DMMV_BLOCK_SIZE; // const size_t local = VK_DMMV_BLOCK_SIZE; @@ -1685,6 +1726,25 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * // VK_CHECK(vkSetKernelArg(*dmmv, 3, sizeof(vk_buffer), &d_D)); // VK_CHECK(vkSetKernelArg(*dmmv, 4, sizeof(vk_int), &ncols)); // VK_CHECK(vkEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++)); + q_semaphores.push_back(s_y); + const int ncols = ne00; + vk_submission s = ggml_vk_begin_submission(vk_compute_queue); + ggml_vk_sync_buffers(s.buffer, { d_Q, d_Y }, vk_compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false); + ggml_vk_sync_buffers(s.buffer, { d_D }, vk_compute_queue, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false); + ggml_vk_dispatch_pipeline(s, *dmmv, {d_Q, d_Y, d_D}, sizeof(int), &ncols, { (uint32_t)ne01, 1, 1}); + if (!last) { + if (load_x) { + s_it_x = ggml_vk_create_semaphore(vk_compute_queue); + s_it_y = ggml_vk_create_semaphore(vk_compute_queue); + ggml_vk_end_submission(s, std::move(q_semaphores), { s_mm, s_it_x, s_it_y }); + } else { + s_it_y = ggml_vk_create_semaphore(vk_compute_queue); + ggml_vk_end_submission(s, std::move(q_semaphores), { s_mm, s_it_y }); + } + } else { + ggml_vk_end_submission(s, std::move(q_semaphores), { s_mm }); + } + compute_seqs.push_back({ s }); } else { // general dequantization kernel + VK matrix matrix multiplication // convert src0 to fp32 on device @@ -1692,7 +1752,7 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * const std::vector pc = { (int)ne01, (int)ne10, (int)ne10, (int)ne10 }; ggml_vk_sync_buffers(s.buffer, { d_Q }, vk_compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false); ggml_vk_sync_buffers(s.buffer, { d_X }, vk_compute_queue, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false); - ggml_vk_dispatch_pipeline(s, *to_fp32_vk, {d_Q, d_X}, pc.size() * sizeof(int), pc.data(), { (uint32_t)x_ne, 1, 1}, vk_compute_queue); + ggml_vk_dispatch_pipeline(s, *to_fp32_vk, {d_Q, d_X}, pc.size() * sizeof(int), pc.data(), { (uint32_t)x_ne, 1, 1}); if (load_x && !last) { s_it_x = ggml_vk_create_semaphore(vk_compute_queue); ggml_vk_end_submission(s, std::move(q_semaphores), { s_q, s_it_x }); @@ -1704,9 +1764,9 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * // compute if (!last) { s_it_y = ggml_vk_create_semaphore(vk_compute_queue); - compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm, s_it_y })); + compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_compute_queue, { s_q, s_y }, { s_mm, s_it_y })); } else { - compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm })); + compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_compute_queue, { s_q, s_y }, { s_mm })); } } @@ -1755,6 +1815,9 @@ bool ggml_vk_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens } bool ggml_vk_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) { +#ifdef VK_DEBUG + std::cerr << "ggml_vk_mul_mat_use_f16(" << src0 << ", " << src1 << ")" << std::endl; +#endif // If device doesn't support FP16 if (!vk_fp16_support) { return false; @@ -1775,6 +1838,9 @@ bool ggml_vk_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_ } void ggml_vk_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { +#ifdef VK_DEBUG + std::cerr << "ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")" << std::endl; +#endif GGML_ASSERT(ggml_vk_can_mul_mat(src0, src1, dst)); if (src0->type == GGML_TYPE_F32) { @@ -1797,12 +1863,187 @@ void ggml_vk_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * } size_t ggml_vk_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { +#ifdef VK_DEBUG + std::cerr << "ggml_vk_mul_mat_get_wsize(" << src0 << ", " << src1 << ", " << dst << ")" << std::endl; +#endif if (ggml_vk_mul_mat_use_f16(src0, src1, dst)) { return ggml_nelements(src1) * sizeof(ggml_fp16_t); } return 0; } +static void ggml_vk_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef VK_DEBUG + std::cerr << "ggml_vk_mul_f32((type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3]; + std::cerr << "), (type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3]; + std::cerr << "), (type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "),)" << std::endl; +#endif + GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + const int64_t ne0 = ne00 * ne01 * ne02 * ne03; + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + const int64_t nb10 = src1->nb[0]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + vk_buffer d_X; + vk_buffer d_Y = *(vk_buffer*) src1->data; + vk_buffer d_D; + ggml_vk_pool_malloc(sizeof(float) * ne0, &d_X, {}); + ggml_vk_pool_malloc(sizeof(float) * ne0, &d_D, {}); + + std::vector compute_seqs; + std::vector transfer_0_seqs; + std::vector transfer_1_seqs; + + vk::Semaphore s_it_x; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const bool first = i03 == 0 && i02 == 0; + const bool last = i03 == ne03 - 1 && i02 == ne02 - 1; + + vk::Semaphore s_x = ggml_vk_create_semaphore(vk_compute_queue); + vk::Semaphore s_mm = ggml_vk_create_semaphore(vk_compute_queue); + // copy src0 to device + if (first) { + transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0], {}, { s_x })); + } else { + // Wait for previous matmul to be done before writing to the input buffers again + transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0], { s_it_x }, { s_x })); + } + + ggml_vk_submit(vk_transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE); + + if (nb10 == sizeof(float)) { + // Contiguous, avoid overhead from queueing many kernel runs + const int64_t i13 = i03%ne13; + const int64_t i12 = i02%ne12; + const int i1 = i13*ne12*ne11 + i12*ne11; + + // cl_int x_offset = 0; + // cl_int y_offset = i1*ne10; + // cl_int d_offset = 0; + + // size_t global = ne00 * ne01; + // cl_int ky = ne10; + // CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X)); + // CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset)); + // CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y)); + // CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset)); + // CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D)); + // CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset)); + // CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky)); + // CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL)); + const std::vector pc = { (int)ne00, (int)ne01, (int)ne00, (int)ne00, (int)ne00, 0, (int)(i1 * ne10), 0 }; + vk_submission s = ggml_vk_begin_submission(vk_compute_queue); + ggml_vk_sync_buffers(s.buffer, { d_X, d_Y }, vk_compute_queue, vk::AccessFlagBits::eMemoryWrite, vk::AccessFlagBits::eShaderRead, false); + ggml_vk_sync_buffers(s.buffer, { d_D }, vk_compute_queue, vk::AccessFlagBits::eMemoryRead, vk::AccessFlagBits::eShaderWrite, false); + ggml_vk_dispatch_pipeline(s, vk_pipeline_mul_f32, {d_X, d_Y, d_D}, sizeof(int) * pc.size(), pc.data(), { (uint32_t)ne00, (uint32_t)ne01, 1}); + if (!last) { + s_it_x = ggml_vk_create_semaphore(vk_compute_queue); + ggml_vk_end_submission(s, { s_x }, { s_mm, s_it_x }); + } else { + ggml_vk_end_submission(s, { s_x }, { s_mm }); + } + compute_seqs.push_back({ s }); + } else { + GGML_ASSERT(false); + for (int64_t i01 = 0; i01 < ne01; i01++) { + const int64_t i13 = i03%ne13; + const int64_t i12 = i02%ne12; + const int64_t i11 = i01%ne11; + const int i1 = i13*ne12*ne11 + i12*ne11 + i11; + + // cl_int x_offset = i01*ne00; + // cl_int y_offset = i1*ne10; + // cl_int d_offset = i01*ne00; + + // // compute + // size_t global = ne00; + // cl_int ky = ne10; + // CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X)); + // CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset)); + // CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y)); + // CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset)); + // CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D)); + // CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset)); + // CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky)); + // CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL)); + const std::vector pc = { (int)ne00, 1, (int)ne00, (int)ne00, (int)ne00, (int)(i01 * ne00), (int)(i1 * ne10), (int)(i01*ne00) }; + vk_submission s = ggml_vk_begin_submission(vk_compute_queue); + ggml_vk_sync_buffers(s.buffer, { d_X, d_Y }, vk_compute_queue, vk::AccessFlagBits::eMemoryWrite, vk::AccessFlagBits::eShaderRead, false); + ggml_vk_sync_buffers(s.buffer, { d_D }, vk_compute_queue, vk::AccessFlagBits::eMemoryRead, vk::AccessFlagBits::eShaderWrite, false); + ggml_vk_dispatch_pipeline(s, vk_pipeline_mul_f32, {d_X, d_Y, d_D}, sizeof(int) * pc.size(), pc.data(), { (uint32_t)ne00, 1, 1}); + if (!last) { + s_it_x = ggml_vk_create_semaphore(vk_compute_queue); + ggml_vk_end_submission(s, { s_x }, { s_mm, s_it_x }); + } else { + ggml_vk_end_submission(s, { s_x }, { s_mm }); + } + compute_seqs.push_back({ s }); + } + } + + // copy dst to host + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + transfer_1_seqs.push_back(ggml_vk_buffer_read_async(&d_D, 0, d, sizeof(float) * ne00 * ne01, vk_transfer_queues[1], { s_mm }, {})); + + ggml_vk_submit(vk_compute_queue, compute_seqs, VK_NULL_HANDLE); + ggml_vk_submit(vk_transfer_queues[1], transfer_1_seqs, VK_NULL_HANDLE); + } + } + + // cleanup waits for the queue to be done + ggml_vk_queue_cleanup(vk_transfer_queues[0]); + ggml_vk_queue_cleanup(vk_transfer_queues[1]); + ggml_vk_queue_cleanup(vk_compute_queue); + + ggml_vk_pool_free(d_X); + ggml_vk_pool_free(d_D); +} + +void ggml_vk_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + ggml_vk_mul_f32(src0, src1, dst); +} + +void ggml_vk_transform_tensor(void * data, ggml_tensor * tensor) { +#ifdef VK_DEBUG + std::cerr << "ggml_vk_transform_tensor(" << data << ", " << tensor << ")" << std::endl; +#endif + const int64_t ne0 = tensor->ne[0]; + const int64_t ne1 = tensor->ne[1]; + const int64_t ne2 = tensor->ne[2]; + const int64_t ne3 = tensor->ne[3]; + + GGML_ASSERT(ne2 == 1 && ne3 == 1); + + const ggml_type type = tensor->type; + const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type); + + vk_buffer dst = ggml_vk_create_buffer(q_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + + std::vector seqs; + + tensor->data = data; + // copy tensor to device + seqs.push_back(ggml_vk_h2d_tensor_2d(&dst, 0, tensor, 0, 0, vk_transfer_queues[0], {}, {})); + + ggml_vk_submit(vk_transfer_queues[0], seqs, VK_NULL_HANDLE); + vk_transfer_queues[0].queue.waitIdle(); + + tensor->data = malloc(sizeof(vk_buffer)); + *(vk_buffer*) tensor->data = dst; + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); +} + #ifdef VK_CHK_KERNEL void ggml_vk_test_transfer(size_t ne) { #ifdef VK_DEBUG diff --git a/ggml-vulkan.h b/ggml-vulkan.h index ece0ec7df..e5880f448 100644 --- a/ggml-vulkan.h +++ b/ggml-vulkan.h @@ -16,10 +16,9 @@ void ggml_vk_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor void * ggml_vk_host_malloc(size_t size); void ggml_vk_host_free(void * ptr); -void ggml_vk_free_data(const struct ggml_tensor* tensor); +void ggml_vk_free_data(const struct ggml_tensor * tensor); -void ggml_vk_transform_tensor(struct ggml_tensor * tensor); -void ggml_vk_load_data(const char * fname, struct ggml_tensor * tensor, size_t offset); +void ggml_vk_transform_tensor(void * data, struct ggml_tensor * tensor); #ifdef __cplusplus } diff --git a/ggml.c b/ggml.c index 8c6f2f11f..413425135 100644 --- a/ggml.c +++ b/ggml.c @@ -9190,13 +9190,20 @@ static void ggml_compute_forward_mul_f32( const int ith = params->ith; const int nth = params->nth; -#ifdef GGML_USE_CLBLAST +#if defined(GGML_USE_CLBLAST) if (src1->backend == GGML_BACKEND_GPU) { if (ith == 0) { ggml_cl_mul(src0, src1, dst); } return; } +#elif defined(GGML_USE_VULKAN) + if (src1->backend == GGML_BACKEND_GPU) { + if (ith == 0) { + ggml_vk_mul(src0, src1, dst); + } + return; + } #endif const int64_t nr = ggml_nrows(src0); diff --git a/llama.cpp b/llama.cpp index 0f9d5346d..1374687a0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -14,6 +14,8 @@ #include "ggml-cuda.h" #elif defined(GGML_USE_CLBLAST) #include "ggml-opencl.h" +#elif defined(GGML_USE_VULKAN) +#include "ggml-vulkan.h" #endif #ifdef GGML_USE_METAL @@ -303,6 +305,10 @@ struct llama_model { for (size_t i = 0; i < tensors_by_name.size(); ++i) { ggml_cl_free_data(tensors_by_name[i].second); } +#elif defined(GGML_USE_VULKAN) + for (size_t i = 0; i < tensors_by_name.size(); ++i) { + ggml_vk_free_data(tensors_by_name[i].second); + } #endif } }; @@ -756,6 +762,13 @@ struct llama_model_loader { free(lt.data); } break; +#elif defined(GGML_USE_VULKAN) + case GGML_BACKEND_GPU: + ggml_vk_transform_tensor(lt.data, lt.ggml_tensor); + if (!use_mmap) { + free(lt.data); + } + break; #endif default: continue; @@ -1089,6 +1102,10 @@ static void llama_model_load_internal( fprintf(stderr, "%s: using OpenCL for GPU acceleration\n", __func__); #define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU +#elif defined(GGML_USE_VULKAN) + fprintf(stderr, "%s: using Vulkan for GPU acceleration\n", __func__); +#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU +#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU #else #define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU @@ -1208,7 +1225,7 @@ static void llama_model_load_internal( } #endif // GGML_USE_CUBLAS -#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_VULKAN) const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); fprintf(stderr, "%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); @@ -1236,10 +1253,10 @@ static void llama_model_load_internal( vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; } } -#elif defined(GGML_USE_CLBLAST) +#elif defined(GGML_USE_CLBLAST) || defined(GGML_USE_VULKAN) const int max_backend_supported_layers = hparams.n_layer + 1; const int max_offloadable_layers = hparams.n_layer + 1; -#endif // GGML_USE_CUBLAS +#endif fprintf(stderr, "%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); @@ -1247,7 +1264,7 @@ static void llama_model_load_internal( __func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up #else (void) n_gpu_layers; -#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) +#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_VULKAN) } // populate `tensors_by_name` diff --git a/llama.h b/llama.h index e744584f2..46d0e7e20 100644 --- a/llama.h +++ b/llama.h @@ -48,7 +48,7 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF -#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) // Defined when llama.cpp is compiled with support for offloading model layers to GPU. #define LLAMA_SUPPORTS_GPU_OFFLOAD #endif diff --git a/vk_shaders/dequant_mul_mat_vec_q4_0.glsl b/vk_shaders/dequant_mul_mat_vec_q4_0.glsl new file mode 100644 index 000000000..d713497ea --- /dev/null +++ b/vk_shaders/dequant_mul_mat_vec_q4_0.glsl @@ -0,0 +1,73 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#define QUANT_K 32 +#define QUANT_R 2 +#define BLOCK_SIZE 32 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +struct block_q4_0 +{ + float16_t d; + uint8_t qs[16]; +}; + +layout (binding = 0) readonly buffer A { block_q4_0 x[]; }; +layout (binding = 1) readonly buffer B { float y[]; }; +layout (binding = 2) writeonly buffer D { float dst[]; }; + +layout (push_constant) uniform parameter +{ + int ncols; +} p; + +shared float tmp[BLOCK_SIZE]; + +void main() { + const int block_size = int(gl_WorkGroupSize.x); + const int row = int(gl_WorkGroupID.x); + const int tid = int(gl_LocalInvocationID.x); + + const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + + tmp[tid] = 0; + + for (int i = 0; i < p.ncols/block_size; i += 2) { + const int col = i*block_size + 2*tid; + const int ib = (row*p.ncols + col)/QUANT_K; // block index + const int iqs = (col%QUANT_K)/QUANT_R; // quant index + const int iybs = col - col%QUANT_K; // y block start index + + // dequantize + const float d = float(x[ib].d); + + const uint8_t vui = x[ib].qs[iqs]; + + const int8_t vi0 = int8_t(vui & 0xF); + const int8_t vi1 = int8_t(vui >> 4); + + float v0 = (vi0 - 8)*d; + float v1 = (vi1 - 8)*d; + + // matrix multiplication + tmp[tid] += v0 * y[iybs + iqs + 0]; + tmp[tid] += v1 * y[iybs + iqs + y_offset]; + } + + // sum up partial sums and write back result + barrier(); + for (int s=block_size/2; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + if (tid == 0) { + dst[row] = tmp[0]; + } +} diff --git a/vk_shaders/mul_f32.glsl b/vk_shaders/mul_f32.glsl new file mode 100644 index 000000000..21cc0a795 --- /dev/null +++ b/vk_shaders/mul_f32.glsl @@ -0,0 +1,30 @@ +#version 450 + +layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; + +layout (binding = 0) buffer X { float data_x[]; }; +layout (binding = 1) buffer Y { float data_y[]; }; +layout (binding = 2) buffer D { float data_d[]; }; + +layout (push_constant) uniform parameter +{ + int M; + int N; + int stride_x; + int stride_y; + int stride_d; + int x_offset; + int y_offset; + int d_offset; +} p; + +void main() { + const int x = int(gl_GlobalInvocationID.x); + const int y = int(gl_GlobalInvocationID.y); + + if (x >= p.M || y >= p.N) { + return; + } + + data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] * data_y[p.y_offset + x]; +}