Handle stage flags during command buffer submission properly

This commit is contained in:
0cc4m 2023-07-15 22:00:47 +02:00
parent ad3d28ee0a
commit 22a4cb7f03

View file

@ -88,13 +88,18 @@ struct vk_pipeline {
struct vk_queue { struct vk_queue {
vk_queue() {}; vk_queue() {};
vk_queue(const vk_queue& b) : queue_family_index(b.queue_family_index), queue(b.queue), pool(b.pool) {} vk_queue(const vk_queue& b) : queue_family_index(b.queue_family_index), queue(b.queue), pool(b.pool), cmd_buffer_idx(b.cmd_buffer_idx), cmd_buffers(b.cmd_buffers), semaphore_idx(b.semaphore_idx), semaphores(b.semaphores), stage_flags(b.stage_flags) {}
vk_queue& operator=(const vk_queue& b) { vk_queue& operator=(const vk_queue& b) {
if (this != &b) { if (this != &b) {
queue_family_index = b.queue_family_index; queue_family_index = b.queue_family_index;
queue = b.queue; queue = b.queue;
pool = b.pool; pool = b.pool;
cmd_buffer_idx = b.cmd_buffer_idx;
cmd_buffers = b.cmd_buffers;
semaphore_idx = b.semaphore_idx;
semaphores = b.semaphores;
stage_flags = b.stage_flags;
} }
return *this; return *this;
} }
@ -106,6 +111,9 @@ struct vk_queue {
std::vector<vk::CommandBuffer> cmd_buffers; std::vector<vk::CommandBuffer> cmd_buffers;
uint32_t semaphore_idx; uint32_t semaphore_idx;
std::vector<vk::Semaphore> semaphores; std::vector<vk::Semaphore> semaphores;
vk::PipelineStageFlags stage_flags;
std::mutex mutex; std::mutex mutex;
}; };
@ -124,7 +132,6 @@ uint32_t vk_device_vendor_id;
vk_queue vk_compute_queue; vk_queue vk_compute_queue;
vk_queue vk_transfer_queues[VK_TRANSFER_QUEUE_COUNT]; vk_queue vk_transfer_queues[VK_TRANSFER_QUEUE_COUNT];
VmaAllocator vk_allocator; VmaAllocator vk_allocator;
vk::PipelineStageFlags vk_stage_flags[8] = { vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eComputeShader };
vk_pipeline vk_pipeline_matmul_f32_l, vk_pipeline_matmul_f32_m, vk_pipeline_matmul_f32_s, vk_pipeline_matmul_f16_l, vk_pipeline_matmul_f16_m, vk_pipeline_matmul_f16_s; vk_pipeline vk_pipeline_matmul_f32_l, vk_pipeline_matmul_f32_m, vk_pipeline_matmul_f32_s, vk_pipeline_matmul_f16_l, vk_pipeline_matmul_f16_m, vk_pipeline_matmul_f16_s;
vk_pipeline vk_pipeline_matmul_split_k_reduce; vk_pipeline vk_pipeline_matmul_split_k_reduce;
vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0; vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0;
@ -275,20 +282,27 @@ static vk_sequence ggml_vk_create_sequence_1(vk_queue& q, std::vector<vk::Semaph
static void ggml_vk_submit(vk_queue& q, std::vector<vk_sequence>& sequences, vk::Fence fence) { static void ggml_vk_submit(vk_queue& q, std::vector<vk_sequence>& sequences, vk::Fence fence) {
#ifdef VK_DEBUG #ifdef VK_DEBUG
std::cerr << "ggml_vk_submit()" << std::endl; std::cerr << "ggml_vk_submit(" << q.queue_family_index << ", " << sequences.size() << ")" << std::endl;
#endif #endif
if (sequences.empty()) { if (sequences.empty()) {
return; return;
} }
std::vector<vk::SubmitInfo> submit_infos; std::vector<vk::SubmitInfo> submit_infos;
int idx = -1;
std::vector<std::vector<vk::PipelineStageFlags>> stage_flags;
for (const auto& sequence : sequences) { for (const auto& sequence : sequences) {
for (const auto& submission : sequence) { for (const auto& submission : sequence) {
stage_flags.push_back({});
idx++;
for (size_t i = 0; i < submission.wait_semaphores.size(); i++) {
stage_flags[idx].push_back(q.stage_flags);
}
submit_infos.push_back({ submit_infos.push_back({
(uint32_t) submission.wait_semaphores.size(), (uint32_t) submission.wait_semaphores.size(),
submission.wait_semaphores.data(), submission.wait_semaphores.data(),
vk_stage_flags, stage_flags[idx].data(),
1, 1,
&submission.buffer, &submission.buffer,
(uint32_t) submission.signal_semaphores.size(), (uint32_t) submission.signal_semaphores.size(),
@ -346,7 +360,7 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyPrope
abort(); abort();
} }
static vk_queue ggml_vk_create_queue(uint32_t queue_family_index, uint32_t queue_index) { static vk_queue ggml_vk_create_queue(uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags) {
#ifdef VK_DEBUG #ifdef VK_DEBUG
std::cerr << "ggml_vk_create_queue()" << std::endl; std::cerr << "ggml_vk_create_queue()" << std::endl;
#endif #endif
@ -361,6 +375,8 @@ static vk_queue ggml_vk_create_queue(uint32_t queue_family_index, uint32_t queue
q.queue = vk_device.getQueue(queue_family_index, queue_index); q.queue = vk_device.getQueue(queue_family_index, queue_index);
q.stage_flags = stage_flags;
return q; return q;
} }
@ -430,7 +446,10 @@ static vk_buffer ggml_vk_create_buffer(size_t size, VmaAllocationCreateFlags all
return buf; return buf;
} }
static void ggml_vk_sync_buffers(vk::CommandBuffer& cmd_buffer, std::vector<vk_buffer>&& buffers, vk_queue& q, vk::AccessFlags src_mask, vk::AccessFlags dst_mask) { static void ggml_vk_sync_buffers(vk::CommandBuffer& cmd_buffer, std::vector<vk_buffer>&& buffers, vk_queue& q, vk::AccessFlags&& src_mask, vk::AccessFlags&& dst_mask) {
#ifdef VK_DEBUG
std::cerr << "ggml_vk_sync_buffers()" << std::endl;
#endif
std::vector<vk::BufferMemoryBarrier> bmem_barriers; std::vector<vk::BufferMemoryBarrier> bmem_barriers;
uint32_t sfi; uint32_t sfi;
@ -453,8 +472,8 @@ static void ggml_vk_sync_buffers(vk::CommandBuffer& cmd_buffer, std::vector<vk_b
} }
cmd_buffer.pipelineBarrier( cmd_buffer.pipelineBarrier(
vk::PipelineStageFlagBits::eComputeShader, q.stage_flags,
vk::PipelineStageFlagBits::eComputeShader, q.stage_flags,
{}, {},
{}, {},
bmem_barriers, bmem_barriers,
@ -501,7 +520,21 @@ void ggml_vk_init(void) {
"VK_LAYER_KHRONOS_validation", "VK_LAYER_KHRONOS_validation",
#endif #endif
}; };
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags(), &app_info, layers.size(), layers.data()); const std::vector<const char*> extensions = {
#ifdef VK_VALIDATE
"VK_EXT_validation_features",
#endif
};
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags(), &app_info, layers, extensions);
#ifdef VK_VALIDATE
const std::vector<vk::ValidationFeatureEnableEXT> features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
vk::ValidationFeaturesEXT validation_features = {
features_enable,
{},
};
validation_features.setPNext(nullptr);
instance_create_info.setPNext(&validation_features);
#endif
vk_instance = vk::createInstance(instance_create_info); vk_instance = vk::createInstance(instance_create_info);
vk_physical_device = vk_instance.enumeratePhysicalDevices()[dev_num]; vk_physical_device = vk_instance.enumeratePhysicalDevices()[dev_num];
@ -618,9 +651,9 @@ void ggml_vk_init(void) {
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 1, {256*32, 1, 1}, {}); // Group size * values per quant group vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 1, {256*32, 1, 1}, {}); // Group size * values per quant group
// Queues // Queues
vk_compute_queue = ggml_vk_create_queue(compute_queue_family_index, 0); vk_compute_queue = ggml_vk_create_queue(compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader });
for (int i = 0; i < VK_TRANSFER_QUEUE_COUNT; i++) { for (int i = 0; i < VK_TRANSFER_QUEUE_COUNT; i++) {
vk_transfer_queues[i] = ggml_vk_create_queue(transfer_queue_family_index, i); vk_transfer_queues[i] = ggml_vk_create_queue(transfer_queue_family_index, i, { vk::PipelineStageFlagBits::eTransfer });
} }
#if defined(VK_CHK_KERNEL) #if defined(VK_CHK_KERNEL)
@ -1164,9 +1197,10 @@ static vk_sequence ggml_vk_matmul(vk_pipeline& pipeline, vk_buffer& a, vk_buffer
// Synchronize the two submissions // Synchronize the two submissions
ggml_vk_dispatch_pipeline(s, pipeline, { a, b, d }, { m, n, k, k, k, m, CEIL_DIV(k, split_k) }, { (uint32_t)m * split_k, (uint32_t)n, 1 }, q); ggml_vk_dispatch_pipeline(s, pipeline, { a, b, d }, { m, n, k, k, k, m, CEIL_DIV(k, split_k) }, { (uint32_t)m * split_k, (uint32_t)n, 1 }, q);
s.buffer.pipelineBarrier( s.buffer.pipelineBarrier(
vk::PipelineStageFlagBits::eComputeShader, q.stage_flags,
vk::PipelineStageFlagBits::eComputeShader, q.stage_flags,
{}, {},
{}, {},
{ {