diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index bffe95086..1eb205524 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -35,6 +35,7 @@ #define VK_VENDOR_ID_APPLE 0x106b #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de +#define VK_VENDOR_ID_QUALCOMM 0x5143 #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32 @@ -168,6 +169,7 @@ struct vk_device_struct { uint32_t shader_core_count; bool uma; bool prefer_host_memory; + bool simpler_shaders; bool float_controls_rte_fp16; bool subgroup_size_control; @@ -217,6 +219,18 @@ struct vk_device_struct { vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; + + vk_pipeline pipeline_simpler_mul_mat_mat_f32; + vk_pipeline pipeline_simpler_mul_mat_f16; + vk_pipeline pipeline_simpler_mul_mat_q4_0; + vk_pipeline pipeline_simpler_mul_mat_q4_1; + vk_pipeline pipeline_simpler_mul_mat_q4_k; + vk_pipeline pipeline_simpler_mul_mat_q6_k; + vk_pipeline pipeline_simpler_mul_mat_q8_0; + + vk_pipeline pipeline_simpler_soft_max_f16; + vk_pipeline pipeline_simpler_soft_max_f32; + vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; @@ -507,6 +521,18 @@ struct vk_op_soft_max_push_constants { uint32_t nrows_x; }; +struct vk_op_simpler_soft_max_push_constants { + int32_t ne00; + int32_t ne01; + int32_t ne02; + float scale; + float max_bias; + float m0; + float m1; + uint32_t n_head_log2; + int32_t mask; +}; + struct vk_op_argsort_push_constants { uint32_t ncols; uint32_t ncols_pad; @@ -2069,6 +2095,17 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_mat_f32, "simpler_mul_mat_mat_f32", simpler_mul_mat_mat_f32_len, simpler_mul_mat_mat_f32_data, "main", 3, 14 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_f16, "simpler_mul_mat_f16", simpler_mul_mat_f16_len, simpler_mul_mat_f16_data, "main", 3, 21 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size * 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q4_0, "simpler_mul_mat_q4_0", simpler_mul_mat_q4_0_len, simpler_mul_mat_q4_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q4_1, "simpler_mul_mat_q4_1", simpler_mul_mat_q4_1_len, simpler_mul_mat_q4_1_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q4_k, "simpler_mul_mat_q4_k", simpler_mul_mat_q4_k_len, simpler_mul_mat_q4_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q6_k, "simpler_mul_mat_q6_k", simpler_mul_mat_q6_k_len, simpler_mul_mat_q6_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {2, device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q8_0, "simpler_mul_mat_q8_0", simpler_mul_mat_q8_0_len, simpler_mul_mat_q8_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f16, "simpler_soft_max_f16", simpler_soft_max_f16_len, simpler_soft_max_f16_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f32, "simpler_soft_max_f32", simpler_soft_max_f32_len, simpler_soft_max_f32_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -2175,6 +2212,18 @@ static void ggml_vk_load_shaders(vk_device& device) { static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); +static bool ggml_vk_device_use_simpler_shaders(vk::PhysicalDeviceProperties *props) { + switch (props->vendorID) { + case VK_VENDOR_ID_APPLE: + // Vulkan on Apple implies using MoltenVK, which doesn't support + // the regular MAT_MUL shaders, among others. + case VK_VENDOR_ID_QUALCOMM: + return true; + default: + return false; + } +} + static vk_device ggml_vk_get_device(size_t idx) { VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); @@ -2314,6 +2363,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->subgroup_size = subgroup_props.subgroupSize; device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + device->simpler_shaders = ggml_vk_device_use_simpler_shaders(&device->properties); if (sm_builtins) { device->shader_core_count = sm_props.shaderSMCount; } else if (amd_shader_core_properties2) { @@ -4512,6 +4562,167 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c } } +static void ggml_vk_simpler_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_simpler_mul_mat(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t ne0 = dst->ne[0]; + const uint64_t ne1 = dst->ne[1]; + const uint64_t ne2 = dst->ne[2]; + const uint64_t ne3 = dst->ne[3]; + + const uint64_t nb00 = src0->nb[0]; + const uint64_t nb01 = src0->nb[1]; + const uint64_t nb02 = src0->nb[2]; + const uint64_t nb03 = src0->nb[3]; + + const uint64_t nb10 = src1->nb[0]; + const uint64_t nb11 = src1->nb[1]; + const uint64_t nb12 = src1->nb[2]; + const uint64_t nb13 = src1->nb[3]; + + const uint64_t nb1 = dst->nb[1]; + const uint64_t nb2 = dst->nb[2]; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + const uint64_t x_ne = ne00 * ne01 * ne02 * ne03; + const uint64_t y_ne = ne10 * ne11 * ne12 * ne13; + const uint64_t d_ne = ne0 * ne1 * ne2 * ne3; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline pipeline; + switch (src0->type) { + case GGML_TYPE_F32: + pipeline = ctx->device->pipeline_simpler_mul_mat_mat_f32; + break; + case GGML_TYPE_F16: + pipeline = ctx->device->pipeline_simpler_mul_mat_f16; + break; + case GGML_TYPE_Q4_0: + pipeline = ctx->device->pipeline_simpler_mul_mat_q4_0; + break; + case GGML_TYPE_Q4_1: + pipeline = ctx->device->pipeline_simpler_mul_mat_q4_1; + break; + case GGML_TYPE_Q4_K: + pipeline = ctx->device->pipeline_simpler_mul_mat_q4_k; + break; + case GGML_TYPE_Q6_K: + pipeline = ctx->device->pipeline_simpler_mul_mat_q6_k; + break; + case GGML_TYPE_Q8_0: + pipeline = ctx->device->pipeline_simpler_mul_mat_q8_0; + break; + default: + GGML_ABORT("vk_simpler_mul_mat: unsupported quantization type: %d", src0->type); + } + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + GGML_ASSERT(d_D->size >= d_buf_offset + d_sz); + if (!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + + const uint64_t qx_buffer_offset = (qx_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qx_shader_offset = qx_buf_offset - qx_buffer_offset; + + const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; + + const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; + + // compute + ggml_vk_sync_buffers(subctx); + switch (src0->type) { + case GGML_TYPE_F32: + { + const std::array pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne11, (uint32_t)ne12, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb1, (uint32_t)nb2 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 14 * sizeof(uint32_t), &pc, { (uint32_t)ne01, (uint32_t)ne11, (uint32_t)std::max(ne12, ne02) }); + break; + } + case GGML_TYPE_F16: + { + const std::array pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)nb00, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb03, (uint32_t)ne10, (uint32_t)ne11, (uint32_t)ne12, (uint32_t)nb10, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb13, (uint32_t)ne0, (uint32_t)ne1, (uint32_t)r2, (uint32_t)r3 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 21 * sizeof(uint32_t), &pc, { (uint32_t)ne01, (uint32_t)((ne11 + 3) / 4), (uint32_t)(ne12 * ne13) }); + break; + } + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + { + const std::array pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne10, (uint32_t)ne12, (uint32_t)ne0, (uint32_t)ne1, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb03, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb13, (uint32_t)r2, (uint32_t)r3 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 18 * sizeof(uint32_t), &pc, { (uint32_t)((ne01 + 7) / 8), (uint32_t)ne11, (uint32_t)(ne12 * ne13) }); + break; + } + case GGML_TYPE_Q4_K: + { + const std::array pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb03, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb13, (uint32_t)r2, (uint32_t)r3 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 18 * sizeof(uint32_t), &pc, { (uint32_t)((ne01 + 3) / 4), (uint32_t)ne11, (uint32_t)(ne12 * ne13) }); + break; + } + case GGML_TYPE_Q6_K: + { + const std::array pc = { (uint32_t)qx_shader_offset, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)nb01, (uint32_t)nb02, (uint32_t)nb03, (uint32_t)nb11, (uint32_t)nb12, (uint32_t)nb13, (uint32_t)r2, (uint32_t)r3 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Qx, qx_buffer_offset, qx_sz + qx_shader_offset }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 18 * sizeof(uint32_t), &pc, { (uint32_t)((ne01 + 1) / 2), (uint32_t)ne11, (uint32_t)(ne12 * ne13) }); + break; + } + default: + GGML_ABORT("vk_simpler_mul_mat: unsupported quantization type: %d", src0->type); + } +} + static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; @@ -5247,6 +5458,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_SOFT_MAX: GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + if (ctx->device->simpler_shaders) { + if (src1 && src1->type == GGML_TYPE_F16) { + return ctx->device->pipeline_simpler_soft_max_f16; + } + return ctx->device->pipeline_simpler_soft_max_f32; + } + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; } @@ -5545,9 +5763,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); switch (op) { + case GGML_OP_SOFT_MAX: + if (ctx->device->simpler_shaders) { + elements = { (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3] }; + break; + } + // fall-through case GGML_OP_NORM: case GGML_OP_RMS_NORM: - case GGML_OP_SOFT_MAX: case GGML_OP_SUM_ROWS: { const uint32_t nr = ggml_nrows(src0); @@ -6088,14 +6311,26 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { - ncols, - src1 != nullptr ? nrows_y : (uint32_t)0, - scale, max_bias, - m0, m1, - n_head_log2, - nrows_x, - }, dryrun); + if (ctx->device->simpler_shaders) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + (int32_t) src0->ne[0], + (int32_t) src0->ne[1], + (int32_t) src0->ne[2], + scale, max_bias, + m0, m1, + n_head_log2, + src1 == nullptr ? 0 : 1, + }, dryrun); + } else { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + ncols, + src1 != nullptr ? nrows_y : (uint32_t)0, + scale, max_bias, + m0, m1, + n_head_log2, + nrows_x, + }, dryrun); + } } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { @@ -7227,8 +7462,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod break; case GGML_OP_MUL_MAT: - ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); - + if (ctx->device->simpler_shaders) { + ggml_vk_simpler_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); + } else { + ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); + } break; case GGML_OP_MUL_MAT_ID: ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); @@ -8019,7 +8257,7 @@ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const return ggml_backend_vk_init(ctx->device); } -static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { +static bool ggml_backend_vk_complete_device_supports_op(const vk_device& device, const ggml_tensor * op) { switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -8088,8 +8326,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } break; case GGML_OP_FLASH_ATTN_EXT: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - if (!ggml_vk_get_device(ctx->device)->coopmat2) { + if (!device->coopmat2) { return false; } switch (op->src[0]->ne[0]) { @@ -8254,8 +8491,43 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm default: return false; } +} - UNUSED(dev); +static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + + if (!device->simpler_shaders) { + return ggml_backend_vk_complete_device_supports_op(device, op); + } + + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1])) + return false; + + switch (op->src[0]->type) { + case GGML_TYPE_F32: + return op->ne[3] == 1; + case GGML_TYPE_Q8_0: + // TODO (slp) - Fix Q8_0 with permutations + if (ggml_is_permuted(op->src[0]) || ggml_is_permuted(op->src[1])) { + return false; + } + case GGML_TYPE_Q6_K: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + return true; + default: + return false; + } + case GGML_OP_MUL_MAT_ID: + return false; + default: + return ggml_backend_vk_complete_device_supports_op(device, op); + } } static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_common.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_common.comp new file mode 100644 index 000000000..dbe4cf804 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_common.comp @@ -0,0 +1,112 @@ +#extension GL_EXT_shader_16bit_storage: require +#extension GL_EXT_shader_8bit_storage: require +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#extension GL_EXT_shader_explicit_arithmetic_types_int8: require +#extension GL_EXT_shader_explicit_arithmetic_types_int16: require +#extension GL_EXT_shader_explicit_arithmetic_types_int64: require +#extension GL_EXT_control_flow_attributes: enable +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_EXT_debug_printf : enable + +#define QK4_0 32 +#define QK4_1 32 + +#define GELU_COEF_A 0.044715 +#define SQRT_2_OVER_PI 0.79788456080286535587989211986876 +#define TWOPI_F 6.283185307179586f + +#define QK_K 256 +#define K_SCALE_SIZE 12 + +#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx]) +#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx) +#define u8BufToU32(buf, idx) (((uint32_t u8BufToU16(buf, idx + 2) << 8 | buf[idx + 1]) << 8) | buf[idx]) +#define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx) + +#define sizeof_block_q4_0 0x12 +struct block_q4_0 { + float16_t d; + uint8_t qs[QK4_0 / 2]; +}; +mat4 dequantize_q4_0(const block_q4_0 xb, uint il) { + const float d1 = il != 0 ? (xb.d / 16.f) : xb.d; + const float d2 = d1 / 256.f; + const float md = -8.f * xb.d; + const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F); + const uint16_t mask1 = mask0 << 8; + + mat4 reg; + for (int i=0;i<8;i++) { + uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]); + reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md; + } + return reg; +} + +#define sizeof_block_q4_1 0x14 +struct block_q4_1 { + float16_t d; + float16_t m; + uint8_t qs[QK4_1 / 2]; +}; +mat4 dequantize_q4_1(const block_q4_1 xb, uint il) { + const float d1 = il != 0 ? (xb.d / 16.f) : xb.d; + const float d2 = d1 / 256.f; + const float m = xb.m; + const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F); + const uint16_t mask1 = mask0 << 8; + + mat4 reg; + for (int i=0;i<8;i++) { + uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]); + reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m; + reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m; + } + return reg; +} + +#define sizeof_block_q4_k 144 +struct block_q4_k { + float16_t d; + float16_t dmin; + uint8_t scales[K_SCALE_SIZE]; + uint8_t qs[QK_K/2]; +}; + +#define sizeof_block_q6_k 210 +struct block_q6_k { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + float16_t d; // super-block scale +}; +mat4 dequantize_q6_k(const block_q6_k xb, uint il) { + const float16_t d_all = xb.d; + + const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1); + const uint qhIndex = 32*(il/8) + 16*(il&1); + float16_t sc = xb.scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; + + const uint16_t kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3); + const uint16_t kmask2 = il>1 ? uint8_t(0xF0) : uint8_t(0x0F); + const float16_t coef = il>1 ? float16_t(1.f/16.f) : float16_t(1.f); + const float16_t ml = float16_t(d_all * sc * 32.f); + const float16_t dl = float16_t(d_all * sc * coef); + mat4 reg; + for (int i = 0; i < 16; ++i) { + const float16_t q = (il&1) != 0 ? ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 2)) + : ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 4)); + reg[i/4][i%4] = dl * q - ml; + } + return reg; +} + + +#define QK8_0 32 +// struct block_q8_0 { +// float16_t d; // delta +// int8_t qs[QK8_0]; // quants +// }; +#define sizeof_block_q8_0 34 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_f16.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_f16.comp new file mode 100644 index 000000000..38416eefd --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_f16.comp @@ -0,0 +1,69 @@ +#version 450 + +#include "simpler_common.comp" + +#extension GL_KHR_shader_subgroup_arithmetic : require + +layout(local_size_x_id = 0) in; + +layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int ne01; + int ne02; + uint nb00; + uint nb01; + uint nb02; + uint nb03; + int ne10; + int ne11; + int ne12; + uint nb10; + uint nb11; + uint nb12; + uint nb13; + int ne0; + int ne1; + uint r2; + uint r3; +} pcs; + +#define N_F16_F32 4 + +void main() { + const uint r0 = gl_WorkGroupID.x; + const uint rb = gl_WorkGroupID.y*N_F16_F32; + const uint im = gl_WorkGroupID.z; + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03; + + const uint x = offset0 / 2 + pcs.inAOff; // Based from inA + + for (uint row = 0; row < N_F16_F32; ++row) { + uint r1 = rb + row; + if (r1 >= pcs.ne11) { + break; + } + + const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; + + float sumf = 0; + for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) { + sumf += float(inA[x+i]) * float(inB[y+i]); + } + + const float all_sum = subgroupAdd(sumf); + if (subgroupElect()) { + out_[im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0 + pcs.outOff] = all_sum; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_mat_f32.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_mat_f32.comp new file mode 100644 index 000000000..dc048f430 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_mat_f32.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "simpler_common.comp" + +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_EXT_debug_printf : enable + +// device subgroup size +layout (local_size_x_id = 0) in; + +layout(binding = 0) readonly buffer tensorInA { float inA[]; }; +layout(binding = 1) readonly buffer tensorInB { float inB[]; }; +layout(binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout(push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int ne01; + int ne02; + int ne11; + int ne12; + uint nb01; + uint nb02; + uint nb11; + uint nb12; + uint nb1; + uint nb2; +} +pcs; + + +void main() { + uvec3 gid = gl_WorkGroupID; + + uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z; + uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z; + + const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA + const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB + float sum = 0.0f; + for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) { + sum += float(inA[x+i]) * float(inB[y+i]); + } + + const float all_sum = subgroupAdd(sum); + if (subgroupElect()) { + out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_0.comp new file mode 100644 index 000000000..db9dc612b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_0.comp @@ -0,0 +1,33 @@ +#version 450 + +#include "simpler_common.comp" + +#define BLOCKS_IN_QUANT QK4_0 +#define SIZE_OF_BLOCK sizeof_block_q4_0 +#define N_ROWS 4 + +#include "simpler_mul_mv_q_n_pre.comp" + +// The q4_0 version of this function +float block_q_n_dot_y(uint block_index, uint yb, uint il) { + vec2 acc = vec2(0.0, 0.0); + const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff; + float d = float(u8BufToFloat16(inA, index)); + float sumy = 0.0f; + for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) { + const uint16_t b = u8BufToU16(inA, index + 2 + il + i); + + const float yl0 = inB[yb + i]; + const float yl1 = inB[yb + i + 1]; + const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2]; + const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1]; + + sumy += yl0 + yl1 + yl8 + yl9; + + acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00); + acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000); + } + return d * (sumy * -8.f + acc[0] + acc[1]); +} + +#include "simpler_mul_mv_q_n.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_1.comp new file mode 100644 index 000000000..75449fa7b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_1.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "simpler_common.comp" + +#define BLOCKS_IN_QUANT QK4_1 +#define SIZE_OF_BLOCK sizeof_block_q4_1 +#define N_ROWS 4 + +#include "simpler_mul_mv_q_n_pre.comp" + +// The q4_1 version of this function +float block_q_n_dot_y(uint block_index, uint yb, uint il) { + vec2 acc = vec2(0.0, 0.0); + const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff; + float d = float(u8BufToFloat16(inA, index)); + float m = float(u8BufToFloat16(inA, index+2)); + + float sumy = 0.0f; + for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) { + const uint16_t b = u8BufToU16(inA, index + 4 + il + i); + + const float yl0 = inB[yb + i]; + const float yl1 = inB[yb + i + 1]; + const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2]; + const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1]; + + sumy += yl0 + yl1 + yl8 + yl9; + + acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00); + acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + +#include "simpler_mul_mv_q_n.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_k.comp new file mode 100644 index 000000000..7b767ac10 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q4_k.comp @@ -0,0 +1,140 @@ +#version 450 + +#include "simpler_common.comp" + +#define N_DST 4 +#define SIZE_OF_BLOCK sizeof_block_q4_k + +layout(local_size_x = 4) in; +layout(local_size_y = 8) in; +layout(local_size_z = 1) in; + +layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int ne10; + int ne0; + int ne1; + int ne01; + int ne02; + int ne12; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; + uint r2; + uint r3; +} pcs; + +void main() { + const uint16_t kmask1 = uint16_t(0x3f3f); + const uint16_t kmask2 = uint16_t(0x0f0f); + const uint16_t kmask3 = uint16_t(0xc0c0); + + const uint ix = gl_SubgroupInvocationID/8; // 0...3 + const uint it = gl_SubgroupInvocationID%8; // 0...7 + const uint iq = it/4; // 0 or 1 + const uint ir = it%4; // 0...3 + + const uint nb = pcs.ne00/QK_K; + + const uint r0 = gl_WorkGroupID.x; + const uint r1 = gl_WorkGroupID.y; + const uint im = gl_WorkGroupID.z; + + const uint first_row = r0 * N_DST; + const uint ib_row = first_row * nb; + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); + const uint offset1 = r1*pcs.nb11 + (i12 )*pcs.nb12 + (i13 )*pcs.nb13; + + const uint xblk = offset0 + pcs.inAOff; + const uint y = (offset1 / 4) + pcs.inBOff; + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f}; + float all_sum = 0.f; + + uint y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + for (uint ib = ix; ib < nb; ib += 4) { + const uint blk_idx = ib + xblk; + + float sumy[4] = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0]; + yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8]; + } + + for (int row = 0; row < N_DST; row++) { + uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK); + + uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0); + uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2); + uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4); + uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6); + uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8); + + uint16_t sc16[4]; + sc16[0] = sc_0 & kmask1; + sc16[1] = sc_2 & kmask1; + sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2); + sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2); + + float acc1[4] = {0.f, 0.f, 0.f, 0.f}; + float acc2[4] = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i); + uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i); + acc1[0] += yl[i+0] * (q1 & 0x000F); + acc1[1] += yl[i+1] * (q1 & 0x0F00); + acc1[2] += yl[i+8] * (q1 & 0x00F0); + acc1[3] += yl[i+9] * (q1 & 0xF000); + acc2[0] += yh[i+0] * (q2 & 0x000F); + acc2[1] += yh[i+1] * (q2 & 0x0F00); + acc2[2] += yh[i+8] * (q2 & 0x00F0); + acc2[3] += yh[i+9] * (q2 & 0xF000); + } + + uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF); + uint8_t sc8_1 = uint8_t(sc16[0] >> 8 ); + uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF); + uint8_t sc8_3 = uint8_t(sc16[1] >> 8 ); + uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF); + uint8_t sc8_5 = uint8_t(sc16[2] >> 8 ); + uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF); + uint8_t sc8_7 = uint8_t(sc16[3] >> 8 ); + + float dall = float(inA[blk_idx + row_idx].d); + float dmin = float(inA[blk_idx + row_idx].dmin); + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) - + dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7); + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = subgroupAdd(sumf[row]); + if (subgroupElect()) { + out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q6_k.comp new file mode 100644 index 000000000..32527c46a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q6_k.comp @@ -0,0 +1,106 @@ +#version 450 + +#include "simpler_common.comp" + +#define SIZE_OF_BLOCK sizeof_block_q6_k + +layout(local_size_x_id = 0) in; +layout(local_size_y_id = 1) in; +layout(local_size_z = 1) in; + +layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int ne10; + int ne0; + int ne1; + int ne01; + int ne02; + int ne12; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; + uint r2; + uint r3; +} pcs; + +void main() { + const uint8_t kmask1 = uint8_t(0x03); + const uint8_t kmask2 = uint8_t(0x0C); + const uint8_t kmask3 = uint8_t(0x30); + const uint8_t kmask4 = uint8_t(0xC0); + + const uint nb = pcs.ne00/QK_K; + + const uint r0 = gl_WorkGroupID.x; + const uint r1 = gl_WorkGroupID.y; + const uint im = gl_WorkGroupID.z; + + const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID); + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); + const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; + + float sumf = 0; + + // bits of invocation ID for gl_SubgroupSize=32: + // x x x x x + // 4 3 2 1 0 + // ( tid ) ix + // ip ( il ) + + const uint block_stride = gl_SubgroupSize / 16; // number of blocks each subgroup processes + const uint tid = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0 + const uint ix = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1 + const uint ip = tid/8; // first or second half of block (0 or 1) + const uint il = tid%8; // each half has 8 parts, one per scale + const uint n = 4; // 4 scales at a time (and 4 sums) + const uint l0 = n*il; // offset into half-block, 0..28 + const uint is = 8*ip + l0/16; // 0, 1, 8, 9 + + const uint y_offset = 128*ip + l0; + const uint q_offset_l = 64*ip + l0; + const uint q_offset_h = 32*ip + l0; + + for (uint i = ix; i < nb; i += block_stride) { + + const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff; + + const uint qlIndex = q_offset_l; + const uint q2Index = qlIndex + QK_K/8; + const uint qhIndex = q_offset_h; + const uint y = yy + i * QK_K + y_offset; + + float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + for (uint l = 0; l < n; ++l) { + const uint8_t currentQ1 = inA[baseIndex + qlIndex + l]; + const uint8_t currentQ2 = inA[baseIndex + q2Index + l]; + const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l]; + + sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32); + sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32); + sums[2] += inB[y+l+64] * (int8_t((currentQ1 >> 4) | ((currentQh & kmask3) << 0)) - 32); + sums[3] += inB[y+l+96] * (int8_t((currentQ2 >> 4) | ((currentQh & kmask4) >> 2)) - 32); + } + + float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16); + sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is])); + } + + const float tot = subgroupAdd(sumf); + if (subgroupElect()) { + out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q8_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q8_0.comp new file mode 100644 index 000000000..dcb9a1f6d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mat_q8_0.comp @@ -0,0 +1,73 @@ +#version 450 + +#include "simpler_common.comp" + +#include "simpler_mul_mv_q_n_pre.comp" + +#define SIZE_OF_D 2 + +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +#define NB_Q8_0 8 + +void main() { + // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64 + if (gl_SubgroupInvocationID > 31) + return; + + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + + const int nb = pcs.ne00/QK8_0; + const uint r0 = gl_WorkGroupID.x; + const uint r1 = gl_WorkGroupID.y; + const uint im = gl_WorkGroupID.z; + + const uint first_row = (r0 * nsg + gl_SubgroupID) * nr; + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02); + + const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA + const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB + + float yl[NB_Q8_0]; + float sumf[N_DST]={0.f, 0.f, 0.f, 0.f}; + + const uint ix = gl_SubgroupInvocationID.x/4; + const uint il = gl_SubgroupInvocationID.x%4; + + uint yb = y + ix * QK8_0 + NB_Q8_0*il; + + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (uint ib = ix; ib < nb; ib += nw/4) { + for (int i = 0; i < NB_Q8_0; ++i) { + yl[i] = inB[yb + i]; + } + + for (int row = 0; row < nr; row++) { + const uint block_offset = (ib+row*nb) * sizeof_block_q8_0; + float sumq = 0.f; + for (int iq = 0; iq < NB_Q8_0; ++iq) { + const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]); + sumq += qs_iq * yl[iq]; + } + const float16_t d = u8BufToFloat16(inA, x + block_offset); + sumf[row] += sumq*d; + } + + yb += NB_Q8_0 * nw; + } + + for (int row = 0; row < nr; ++row) { + const float tot = subgroupAdd(sumf[row]); + if (subgroupElect() && first_row + row < pcs.ne01) { + out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n.comp new file mode 100644 index 000000000..a6517cc1f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n.comp @@ -0,0 +1,52 @@ +void main() { + // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64 + if (gl_SubgroupInvocationID > 31) + return; + + const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT); + + const uint r0 = gl_WorkGroupID.x; + const uint r1 = gl_WorkGroupID.y; + const uint im = gl_WorkGroupID.z; + + const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS; + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + // pointers to src0 rows + uint ax[N_ROWS]; + for (int row = 0; row < N_ROWS; ++row) { + const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); + + ax[row] = offset0 + pcs.inAOff; + } + + const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; + + float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f}; + + const uint ix = gl_SubgroupInvocationID/2; + const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2); + + uint yb = y + ix * BLOCKS_IN_QUANT + il; + + //debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n", + // gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize, + // gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z); + + for (uint ib = ix; ib < nb; ib += 16) { + for (int row = 0; row < N_ROWS; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il); + } + + yb += BLOCKS_IN_QUANT * 16; + } + + for (int row = 0; row < N_ROWS; ++row) { + const float tot = subgroupAdd(sumf[row]); + if (first_row + row < pcs.ne01 && subgroupElect()) { + out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n_pre.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n_pre.comp new file mode 100644 index 000000000..a9a2f2218 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_mul_mv_q_n_pre.comp @@ -0,0 +1,28 @@ +layout(local_size_x_id = 0) in; +layout(local_size_y = 8) in; +layout(local_size_z = 1) in; + +layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int ne01; + int ne02; + int ne10; + int ne12; + int ne0; + int ne1; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; + uint r2; + uint r3; +} pcs; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp new file mode 100644 index 000000000..d2707c504 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp @@ -0,0 +1,69 @@ +// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4) + +#version 450 + +#include "simpler_common.comp" + +layout(local_size_x = 32) in; + +layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; +layout(binding = 1) buffer restrict readonly tensorInB { A_TYPE inB[]; }; +layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; + +layout(push_constant) uniform PushConstants { + int ne00; + int ne01; + int ne02; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + int mask; +} pcs; + +void main() { + if (gl_SubgroupInvocationID > 31) + return; + + const uint i03 = gl_WorkGroupID.z; + const uint i02 = gl_WorkGroupID.y; + const uint i01 = gl_WorkGroupID.x; + + const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00; + const uint psrc0 = extra_off; + const uint pmask = i01*pcs.ne00; + const uint pdst = extra_off; + + float slope = 1.0f; + + // ALiBi + if (pcs.max_bias > 0.0f) { + int64_t h = i02; + + float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1; + int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1; + + slope = pow(base, float(exp)); + } + + // parallel max + float localMax = uintBitsToFloat(0xFF800000); + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { + localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f)); + } + float max_ = subgroupMax(localMax); + + // parallel sum + float localSum = 0.0f; + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { + const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_); + localSum += exp_psrc0; + out_[pdst + i00] = exp_psrc0; + } + + const float sum = subgroupAdd(localSum); + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { + out_[pdst + i00] /= sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 77e7e1148..d616795e7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -496,6 +496,17 @@ void process_shaders() { string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("simpler_mul_mat_mat_f32", "simpler_mul_mat_mat_f32.comp", {}); + string_to_spv("simpler_mul_mat_f16", "simpler_mul_mat_f16.comp", {}); + string_to_spv("simpler_mul_mat_q4_0", "simpler_mul_mat_q4_0.comp", {}); + string_to_spv("simpler_mul_mat_q4_1", "simpler_mul_mat_q4_1.comp", {}); + string_to_spv("simpler_mul_mat_q4_k", "simpler_mul_mat_q4_k.comp", {}); + string_to_spv("simpler_mul_mat_q6_k", "simpler_mul_mat_q6_k.comp", {}); + string_to_spv("simpler_mul_mat_q8_0", "simpler_mul_mat_q8_0.comp", {}); + + string_to_spv("simpler_soft_max_f16", "simpler_soft_max.comp", {{"A_TYPE", "float16_t"}}); + string_to_spv("simpler_soft_max_f32", "simpler_soft_max.comp", {{"A_TYPE", "float"}}); + for (auto &c : compiles) { c.wait(); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1bfd41254..5c562aa90 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4166,7 +4166,7 @@ static std::vector> make_test_cases_eval() { for (float max_bias : {0.0f, 8.0f}) { if (!mask && max_bias > 0.0f) continue; for (float scale : {1.0f, 0.1f}) { - for (int64_t ne0 : {16, 1024}) { + for (int64_t ne0 : {16, 1024, 2048}) { for (int64_t ne1 : {16, 1024}) { if (mask) { for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {