From 8b47fce07a5643a02e4f022979e3784f962282ff Mon Sep 17 00:00:00 2001 From: Sergio Lopez Date: Wed, 29 Jan 2025 06:30:52 +0100 Subject: [PATCH 1/3] vulkan: enable the use of simpler matmul shaders Import simpler matmul shaders from the kompute backend and use them on GPUs know to not be able to use the regular ones. Signed-off-by: Sergio Lopez --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 242 +++++++++++++++++- .../vulkan-shaders/simpler_common.comp | 112 ++++++++ .../vulkan-shaders/simpler_mul_mat_f16.comp | 69 +++++ .../simpler_mul_mat_mat_f32.comp | 51 ++++ .../vulkan-shaders/simpler_mul_mat_q4_0.comp | 33 +++ .../vulkan-shaders/simpler_mul_mat_q4_1.comp | 35 +++ .../vulkan-shaders/simpler_mul_mat_q4_k.comp | 140 ++++++++++ .../vulkan-shaders/simpler_mul_mat_q6_k.comp | 106 ++++++++ .../vulkan-shaders/simpler_mul_mat_q8_0.comp | 73 ++++++ .../vulkan-shaders/simpler_mul_mv_q_n.comp | 52 ++++ .../simpler_mul_mv_q_n_pre.comp | 28 ++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 8 + 12 files changed, 943 insertions(+), 6 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/simpler_common.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_f16.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_mat_f32.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_0.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_1.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_k.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q6_k.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q8_0.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n_pre.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index bffe95086..5c81d51bc 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -35,6 +35,7 @@ #define VK_VENDOR_ID_APPLE 0x106b #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de +#define VK_VENDOR_ID_QUALCOMM 0x5143 #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32 @@ -168,6 +169,7 @@ struct vk_device_struct { uint32_t shader_core_count; bool uma; bool prefer_host_memory; + bool simpler_shaders; bool float_controls_rte_fp16; bool subgroup_size_control; @@ -217,6 +219,15 @@ struct vk_device_struct { vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; + + vk_pipeline pipeline_simpler_mul_mat_mat_f32; + vk_pipeline pipeline_simpler_mul_mat_f16; + vk_pipeline pipeline_simpler_mul_mat_q4_0; + vk_pipeline pipeline_simpler_mul_mat_q4_1; + vk_pipeline pipeline_simpler_mul_mat_q4_k; + vk_pipeline pipeline_simpler_mul_mat_q6_k; + vk_pipeline pipeline_simpler_mul_mat_q8_0; + vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; @@ -2069,6 +2080,14 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_mat_f32, "simpler_mul_mat_mat_f32", simpler_mul_mat_mat_f32_len, simpler_mul_mat_mat_f32_data, "main", 3, 14 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_f16, "simpler_mul_mat_f16", simpler_mul_mat_f16_len, simpler_mul_mat_f16_data, "main", 3, 21 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size * 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q4_0, "simpler_mul_mat_q4_0", simpler_mul_mat_q4_0_len, simpler_mul_mat_q4_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q4_1, "simpler_mul_mat_q4_1", simpler_mul_mat_q4_1_len, simpler_mul_mat_q4_1_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q4_k, "simpler_mul_mat_q4_k", simpler_mul_mat_q4_k_len, simpler_mul_mat_q4_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q6_k, "simpler_mul_mat_q6_k", simpler_mul_mat_q6_k_len, simpler_mul_mat_q6_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {2, device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q8_0, "simpler_mul_mat_q8_0", simpler_mul_mat_q8_0_len, simpler_mul_mat_q8_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1); + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -2175,6 +2194,18 @@ static void ggml_vk_load_shaders(vk_device& device) { static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); +static bool ggml_vk_device_use_simpler_shaders(vk::PhysicalDeviceProperties *props) { + switch (props->vendorID) { + case VK_VENDOR_ID_APPLE: + // Vulkan on Apple implies using MoltenVK, which doesn't support + // the regular MAT_MUL shaders, among others. + case VK_VENDOR_ID_QUALCOMM: + return true; + default: + return false; + } +} + static vk_device ggml_vk_get_device(size_t idx) { VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); @@ -2314,6 +2345,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->subgroup_size = subgroup_props.subgroupSize; device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + device->simpler_shaders = ggml_vk_device_use_simpler_shaders(&device->properties); if (sm_builtins) { device->shader_core_count = sm_props.shaderSMCount; } else if (amd_shader_core_properties2) { @@ -4512,6 +4544,167 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c } } +static void ggml_vk_simpler_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_simpler_mul_mat(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t ne0 = dst->ne[0]; + const uint64_t ne1 = dst->ne[1]; + const uint64_t ne2 = dst->ne[2]; + const uint64_t ne3 = dst->ne[3]; + + const uint64_t nb00 = src0->nb[0]; + const uint64_t nb01 = src0->nb[1]; + const uint64_t nb02 = src0->nb[2]; + const uint64_t nb03 = src0->nb[3]; + + const uint64_t nb10 = src1->nb[0]; + const uint64_t nb11 = src1->nb[1]; + const uint64_t nb12 = src1->nb[2]; + const uint64_t nb13 = src1->nb[3]; + + const uint64_t nb1 = dst->nb[1]; + const uint64_t nb2 = dst->nb[2]; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + const uint64_t x_ne = ne00 * ne01 * ne02 * ne03; + const uint64_t y_ne = ne10 * ne11 * ne12 * ne13; + const uint64_t d_ne = ne0 * ne1 * ne2 * ne3; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline pipeline; + switch (src0->type) { + case GGML_TYPE_F32: + pipeline = ctx->device->pipeline_simpler_mul_mat_mat_f32; + break; + case GGML_TYPE_F16: + pipeline = ctx->device->pipeline_simpler_mul_mat_f16; + break; + case GGML_TYPE_Q4_0: + pipeline = ctx->device->pipeline_simpler_mul_mat_q4_0; + break; + case GGML_TYPE_Q4_1: + pipeline = ctx->device->pipeline_simpler_mul_mat_q4_1; + break; + case GGML_TYPE_Q4_K: + pipeline = ctx->device->pipeline_simpler_mul_mat_q4_k; + break; + case GGML_TYPE_Q6_K: + pipeline = ctx->device->pipeline_simpler_mul_mat_q6_k; + break; + case GGML_TYPE_Q8_0: + pipeline = ctx->device->pipeline_simpler_mul_mat_q8_0; + break; + default: + GGML_ABORT("vk_simpler_mul_mat: unsupported quantization type: %d", src0->type); + } + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + GGML_ASSERT(d_D->size >= d_buf_offset + d_sz); + if (!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + + const uint64_t qx_buffer_offset = (qx_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qx_shader_offset = qx_buf_offset - qx_buffer_offset; + + const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; + + const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; + + // compute + ggml_vk_sync_buffers(subctx); + switch (src0->type) { + case GGML_TYPE_F32: + { + const std::array pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne11, (uint32_t)ne12, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb1, (uint32_t)nb2 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 14 * sizeof(uint32_t), &pc, { (uint32_t)ne01, (uint32_t)ne11, (uint32_t)std::max(ne12, ne02) }); + break; + } + case GGML_TYPE_F16: + { + const std::array pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)nb00, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb03, (uint32_t)ne10, (uint32_t)ne11, (uint32_t)ne12, (uint32_t)nb10, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb13, (uint32_t)ne0, (uint32_t)ne1, (uint32_t)r2, (uint32_t)r3 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 21 * sizeof(uint32_t), &pc, { (uint32_t)ne01, (uint32_t)((ne11 + 3) / 4), (uint32_t)(ne12 * ne13) }); + break; + } + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + { + const std::array pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne10, (uint32_t)ne12, (uint32_t)ne0, (uint32_t)ne1, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb03, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb13, (uint32_t)r2, (uint32_t)r3 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 18 * sizeof(uint32_t), &pc, { (uint32_t)((ne01 + 7) / 8), (uint32_t)ne11, (uint32_t)(ne12 * ne13) }); + break; + } + case GGML_TYPE_Q4_K: + { + const std::array pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb03, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb13, (uint32_t)r2, (uint32_t)r3 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 18 * sizeof(uint32_t), &pc, { (uint32_t)((ne01 + 3) / 4), (uint32_t)ne11, (uint32_t)(ne12 * ne13) }); + break; + } + case GGML_TYPE_Q6_K: + { + const std::array pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb03, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb13, (uint32_t)r2, (uint32_t)r3 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 18 * sizeof(uint32_t), &pc, { (uint32_t)((ne01 + 1) / 2), (uint32_t)ne11, (uint32_t)(ne12 * ne13) }); + break; + } + default: + GGML_ABORT("vk_simpler_mul_mat: unsupported quantization type: %d", src0->type); + } +} + static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; @@ -7227,8 +7420,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod break; case GGML_OP_MUL_MAT: - ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); - + if (ctx->device->simpler_shaders) { + ggml_vk_simpler_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); + } else { + ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); + } break; case GGML_OP_MUL_MAT_ID: ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); @@ -8019,7 +8215,7 @@ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const return ggml_backend_vk_init(ctx->device); } -static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { +static bool ggml_backend_vk_complete_device_supports_op(const vk_device& device, const ggml_tensor * op) { switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -8088,8 +8284,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } break; case GGML_OP_FLASH_ATTN_EXT: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - if (!ggml_vk_get_device(ctx->device)->coopmat2) { + if (!device->coopmat2) { return false; } switch (op->src[0]->ne[0]) { @@ -8254,8 +8449,43 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm default: return false; } +} - UNUSED(dev); +static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + + if (!device->simpler_shaders) { + return ggml_backend_vk_complete_device_supports_op(device, op); + } + + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1])) + return false; + + switch (op->src[0]->type) { + case GGML_TYPE_F32: + return op->ne[3] == 1; + case GGML_TYPE_Q8_0: + // TODO (slp) - Fix Q8_0 with permutations + if (ggml_is_permuted(op->src[0]) || ggml_is_permuted(op->src[1])) { + return false; + } + case GGML_TYPE_Q6_K: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + return true; + default: + return false; + } + case GGML_OP_MUL_MAT_ID: + return false; + default: + return ggml_backend_vk_complete_device_supports_op(device, op); + } } static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_common.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_common.comp new file mode 100644 index 000000000..dbe4cf804 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_common.comp @@ -0,0 +1,112 @@ +#extension GL_EXT_shader_16bit_storage: require +#extension GL_EXT_shader_8bit_storage: require +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#extension GL_EXT_shader_explicit_arithmetic_types_int8: require +#extension GL_EXT_shader_explicit_arithmetic_types_int16: require +#extension GL_EXT_shader_explicit_arithmetic_types_int64: require +#extension GL_EXT_control_flow_attributes: enable +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_EXT_debug_printf : enable + +#define QK4_0 32 +#define QK4_1 32 + +#define GELU_COEF_A 0.044715 +#define SQRT_2_OVER_PI 0.79788456080286535587989211986876 +#define TWOPI_F 6.283185307179586f + +#define QK_K 256 +#define K_SCALE_SIZE 12 + +#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx]) +#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx) +#define u8BufToU32(buf, idx) (((uint32_t u8BufToU16(buf, idx + 2) << 8 | buf[idx + 1]) << 8) | buf[idx]) +#define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx) + +#define sizeof_block_q4_0 0x12 +struct block_q4_0 { + float16_t d; + uint8_t qs[QK4_0 / 2]; +}; +mat4 dequantize_q4_0(const block_q4_0 xb, uint il) { + const float d1 = il != 0 ? (xb.d / 16.f) : xb.d; + const float d2 = d1 / 256.f; + const float md = -8.f * xb.d; + const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F); + const uint16_t mask1 = mask0 << 8; + + mat4 reg; + for (int i=0;i<8;i++) { + uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]); + reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md; + } + return reg; +} + +#define sizeof_block_q4_1 0x14 +struct block_q4_1 { + float16_t d; + float16_t m; + uint8_t qs[QK4_1 / 2]; +}; +mat4 dequantize_q4_1(const block_q4_1 xb, uint il) { + const float d1 = il != 0 ? (xb.d / 16.f) : xb.d; + const float d2 = d1 / 256.f; + const float m = xb.m; + const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F); + const uint16_t mask1 = mask0 << 8; + + mat4 reg; + for (int i=0;i<8;i++) { + uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]); + reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m; + reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m; + } + return reg; +} + +#define sizeof_block_q4_k 144 +struct block_q4_k { + float16_t d; + float16_t dmin; + uint8_t scales[K_SCALE_SIZE]; + uint8_t qs[QK_K/2]; +}; + +#define sizeof_block_q6_k 210 +struct block_q6_k { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + float16_t d; // super-block scale +}; +mat4 dequantize_q6_k(const block_q6_k xb, uint il) { + const float16_t d_all = xb.d; + + const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1); + const uint qhIndex = 32*(il/8) + 16*(il&1); + float16_t sc = xb.scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; + + const uint16_t kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3); + const uint16_t kmask2 = il>1 ? uint8_t(0xF0) : uint8_t(0x0F); + const float16_t coef = il>1 ? float16_t(1.f/16.f) : float16_t(1.f); + const float16_t ml = float16_t(d_all * sc * 32.f); + const float16_t dl = float16_t(d_all * sc * coef); + mat4 reg; + for (int i = 0; i < 16; ++i) { + const float16_t q = (il&1) != 0 ? ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 2)) + : ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 4)); + reg[i/4][i%4] = dl * q - ml; + } + return reg; +} + + +#define QK8_0 32 +// struct block_q8_0 { +// float16_t d; // delta +// int8_t qs[QK8_0]; // quants +// }; +#define sizeof_block_q8_0 34 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_f16.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_f16.comp new file mode 100644 index 000000000..38416eefd --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_f16.comp @@ -0,0 +1,69 @@ +#version 450 + +#include "simpler_common.comp" + +#extension GL_KHR_shader_subgroup_arithmetic : require + +layout(local_size_x_id = 0) in; + +layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int ne01; + int ne02; + uint nb00; + uint nb01; + uint nb02; + uint nb03; + int ne10; + int ne11; + int ne12; + uint nb10; + uint nb11; + uint nb12; + uint nb13; + int ne0; + int ne1; + uint r2; + uint r3; +} pcs; + +#define N_F16_F32 4 + +void main() { + const uint r0 = gl_WorkGroupID.x; + const uint rb = gl_WorkGroupID.y*N_F16_F32; + const uint im = gl_WorkGroupID.z; + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03; + + const uint x = offset0 / 2 + pcs.inAOff; // Based from inA + + for (uint row = 0; row < N_F16_F32; ++row) { + uint r1 = rb + row; + if (r1 >= pcs.ne11) { + break; + } + + const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; + + float sumf = 0; + for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) { + sumf += float(inA[x+i]) * float(inB[y+i]); + } + + const float all_sum = subgroupAdd(sumf); + if (subgroupElect()) { + out_[im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0 + pcs.outOff] = all_sum; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_mat_f32.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_mat_f32.comp new file mode 100644 index 000000000..dc048f430 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_mat_f32.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "simpler_common.comp" + +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_EXT_debug_printf : enable + +// device subgroup size +layout (local_size_x_id = 0) in; + +layout(binding = 0) readonly buffer tensorInA { float inA[]; }; +layout(binding = 1) readonly buffer tensorInB { float inB[]; }; +layout(binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout(push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int ne01; + int ne02; + int ne11; + int ne12; + uint nb01; + uint nb02; + uint nb11; + uint nb12; + uint nb1; + uint nb2; +} +pcs; + + +void main() { + uvec3 gid = gl_WorkGroupID; + + uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z; + uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z; + + const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA + const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB + float sum = 0.0f; + for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) { + sum += float(inA[x+i]) * float(inB[y+i]); + } + + const float all_sum = subgroupAdd(sum); + if (subgroupElect()) { + out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_0.comp new file mode 100644 index 000000000..db9dc612b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_0.comp @@ -0,0 +1,33 @@ +#version 450 + +#include "simpler_common.comp" + +#define BLOCKS_IN_QUANT QK4_0 +#define SIZE_OF_BLOCK sizeof_block_q4_0 +#define N_ROWS 4 + +#include "simpler_mul_mv_q_n_pre.comp" + +// The q4_0 version of this function +float block_q_n_dot_y(uint block_index, uint yb, uint il) { + vec2 acc = vec2(0.0, 0.0); + const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff; + float d = float(u8BufToFloat16(inA, index)); + float sumy = 0.0f; + for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) { + const uint16_t b = u8BufToU16(inA, index + 2 + il + i); + + const float yl0 = inB[yb + i]; + const float yl1 = inB[yb + i + 1]; + const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2]; + const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1]; + + sumy += yl0 + yl1 + yl8 + yl9; + + acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00); + acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000); + } + return d * (sumy * -8.f + acc[0] + acc[1]); +} + +#include "simpler_mul_mv_q_n.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_1.comp new file mode 100644 index 000000000..75449fa7b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_1.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "simpler_common.comp" + +#define BLOCKS_IN_QUANT QK4_1 +#define SIZE_OF_BLOCK sizeof_block_q4_1 +#define N_ROWS 4 + +#include "simpler_mul_mv_q_n_pre.comp" + +// The q4_1 version of this function +float block_q_n_dot_y(uint block_index, uint yb, uint il) { + vec2 acc = vec2(0.0, 0.0); + const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff; + float d = float(u8BufToFloat16(inA, index)); + float m = float(u8BufToFloat16(inA, index+2)); + + float sumy = 0.0f; + for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) { + const uint16_t b = u8BufToU16(inA, index + 4 + il + i); + + const float yl0 = inB[yb + i]; + const float yl1 = inB[yb + i + 1]; + const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2]; + const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1]; + + sumy += yl0 + yl1 + yl8 + yl9; + + acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00); + acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + +#include "simpler_mul_mv_q_n.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_k.comp new file mode 100644 index 000000000..7b767ac10 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_k.comp @@ -0,0 +1,140 @@ +#version 450 + +#include "simpler_common.comp" + +#define N_DST 4 +#define SIZE_OF_BLOCK sizeof_block_q4_k + +layout(local_size_x = 4) in; +layout(local_size_y = 8) in; +layout(local_size_z = 1) in; + +layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int ne10; + int ne0; + int ne1; + int ne01; + int ne02; + int ne12; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; + uint r2; + uint r3; +} pcs; + +void main() { + const uint16_t kmask1 = uint16_t(0x3f3f); + const uint16_t kmask2 = uint16_t(0x0f0f); + const uint16_t kmask3 = uint16_t(0xc0c0); + + const uint ix = gl_SubgroupInvocationID/8; // 0...3 + const uint it = gl_SubgroupInvocationID%8; // 0...7 + const uint iq = it/4; // 0 or 1 + const uint ir = it%4; // 0...3 + + const uint nb = pcs.ne00/QK_K; + + const uint r0 = gl_WorkGroupID.x; + const uint r1 = gl_WorkGroupID.y; + const uint im = gl_WorkGroupID.z; + + const uint first_row = r0 * N_DST; + const uint ib_row = first_row * nb; + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); + const uint offset1 = r1*pcs.nb11 + (i12 )*pcs.nb12 + (i13 )*pcs.nb13; + + const uint xblk = offset0 + pcs.inAOff; + const uint y = (offset1 / 4) + pcs.inBOff; + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f}; + float all_sum = 0.f; + + uint y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + for (uint ib = ix; ib < nb; ib += 4) { + const uint blk_idx = ib + xblk; + + float sumy[4] = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0]; + yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8]; + } + + for (int row = 0; row < N_DST; row++) { + uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK); + + uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0); + uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2); + uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4); + uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6); + uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8); + + uint16_t sc16[4]; + sc16[0] = sc_0 & kmask1; + sc16[1] = sc_2 & kmask1; + sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2); + sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2); + + float acc1[4] = {0.f, 0.f, 0.f, 0.f}; + float acc2[4] = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i); + uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i); + acc1[0] += yl[i+0] * (q1 & 0x000F); + acc1[1] += yl[i+1] * (q1 & 0x0F00); + acc1[2] += yl[i+8] * (q1 & 0x00F0); + acc1[3] += yl[i+9] * (q1 & 0xF000); + acc2[0] += yh[i+0] * (q2 & 0x000F); + acc2[1] += yh[i+1] * (q2 & 0x0F00); + acc2[2] += yh[i+8] * (q2 & 0x00F0); + acc2[3] += yh[i+9] * (q2 & 0xF000); + } + + uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF); + uint8_t sc8_1 = uint8_t(sc16[0] >> 8 ); + uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF); + uint8_t sc8_3 = uint8_t(sc16[1] >> 8 ); + uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF); + uint8_t sc8_5 = uint8_t(sc16[2] >> 8 ); + uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF); + uint8_t sc8_7 = uint8_t(sc16[3] >> 8 ); + + float dall = float(inA[blk_idx + row_idx].d); + float dmin = float(inA[blk_idx + row_idx].dmin); + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) - + dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7); + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = subgroupAdd(sumf[row]); + if (subgroupElect()) { + out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q6_k.comp new file mode 100644 index 000000000..32527c46a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q6_k.comp @@ -0,0 +1,106 @@ +#version 450 + +#include "simpler_common.comp" + +#define SIZE_OF_BLOCK sizeof_block_q6_k + +layout(local_size_x_id = 0) in; +layout(local_size_y_id = 1) in; +layout(local_size_z = 1) in; + +layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int ne10; + int ne0; + int ne1; + int ne01; + int ne02; + int ne12; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; + uint r2; + uint r3; +} pcs; + +void main() { + const uint8_t kmask1 = uint8_t(0x03); + const uint8_t kmask2 = uint8_t(0x0C); + const uint8_t kmask3 = uint8_t(0x30); + const uint8_t kmask4 = uint8_t(0xC0); + + const uint nb = pcs.ne00/QK_K; + + const uint r0 = gl_WorkGroupID.x; + const uint r1 = gl_WorkGroupID.y; + const uint im = gl_WorkGroupID.z; + + const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID); + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); + const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; + + float sumf = 0; + + // bits of invocation ID for gl_SubgroupSize=32: + // x x x x x + // 4 3 2 1 0 + // ( tid ) ix + // ip ( il ) + + const uint block_stride = gl_SubgroupSize / 16; // number of blocks each subgroup processes + const uint tid = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0 + const uint ix = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1 + const uint ip = tid/8; // first or second half of block (0 or 1) + const uint il = tid%8; // each half has 8 parts, one per scale + const uint n = 4; // 4 scales at a time (and 4 sums) + const uint l0 = n*il; // offset into half-block, 0..28 + const uint is = 8*ip + l0/16; // 0, 1, 8, 9 + + const uint y_offset = 128*ip + l0; + const uint q_offset_l = 64*ip + l0; + const uint q_offset_h = 32*ip + l0; + + for (uint i = ix; i < nb; i += block_stride) { + + const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff; + + const uint qlIndex = q_offset_l; + const uint q2Index = qlIndex + QK_K/8; + const uint qhIndex = q_offset_h; + const uint y = yy + i * QK_K + y_offset; + + float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + for (uint l = 0; l < n; ++l) { + const uint8_t currentQ1 = inA[baseIndex + qlIndex + l]; + const uint8_t currentQ2 = inA[baseIndex + q2Index + l]; + const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l]; + + sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32); + sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32); + sums[2] += inB[y+l+64] * (int8_t((currentQ1 >> 4) | ((currentQh & kmask3) << 0)) - 32); + sums[3] += inB[y+l+96] * (int8_t((currentQ2 >> 4) | ((currentQh & kmask4) >> 2)) - 32); + } + + float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16); + sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is])); + } + + const float tot = subgroupAdd(sumf); + if (subgroupElect()) { + out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q8_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q8_0.comp new file mode 100644 index 000000000..dcb9a1f6d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q8_0.comp @@ -0,0 +1,73 @@ +#version 450 + +#include "simpler_common.comp" + +#include "simpler_mul_mv_q_n_pre.comp" + +#define SIZE_OF_D 2 + +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +#define NB_Q8_0 8 + +void main() { + // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64 + if (gl_SubgroupInvocationID > 31) + return; + + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + + const int nb = pcs.ne00/QK8_0; + const uint r0 = gl_WorkGroupID.x; + const uint r1 = gl_WorkGroupID.y; + const uint im = gl_WorkGroupID.z; + + const uint first_row = (r0 * nsg + gl_SubgroupID) * nr; + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02); + + const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA + const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB + + float yl[NB_Q8_0]; + float sumf[N_DST]={0.f, 0.f, 0.f, 0.f}; + + const uint ix = gl_SubgroupInvocationID.x/4; + const uint il = gl_SubgroupInvocationID.x%4; + + uint yb = y + ix * QK8_0 + NB_Q8_0*il; + + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (uint ib = ix; ib < nb; ib += nw/4) { + for (int i = 0; i < NB_Q8_0; ++i) { + yl[i] = inB[yb + i]; + } + + for (int row = 0; row < nr; row++) { + const uint block_offset = (ib+row*nb) * sizeof_block_q8_0; + float sumq = 0.f; + for (int iq = 0; iq < NB_Q8_0; ++iq) { + const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]); + sumq += qs_iq * yl[iq]; + } + const float16_t d = u8BufToFloat16(inA, x + block_offset); + sumf[row] += sumq*d; + } + + yb += NB_Q8_0 * nw; + } + + for (int row = 0; row < nr; ++row) { + const float tot = subgroupAdd(sumf[row]); + if (subgroupElect() && first_row + row < pcs.ne01) { + out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n.comp new file mode 100644 index 000000000..a6517cc1f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n.comp @@ -0,0 +1,52 @@ +void main() { + // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64 + if (gl_SubgroupInvocationID > 31) + return; + + const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT); + + const uint r0 = gl_WorkGroupID.x; + const uint r1 = gl_WorkGroupID.y; + const uint im = gl_WorkGroupID.z; + + const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS; + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + // pointers to src0 rows + uint ax[N_ROWS]; + for (int row = 0; row < N_ROWS; ++row) { + const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); + + ax[row] = offset0 + pcs.inAOff; + } + + const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; + + float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f}; + + const uint ix = gl_SubgroupInvocationID/2; + const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2); + + uint yb = y + ix * BLOCKS_IN_QUANT + il; + + //debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n", + // gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize, + // gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z); + + for (uint ib = ix; ib < nb; ib += 16) { + for (int row = 0; row < N_ROWS; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il); + } + + yb += BLOCKS_IN_QUANT * 16; + } + + for (int row = 0; row < N_ROWS; ++row) { + const float tot = subgroupAdd(sumf[row]); + if (first_row + row < pcs.ne01 && subgroupElect()) { + out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n_pre.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n_pre.comp new file mode 100644 index 000000000..a9a2f2218 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n_pre.comp @@ -0,0 +1,28 @@ +layout(local_size_x_id = 0) in; +layout(local_size_y = 8) in; +layout(local_size_z = 1) in; + +layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int ne01; + int ne02; + int ne10; + int ne12; + int ne0; + int ne1; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; + uint r2; + uint r3; +} pcs; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 77e7e1148..ec357977e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -496,6 +496,14 @@ void process_shaders() { string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("simpler_mul_mat_mat_f32", "simpler_mul_mat_mat_f32.comp", {}); + string_to_spv("simpler_mul_mat_f16", "simpler_mul_mat_f16.comp", {}); + string_to_spv("simpler_mul_mat_q4_0", "simpler_mul_mat_q4_0.comp", {}); + string_to_spv("simpler_mul_mat_q4_1", "simpler_mul_mat_q4_1.comp", {}); + string_to_spv("simpler_mul_mat_q4_k", "simpler_mul_mat_q4_k.comp", {}); + string_to_spv("simpler_mul_mat_q6_k", "simpler_mul_mat_q6_k.comp", {}); + string_to_spv("simpler_mul_mat_q8_0", "simpler_mul_mat_q8_0.comp", {}); + for (auto &c : compiles) { c.wait(); } From bae9a58f5df8fe4e6fe7c8b71a4383c15a5f1cda Mon Sep 17 00:00:00 2001 From: Sergio Lopez Date: Fri, 7 Feb 2025 14:57:03 +0100 Subject: [PATCH 2/3] vulkan: enable the use of simpler softmax shaders Even though the regular softmax shaders successfully pass test-backend-ops with Apple GPUs, running long inference tests has shown the models end derailing with softmax OPs being the root cause. With this commit, we use simpler softmax shaders borrowed from the Kompute backend (which are basically reimplementations of the Metal shaders) on certain GPUs know to have problem with the regular ones. Signed-off-by: Sergio Lopez --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 60 +++++++++++++--- .../vulkan-shaders/simpler_soft_max.comp | 69 +++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 3 + 3 files changed, 123 insertions(+), 9 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 5c81d51bc..1eb205524 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -228,6 +228,9 @@ struct vk_device_struct { vk_pipeline pipeline_simpler_mul_mat_q6_k; vk_pipeline pipeline_simpler_mul_mat_q8_0; + vk_pipeline pipeline_simpler_soft_max_f16; + vk_pipeline pipeline_simpler_soft_max_f32; + vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; @@ -518,6 +521,18 @@ struct vk_op_soft_max_push_constants { uint32_t nrows_x; }; +struct vk_op_simpler_soft_max_push_constants { + int32_t ne00; + int32_t ne01; + int32_t ne02; + float scale; + float max_bias; + float m0; + float m1; + uint32_t n_head_log2; + int32_t mask; +}; + struct vk_op_argsort_push_constants { uint32_t ncols; uint32_t ncols_pad; @@ -2088,6 +2103,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q6_k, "simpler_mul_mat_q6_k", simpler_mul_mat_q6_k_len, simpler_mul_mat_q6_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {2, device->subgroup_size}, 1); ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q8_0, "simpler_mul_mat_q8_0", simpler_mul_mat_q8_0_len, simpler_mul_mat_q8_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f16, "simpler_soft_max_f16", simpler_soft_max_f16_len, simpler_soft_max_f16_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f32, "simpler_soft_max_f32", simpler_soft_max_f32_len, simpler_soft_max_f32_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -5440,6 +5458,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_SOFT_MAX: GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + if (ctx->device->simpler_shaders) { + if (src1 && src1->type == GGML_TYPE_F16) { + return ctx->device->pipeline_simpler_soft_max_f16; + } + return ctx->device->pipeline_simpler_soft_max_f32; + } + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; } @@ -5738,9 +5763,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); switch (op) { + case GGML_OP_SOFT_MAX: + if (ctx->device->simpler_shaders) { + elements = { (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3] }; + break; + } + // fall-through case GGML_OP_NORM: case GGML_OP_RMS_NORM: - case GGML_OP_SOFT_MAX: case GGML_OP_SUM_ROWS: { const uint32_t nr = ggml_nrows(src0); @@ -6281,14 +6311,26 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { - ncols, - src1 != nullptr ? nrows_y : (uint32_t)0, - scale, max_bias, - m0, m1, - n_head_log2, - nrows_x, - }, dryrun); + if (ctx->device->simpler_shaders) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + (int32_t) src0->ne[0], + (int32_t) src0->ne[1], + (int32_t) src0->ne[2], + scale, max_bias, + m0, m1, + n_head_log2, + src1 == nullptr ? 0 : 1, + }, dryrun); + } else { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + ncols, + src1 != nullptr ? nrows_y : (uint32_t)0, + scale, max_bias, + m0, m1, + n_head_log2, + nrows_x, + }, dryrun); + } } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp new file mode 100644 index 000000000..d2707c504 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp @@ -0,0 +1,69 @@ +// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4) + +#version 450 + +#include "simpler_common.comp" + +layout(local_size_x = 32) in; + +layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; +layout(binding = 1) buffer restrict readonly tensorInB { A_TYPE inB[]; }; +layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; + +layout(push_constant) uniform PushConstants { + int ne00; + int ne01; + int ne02; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + int mask; +} pcs; + +void main() { + if (gl_SubgroupInvocationID > 31) + return; + + const uint i03 = gl_WorkGroupID.z; + const uint i02 = gl_WorkGroupID.y; + const uint i01 = gl_WorkGroupID.x; + + const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00; + const uint psrc0 = extra_off; + const uint pmask = i01*pcs.ne00; + const uint pdst = extra_off; + + float slope = 1.0f; + + // ALiBi + if (pcs.max_bias > 0.0f) { + int64_t h = i02; + + float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1; + int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1; + + slope = pow(base, float(exp)); + } + + // parallel max + float localMax = uintBitsToFloat(0xFF800000); + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { + localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f)); + } + float max_ = subgroupMax(localMax); + + // parallel sum + float localSum = 0.0f; + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { + const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_); + localSum += exp_psrc0; + out_[pdst + i00] = exp_psrc0; + } + + const float sum = subgroupAdd(localSum); + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { + out_[pdst + i00] /= sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ec357977e..d616795e7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -504,6 +504,9 @@ void process_shaders() { string_to_spv("simpler_mul_mat_q6_k", "simpler_mul_mat_q6_k.comp", {}); string_to_spv("simpler_mul_mat_q8_0", "simpler_mul_mat_q8_0.comp", {}); + string_to_spv("simpler_soft_max_f16", "simpler_soft_max.comp", {{"A_TYPE", "float16_t"}}); + string_to_spv("simpler_soft_max_f32", "simpler_soft_max.comp", {{"A_TYPE", "float"}}); + for (auto &c : compiles) { c.wait(); } From da1e744efd1b2d945cf509abd8a10db3d6f9562d Mon Sep 17 00:00:00 2001 From: Sergio Lopez Date: Mon, 10 Feb 2025 11:54:51 +0100 Subject: [PATCH 3/3] tests: test softmax with ne00 == 2048 It was found that softmax vulkan shaders may fail on some contexts when ne00 > 1024, so let's add a test for it. Signed-off-by: Sergio Lopez --- tests/test-backend-ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1bfd41254..5c562aa90 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4166,7 +4166,7 @@ static std::vector> make_test_cases_eval() { for (float max_bias : {0.0f, 8.0f}) { if (!mask && max_bias > 0.0f) continue; for (float scale : {1.0f, 0.1f}) { - for (int64_t ne0 : {16, 1024}) { + for (int64_t ne0 : {16, 1024, 2048}) { for (int64_t ne1 : {16, 1024}) { if (mask) { for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {