From c7c761a2b7f1cd6e98e585db6c6c60997398f086 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 8 Jul 2023 17:27:05 +0200 Subject: [PATCH] Add split-k optimization for small matrix multiplication Use semaphores for synchronization instead of fences or waitidle Rework async write/read for synchronization --- Makefile | 1 + ggml-vulkan.cpp | 522 +++++++++++++------------- vk_shaders/f16_to_f32.glsl | 2 +- vk_shaders/matmul_f16.glsl | 22 +- vk_shaders/matmul_f32.glsl | 20 +- vk_shaders/matmul_split_k_reduce.glsl | 27 ++ 6 files changed, 321 insertions(+), 273 deletions(-) create mode 100644 vk_shaders/matmul_split_k_reduce.glsl diff --git a/Makefile b/Makefile index a2aad1e89..8bc3ad7e4 100644 --- a/Makefile +++ b/Makefile @@ -222,6 +222,7 @@ ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h $(CXX) $(CXXFLAGS) -c $< -o $@ glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f32.glsl -o vk_shaders/matmul_f32.spv glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16.glsl -o vk_shaders/matmul_f16.spv + glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_split_k_reduce.glsl -o vk_shaders/matmul_split_k_reduce.spv glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f16_to_f32.glsl -o vk_shaders/f16_to_f32.spv glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_q4_0.glsl -o vk_shaders/dequant_q4_0.spv endif diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 05bd55f27..1e0ec5f3c 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -97,6 +97,7 @@ struct vk_queue { vk::Queue queue; vk::CommandPool pool; std::vector cmd_buffers; + std::vector semaphores; std::mutex mutex; }; @@ -106,7 +107,7 @@ vk::Device vk_device; vk_queue vk_compute_queue; vk_queue vk_transfer_queues[VK_TRANSFER_QUEUE_COUNT]; VmaAllocator vk_allocator; -vk_pipeline vk_pipeline_matmul_f32, vk_pipeline_matmul_f16; +vk_pipeline vk_pipeline_matmul_f32, vk_pipeline_matmul_f16, vk_pipeline_matmul_split_k_reduce; vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0; VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc; vk::Buffer vk_buffer_qa, vk_buffer_a, vk_buffer_b, vk_buffer_c; @@ -185,7 +186,7 @@ static vk_pipeline ggml_vk_create_pipeline(const std::string& path, const std::s return pipeline; } -static void ggml_vk_dispatch_pipeline(vk_pipeline& pipeline, std::vector buffers, const std::vector&& push_constants, std::array elements, vk::CommandBuffer& cmd_buffer, vk::Fence& fence) { +static void ggml_vk_dispatch_pipeline(vk_pipeline& pipeline, std::vector buffers, const std::vector&& push_constants, std::array elements, vk::CommandBuffer& cmd_buffer, vk::Fence fence, std::vector&& wait_semaphores, std::vector&& signal_semaphores) { PROFILE("ggml_vk_dispatch_pipeline", std::vector descriptor_buffer_infos; std::vector write_descriptor_sets; @@ -213,11 +214,13 @@ static void ggml_vk_dispatch_pipeline(vk_pipeline& pipeline, std::vector guard(vk_compute_queue.mutex); - vk::SubmitInfo submit_info(0, - nullptr, + vk::SubmitInfo submit_info(wait_semaphores.size(), + wait_semaphores.data(), nullptr, 1, - &cmd_buffer); + &cmd_buffer, + signal_semaphores.size(), + signal_semaphores.data()); vk_compute_queue.queue.submit({ submit_info }, fence); ); @@ -250,7 +253,7 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector device_queue_create_infos; @@ -439,12 +454,13 @@ void ggml_vk_init(void) { vmaCreateAllocator(&allocator_info, &vk_allocator); // Shaders - vk_pipeline_matmul_f32 = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 6, {64, 64, 1}); + vk_pipeline_matmul_f32 = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7, {64, 64, 1}); if (vk_fp16_support) { - vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 6, {64, 64, 1}); + vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7, {64, 64, 1}); } + vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 2, {64, 1, 1}); - vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {32, 1, 1}); + vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {64, 1, 1}); vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 1, {32, 1, 1}); // Queues @@ -466,10 +482,21 @@ void ggml_vk_init(void) { 4096, 49, 11008, 4096, 49, 4096, 32000, 49, 4096, + 512, 512, 128, + 128, 512, 512, + 4096, 512, 4096, + 11008, 512, 4096, + 4096, 512, 11008, + 4096, 512, 4096, + 32000, 512, 4096, + 512, 512, 128, + 128, 512, 512, }; for (size_t i = 0; i < vals.size(); i += 3) { - ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2]); - ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2]); + ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], 10, 1); + ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], 10, 1); + ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], 10, 4); + ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], 10, 4); } #endif } @@ -606,92 +633,93 @@ void ggml_vk_host_free(void* ptr) { ggml_vk_destroy_buffer(*buf); } -static void ggml_vk_buffer_write(vk_buffer* dst, size_t offset, const void * src, size_t size, vk_queue& q) { +static void ggml_vk_buffer_write_2d_async(vk_buffer* dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_queue& q, vk::Fence fence, std::vector&& wait_semaphores, std::vector&& signal_semaphores) { VkMemoryPropertyFlags mem_prop_flags; vmaGetAllocationMemoryProperties(vk_allocator, dst->allocation, &mem_prop_flags); // Buffer is already mapped if(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) { - GGML_ASSERT(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT); - - PROFILE("ggml_vk_buffer_write visible", - memcpy((uint8_t *)dst->info.pMappedData + offset, src, size); - ); - } else { - // Check if src is pinned memory - vk_buffer* buf = nullptr; - size_t buf_offset = 0; - PROFILE("ggml_vk_buffer_write pinned check", - for (size_t i = 0; i < vk_buf_list.size(); i++) { - const uint8_t* addr = (const uint8_t*) std::get<0>(vk_buf_list[i]); - const uint8_t* endr = addr + std::get<1>(vk_buf_list[i]); - if (src >= addr && src < endr) { - buf = &std::get<2>(vk_buf_list[i]); - buf_offset = ((const uint8_t *)src) - addr; - break; - } + std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl; + } + // Check if src is pinned memory + vk_buffer* buf = nullptr; + size_t buf_offset = 0; + PROFILE("ggml_vk_buffer_write pinned check", + for (size_t i = 0; i < vk_buf_list.size(); i++) { + const uint8_t* addr = (const uint8_t*) std::get<0>(vk_buf_list[i]); + const uint8_t* endr = addr + std::get<1>(vk_buf_list[i]); + if (src >= addr && src < endr) { + buf = &std::get<2>(vk_buf_list[i]); + buf_offset = ((const uint8_t *)src) - addr; + break; } - ); + } + ); - if (buf != nullptr) { - // Memory is pinned, use as staging buffer - VkBufferCopy buf_copy = { - buf_offset, // srcOffset - offset, // dstOffset, - size}; // size - - vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(q); - PROFILE("ggml_vk_buffer_write pinned write", - vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit); - cmd_buffer.begin(cmd_buffer_begin_info); - vkCmdCopyBuffer(cmd_buffer, buf->buffer, dst->buffer, 1, &buf_copy); - cmd_buffer.end(); - ); - - vk::SubmitInfo submit_info(0, - nullptr, - nullptr, - 1, - &cmd_buffer); - std::lock_guard guard(q.mutex); - q.queue.submit({ submit_info }, VK_NULL_HANDLE); - - return; + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + std::vector slices(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = buf_offset + i * spitch; + slices[i].dstOffset = offset + i * width; + slices[i].size = width; } - // Staging buffer required, malloc because of async transfer - if (dst->sb_write == nullptr) { - dst->sb_write = new vk_buffer; - *dst->sb_write = ggml_vk_create_buffer(dst->size, VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_MAPPED_BIT, VMA_MEMORY_USAGE_AUTO_PREFER_HOST, 0); - } - - VkMemoryPropertyFlags mpf_staging; - vmaGetAllocationMemoryProperties(vk_allocator, dst->sb_write->allocation, &mpf_staging); - GGML_ASSERT(mpf_staging & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT); - - VkBufferCopy buf_copy = { - 0, // srcOffset - offset, // dstOffset, - size}; // size - - PROFILE("ggml_vk_buffer_write staging", vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(q); + PROFILE("ggml_vk_buffer_write pinned write", vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit); cmd_buffer.begin(cmd_buffer_begin_info); - vkCmdCopyBuffer(cmd_buffer, dst->sb_write->buffer, dst->buffer, 1, &buf_copy); + vkCmdCopyBuffer(cmd_buffer, buf->buffer, dst->buffer, height, slices.data()); cmd_buffer.end(); + ); - memcpy(dst->sb_write->info.pMappedData, src, size); - - vk::SubmitInfo submit_info(0, - nullptr, + vk::SubmitInfo submit_info(wait_semaphores.size(), + wait_semaphores.data(), nullptr, 1, - &cmd_buffer); + &cmd_buffer, + signal_semaphores.size(), + signal_semaphores.data()); std::lock_guard guard(q.mutex); - q.queue.submit({ submit_info }, VK_NULL_HANDLE); - ); + q.queue.submit({ submit_info }, fence); + + return; } + + // Staging buffer required, malloc because of async transfer + if (dst->sb_write == nullptr) { + dst->sb_write = new vk_buffer; + *dst->sb_write = ggml_vk_create_buffer(dst->size, VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_MAPPED_BIT, VMA_MEMORY_USAGE_AUTO_PREFER_HOST, 0); + } + + VkMemoryPropertyFlags mpf_staging; + vmaGetAllocationMemoryProperties(vk_allocator, dst->sb_write->allocation, &mpf_staging); + GGML_ASSERT(mpf_staging & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); + + VkBufferCopy buf_copy = { + 0, + offset, + width * height}; + + vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(q); + vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit); + cmd_buffer.begin(cmd_buffer_begin_info); + vkCmdCopyBuffer(cmd_buffer, dst->sb_write->buffer, dst->buffer, 1, &buf_copy); + cmd_buffer.end(); + + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *)dst->sb_write->info.pMappedData + offset + i * width, (const uint8_t *) src + i * spitch, width); + } + + vk::SubmitInfo submit_info(wait_semaphores.size(), + wait_semaphores.data(), + nullptr, + 1, + &cmd_buffer, + signal_semaphores.size(), + signal_semaphores.data()); + std::lock_guard guard(q.mutex); + q.queue.submit({ submit_info }, fence); } static void ggml_vk_buffer_write_2d(vk_buffer* dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_queue& q) { @@ -704,90 +732,63 @@ static void ggml_vk_buffer_write_2d(vk_buffer* dst, size_t offset, const void * PROFILE("ggml_vk_buffer_write visible", for (size_t i = 0; i < height; i++) { - memcpy((uint8_t *)dst->info.pMappedData + offset + i * width, (uint8_t *) src + i * spitch, width); + memcpy((uint8_t *)dst->info.pMappedData + offset + i * width, (const uint8_t *) src + i * spitch, width); } ); } else { - // Check if src is pinned memory - vk_buffer* buf = nullptr; - size_t buf_offset = 0; - PROFILE("ggml_vk_buffer_write pinned check", - for (size_t i = 0; i < vk_buf_list.size(); i++) { - const uint8_t* addr = (const uint8_t*) std::get<0>(vk_buf_list[i]); - const uint8_t* endr = addr + std::get<1>(vk_buf_list[i]); - if (src >= addr && src < endr) { - buf = &std::get<2>(vk_buf_list[i]); - buf_offset = ((const uint8_t *)src) - addr; - break; - } - } - ); - - if (buf != nullptr) { - // Memory is pinned, use as staging buffer - std::vector slices(height); - for (size_t i = 0; i < height; i++) { - slices[i].srcOffset = buf_offset + i * spitch; - slices[i].dstOffset = offset + i * width; - slices[i].size = width; - } - - vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(q); - PROFILE("ggml_vk_buffer_write pinned write", - vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit); - cmd_buffer.begin(cmd_buffer_begin_info); - vkCmdCopyBuffer(cmd_buffer, buf->buffer, dst->buffer, height, slices.data()); - cmd_buffer.end(); - ); - - vk::SubmitInfo submit_info(0, - nullptr, - nullptr, - 1, - &cmd_buffer); - std::lock_guard guard(q.mutex); - q.queue.submit({ submit_info }, VK_NULL_HANDLE); - - return; - } - - // Staging buffer required, malloc because of async transfer - if (dst->sb_write == nullptr) { - dst->sb_write = new vk_buffer; - *dst->sb_write = ggml_vk_create_buffer(dst->size, VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_MAPPED_BIT, VMA_MEMORY_USAGE_AUTO_PREFER_HOST, 0); - } - - VkMemoryPropertyFlags mpf_staging; - vmaGetAllocationMemoryProperties(vk_allocator, dst->sb_write->allocation, &mpf_staging); - GGML_ASSERT(mpf_staging & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT); - - VkBufferCopy buf_copy = { - 0, - offset, - width * height}; - - PROFILE("ggml_vk_buffer_write staging", - vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(q); - vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit); - cmd_buffer.begin(cmd_buffer_begin_info); - vkCmdCopyBuffer(cmd_buffer, dst->sb_write->buffer, dst->buffer, 1, &buf_copy); - cmd_buffer.end(); - - for (size_t i = 0; i < height; i++) { - memcpy((uint8_t *)dst->info.pMappedData + offset + i * width, (uint8_t *) src + i * spitch, width); - } - - vk::SubmitInfo submit_info(0, - nullptr, - nullptr, - 1, - &cmd_buffer); - std::lock_guard guard(q.mutex); - q.queue.submit({ submit_info }, VK_NULL_HANDLE); - ); + ggml_vk_buffer_write_2d_async(dst, offset, src, spitch, width, height, q, VK_NULL_HANDLE, {}, {}); } } +static void ggml_vk_buffer_write_async(vk_buffer* dst, size_t offset, const void * src, size_t size, vk_queue& q, vk::Fence fence, std::vector&& wait_semaphores, std::vector&& signal_semaphores) { + ggml_vk_buffer_write_2d_async(dst, offset, src, 0, size, 1, q, fence, std::move(wait_semaphores), std::move(signal_semaphores)); +} + +static void ggml_vk_buffer_write(vk_buffer* dst, size_t offset, const void * src, size_t size, vk_queue& q) { + ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1, q); +} + +static void ggml_vk_buffer_read_async(vk_buffer* src, size_t offset, void * dst, size_t size, vk_queue& q, vk::Fence fence, std::vector&& wait_semaphores, std::vector&& signal_semaphores) { + // Check if dst is pinned memory + vk_buffer* buf = nullptr; + size_t buf_offset = 0; + for (size_t i = 0; i < vk_buf_list.size(); i++) { + const uint8_t* addr = (const uint8_t*) std::get<0>(vk_buf_list[i]); + const uint8_t* endr = addr + std::get<1>(vk_buf_list[i]); + if (dst >= addr && dst < endr) { + buf = &std::get<2>(vk_buf_list[i]); + buf_offset = ((const uint8_t *)dst) - addr; + break; + } + } + + if (buf == nullptr) { + std::cerr << "ggml_vulkan: Error: buffer_read_async only works on pinned memory" << std::endl; + GGML_ASSERT(false); + } + // Memory is pinned, use as staging buffer + VkBufferCopy buf_copy = { + offset, // srcOffset + buf_offset, // dstOffset, + size}; // size + + vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(q); + vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit); + cmd_buffer.begin(cmd_buffer_begin_info); + vkCmdCopyBuffer(cmd_buffer, src->buffer, buf->buffer, 1, &buf_copy); + cmd_buffer.end(); + + vk::SubmitInfo submit_info(wait_semaphores.size(), + wait_semaphores.data(), + nullptr, + 1, + &cmd_buffer, + signal_semaphores.size(), + signal_semaphores.data()); + std::lock_guard guard(q.mutex); + q.queue.submit({ submit_info }, fence); +} + static void ggml_vk_buffer_read(vk_buffer* src, size_t offset, void * dst, size_t size, vk_queue& q) { VkMemoryPropertyFlags mem_prop_flags; vmaGetAllocationMemoryProperties(vk_allocator, src->allocation, &mem_prop_flags); @@ -876,7 +877,7 @@ static void ggml_vk_buffer_read(vk_buffer* src, size_t offset, void * dst, size_ } } -static void ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, vk_queue& q) { +static void ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, vk_queue& q, vk::Fence fence, std::vector&& wait_semaphores, std::vector&& signal_semaphores) { const uint64_t ne0 = src->ne[0]; const uint64_t ne1 = src->ne[1]; const uint64_t nb0 = src->nb[0]; @@ -890,19 +891,19 @@ static void ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct gg const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3); if (nb0 == ts && nb1 == row_length) { - ggml_vk_buffer_write(dst, offset, x, ne1*nb1, q); + ggml_vk_buffer_write_async(dst, offset, x, ne1*nb1, q, fence, std::move(wait_semaphores), std::move(signal_semaphores)); return; } if (nb0 == ts) { PROFILE("ggml_vk_buffer_write_2d", - ggml_vk_buffer_write_2d(dst, offset, x, nb1, row_length, ne1, q); + ggml_vk_buffer_write_2d_async(dst, offset, x, nb1, row_length, ne1, q, fence, std::move(wait_semaphores), std::move(signal_semaphores)); ); return; } GGML_ASSERT(false); // TODO: also needs handling of staging buffers uint8_t* dst_ptr = (uint8_t*) dst->info.pMappedData; - uint8_t* xc = (uint8_t*)x; + const uint8_t* xc = (const uint8_t*)x; for (uint64_t i1 = 0; i1 < ne1; i1++) { for (uint64_t i0 = 0; i0 < ne0; i0++) { dst_ptr[offset + i1 * row_length + i0 * ts] = xc[i1 * nb1 + i0 * nb0]; @@ -910,6 +911,29 @@ static void ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct gg } } +static int ggml_vk_guess_split_k(int m, int n, int k) { + if (k > 64 && (m < 128 || n < 128)) { + return 4; + } + + return 1; +} + +static void ggml_vk_matmul(vk_pipeline& pipeline, vk_buffer& a, vk_buffer& b, vk_buffer& d, int m, int n, int k, int split_k, vk::Fence fence, std::vector&& wait_semaphores, std::vector&& signal_semaphores) { + if (split_k == 1) { + vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue); + ggml_vk_dispatch_pipeline(pipeline, {&a, &b, &d}, { m, n, k, k, k, m, k}, { (uint32_t)m, (uint32_t)n, 1}, cmd_buffer, fence, std::move(wait_semaphores), std::move(signal_semaphores)); + } else { + vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue); + vk::Semaphore semaphore = vk_device.createSemaphore({}); + vk_compute_queue.semaphores.push_back(semaphore); + ggml_vk_dispatch_pipeline(pipeline, {&a, &b, &d}, { m, n, k, k, k, m, CEIL_DIV(k, split_k)}, { (uint32_t)m * split_k, (uint32_t)n, 1}, cmd_buffer, VK_NULL_HANDLE, std::move(wait_semaphores), { semaphore }); + + vk::CommandBuffer cmd_buffer_reduce = ggml_vk_cmd_buffer_create(vk_compute_queue); + ggml_vk_dispatch_pipeline(vk_pipeline_matmul_split_k_reduce, {&d}, { m * n, split_k}, { (uint32_t)m * n, 1, 1}, cmd_buffer_reduce, fence, { semaphore }, std::move(signal_semaphores)); + } +} + static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -926,6 +950,8 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; + const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10); + vk_buffer d_X; vk_buffer d_Y; vk_buffer d_D; @@ -935,42 +961,35 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr ggml_vk_pool_malloc(ggml_type_size(src0->type) * x_ne, &d_X, 0); } ggml_vk_pool_malloc(sizeof(float) * y_ne, &d_Y, 0); - ggml_vk_pool_malloc(sizeof(float) * d_ne, &d_D, 0); + ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, 0); - vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue); vk::Fence fence = vk_device.createFence(vk::FenceCreateInfo()); for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { + vk::Semaphore s_x; + vk::Semaphore s_y = vk_device.createSemaphore({}); + std::vector semaphores = { s_y }; + vk_compute_queue.semaphores.push_back(s_y); // copy data to device if (src0->backend != GGML_BACKEND_GPU) { - ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0]); + s_x = vk_device.createSemaphore({}); + vk_compute_queue.semaphores.push_back(s_x); + semaphores.push_back(s_x); + ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0], VK_NULL_HANDLE, {}, { s_x }); } - ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1]); + ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1], VK_NULL_HANDLE, {}, { s_y }); // compute -#ifdef VK_CHK_KERNEL - auto begin = std::chrono::high_resolution_clock::now(); -#endif - - // Wait for transfers to finish - vk_transfer_queues[0].queue.waitIdle(); - vk_transfer_queues[1].queue.waitIdle(); - - ggml_vk_dispatch_pipeline(vk_pipeline_matmul_f32, {&d_X, &d_Y, &d_D}, { (int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01 }, { (uint32_t)ne01, (uint32_t)ne11, 1}, cmd_buffer, fence); - - vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_f32 waitForFences"); - -#ifdef VK_CHK_KERNEL - auto end = std::chrono::high_resolution_clock::now(); - - std::cout << "m=" << ne01 << " n=" << ne11 << " k=" << ne10 << " matmul " << std::chrono::duration_cast(end-begin).count() / 1000.0 << "ms" << std::endl; -#endif + vk::Semaphore s_mm = vk_device.createSemaphore({}); + vk_compute_queue.semaphores.push_back(s_mm); + ggml_vk_matmul(vk_pipeline_matmul_f32, d_X, d_Y, d_D, ne01, ne11, ne10, split_k, VK_NULL_HANDLE, std::vector(semaphores), { s_mm }); // copy dst to host float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne, vk_transfer_queues[0]); + ggml_vk_buffer_read_async(&d_D, 0, d, sizeof(float) * d_ne, vk_transfer_queues[0], fence, { s_mm }, {}); + vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_f32 waitForFences"); vk_device.resetFences({fence}); } } @@ -1014,6 +1033,8 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; + const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10); + vk_buffer d_X; vk_buffer d_Y; vk_buffer d_D; @@ -1023,19 +1044,26 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &d_X, 0); } ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * y_ne, &d_Y, 0); - ggml_vk_pool_malloc(sizeof(float) * d_ne, &d_D, 0); + ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, 0); bool src1_cont_rows = nb10 == sizeof(float); bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float); - vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue); vk::Fence fence = vk_device.createFence(vk::FenceCreateInfo()); for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { + vk::Semaphore s_x; + vk::Semaphore s_y = vk_device.createSemaphore({}); + std::vector semaphores = { s_y }; + vk_compute_queue.semaphores.push_back(s_y); + // copy data to device if (src1->backend != GGML_BACKEND_GPU) { - ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0]); + s_x = vk_device.createSemaphore({}); + vk_compute_queue.semaphores.push_back(s_x); + semaphores.push_back(s_x); + ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0], VK_NULL_HANDLE, {}, { s_x }); } // convert src1 to fp16 @@ -1060,30 +1088,18 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr } } } - ggml_vk_buffer_write(&d_Y, 0, tmp, sizeof(ggml_fp16_t) * y_ne, vk_transfer_queues[1]); - - // Wait for transfers to finish - vk_transfer_queues[0].queue.waitIdle(); - vk_transfer_queues[1].queue.waitIdle(); + ggml_vk_buffer_write_async(&d_Y, 0, tmp, sizeof(ggml_fp16_t) * y_ne, vk_transfer_queues[1], VK_NULL_HANDLE, {}, { s_y }); // compute -#ifdef VK_CHK_KERNEL - auto begin = std::chrono::high_resolution_clock::now(); -#endif - - ggml_vk_dispatch_pipeline(vk_pipeline_matmul_f16, {&d_X, &d_Y, &d_D}, { (int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01 }, { (uint32_t)ne01, (uint32_t)ne11, 1}, cmd_buffer, fence); - vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_f16 waitForFences"); - -#ifdef VK_CHK_KERNEL - auto end = std::chrono::high_resolution_clock::now(); - - std::cout << "m=" << ne01 << " n=" << ne11 << " k=" << ne10 << " matmul " << std::chrono::duration_cast(end-begin).count() / 1000.0 << "ms" << std::endl; -#endif + vk::Semaphore s_mm = vk_device.createSemaphore({}); + vk_compute_queue.semaphores.push_back(s_mm); + ggml_vk_matmul(vk_pipeline_matmul_f16, d_X, d_Y, d_D, ne01, ne11, ne10, split_k, VK_NULL_HANDLE, std::vector(semaphores), { s_mm }); // copy dst to host float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne, vk_transfer_queues[0]); + ggml_vk_buffer_read_async(&d_D, 0, d, sizeof(float) * d_ne, vk_transfer_queues[0], fence, { s_mm }, {}); + vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_f16 waitForFences"); vk_device.resetFences({fence}); } } @@ -1120,6 +1136,8 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * const int d_ne = ne11 * ne01; const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type); + const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10); + vk_buffer d_X; vk_buffer d_Y; vk_buffer d_D; @@ -1127,7 +1145,7 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * ggml_vk_pool_malloc(sizeof(float) * x_ne, &d_X, 0); } ggml_vk_pool_malloc(sizeof(float) * y_ne, &d_Y, 0); - ggml_vk_pool_malloc(sizeof(float) * d_ne, &d_D, 0); + ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, 0); vk_buffer d_Q; if (src0->backend == GGML_BACKEND_CPU) { ggml_vk_pool_malloc(q_sz, &d_Q, 0); @@ -1137,14 +1155,25 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * // vk_pipeline* dmmv = ggml_get_dequantize_mul_mat_vec_vk(type); GGML_ASSERT(to_fp32_vk != nullptr); - vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue); vk::Fence fence = vk_device.createFence(vk::FenceCreateFlags{}); for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { + vk::Semaphore s_x; + vk::Semaphore s_y = vk_device.createSemaphore({}); + std::vector semaphores_q; + std::vector semaphores = { s_y }; + vk_compute_queue.semaphores.push_back(s_y); + + vk::Semaphore s_mm = vk_device.createSemaphore({}); + vk_compute_queue.semaphores.push_back(s_mm); + // copy src0 to device if necessary if (src0->backend == GGML_BACKEND_CPU) { - ggml_vk_h2d_tensor_2d(&d_Q, 0, src0, i03, i02, vk_transfer_queues[0]); + s_x = vk_device.createSemaphore({}); + vk_compute_queue.semaphores.push_back(s_x); + semaphores_q.push_back(s_x); + ggml_vk_h2d_tensor_2d(&d_Q, 0, src0, i03, i02, vk_transfer_queues[0], VK_NULL_HANDLE, {}, { s_x }); } else if (src0->backend == GGML_BACKEND_GPU) { d_Q = *(vk_buffer *) src0->data; } else { @@ -1169,43 +1198,25 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * // VK_CHECK(vkEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++)); } else { // general dequantization kernel + VK matrix matrix multiplication // convert src0 to fp32 on device - // Wait for transfers to finish - vk_transfer_queues[0].queue.waitIdle(); - - vk_device.resetFences({ fence }); - ggml_vk_dispatch_pipeline(*to_fp32_vk, {&d_Q, &d_X}, { (int)x_ne }, { (uint32_t)x_ne, 1, 1}, cmd_buffer, fence); + vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue); + vk::Semaphore s_q = vk_device.createSemaphore({}); + vk_compute_queue.semaphores.push_back(s_q); + semaphores.push_back(s_q); + ggml_vk_dispatch_pipeline(*to_fp32_vk, {&d_Q, &d_X}, { (int)x_ne }, { (uint32_t)x_ne, 1, 1}, cmd_buffer, VK_NULL_HANDLE, std::vector(semaphores_q), { s_q }); // copy src1 to device - ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1]); - - // wait for conversion - vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_q_f32 src0 convert waitForFences"); - - vk_device.resetFences({ fence }); - cmd_buffer.reset(vk::CommandBufferResetFlags()); - - // Wait for transfers to finish - vk_transfer_queues[1].queue.waitIdle(); + ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1], VK_NULL_HANDLE, {}, { s_y }); // compute -#ifdef VK_CHK_KERNEL - auto begin = std::chrono::high_resolution_clock::now(); -#endif - - ggml_vk_dispatch_pipeline(vk_pipeline_matmul_f32, {&d_X, &d_Y, &d_D}, { (int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01 }, { (uint32_t)ne01, (uint32_t)ne11, 1}, cmd_buffer, fence); - - vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_q_f32 matmul waitForFences"); - -#ifdef VK_CHK_KERNEL - auto end = std::chrono::high_resolution_clock::now(); - - std::cout << "m=" << ne01 << " n=" << ne11 << " k=" << ne10 << " matmul " << std::chrono::duration_cast(end-begin).count() / 1000.0 << "ms" << std::endl; -#endif + ggml_vk_matmul(vk_pipeline_matmul_f32, d_X, d_Y, d_D, ne01, ne11, ne10, split_k, VK_NULL_HANDLE, std::vector(semaphores), { s_mm }); } // copy dst to host float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne, vk_transfer_queues[0]); + ggml_vk_buffer_read_async(&d_D, 0, d, sizeof(float) * d_ne, vk_transfer_queues[0], fence, { s_mm }, {}); + + vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_q_f32 matmul waitForFences"); + vk_device.resetFences({ fence }); } } @@ -1336,7 +1347,7 @@ void ggml_vk_test_transfer(size_t ne) { free(x); free(y); } -void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k) { +void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int split_k) { const size_t x_ne = m * k; const size_t y_ne = k * n; const size_t d_ne = m * n; @@ -1346,7 +1357,7 @@ void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k) { vk_buffer d_D; ggml_vk_pool_malloc(sizeof(float) * x_ne, &d_X, 0); ggml_vk_pool_malloc(sizeof(float) * y_ne, &d_Y, 0); - ggml_vk_pool_malloc(sizeof(float) * d_ne, &d_D, 0); + ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, 0); float* x = (float *) malloc(sizeof(float) * x_ne); float* y = (float *) malloc(sizeof(float) * y_ne); @@ -1366,15 +1377,13 @@ void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k) { vk_transfer_queues[0].queue.waitIdle(); vk_transfer_queues[1].queue.waitIdle(); - vk::Fence fence = vk_device.createFence(vk::FenceCreateFlags{}); - - vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue); - auto begin = std::chrono::high_resolution_clock::now(); - ggml_vk_dispatch_pipeline(vk_pipeline_matmul_f32, {&d_X, &d_Y, &d_D}, { (int)m, (int)n, (int)k, (int)k, (int)k, (int)m }, { (uint32_t)m, (uint32_t)n, 1}, cmd_buffer, fence); + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul(vk_pipeline_matmul_f32, d_X, d_Y, d_D, m, n, k, split_k, VK_NULL_HANDLE, {}, {}); + } - vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "test_matmul_f32 waitForFences"); + vk_compute_queue.queue.waitIdle(); auto end = std::chrono::high_resolution_clock::now(); @@ -1397,12 +1406,10 @@ void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k) { } } - std::cout << "TEST FP32 m=" << m << " n=" << n << " k=" << k << " matmul " << std::chrono::duration_cast(end-begin).count() / 1000.0 << "ms avg_err=" << avg_err / (m * n) << std::endl; + std::cout << "TEST FP32 m=" << m << " n=" << n << " k=" << k << " split_k=" << split_k << " matmul " << std::chrono::duration_cast(end-begin).count() / 1000.0 / num_it << "ms avg_err=" << avg_err / (m * n) << std::endl; free(d_chk); - vk_device.destroyFence(fence); - ggml_vk_queue_cleanup(vk_compute_queue); ggml_vk_queue_cleanup(vk_transfer_queues[0]); ggml_vk_queue_cleanup(vk_transfer_queues[1]); @@ -1416,7 +1423,7 @@ void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k) { free(d); } -void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k) { +void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k, size_t num_it, int split_k) { if (!vk_fp16_support) { return; } @@ -1429,7 +1436,7 @@ void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k) { vk_buffer d_D; ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &d_X, 0); ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * y_ne, &d_Y, 0); - ggml_vk_pool_malloc(sizeof(float) * d_ne, &d_D, 0); + ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, 0); ggml_fp16_t* x = (ggml_fp16_t *) malloc(sizeof(ggml_fp16_t) * x_ne); ggml_fp16_t* y = (ggml_fp16_t *) malloc(sizeof(ggml_fp16_t) * y_ne); @@ -1448,14 +1455,13 @@ void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k) { vk_transfer_queues[0].queue.waitIdle(); vk_transfer_queues[1].queue.waitIdle(); - vk::Fence fence = vk_device.createFence(vk::FenceCreateFlags{}); - vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue); - auto begin = std::chrono::high_resolution_clock::now(); - ggml_vk_dispatch_pipeline(vk_pipeline_matmul_f16, {&d_X, &d_Y, &d_D}, { (int)m, (int)n, (int)k, (int)k, (int)k, (int)m }, { (uint32_t)m, (uint32_t)n, 1}, cmd_buffer, fence); + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul(vk_pipeline_matmul_f16, d_X, d_Y, d_D, m, n, k, split_k, VK_NULL_HANDLE, {}, {}); + } - vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "test_matmul_f16 waitForFences"); + vk_compute_queue.queue.waitIdle(); auto end = std::chrono::high_resolution_clock::now(); @@ -1483,14 +1489,12 @@ void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k) { } } - std::cout << "TEST FP16 m=" << m << " n=" << n << " k=" << k << " matmul " << std::chrono::duration_cast(end-begin).count() / 1000.0 << "ms avg_err=" << avg_err / (m * n) << std::endl; + std::cout << "TEST FP16 m=" << m << " n=" << n << " k=" << k << " split_k=" << split_k << " matmul " << std::chrono::duration_cast(end-begin).count() / 1000.0 / num_it << "ms avg_err=" << avg_err / (m * n) << std::endl; free(fx); free(fy); free(d_chk); - vk_device.destroyFence(fence); - ggml_vk_queue_cleanup(vk_compute_queue); ggml_vk_queue_cleanup(vk_transfer_queues[0]); ggml_vk_queue_cleanup(vk_transfer_queues[1]); diff --git a/vk_shaders/f16_to_f32.glsl b/vk_shaders/f16_to_f32.glsl index 85d7152db..8dddee22c 100644 --- a/vk_shaders/f16_to_f32.glsl +++ b/vk_shaders/f16_to_f32.glsl @@ -2,7 +2,7 @@ #extension GL_EXT_shader_16bit_storage : require -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A { float16_t data_a[]; }; layout (binding = 1) writeonly buffer D { float data_b[]; }; diff --git a/vk_shaders/matmul_f16.glsl b/vk_shaders/matmul_f16.glsl index 8fa3cda6c..4a56ffcff 100644 --- a/vk_shaders/matmul_f16.glsl +++ b/vk_shaders/matmul_f16.glsl @@ -28,13 +28,16 @@ layout (push_constant) uniform parameter int stride_a; int stride_b; int stride_d; + int k_split; } p; shared float16_t buf_a[BM * (BK+1)]; shared float16_t buf_b[BN * (BK+1)]; void main() { - const int ir = int(gl_WorkGroupID.x); + const int blocks_x = (p.M + BM - 1) / BM; + const int ir = int(gl_WorkGroupID.x) % blocks_x; + const int ik = int(gl_WorkGroupID.x) / blocks_x; const int ic = int(gl_WorkGroupID.y); const int warp_i = int(gl_LocalInvocationID.x / WARP); @@ -54,18 +57,21 @@ void main() { const int loadstride = int(gl_WorkGroupSize.x); - int pos_a = ir * BM * p.stride_a; - int pos_b = ic * BN * p.stride_b; + const int start_k = ik * p.k_split; + const int end_k = (ik + 1) * p.k_split; + + int pos_a = ir * BM * p.stride_a + start_k; + int pos_b = ic * BN * p.stride_b + start_k; float sums[WMITER * TM * WNITER * TN]; float16_t cache_a[WMITER * TM]; float16_t cache_b[WNITER * TN]; [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0hf; + sums[i] = 0.0f; } - [[unroll]] for (int block = 0; block < p.K; block += BK) { + [[unroll]] for (int block = start_k; block < end_k; block += BK) { [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { const int lr = l % BK; const int lc = l / BK; @@ -90,7 +96,7 @@ void main() { pos_a += BK; pos_b += BK; - [[unroll]] for (int i = 0; i < BK; i++) { + for (int i = 0; i < min(BK, p.K - block); i++) { // Load from shared into cache [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (int j = 0; j < TM; j++) { @@ -120,6 +126,8 @@ void main() { const int dr = ir * BM + warp_r * WM; const int dc = ic * BN + warp_c * WN; + const int k_split_offset = ik * p.M * p.N; + [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { @@ -128,7 +136,7 @@ void main() { [[unroll]] for (int cc = 0; cc < TN; cc++) { [[unroll]] for (int cr = 0; cr < TM; cr++) { if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; + data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; } } } diff --git a/vk_shaders/matmul_f32.glsl b/vk_shaders/matmul_f32.glsl index a353345af..c461685d0 100644 --- a/vk_shaders/matmul_f32.glsl +++ b/vk_shaders/matmul_f32.glsl @@ -27,13 +27,16 @@ layout (push_constant) uniform parameter int stride_a; int stride_b; int stride_d; + int k_split; } p; shared float buf_a[BM * (BK+1)]; shared float buf_b[BN * (BK+1)]; void main() { - const int ir = int(gl_WorkGroupID.x); + const int blocks_x = (p.M + BM - 1) / BM; + const int ir = int(gl_WorkGroupID.x) % blocks_x; + const int ik = int(gl_WorkGroupID.x) / blocks_x; const int ic = int(gl_WorkGroupID.y); const int warp_i = int(gl_LocalInvocationID.x / WARP); @@ -53,8 +56,11 @@ void main() { const int loadstride = int(gl_WorkGroupSize.x); - int pos_a = ir * BM * p.stride_a; - int pos_b = ic * BN * p.stride_b; + const int start_k = ik * p.k_split; + const int end_k = (ik + 1) * p.k_split; + + int pos_a = ir * BM * p.stride_a + start_k; + int pos_b = ic * BN * p.stride_b + start_k; float sums[WMITER * TM * WNITER * TN]; float cache_a[WMITER * TM]; @@ -64,7 +70,7 @@ void main() { sums[i] = 0.0f; } - [[unroll]] for (int block = 0; block < p.K; block += BK) { + [[unroll]] for (int block = start_k; block < end_k; block += BK) { [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { const int lr = l % BK; const int lc = l / BK; @@ -89,7 +95,7 @@ void main() { pos_a += BK; pos_b += BK; - [[unroll]] for (int i = 0; i < BK; i++) { + for (int i = 0; i < min(BK, p.K - block); i++) { // Load from shared into cache [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (int j = 0; j < TM; j++) { @@ -119,6 +125,8 @@ void main() { const int dr = ir * BM + warp_r * WM; const int dc = ic * BN + warp_c * WN; + const int k_split_offset = ik * p.M * p.N; + [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { @@ -127,7 +135,7 @@ void main() { [[unroll]] for (int cc = 0; cc < TN; cc++) { [[unroll]] for (int cr = 0; cr < TM; cr++) { if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; + data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; } } } diff --git a/vk_shaders/matmul_split_k_reduce.glsl b/vk_shaders/matmul_split_k_reduce.glsl new file mode 100644 index 000000000..710f10e39 --- /dev/null +++ b/vk_shaders/matmul_split_k_reduce.glsl @@ -0,0 +1,27 @@ +#version 450 + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) buffer A { float data[]; }; + +layout (push_constant) uniform parameter +{ + int N; + int k_num; +} p; + +void main() { + const int idx = int(gl_GlobalInvocationID.x); + + if (idx >= p.N) { + return; + } + + float result = 0.0f; + + for (int i = 0; i < p.k_num; i++) { + result += data[i * p.N + idx]; + } + + data[idx] = result; +}