Add split-k optimization for small matrix multiplication
Use semaphores for synchronization instead of fences or waitidle Rework async write/read for synchronization
This commit is contained in:
parent
c3d947510b
commit
c7c761a2b7
6 changed files with 321 additions and 273 deletions
1
Makefile
1
Makefile
|
@ -222,6 +222,7 @@ ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f32.glsl -o vk_shaders/matmul_f32.spv
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f32.glsl -o vk_shaders/matmul_f32.spv
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16.glsl -o vk_shaders/matmul_f16.spv
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16.glsl -o vk_shaders/matmul_f16.spv
|
||||||
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_split_k_reduce.glsl -o vk_shaders/matmul_split_k_reduce.spv
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f16_to_f32.glsl -o vk_shaders/f16_to_f32.spv
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f16_to_f32.glsl -o vk_shaders/f16_to_f32.spv
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_q4_0.glsl -o vk_shaders/dequant_q4_0.spv
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_q4_0.glsl -o vk_shaders/dequant_q4_0.spv
|
||||||
endif
|
endif
|
||||||
|
|
446
ggml-vulkan.cpp
446
ggml-vulkan.cpp
|
@ -97,6 +97,7 @@ struct vk_queue {
|
||||||
vk::Queue queue;
|
vk::Queue queue;
|
||||||
vk::CommandPool pool;
|
vk::CommandPool pool;
|
||||||
std::vector<vk::CommandBuffer> cmd_buffers;
|
std::vector<vk::CommandBuffer> cmd_buffers;
|
||||||
|
std::vector<vk::Semaphore> semaphores;
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -106,7 +107,7 @@ vk::Device vk_device;
|
||||||
vk_queue vk_compute_queue;
|
vk_queue vk_compute_queue;
|
||||||
vk_queue vk_transfer_queues[VK_TRANSFER_QUEUE_COUNT];
|
vk_queue vk_transfer_queues[VK_TRANSFER_QUEUE_COUNT];
|
||||||
VmaAllocator vk_allocator;
|
VmaAllocator vk_allocator;
|
||||||
vk_pipeline vk_pipeline_matmul_f32, vk_pipeline_matmul_f16;
|
vk_pipeline vk_pipeline_matmul_f32, vk_pipeline_matmul_f16, vk_pipeline_matmul_split_k_reduce;
|
||||||
vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0;
|
vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0;
|
||||||
VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc;
|
VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc;
|
||||||
vk::Buffer vk_buffer_qa, vk_buffer_a, vk_buffer_b, vk_buffer_c;
|
vk::Buffer vk_buffer_qa, vk_buffer_a, vk_buffer_b, vk_buffer_c;
|
||||||
|
@ -185,7 +186,7 @@ static vk_pipeline ggml_vk_create_pipeline(const std::string& path, const std::s
|
||||||
return pipeline;
|
return pipeline;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_dispatch_pipeline(vk_pipeline& pipeline, std::vector<vk_buffer *> buffers, const std::vector<int>&& push_constants, std::array<uint32_t, 3> elements, vk::CommandBuffer& cmd_buffer, vk::Fence& fence) {
|
static void ggml_vk_dispatch_pipeline(vk_pipeline& pipeline, std::vector<vk_buffer *> buffers, const std::vector<int>&& push_constants, std::array<uint32_t, 3> elements, vk::CommandBuffer& cmd_buffer, vk::Fence fence, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
||||||
PROFILE("ggml_vk_dispatch_pipeline",
|
PROFILE("ggml_vk_dispatch_pipeline",
|
||||||
std::vector<vk::DescriptorBufferInfo> descriptor_buffer_infos;
|
std::vector<vk::DescriptorBufferInfo> descriptor_buffer_infos;
|
||||||
std::vector<vk::WriteDescriptorSet> write_descriptor_sets;
|
std::vector<vk::WriteDescriptorSet> write_descriptor_sets;
|
||||||
|
@ -213,11 +214,13 @@ static void ggml_vk_dispatch_pipeline(vk_pipeline& pipeline, std::vector<vk_buff
|
||||||
|
|
||||||
std::lock_guard<std::mutex> guard(vk_compute_queue.mutex);
|
std::lock_guard<std::mutex> guard(vk_compute_queue.mutex);
|
||||||
|
|
||||||
vk::SubmitInfo submit_info(0,
|
vk::SubmitInfo submit_info(wait_semaphores.size(),
|
||||||
nullptr,
|
wait_semaphores.data(),
|
||||||
nullptr,
|
nullptr,
|
||||||
1,
|
1,
|
||||||
&cmd_buffer);
|
&cmd_buffer,
|
||||||
|
signal_semaphores.size(),
|
||||||
|
signal_semaphores.data());
|
||||||
|
|
||||||
vk_compute_queue.queue.submit({ submit_info }, fence);
|
vk_compute_queue.queue.submit({ submit_info }, fence);
|
||||||
);
|
);
|
||||||
|
@ -250,7 +253,7 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyPrope
|
||||||
std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl;
|
std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl;
|
||||||
|
|
||||||
for(auto &q_family : queue_family_props) {
|
for(auto &q_family : queue_family_props) {
|
||||||
std::cout << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl;
|
std::cerr << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl;
|
||||||
}
|
}
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
@ -284,6 +287,11 @@ static void ggml_vk_queue_cleanup(vk_queue& q) {
|
||||||
q.queue.waitIdle();
|
q.queue.waitIdle();
|
||||||
vk_device.freeCommandBuffers(q.pool, q.cmd_buffers);
|
vk_device.freeCommandBuffers(q.pool, q.cmd_buffers);
|
||||||
q.cmd_buffers.clear();
|
q.cmd_buffers.clear();
|
||||||
|
|
||||||
|
for (auto semaphore : q.semaphores) {
|
||||||
|
vk_device.destroySemaphore(semaphore);
|
||||||
|
}
|
||||||
|
q.semaphores.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_buffer ggml_vk_create_buffer(size_t size, VmaAllocationCreateFlags alloc_flags, VmaMemoryUsage vma_usage, VkMemoryPropertyFlags req_flags = 0) {
|
static vk_buffer ggml_vk_create_buffer(size_t size, VmaAllocationCreateFlags alloc_flags, VmaMemoryUsage vma_usage, VkMemoryPropertyFlags req_flags = 0) {
|
||||||
|
@ -339,8 +347,8 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_vk_test_transfer(size_t ne);
|
void ggml_vk_test_transfer(size_t ne);
|
||||||
void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k);
|
void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int split_k);
|
||||||
void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k);
|
void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k, size_t num_it, int split_k);
|
||||||
|
|
||||||
void ggml_vk_init(void) {
|
void ggml_vk_init(void) {
|
||||||
char* GGML_VULKAN_DEVICE = getenv("GGML_VULKAN_DEVICE");
|
char* GGML_VULKAN_DEVICE = getenv("GGML_VULKAN_DEVICE");
|
||||||
|
@ -378,6 +386,13 @@ void ggml_vk_init(void) {
|
||||||
uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
|
uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
|
||||||
uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics | vk::QueueFlagBits::eVideoDecodeKHR | vk::QueueFlagBits::eProtected | vk::QueueFlagBits::eOpticalFlowNV, compute_queue_family_index, 2);
|
uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics | vk::QueueFlagBits::eVideoDecodeKHR | vk::QueueFlagBits::eProtected | vk::QueueFlagBits::eOpticalFlowNV, compute_queue_family_index, 2);
|
||||||
|
|
||||||
|
std::cerr << "Queue Families:" << std::endl;
|
||||||
|
for(size_t i = 0; i < queue_family_props.size(); i++) {
|
||||||
|
std::cerr << i << ": Queues: " + std::to_string(queue_family_props[i].queueCount) << " flags: " + to_string(queue_family_props[i].queueFlags) << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cerr << "Using compute queue family " << compute_queue_family_index << " and transfer queue family " << transfer_queue_family_index << std::endl;
|
||||||
|
|
||||||
const float compute_queue_priority = 1.0f;
|
const float compute_queue_priority = 1.0f;
|
||||||
const float transfer_queue_priority[] = { 1.0f, 1.0f, 1.0f };
|
const float transfer_queue_priority[] = { 1.0f, 1.0f, 1.0f };
|
||||||
std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
|
std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
|
||||||
|
@ -439,12 +454,13 @@ void ggml_vk_init(void) {
|
||||||
vmaCreateAllocator(&allocator_info, &vk_allocator);
|
vmaCreateAllocator(&allocator_info, &vk_allocator);
|
||||||
|
|
||||||
// Shaders
|
// Shaders
|
||||||
vk_pipeline_matmul_f32 = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 6, {64, 64, 1});
|
vk_pipeline_matmul_f32 = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7, {64, 64, 1});
|
||||||
if (vk_fp16_support) {
|
if (vk_fp16_support) {
|
||||||
vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 6, {64, 64, 1});
|
vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7, {64, 64, 1});
|
||||||
}
|
}
|
||||||
|
vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 2, {64, 1, 1});
|
||||||
|
|
||||||
vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {32, 1, 1});
|
vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {64, 1, 1});
|
||||||
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 1, {32, 1, 1});
|
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 1, {32, 1, 1});
|
||||||
|
|
||||||
// Queues
|
// Queues
|
||||||
|
@ -466,10 +482,21 @@ void ggml_vk_init(void) {
|
||||||
4096, 49, 11008,
|
4096, 49, 11008,
|
||||||
4096, 49, 4096,
|
4096, 49, 4096,
|
||||||
32000, 49, 4096,
|
32000, 49, 4096,
|
||||||
|
512, 512, 128,
|
||||||
|
128, 512, 512,
|
||||||
|
4096, 512, 4096,
|
||||||
|
11008, 512, 4096,
|
||||||
|
4096, 512, 11008,
|
||||||
|
4096, 512, 4096,
|
||||||
|
32000, 512, 4096,
|
||||||
|
512, 512, 128,
|
||||||
|
128, 512, 512,
|
||||||
};
|
};
|
||||||
for (size_t i = 0; i < vals.size(); i += 3) {
|
for (size_t i = 0; i < vals.size(); i += 3) {
|
||||||
ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2]);
|
ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], 10, 1);
|
||||||
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2]);
|
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], 10, 1);
|
||||||
|
ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], 10, 4);
|
||||||
|
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], 10, 4);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -606,108 +633,14 @@ void ggml_vk_host_free(void* ptr) {
|
||||||
ggml_vk_destroy_buffer(*buf);
|
ggml_vk_destroy_buffer(*buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_buffer_write(vk_buffer* dst, size_t offset, const void * src, size_t size, vk_queue& q) {
|
static void ggml_vk_buffer_write_2d_async(vk_buffer* dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_queue& q, vk::Fence fence, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
||||||
VkMemoryPropertyFlags mem_prop_flags;
|
VkMemoryPropertyFlags mem_prop_flags;
|
||||||
vmaGetAllocationMemoryProperties(vk_allocator, dst->allocation, &mem_prop_flags);
|
vmaGetAllocationMemoryProperties(vk_allocator, dst->allocation, &mem_prop_flags);
|
||||||
|
|
||||||
// Buffer is already mapped
|
// Buffer is already mapped
|
||||||
if(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) {
|
if(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) {
|
||||||
GGML_ASSERT(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
|
std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl;
|
||||||
|
|
||||||
PROFILE("ggml_vk_buffer_write visible",
|
|
||||||
memcpy((uint8_t *)dst->info.pMappedData + offset, src, size);
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
// Check if src is pinned memory
|
|
||||||
vk_buffer* buf = nullptr;
|
|
||||||
size_t buf_offset = 0;
|
|
||||||
PROFILE("ggml_vk_buffer_write pinned check",
|
|
||||||
for (size_t i = 0; i < vk_buf_list.size(); i++) {
|
|
||||||
const uint8_t* addr = (const uint8_t*) std::get<0>(vk_buf_list[i]);
|
|
||||||
const uint8_t* endr = addr + std::get<1>(vk_buf_list[i]);
|
|
||||||
if (src >= addr && src < endr) {
|
|
||||||
buf = &std::get<2>(vk_buf_list[i]);
|
|
||||||
buf_offset = ((const uint8_t *)src) - addr;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
if (buf != nullptr) {
|
|
||||||
// Memory is pinned, use as staging buffer
|
|
||||||
VkBufferCopy buf_copy = {
|
|
||||||
buf_offset, // srcOffset
|
|
||||||
offset, // dstOffset,
|
|
||||||
size}; // size
|
|
||||||
|
|
||||||
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(q);
|
|
||||||
PROFILE("ggml_vk_buffer_write pinned write",
|
|
||||||
vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit);
|
|
||||||
cmd_buffer.begin(cmd_buffer_begin_info);
|
|
||||||
vkCmdCopyBuffer(cmd_buffer, buf->buffer, dst->buffer, 1, &buf_copy);
|
|
||||||
cmd_buffer.end();
|
|
||||||
);
|
|
||||||
|
|
||||||
vk::SubmitInfo submit_info(0,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
1,
|
|
||||||
&cmd_buffer);
|
|
||||||
std::lock_guard<std::mutex> guard(q.mutex);
|
|
||||||
q.queue.submit({ submit_info }, VK_NULL_HANDLE);
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Staging buffer required, malloc because of async transfer
|
|
||||||
if (dst->sb_write == nullptr) {
|
|
||||||
dst->sb_write = new vk_buffer;
|
|
||||||
*dst->sb_write = ggml_vk_create_buffer(dst->size, VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_MAPPED_BIT, VMA_MEMORY_USAGE_AUTO_PREFER_HOST, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
VkMemoryPropertyFlags mpf_staging;
|
|
||||||
vmaGetAllocationMemoryProperties(vk_allocator, dst->sb_write->allocation, &mpf_staging);
|
|
||||||
GGML_ASSERT(mpf_staging & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
|
|
||||||
|
|
||||||
VkBufferCopy buf_copy = {
|
|
||||||
0, // srcOffset
|
|
||||||
offset, // dstOffset,
|
|
||||||
size}; // size
|
|
||||||
|
|
||||||
PROFILE("ggml_vk_buffer_write staging",
|
|
||||||
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(q);
|
|
||||||
vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit);
|
|
||||||
cmd_buffer.begin(cmd_buffer_begin_info);
|
|
||||||
vkCmdCopyBuffer(cmd_buffer, dst->sb_write->buffer, dst->buffer, 1, &buf_copy);
|
|
||||||
cmd_buffer.end();
|
|
||||||
|
|
||||||
memcpy(dst->sb_write->info.pMappedData, src, size);
|
|
||||||
|
|
||||||
vk::SubmitInfo submit_info(0,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
1,
|
|
||||||
&cmd_buffer);
|
|
||||||
std::lock_guard<std::mutex> guard(q.mutex);
|
|
||||||
q.queue.submit({ submit_info }, VK_NULL_HANDLE);
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_vk_buffer_write_2d(vk_buffer* dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_queue& q) {
|
|
||||||
VkMemoryPropertyFlags mem_prop_flags;
|
|
||||||
vmaGetAllocationMemoryProperties(vk_allocator, dst->allocation, &mem_prop_flags);
|
|
||||||
|
|
||||||
// Buffer is already mapped
|
|
||||||
if(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) {
|
|
||||||
GGML_ASSERT(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
|
|
||||||
|
|
||||||
PROFILE("ggml_vk_buffer_write visible",
|
|
||||||
for (size_t i = 0; i < height; i++) {
|
|
||||||
memcpy((uint8_t *)dst->info.pMappedData + offset + i * width, (uint8_t *) src + i * spitch, width);
|
|
||||||
}
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
// Check if src is pinned memory
|
// Check if src is pinned memory
|
||||||
vk_buffer* buf = nullptr;
|
vk_buffer* buf = nullptr;
|
||||||
size_t buf_offset = 0;
|
size_t buf_offset = 0;
|
||||||
|
@ -740,13 +673,15 @@ static void ggml_vk_buffer_write_2d(vk_buffer* dst, size_t offset, const void *
|
||||||
cmd_buffer.end();
|
cmd_buffer.end();
|
||||||
);
|
);
|
||||||
|
|
||||||
vk::SubmitInfo submit_info(0,
|
vk::SubmitInfo submit_info(wait_semaphores.size(),
|
||||||
nullptr,
|
wait_semaphores.data(),
|
||||||
nullptr,
|
nullptr,
|
||||||
1,
|
1,
|
||||||
&cmd_buffer);
|
&cmd_buffer,
|
||||||
|
signal_semaphores.size(),
|
||||||
|
signal_semaphores.data());
|
||||||
std::lock_guard<std::mutex> guard(q.mutex);
|
std::lock_guard<std::mutex> guard(q.mutex);
|
||||||
q.queue.submit({ submit_info }, VK_NULL_HANDLE);
|
q.queue.submit({ submit_info }, fence);
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -759,14 +694,13 @@ static void ggml_vk_buffer_write_2d(vk_buffer* dst, size_t offset, const void *
|
||||||
|
|
||||||
VkMemoryPropertyFlags mpf_staging;
|
VkMemoryPropertyFlags mpf_staging;
|
||||||
vmaGetAllocationMemoryProperties(vk_allocator, dst->sb_write->allocation, &mpf_staging);
|
vmaGetAllocationMemoryProperties(vk_allocator, dst->sb_write->allocation, &mpf_staging);
|
||||||
GGML_ASSERT(mpf_staging & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
|
GGML_ASSERT(mpf_staging & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
|
||||||
|
|
||||||
VkBufferCopy buf_copy = {
|
VkBufferCopy buf_copy = {
|
||||||
0,
|
0,
|
||||||
offset,
|
offset,
|
||||||
width * height};
|
width * height};
|
||||||
|
|
||||||
PROFILE("ggml_vk_buffer_write staging",
|
|
||||||
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(q);
|
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(q);
|
||||||
vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit);
|
vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit);
|
||||||
cmd_buffer.begin(cmd_buffer_begin_info);
|
cmd_buffer.begin(cmd_buffer_begin_info);
|
||||||
|
@ -774,18 +708,85 @@ static void ggml_vk_buffer_write_2d(vk_buffer* dst, size_t offset, const void *
|
||||||
cmd_buffer.end();
|
cmd_buffer.end();
|
||||||
|
|
||||||
for (size_t i = 0; i < height; i++) {
|
for (size_t i = 0; i < height; i++) {
|
||||||
memcpy((uint8_t *)dst->info.pMappedData + offset + i * width, (uint8_t *) src + i * spitch, width);
|
memcpy((uint8_t *)dst->sb_write->info.pMappedData + offset + i * width, (const uint8_t *) src + i * spitch, width);
|
||||||
}
|
}
|
||||||
|
|
||||||
vk::SubmitInfo submit_info(0,
|
vk::SubmitInfo submit_info(wait_semaphores.size(),
|
||||||
nullptr,
|
wait_semaphores.data(),
|
||||||
nullptr,
|
nullptr,
|
||||||
1,
|
1,
|
||||||
&cmd_buffer);
|
&cmd_buffer,
|
||||||
|
signal_semaphores.size(),
|
||||||
|
signal_semaphores.data());
|
||||||
std::lock_guard<std::mutex> guard(q.mutex);
|
std::lock_guard<std::mutex> guard(q.mutex);
|
||||||
q.queue.submit({ submit_info }, VK_NULL_HANDLE);
|
q.queue.submit({ submit_info }, fence);
|
||||||
);
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_buffer_write_2d(vk_buffer* dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_queue& q) {
|
||||||
|
VkMemoryPropertyFlags mem_prop_flags;
|
||||||
|
vmaGetAllocationMemoryProperties(vk_allocator, dst->allocation, &mem_prop_flags);
|
||||||
|
|
||||||
|
// Buffer is already mapped
|
||||||
|
if(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) {
|
||||||
|
GGML_ASSERT(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
|
||||||
|
|
||||||
|
PROFILE("ggml_vk_buffer_write visible",
|
||||||
|
for (size_t i = 0; i < height; i++) {
|
||||||
|
memcpy((uint8_t *)dst->info.pMappedData + offset + i * width, (const uint8_t *) src + i * spitch, width);
|
||||||
}
|
}
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
ggml_vk_buffer_write_2d_async(dst, offset, src, spitch, width, height, q, VK_NULL_HANDLE, {}, {});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_buffer_write_async(vk_buffer* dst, size_t offset, const void * src, size_t size, vk_queue& q, vk::Fence fence, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
||||||
|
ggml_vk_buffer_write_2d_async(dst, offset, src, 0, size, 1, q, fence, std::move(wait_semaphores), std::move(signal_semaphores));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_buffer_write(vk_buffer* dst, size_t offset, const void * src, size_t size, vk_queue& q) {
|
||||||
|
ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1, q);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_buffer_read_async(vk_buffer* src, size_t offset, void * dst, size_t size, vk_queue& q, vk::Fence fence, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
||||||
|
// Check if dst is pinned memory
|
||||||
|
vk_buffer* buf = nullptr;
|
||||||
|
size_t buf_offset = 0;
|
||||||
|
for (size_t i = 0; i < vk_buf_list.size(); i++) {
|
||||||
|
const uint8_t* addr = (const uint8_t*) std::get<0>(vk_buf_list[i]);
|
||||||
|
const uint8_t* endr = addr + std::get<1>(vk_buf_list[i]);
|
||||||
|
if (dst >= addr && dst < endr) {
|
||||||
|
buf = &std::get<2>(vk_buf_list[i]);
|
||||||
|
buf_offset = ((const uint8_t *)dst) - addr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (buf == nullptr) {
|
||||||
|
std::cerr << "ggml_vulkan: Error: buffer_read_async only works on pinned memory" << std::endl;
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
// Memory is pinned, use as staging buffer
|
||||||
|
VkBufferCopy buf_copy = {
|
||||||
|
offset, // srcOffset
|
||||||
|
buf_offset, // dstOffset,
|
||||||
|
size}; // size
|
||||||
|
|
||||||
|
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(q);
|
||||||
|
vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit);
|
||||||
|
cmd_buffer.begin(cmd_buffer_begin_info);
|
||||||
|
vkCmdCopyBuffer(cmd_buffer, src->buffer, buf->buffer, 1, &buf_copy);
|
||||||
|
cmd_buffer.end();
|
||||||
|
|
||||||
|
vk::SubmitInfo submit_info(wait_semaphores.size(),
|
||||||
|
wait_semaphores.data(),
|
||||||
|
nullptr,
|
||||||
|
1,
|
||||||
|
&cmd_buffer,
|
||||||
|
signal_semaphores.size(),
|
||||||
|
signal_semaphores.data());
|
||||||
|
std::lock_guard<std::mutex> guard(q.mutex);
|
||||||
|
q.queue.submit({ submit_info }, fence);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_buffer_read(vk_buffer* src, size_t offset, void * dst, size_t size, vk_queue& q) {
|
static void ggml_vk_buffer_read(vk_buffer* src, size_t offset, void * dst, size_t size, vk_queue& q) {
|
||||||
|
@ -876,7 +877,7 @@ static void ggml_vk_buffer_read(vk_buffer* src, size_t offset, void * dst, size_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, vk_queue& q) {
|
static void ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, vk_queue& q, vk::Fence fence, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
||||||
const uint64_t ne0 = src->ne[0];
|
const uint64_t ne0 = src->ne[0];
|
||||||
const uint64_t ne1 = src->ne[1];
|
const uint64_t ne1 = src->ne[1];
|
||||||
const uint64_t nb0 = src->nb[0];
|
const uint64_t nb0 = src->nb[0];
|
||||||
|
@ -890,19 +891,19 @@ static void ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct gg
|
||||||
|
|
||||||
const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
|
const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
|
||||||
if (nb0 == ts && nb1 == row_length) {
|
if (nb0 == ts && nb1 == row_length) {
|
||||||
ggml_vk_buffer_write(dst, offset, x, ne1*nb1, q);
|
ggml_vk_buffer_write_async(dst, offset, x, ne1*nb1, q, fence, std::move(wait_semaphores), std::move(signal_semaphores));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (nb0 == ts) {
|
if (nb0 == ts) {
|
||||||
PROFILE("ggml_vk_buffer_write_2d",
|
PROFILE("ggml_vk_buffer_write_2d",
|
||||||
ggml_vk_buffer_write_2d(dst, offset, x, nb1, row_length, ne1, q);
|
ggml_vk_buffer_write_2d_async(dst, offset, x, nb1, row_length, ne1, q, fence, std::move(wait_semaphores), std::move(signal_semaphores));
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
// TODO: also needs handling of staging buffers
|
// TODO: also needs handling of staging buffers
|
||||||
uint8_t* dst_ptr = (uint8_t*) dst->info.pMappedData;
|
uint8_t* dst_ptr = (uint8_t*) dst->info.pMappedData;
|
||||||
uint8_t* xc = (uint8_t*)x;
|
const uint8_t* xc = (const uint8_t*)x;
|
||||||
for (uint64_t i1 = 0; i1 < ne1; i1++) {
|
for (uint64_t i1 = 0; i1 < ne1; i1++) {
|
||||||
for (uint64_t i0 = 0; i0 < ne0; i0++) {
|
for (uint64_t i0 = 0; i0 < ne0; i0++) {
|
||||||
dst_ptr[offset + i1 * row_length + i0 * ts] = xc[i1 * nb1 + i0 * nb0];
|
dst_ptr[offset + i1 * row_length + i0 * ts] = xc[i1 * nb1 + i0 * nb0];
|
||||||
|
@ -910,6 +911,29 @@ static void ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct gg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int ggml_vk_guess_split_k(int m, int n, int k) {
|
||||||
|
if (k > 64 && (m < 128 || n < 128)) {
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_matmul(vk_pipeline& pipeline, vk_buffer& a, vk_buffer& b, vk_buffer& d, int m, int n, int k, int split_k, vk::Fence fence, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
||||||
|
if (split_k == 1) {
|
||||||
|
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue);
|
||||||
|
ggml_vk_dispatch_pipeline(pipeline, {&a, &b, &d}, { m, n, k, k, k, m, k}, { (uint32_t)m, (uint32_t)n, 1}, cmd_buffer, fence, std::move(wait_semaphores), std::move(signal_semaphores));
|
||||||
|
} else {
|
||||||
|
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue);
|
||||||
|
vk::Semaphore semaphore = vk_device.createSemaphore({});
|
||||||
|
vk_compute_queue.semaphores.push_back(semaphore);
|
||||||
|
ggml_vk_dispatch_pipeline(pipeline, {&a, &b, &d}, { m, n, k, k, k, m, CEIL_DIV(k, split_k)}, { (uint32_t)m * split_k, (uint32_t)n, 1}, cmd_buffer, VK_NULL_HANDLE, std::move(wait_semaphores), { semaphore });
|
||||||
|
|
||||||
|
vk::CommandBuffer cmd_buffer_reduce = ggml_vk_cmd_buffer_create(vk_compute_queue);
|
||||||
|
ggml_vk_dispatch_pipeline(vk_pipeline_matmul_split_k_reduce, {&d}, { m * n, split_k}, { (uint32_t)m * n, 1, 1}, cmd_buffer_reduce, fence, { semaphore }, std::move(signal_semaphores));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t ne01 = src0->ne[1];
|
const int64_t ne01 = src0->ne[1];
|
||||||
|
@ -926,6 +950,8 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
const int y_ne = ne11 * ne10;
|
const int y_ne = ne11 * ne10;
|
||||||
const int d_ne = ne11 * ne01;
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
|
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
||||||
|
|
||||||
vk_buffer d_X;
|
vk_buffer d_X;
|
||||||
vk_buffer d_Y;
|
vk_buffer d_Y;
|
||||||
vk_buffer d_D;
|
vk_buffer d_D;
|
||||||
|
@ -935,42 +961,35 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
ggml_vk_pool_malloc(ggml_type_size(src0->type) * x_ne, &d_X, 0);
|
ggml_vk_pool_malloc(ggml_type_size(src0->type) * x_ne, &d_X, 0);
|
||||||
}
|
}
|
||||||
ggml_vk_pool_malloc(sizeof(float) * y_ne, &d_Y, 0);
|
ggml_vk_pool_malloc(sizeof(float) * y_ne, &d_Y, 0);
|
||||||
ggml_vk_pool_malloc(sizeof(float) * d_ne, &d_D, 0);
|
ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, 0);
|
||||||
|
|
||||||
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue);
|
|
||||||
vk::Fence fence = vk_device.createFence(vk::FenceCreateInfo());
|
vk::Fence fence = vk_device.createFence(vk::FenceCreateInfo());
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
vk::Semaphore s_x;
|
||||||
|
vk::Semaphore s_y = vk_device.createSemaphore({});
|
||||||
|
std::vector<vk::Semaphore> semaphores = { s_y };
|
||||||
|
vk_compute_queue.semaphores.push_back(s_y);
|
||||||
// copy data to device
|
// copy data to device
|
||||||
if (src0->backend != GGML_BACKEND_GPU) {
|
if (src0->backend != GGML_BACKEND_GPU) {
|
||||||
ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0]);
|
s_x = vk_device.createSemaphore({});
|
||||||
|
vk_compute_queue.semaphores.push_back(s_x);
|
||||||
|
semaphores.push_back(s_x);
|
||||||
|
ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0], VK_NULL_HANDLE, {}, { s_x });
|
||||||
}
|
}
|
||||||
ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1]);
|
ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1], VK_NULL_HANDLE, {}, { s_y });
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
#ifdef VK_CHK_KERNEL
|
vk::Semaphore s_mm = vk_device.createSemaphore({});
|
||||||
auto begin = std::chrono::high_resolution_clock::now();
|
vk_compute_queue.semaphores.push_back(s_mm);
|
||||||
#endif
|
ggml_vk_matmul(vk_pipeline_matmul_f32, d_X, d_Y, d_D, ne01, ne11, ne10, split_k, VK_NULL_HANDLE, std::vector<vk::Semaphore>(semaphores), { s_mm });
|
||||||
|
|
||||||
// Wait for transfers to finish
|
|
||||||
vk_transfer_queues[0].queue.waitIdle();
|
|
||||||
vk_transfer_queues[1].queue.waitIdle();
|
|
||||||
|
|
||||||
ggml_vk_dispatch_pipeline(vk_pipeline_matmul_f32, {&d_X, &d_Y, &d_D}, { (int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01 }, { (uint32_t)ne01, (uint32_t)ne11, 1}, cmd_buffer, fence);
|
|
||||||
|
|
||||||
vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_f32 waitForFences");
|
|
||||||
|
|
||||||
#ifdef VK_CHK_KERNEL
|
|
||||||
auto end = std::chrono::high_resolution_clock::now();
|
|
||||||
|
|
||||||
std::cout << "m=" << ne01 << " n=" << ne11 << " k=" << ne10 << " matmul " << std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0 << "ms" << std::endl;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// copy dst to host
|
// copy dst to host
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne, vk_transfer_queues[0]);
|
ggml_vk_buffer_read_async(&d_D, 0, d, sizeof(float) * d_ne, vk_transfer_queues[0], fence, { s_mm }, {});
|
||||||
|
|
||||||
|
vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_f32 waitForFences");
|
||||||
vk_device.resetFences({fence});
|
vk_device.resetFences({fence});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1014,6 +1033,8 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
const int y_ne = ne11 * ne10;
|
const int y_ne = ne11 * ne10;
|
||||||
const int d_ne = ne11 * ne01;
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
|
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
||||||
|
|
||||||
vk_buffer d_X;
|
vk_buffer d_X;
|
||||||
vk_buffer d_Y;
|
vk_buffer d_Y;
|
||||||
vk_buffer d_D;
|
vk_buffer d_D;
|
||||||
|
@ -1023,19 +1044,26 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &d_X, 0);
|
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &d_X, 0);
|
||||||
}
|
}
|
||||||
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * y_ne, &d_Y, 0);
|
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * y_ne, &d_Y, 0);
|
||||||
ggml_vk_pool_malloc(sizeof(float) * d_ne, &d_D, 0);
|
ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, 0);
|
||||||
|
|
||||||
bool src1_cont_rows = nb10 == sizeof(float);
|
bool src1_cont_rows = nb10 == sizeof(float);
|
||||||
bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
|
bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
|
||||||
|
|
||||||
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue);
|
|
||||||
vk::Fence fence = vk_device.createFence(vk::FenceCreateInfo());
|
vk::Fence fence = vk_device.createFence(vk::FenceCreateInfo());
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
vk::Semaphore s_x;
|
||||||
|
vk::Semaphore s_y = vk_device.createSemaphore({});
|
||||||
|
std::vector<vk::Semaphore> semaphores = { s_y };
|
||||||
|
vk_compute_queue.semaphores.push_back(s_y);
|
||||||
|
|
||||||
// copy data to device
|
// copy data to device
|
||||||
if (src1->backend != GGML_BACKEND_GPU) {
|
if (src1->backend != GGML_BACKEND_GPU) {
|
||||||
ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0]);
|
s_x = vk_device.createSemaphore({});
|
||||||
|
vk_compute_queue.semaphores.push_back(s_x);
|
||||||
|
semaphores.push_back(s_x);
|
||||||
|
ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0], VK_NULL_HANDLE, {}, { s_x });
|
||||||
}
|
}
|
||||||
|
|
||||||
// convert src1 to fp16
|
// convert src1 to fp16
|
||||||
|
@ -1060,30 +1088,18 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ggml_vk_buffer_write(&d_Y, 0, tmp, sizeof(ggml_fp16_t) * y_ne, vk_transfer_queues[1]);
|
ggml_vk_buffer_write_async(&d_Y, 0, tmp, sizeof(ggml_fp16_t) * y_ne, vk_transfer_queues[1], VK_NULL_HANDLE, {}, { s_y });
|
||||||
|
|
||||||
// Wait for transfers to finish
|
|
||||||
vk_transfer_queues[0].queue.waitIdle();
|
|
||||||
vk_transfer_queues[1].queue.waitIdle();
|
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
#ifdef VK_CHK_KERNEL
|
vk::Semaphore s_mm = vk_device.createSemaphore({});
|
||||||
auto begin = std::chrono::high_resolution_clock::now();
|
vk_compute_queue.semaphores.push_back(s_mm);
|
||||||
#endif
|
ggml_vk_matmul(vk_pipeline_matmul_f16, d_X, d_Y, d_D, ne01, ne11, ne10, split_k, VK_NULL_HANDLE, std::vector<vk::Semaphore>(semaphores), { s_mm });
|
||||||
|
|
||||||
ggml_vk_dispatch_pipeline(vk_pipeline_matmul_f16, {&d_X, &d_Y, &d_D}, { (int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01 }, { (uint32_t)ne01, (uint32_t)ne11, 1}, cmd_buffer, fence);
|
|
||||||
vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_f16 waitForFences");
|
|
||||||
|
|
||||||
#ifdef VK_CHK_KERNEL
|
|
||||||
auto end = std::chrono::high_resolution_clock::now();
|
|
||||||
|
|
||||||
std::cout << "m=" << ne01 << " n=" << ne11 << " k=" << ne10 << " matmul " << std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0 << "ms" << std::endl;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// copy dst to host
|
// copy dst to host
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne, vk_transfer_queues[0]);
|
ggml_vk_buffer_read_async(&d_D, 0, d, sizeof(float) * d_ne, vk_transfer_queues[0], fence, { s_mm }, {});
|
||||||
|
|
||||||
|
vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_f16 waitForFences");
|
||||||
vk_device.resetFences({fence});
|
vk_device.resetFences({fence});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1120,6 +1136,8 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
const int d_ne = ne11 * ne01;
|
const int d_ne = ne11 * ne01;
|
||||||
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
|
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
|
||||||
|
|
||||||
|
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
||||||
|
|
||||||
vk_buffer d_X;
|
vk_buffer d_X;
|
||||||
vk_buffer d_Y;
|
vk_buffer d_Y;
|
||||||
vk_buffer d_D;
|
vk_buffer d_D;
|
||||||
|
@ -1127,7 +1145,7 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
ggml_vk_pool_malloc(sizeof(float) * x_ne, &d_X, 0);
|
ggml_vk_pool_malloc(sizeof(float) * x_ne, &d_X, 0);
|
||||||
}
|
}
|
||||||
ggml_vk_pool_malloc(sizeof(float) * y_ne, &d_Y, 0);
|
ggml_vk_pool_malloc(sizeof(float) * y_ne, &d_Y, 0);
|
||||||
ggml_vk_pool_malloc(sizeof(float) * d_ne, &d_D, 0);
|
ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, 0);
|
||||||
vk_buffer d_Q;
|
vk_buffer d_Q;
|
||||||
if (src0->backend == GGML_BACKEND_CPU) {
|
if (src0->backend == GGML_BACKEND_CPU) {
|
||||||
ggml_vk_pool_malloc(q_sz, &d_Q, 0);
|
ggml_vk_pool_malloc(q_sz, &d_Q, 0);
|
||||||
|
@ -1137,14 +1155,25 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
// vk_pipeline* dmmv = ggml_get_dequantize_mul_mat_vec_vk(type);
|
// vk_pipeline* dmmv = ggml_get_dequantize_mul_mat_vec_vk(type);
|
||||||
GGML_ASSERT(to_fp32_vk != nullptr);
|
GGML_ASSERT(to_fp32_vk != nullptr);
|
||||||
|
|
||||||
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue);
|
|
||||||
vk::Fence fence = vk_device.createFence(vk::FenceCreateFlags{});
|
vk::Fence fence = vk_device.createFence(vk::FenceCreateFlags{});
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
vk::Semaphore s_x;
|
||||||
|
vk::Semaphore s_y = vk_device.createSemaphore({});
|
||||||
|
std::vector<vk::Semaphore> semaphores_q;
|
||||||
|
std::vector<vk::Semaphore> semaphores = { s_y };
|
||||||
|
vk_compute_queue.semaphores.push_back(s_y);
|
||||||
|
|
||||||
|
vk::Semaphore s_mm = vk_device.createSemaphore({});
|
||||||
|
vk_compute_queue.semaphores.push_back(s_mm);
|
||||||
|
|
||||||
// copy src0 to device if necessary
|
// copy src0 to device if necessary
|
||||||
if (src0->backend == GGML_BACKEND_CPU) {
|
if (src0->backend == GGML_BACKEND_CPU) {
|
||||||
ggml_vk_h2d_tensor_2d(&d_Q, 0, src0, i03, i02, vk_transfer_queues[0]);
|
s_x = vk_device.createSemaphore({});
|
||||||
|
vk_compute_queue.semaphores.push_back(s_x);
|
||||||
|
semaphores_q.push_back(s_x);
|
||||||
|
ggml_vk_h2d_tensor_2d(&d_Q, 0, src0, i03, i02, vk_transfer_queues[0], VK_NULL_HANDLE, {}, { s_x });
|
||||||
} else if (src0->backend == GGML_BACKEND_GPU) {
|
} else if (src0->backend == GGML_BACKEND_GPU) {
|
||||||
d_Q = *(vk_buffer *) src0->data;
|
d_Q = *(vk_buffer *) src0->data;
|
||||||
} else {
|
} else {
|
||||||
|
@ -1169,43 +1198,25 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
// VK_CHECK(vkEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
|
// VK_CHECK(vkEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
|
||||||
} else { // general dequantization kernel + VK matrix matrix multiplication
|
} else { // general dequantization kernel + VK matrix matrix multiplication
|
||||||
// convert src0 to fp32 on device
|
// convert src0 to fp32 on device
|
||||||
// Wait for transfers to finish
|
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue);
|
||||||
vk_transfer_queues[0].queue.waitIdle();
|
vk::Semaphore s_q = vk_device.createSemaphore({});
|
||||||
|
vk_compute_queue.semaphores.push_back(s_q);
|
||||||
vk_device.resetFences({ fence });
|
semaphores.push_back(s_q);
|
||||||
ggml_vk_dispatch_pipeline(*to_fp32_vk, {&d_Q, &d_X}, { (int)x_ne }, { (uint32_t)x_ne, 1, 1}, cmd_buffer, fence);
|
ggml_vk_dispatch_pipeline(*to_fp32_vk, {&d_Q, &d_X}, { (int)x_ne }, { (uint32_t)x_ne, 1, 1}, cmd_buffer, VK_NULL_HANDLE, std::vector<vk::Semaphore>(semaphores_q), { s_q });
|
||||||
|
|
||||||
// copy src1 to device
|
// copy src1 to device
|
||||||
ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1]);
|
ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1], VK_NULL_HANDLE, {}, { s_y });
|
||||||
|
|
||||||
// wait for conversion
|
|
||||||
vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_q_f32 src0 convert waitForFences");
|
|
||||||
|
|
||||||
vk_device.resetFences({ fence });
|
|
||||||
cmd_buffer.reset(vk::CommandBufferResetFlags());
|
|
||||||
|
|
||||||
// Wait for transfers to finish
|
|
||||||
vk_transfer_queues[1].queue.waitIdle();
|
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
#ifdef VK_CHK_KERNEL
|
ggml_vk_matmul(vk_pipeline_matmul_f32, d_X, d_Y, d_D, ne01, ne11, ne10, split_k, VK_NULL_HANDLE, std::vector<vk::Semaphore>(semaphores), { s_mm });
|
||||||
auto begin = std::chrono::high_resolution_clock::now();
|
|
||||||
#endif
|
|
||||||
|
|
||||||
ggml_vk_dispatch_pipeline(vk_pipeline_matmul_f32, {&d_X, &d_Y, &d_D}, { (int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01 }, { (uint32_t)ne01, (uint32_t)ne11, 1}, cmd_buffer, fence);
|
|
||||||
|
|
||||||
vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_q_f32 matmul waitForFences");
|
|
||||||
|
|
||||||
#ifdef VK_CHK_KERNEL
|
|
||||||
auto end = std::chrono::high_resolution_clock::now();
|
|
||||||
|
|
||||||
std::cout << "m=" << ne01 << " n=" << ne11 << " k=" << ne10 << " matmul " << std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0 << "ms" << std::endl;
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy dst to host
|
// copy dst to host
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne, vk_transfer_queues[0]);
|
ggml_vk_buffer_read_async(&d_D, 0, d, sizeof(float) * d_ne, vk_transfer_queues[0], fence, { s_mm }, {});
|
||||||
|
|
||||||
|
vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "matmul_q_f32 matmul waitForFences");
|
||||||
|
vk_device.resetFences({ fence });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1336,7 +1347,7 @@ void ggml_vk_test_transfer(size_t ne) {
|
||||||
free(x);
|
free(x);
|
||||||
free(y);
|
free(y);
|
||||||
}
|
}
|
||||||
void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k) {
|
void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int split_k) {
|
||||||
const size_t x_ne = m * k;
|
const size_t x_ne = m * k;
|
||||||
const size_t y_ne = k * n;
|
const size_t y_ne = k * n;
|
||||||
const size_t d_ne = m * n;
|
const size_t d_ne = m * n;
|
||||||
|
@ -1346,7 +1357,7 @@ void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k) {
|
||||||
vk_buffer d_D;
|
vk_buffer d_D;
|
||||||
ggml_vk_pool_malloc(sizeof(float) * x_ne, &d_X, 0);
|
ggml_vk_pool_malloc(sizeof(float) * x_ne, &d_X, 0);
|
||||||
ggml_vk_pool_malloc(sizeof(float) * y_ne, &d_Y, 0);
|
ggml_vk_pool_malloc(sizeof(float) * y_ne, &d_Y, 0);
|
||||||
ggml_vk_pool_malloc(sizeof(float) * d_ne, &d_D, 0);
|
ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, 0);
|
||||||
|
|
||||||
float* x = (float *) malloc(sizeof(float) * x_ne);
|
float* x = (float *) malloc(sizeof(float) * x_ne);
|
||||||
float* y = (float *) malloc(sizeof(float) * y_ne);
|
float* y = (float *) malloc(sizeof(float) * y_ne);
|
||||||
|
@ -1366,15 +1377,13 @@ void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k) {
|
||||||
vk_transfer_queues[0].queue.waitIdle();
|
vk_transfer_queues[0].queue.waitIdle();
|
||||||
vk_transfer_queues[1].queue.waitIdle();
|
vk_transfer_queues[1].queue.waitIdle();
|
||||||
|
|
||||||
vk::Fence fence = vk_device.createFence(vk::FenceCreateFlags{});
|
|
||||||
|
|
||||||
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue);
|
|
||||||
|
|
||||||
auto begin = std::chrono::high_resolution_clock::now();
|
auto begin = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
ggml_vk_dispatch_pipeline(vk_pipeline_matmul_f32, {&d_X, &d_Y, &d_D}, { (int)m, (int)n, (int)k, (int)k, (int)k, (int)m }, { (uint32_t)m, (uint32_t)n, 1}, cmd_buffer, fence);
|
for (size_t i = 0; i < num_it; i++) {
|
||||||
|
ggml_vk_matmul(vk_pipeline_matmul_f32, d_X, d_Y, d_D, m, n, k, split_k, VK_NULL_HANDLE, {}, {});
|
||||||
|
}
|
||||||
|
|
||||||
vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "test_matmul_f32 waitForFences");
|
vk_compute_queue.queue.waitIdle();
|
||||||
|
|
||||||
auto end = std::chrono::high_resolution_clock::now();
|
auto end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
@ -1397,12 +1406,10 @@ void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cout << "TEST FP32 m=" << m << " n=" << n << " k=" << k << " matmul " << std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0 << "ms avg_err=" << avg_err / (m * n) << std::endl;
|
std::cout << "TEST FP32 m=" << m << " n=" << n << " k=" << k << " split_k=" << split_k << " matmul " << std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0 / num_it << "ms avg_err=" << avg_err / (m * n) << std::endl;
|
||||||
|
|
||||||
free(d_chk);
|
free(d_chk);
|
||||||
|
|
||||||
vk_device.destroyFence(fence);
|
|
||||||
|
|
||||||
ggml_vk_queue_cleanup(vk_compute_queue);
|
ggml_vk_queue_cleanup(vk_compute_queue);
|
||||||
ggml_vk_queue_cleanup(vk_transfer_queues[0]);
|
ggml_vk_queue_cleanup(vk_transfer_queues[0]);
|
||||||
ggml_vk_queue_cleanup(vk_transfer_queues[1]);
|
ggml_vk_queue_cleanup(vk_transfer_queues[1]);
|
||||||
|
@ -1416,7 +1423,7 @@ void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k) {
|
||||||
free(d);
|
free(d);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k) {
|
void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k, size_t num_it, int split_k) {
|
||||||
if (!vk_fp16_support) {
|
if (!vk_fp16_support) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -1429,7 +1436,7 @@ void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k) {
|
||||||
vk_buffer d_D;
|
vk_buffer d_D;
|
||||||
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &d_X, 0);
|
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &d_X, 0);
|
||||||
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * y_ne, &d_Y, 0);
|
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * y_ne, &d_Y, 0);
|
||||||
ggml_vk_pool_malloc(sizeof(float) * d_ne, &d_D, 0);
|
ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, 0);
|
||||||
|
|
||||||
ggml_fp16_t* x = (ggml_fp16_t *) malloc(sizeof(ggml_fp16_t) * x_ne);
|
ggml_fp16_t* x = (ggml_fp16_t *) malloc(sizeof(ggml_fp16_t) * x_ne);
|
||||||
ggml_fp16_t* y = (ggml_fp16_t *) malloc(sizeof(ggml_fp16_t) * y_ne);
|
ggml_fp16_t* y = (ggml_fp16_t *) malloc(sizeof(ggml_fp16_t) * y_ne);
|
||||||
|
@ -1448,14 +1455,13 @@ void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k) {
|
||||||
vk_transfer_queues[0].queue.waitIdle();
|
vk_transfer_queues[0].queue.waitIdle();
|
||||||
vk_transfer_queues[1].queue.waitIdle();
|
vk_transfer_queues[1].queue.waitIdle();
|
||||||
|
|
||||||
vk::Fence fence = vk_device.createFence(vk::FenceCreateFlags{});
|
|
||||||
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create(vk_compute_queue);
|
|
||||||
|
|
||||||
auto begin = std::chrono::high_resolution_clock::now();
|
auto begin = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
ggml_vk_dispatch_pipeline(vk_pipeline_matmul_f16, {&d_X, &d_Y, &d_D}, { (int)m, (int)n, (int)k, (int)k, (int)k, (int)m }, { (uint32_t)m, (uint32_t)n, 1}, cmd_buffer, fence);
|
for (size_t i = 0; i < num_it; i++) {
|
||||||
|
ggml_vk_matmul(vk_pipeline_matmul_f16, d_X, d_Y, d_D, m, n, k, split_k, VK_NULL_HANDLE, {}, {});
|
||||||
|
}
|
||||||
|
|
||||||
vk::resultCheck(vk_device.waitForFences({ fence }, true, uint64_t(-1)), "test_matmul_f16 waitForFences");
|
vk_compute_queue.queue.waitIdle();
|
||||||
|
|
||||||
auto end = std::chrono::high_resolution_clock::now();
|
auto end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
@ -1483,14 +1489,12 @@ void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cout << "TEST FP16 m=" << m << " n=" << n << " k=" << k << " matmul " << std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0 << "ms avg_err=" << avg_err / (m * n) << std::endl;
|
std::cout << "TEST FP16 m=" << m << " n=" << n << " k=" << k << " split_k=" << split_k << " matmul " << std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0 / num_it << "ms avg_err=" << avg_err / (m * n) << std::endl;
|
||||||
|
|
||||||
free(fx);
|
free(fx);
|
||||||
free(fy);
|
free(fy);
|
||||||
free(d_chk);
|
free(d_chk);
|
||||||
|
|
||||||
vk_device.destroyFence(fence);
|
|
||||||
|
|
||||||
ggml_vk_queue_cleanup(vk_compute_queue);
|
ggml_vk_queue_cleanup(vk_compute_queue);
|
||||||
ggml_vk_queue_cleanup(vk_transfer_queues[0]);
|
ggml_vk_queue_cleanup(vk_transfer_queues[0]);
|
||||||
ggml_vk_queue_cleanup(vk_transfer_queues[1]);
|
ggml_vk_queue_cleanup(vk_transfer_queues[1]);
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
#extension GL_EXT_shader_16bit_storage : require
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
|
||||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A { float16_t data_a[]; };
|
layout (binding = 0) readonly buffer A { float16_t data_a[]; };
|
||||||
layout (binding = 1) writeonly buffer D { float data_b[]; };
|
layout (binding = 1) writeonly buffer D { float data_b[]; };
|
||||||
|
|
|
@ -28,13 +28,16 @@ layout (push_constant) uniform parameter
|
||||||
int stride_a;
|
int stride_a;
|
||||||
int stride_b;
|
int stride_b;
|
||||||
int stride_d;
|
int stride_d;
|
||||||
|
int k_split;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
shared float16_t buf_a[BM * (BK+1)];
|
shared float16_t buf_a[BM * (BK+1)];
|
||||||
shared float16_t buf_b[BN * (BK+1)];
|
shared float16_t buf_b[BN * (BK+1)];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const int ir = int(gl_WorkGroupID.x);
|
const int blocks_x = (p.M + BM - 1) / BM;
|
||||||
|
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||||
|
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||||
const int ic = int(gl_WorkGroupID.y);
|
const int ic = int(gl_WorkGroupID.y);
|
||||||
|
|
||||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||||
|
@ -54,18 +57,21 @@ void main() {
|
||||||
|
|
||||||
const int loadstride = int(gl_WorkGroupSize.x);
|
const int loadstride = int(gl_WorkGroupSize.x);
|
||||||
|
|
||||||
int pos_a = ir * BM * p.stride_a;
|
const int start_k = ik * p.k_split;
|
||||||
int pos_b = ic * BN * p.stride_b;
|
const int end_k = (ik + 1) * p.k_split;
|
||||||
|
|
||||||
|
int pos_a = ir * BM * p.stride_a + start_k;
|
||||||
|
int pos_b = ic * BN * p.stride_b + start_k;
|
||||||
|
|
||||||
float sums[WMITER * TM * WNITER * TN];
|
float sums[WMITER * TM * WNITER * TN];
|
||||||
float16_t cache_a[WMITER * TM];
|
float16_t cache_a[WMITER * TM];
|
||||||
float16_t cache_b[WNITER * TN];
|
float16_t cache_b[WNITER * TN];
|
||||||
|
|
||||||
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||||
sums[i] = 0.0hf;
|
sums[i] = 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (int block = 0; block < p.K; block += BK) {
|
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||||
const int lr = l % BK;
|
const int lr = l % BK;
|
||||||
const int lc = l / BK;
|
const int lc = l / BK;
|
||||||
|
@ -90,7 +96,7 @@ void main() {
|
||||||
pos_a += BK;
|
pos_a += BK;
|
||||||
pos_b += BK;
|
pos_b += BK;
|
||||||
|
|
||||||
[[unroll]] for (int i = 0; i < BK; i++) {
|
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||||
// Load from shared into cache
|
// Load from shared into cache
|
||||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||||
|
@ -120,6 +126,8 @@ void main() {
|
||||||
const int dr = ir * BM + warp_r * WM;
|
const int dr = ir * BM + warp_r * WM;
|
||||||
const int dc = ic * BN + warp_c * WN;
|
const int dc = ic * BN + warp_c * WN;
|
||||||
|
|
||||||
|
const int k_split_offset = ik * p.M * p.N;
|
||||||
|
|
||||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
|
||||||
|
@ -128,7 +136,7 @@ void main() {
|
||||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||||
data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,13 +27,16 @@ layout (push_constant) uniform parameter
|
||||||
int stride_a;
|
int stride_a;
|
||||||
int stride_b;
|
int stride_b;
|
||||||
int stride_d;
|
int stride_d;
|
||||||
|
int k_split;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
shared float buf_a[BM * (BK+1)];
|
shared float buf_a[BM * (BK+1)];
|
||||||
shared float buf_b[BN * (BK+1)];
|
shared float buf_b[BN * (BK+1)];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const int ir = int(gl_WorkGroupID.x);
|
const int blocks_x = (p.M + BM - 1) / BM;
|
||||||
|
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||||
|
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||||
const int ic = int(gl_WorkGroupID.y);
|
const int ic = int(gl_WorkGroupID.y);
|
||||||
|
|
||||||
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||||
|
@ -53,8 +56,11 @@ void main() {
|
||||||
|
|
||||||
const int loadstride = int(gl_WorkGroupSize.x);
|
const int loadstride = int(gl_WorkGroupSize.x);
|
||||||
|
|
||||||
int pos_a = ir * BM * p.stride_a;
|
const int start_k = ik * p.k_split;
|
||||||
int pos_b = ic * BN * p.stride_b;
|
const int end_k = (ik + 1) * p.k_split;
|
||||||
|
|
||||||
|
int pos_a = ir * BM * p.stride_a + start_k;
|
||||||
|
int pos_b = ic * BN * p.stride_b + start_k;
|
||||||
|
|
||||||
float sums[WMITER * TM * WNITER * TN];
|
float sums[WMITER * TM * WNITER * TN];
|
||||||
float cache_a[WMITER * TM];
|
float cache_a[WMITER * TM];
|
||||||
|
@ -64,7 +70,7 @@ void main() {
|
||||||
sums[i] = 0.0f;
|
sums[i] = 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (int block = 0; block < p.K; block += BK) {
|
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||||
const int lr = l % BK;
|
const int lr = l % BK;
|
||||||
const int lc = l / BK;
|
const int lc = l / BK;
|
||||||
|
@ -89,7 +95,7 @@ void main() {
|
||||||
pos_a += BK;
|
pos_a += BK;
|
||||||
pos_b += BK;
|
pos_b += BK;
|
||||||
|
|
||||||
[[unroll]] for (int i = 0; i < BK; i++) {
|
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||||
// Load from shared into cache
|
// Load from shared into cache
|
||||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
[[unroll]] for (int j = 0; j < TM; j++) {
|
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||||
|
@ -119,6 +125,8 @@ void main() {
|
||||||
const int dr = ir * BM + warp_r * WM;
|
const int dr = ir * BM + warp_r * WM;
|
||||||
const int dc = ic * BN + warp_c * WN;
|
const int dc = ic * BN + warp_c * WN;
|
||||||
|
|
||||||
|
const int k_split_offset = ik * p.M * p.N;
|
||||||
|
|
||||||
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
|
||||||
|
@ -127,7 +135,7 @@ void main() {
|
||||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||||
data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
27
vk_shaders/matmul_split_k_reduce.glsl
Normal file
27
vk_shaders/matmul_split_k_reduce.glsl
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) buffer A { float data[]; };
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
int N;
|
||||||
|
int k_num;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const int idx = int(gl_GlobalInvocationID.x);
|
||||||
|
|
||||||
|
if (idx >= p.N) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float result = 0.0f;
|
||||||
|
|
||||||
|
for (int i = 0; i < p.k_num; i++) {
|
||||||
|
result += data[i * p.N + idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
data[idx] = result;
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue