diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 21b7a1f9b..680888188 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -252,6 +252,7 @@ struct vk_device_struct { vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; vk_pipeline pipeline_argsort_f32; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_pool2d_f32; @@ -2149,6 +2150,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); @@ -5282,6 +5285,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_sum_rows_f32; } return nullptr; + case GGML_OP_ARGMAX: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_argmax_f32; + } + return nullptr; case GGML_OP_IM2COL: if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_im2col_f32; @@ -5545,6 +5553,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_RMS_NORM: case GGML_OP_SOFT_MAX: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: { const uint32_t nr = ggml_nrows(src0); if (nr > 262144) { @@ -6149,6 +6158,10 @@ static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); } +static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const int32_t s0 = dst->op_params[0]; const int32_t s1 = dst->op_params[1]; @@ -7040,6 +7053,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: @@ -7092,6 +7106,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: @@ -7219,6 +7234,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_SUM_ROWS: ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_ARGMAX: + ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_IM2COL: ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); @@ -7331,6 +7350,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: @@ -8266,6 +8286,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: @@ -8840,6 +8861,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { tensor_clone = ggml_sum(ggml_ctx, src0_clone); } else if (tensor->op == GGML_OP_SUM_ROWS) { tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone); + } else if (tensor->op == GGML_OP_ARGMAX) { + tensor_clone = ggml_argmax(ggml_ctx, src0_clone); } else if (tensor->op == GGML_OP_IM2COL) { const int32_t s0 = tensor->op_params[0]; const int32_t s1 = tensor->op_params[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp new file mode 100644 index 000000000..eaf4da341 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmpmax[BLOCK_SIZE]; +shared uint tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + if (col >= p.KX) { + return; + } + A_TYPE amax = data_a[row*p.KX + col]; + tmp[col] = col; + + for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) { + A_TYPE val = data_a[row*p.KX + i]; + if (val > amax) { + amax = val; + tmp[col] = i; + } + } + tmpmax[col] = amax; + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s && col + s < p.KX) { + if (tmpmax[col] < tmpmax[col + s]) { + tmpmax[col] = tmpmax[col + s]; + tmp[col] = tmp[col + s]; + } + } + barrier(); + } + + if (col == 0) { + data_d[row] = D_TYPE(tmp[0]); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 77e7e1148..3594fef53 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -484,6 +484,7 @@ void process_shaders() { string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));