Fix validation errors, improve compatibility with AMD GPUs

This commit is contained in:
0cc4m 2023-07-08 20:40:19 +02:00
parent c7c761a2b7
commit 0ef62f511a
2 changed files with 33 additions and 16 deletions

View file

@ -59,6 +59,10 @@ inline static void* ggml_aligned_malloc(size_t size, size_t alignment) {
#define VK_TRANSFER_QUEUE_COUNT 2 #define VK_TRANSFER_QUEUE_COUNT 2
#define VK_VENDOR_ID_AMD 0x1002
#define VK_VENDOR_ID_INTEL 0x8086
#define VK_VENDOR_ID_NVIDIA 0x10de
struct vk_buffer { struct vk_buffer {
vk::Buffer buffer; vk::Buffer buffer;
VmaAllocation allocation; VmaAllocation allocation;
@ -104,9 +108,11 @@ struct vk_queue {
vk::Instance vk_instance; vk::Instance vk_instance;
vk::PhysicalDevice vk_physical_device; vk::PhysicalDevice vk_physical_device;
vk::Device vk_device; vk::Device vk_device;
uint32_t vk_device_vendor_id;
vk_queue vk_compute_queue; vk_queue vk_compute_queue;
vk_queue vk_transfer_queues[VK_TRANSFER_QUEUE_COUNT]; vk_queue vk_transfer_queues[VK_TRANSFER_QUEUE_COUNT];
VmaAllocator vk_allocator; VmaAllocator vk_allocator;
vk::PipelineStageFlags vk_stage_flags[8] = { vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader };
vk_pipeline vk_pipeline_matmul_f32, vk_pipeline_matmul_f16, vk_pipeline_matmul_split_k_reduce; vk_pipeline vk_pipeline_matmul_f32, vk_pipeline_matmul_f16, vk_pipeline_matmul_split_k_reduce;
vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0; vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0;
VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc; VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc;
@ -145,10 +151,18 @@ static vk_pipeline ggml_vk_create_pipeline(const std::string& path, const std::s
vk::ShaderModule shader_module = vk_device.createShaderModule(shader_module_create_info); vk::ShaderModule shader_module = vk_device.createShaderModule(shader_module_create_info);
std::vector<vk::DescriptorSetLayoutBinding> dsl_binding; std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
std::vector<VkDescriptorBindingFlags> dsl_binding_flags;
for (uint32_t i = 0; i < parameter_count; i++) { for (uint32_t i = 0; i < parameter_count; i++) {
dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});
dsl_binding_flags.push_back(VK_DESCRIPTOR_BINDING_UPDATE_AFTER_BIND_BIT);
} }
VkDescriptorSetLayoutBindingFlagsCreateInfo dslbfci;
dslbfci.pNext = nullptr;
dslbfci.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_BINDING_FLAGS_CREATE_INFO;
dslbfci.bindingCount = dsl_binding_flags.size();
dslbfci.pBindingFlags = dsl_binding_flags.data();
vk::PushConstantRange pcr( vk::PushConstantRange pcr(
vk::ShaderStageFlagBits::eCompute, vk::ShaderStageFlagBits::eCompute,
0, 0,
@ -156,12 +170,13 @@ static vk_pipeline ggml_vk_create_pipeline(const std::string& path, const std::s
); );
vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
vk::DescriptorSetLayoutCreateFlags(), vk::DescriptorSetLayoutCreateFlags(VK_DESCRIPTOR_SET_LAYOUT_CREATE_UPDATE_AFTER_BIND_POOL_BIT),
dsl_binding); dsl_binding);
descriptor_set_layout_create_info.setPNext(&dslbfci);
pipeline.dsl = vk_device.createDescriptorSetLayout(descriptor_set_layout_create_info); pipeline.dsl = vk_device.createDescriptorSetLayout(descriptor_set_layout_create_info);
vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count); vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
vk::DescriptorPoolCreateInfo descriptor_pool_create_info(vk::DescriptorPoolCreateFlags(VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT), 1, descriptor_pool_size); vk::DescriptorPoolCreateInfo descriptor_pool_create_info(vk::DescriptorPoolCreateFlags(VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT | VK_DESCRIPTOR_POOL_CREATE_UPDATE_AFTER_BIND_BIT), 1, descriptor_pool_size);
pipeline.descriptor_pool = vk_device.createDescriptorPool(descriptor_pool_create_info); pipeline.descriptor_pool = vk_device.createDescriptorPool(descriptor_pool_create_info);
vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline.descriptor_pool, 1, &pipeline.dsl); vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline.descriptor_pool, 1, &pipeline.dsl);
@ -187,7 +202,6 @@ static vk_pipeline ggml_vk_create_pipeline(const std::string& path, const std::s
} }
static void ggml_vk_dispatch_pipeline(vk_pipeline& pipeline, std::vector<vk_buffer *> buffers, const std::vector<int>&& push_constants, std::array<uint32_t, 3> elements, vk::CommandBuffer& cmd_buffer, vk::Fence fence, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) { static void ggml_vk_dispatch_pipeline(vk_pipeline& pipeline, std::vector<vk_buffer *> buffers, const std::vector<int>&& push_constants, std::array<uint32_t, 3> elements, vk::CommandBuffer& cmd_buffer, vk::Fence fence, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
PROFILE("ggml_vk_dispatch_pipeline",
std::vector<vk::DescriptorBufferInfo> descriptor_buffer_infos; std::vector<vk::DescriptorBufferInfo> descriptor_buffer_infos;
std::vector<vk::WriteDescriptorSet> write_descriptor_sets; std::vector<vk::WriteDescriptorSet> write_descriptor_sets;
for (uint32_t i = 0; i < pipeline.parameter_count; i++) { for (uint32_t i = 0; i < pipeline.parameter_count; i++) {
@ -207,7 +221,6 @@ static void ggml_vk_dispatch_pipeline(vk_pipeline& pipeline, std::vector<vk_buff
pipeline.layout, pipeline.layout,
0, 0,
{ pipeline.descriptor_set }, { pipeline.descriptor_set },
{}); {});
cmd_buffer.dispatch(CEIL_DIV(elements[0], pipeline.wg_denoms[0]), CEIL_DIV(elements[1], pipeline.wg_denoms[1]), CEIL_DIV(elements[2], pipeline.wg_denoms[2])); cmd_buffer.dispatch(CEIL_DIV(elements[0], pipeline.wg_denoms[0]), CEIL_DIV(elements[1], pipeline.wg_denoms[1]), CEIL_DIV(elements[2], pipeline.wg_denoms[2]));
cmd_buffer.end(); cmd_buffer.end();
@ -216,14 +229,13 @@ static void ggml_vk_dispatch_pipeline(vk_pipeline& pipeline, std::vector<vk_buff
vk::SubmitInfo submit_info(wait_semaphores.size(), vk::SubmitInfo submit_info(wait_semaphores.size(),
wait_semaphores.data(), wait_semaphores.data(),
nullptr, vk_stage_flags,
1, 1,
&cmd_buffer, &cmd_buffer,
signal_semaphores.size(), signal_semaphores.size(),
signal_semaphores.data()); signal_semaphores.data());
vk_compute_queue.queue.submit({ submit_info }, fence); vk_compute_queue.queue.submit({ submit_info }, fence);
);
} }
static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyProperties>& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, int32_t min_num_queues) { static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyProperties>& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, int32_t min_num_queues) {
@ -365,6 +377,8 @@ void ggml_vk_init(void) {
vk::PhysicalDeviceProperties device_props = vk_physical_device.getProperties(); vk::PhysicalDeviceProperties device_props = vk_physical_device.getProperties();
std::cerr << "ggml_vulkan: Using " << device_props.deviceName << std::endl; std::cerr << "ggml_vulkan: Using " << device_props.deviceName << std::endl;
vk_device_vendor_id = device_props.vendorID;
std::vector<vk::ExtensionProperties> ext_props = vk_physical_device.enumerateDeviceExtensionProperties(); std::vector<vk::ExtensionProperties> ext_props = vk_physical_device.enumerateDeviceExtensionProperties();
bool fp16_storage = false; bool fp16_storage = false;
@ -458,7 +472,7 @@ void ggml_vk_init(void) {
if (vk_fp16_support) { if (vk_fp16_support) {
vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7, {64, 64, 1}); vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7, {64, 64, 1});
} }
vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 2, {64, 1, 1}); vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 3, {32, 32, 1});
vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {64, 1, 1}); vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {64, 1, 1});
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 1, {32, 1, 1}); vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 1, {32, 1, 1});
@ -675,7 +689,7 @@ static void ggml_vk_buffer_write_2d_async(vk_buffer* dst, size_t offset, const v
vk::SubmitInfo submit_info(wait_semaphores.size(), vk::SubmitInfo submit_info(wait_semaphores.size(),
wait_semaphores.data(), wait_semaphores.data(),
nullptr, vk_stage_flags,
1, 1,
&cmd_buffer, &cmd_buffer,
signal_semaphores.size(), signal_semaphores.size(),
@ -713,7 +727,7 @@ static void ggml_vk_buffer_write_2d_async(vk_buffer* dst, size_t offset, const v
vk::SubmitInfo submit_info(wait_semaphores.size(), vk::SubmitInfo submit_info(wait_semaphores.size(),
wait_semaphores.data(), wait_semaphores.data(),
nullptr, vk_stage_flags,
1, 1,
&cmd_buffer, &cmd_buffer,
signal_semaphores.size(), signal_semaphores.size(),
@ -780,7 +794,7 @@ static void ggml_vk_buffer_read_async(vk_buffer* src, size_t offset, void * dst,
vk::SubmitInfo submit_info(wait_semaphores.size(), vk::SubmitInfo submit_info(wait_semaphores.size(),
wait_semaphores.data(), wait_semaphores.data(),
nullptr, vk_stage_flags,
1, 1,
&cmd_buffer, &cmd_buffer,
signal_semaphores.size(), signal_semaphores.size(),
@ -930,7 +944,7 @@ static void ggml_vk_matmul(vk_pipeline& pipeline, vk_buffer& a, vk_buffer& b, vk
ggml_vk_dispatch_pipeline(pipeline, {&a, &b, &d}, { m, n, k, k, k, m, CEIL_DIV(k, split_k)}, { (uint32_t)m * split_k, (uint32_t)n, 1}, cmd_buffer, VK_NULL_HANDLE, std::move(wait_semaphores), { semaphore }); ggml_vk_dispatch_pipeline(pipeline, {&a, &b, &d}, { m, n, k, k, k, m, CEIL_DIV(k, split_k)}, { (uint32_t)m * split_k, (uint32_t)n, 1}, cmd_buffer, VK_NULL_HANDLE, std::move(wait_semaphores), { semaphore });
vk::CommandBuffer cmd_buffer_reduce = ggml_vk_cmd_buffer_create(vk_compute_queue); vk::CommandBuffer cmd_buffer_reduce = ggml_vk_cmd_buffer_create(vk_compute_queue);
ggml_vk_dispatch_pipeline(vk_pipeline_matmul_split_k_reduce, {&d}, { m * n, split_k}, { (uint32_t)m * n, 1, 1}, cmd_buffer_reduce, fence, { semaphore }, std::move(signal_semaphores)); ggml_vk_dispatch_pipeline(vk_pipeline_matmul_split_k_reduce, {&d}, { m, n, split_k}, { (uint32_t)m, (uint32_t)n, 1}, cmd_buffer_reduce, fence, { semaphore }, std::move(signal_semaphores));
} }
} }
@ -1502,7 +1516,6 @@ void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k, size_t num_it, int sp
ggml_vk_pool_free(d_X); ggml_vk_pool_free(d_X);
ggml_vk_pool_free(d_Y); ggml_vk_pool_free(d_Y);
size_t ev_idx = 0;
ggml_vk_pool_free(d_D); ggml_vk_pool_free(d_D);
free(x); free(x);

View file

@ -1,26 +1,30 @@
#version 450 #version 450
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
layout (binding = 0) buffer A { float data[]; }; layout (binding = 0) buffer A { float data[]; };
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {
int M;
int N; int N;
int k_num; int k_num;
} p; } p;
void main() { void main() {
const int idx = int(gl_GlobalInvocationID.x); const int glr = int(gl_GlobalInvocationID.x);
const int glc = int(gl_GlobalInvocationID.y);
if (idx >= p.N) { if (glr >= p.M || glc >= p.N) {
return; return;
} }
const int idx = glc * p.M + glr;
float result = 0.0f; float result = 0.0f;
for (int i = 0; i < p.k_num; i++) { for (int i = 0; i < p.k_num; i++) {
result += data[i * p.N + idx]; result += data[i * p.M * p.N + idx];
} }
data[idx] = result; data[idx] = result;