vulkan: Further soft_max optimizations

Restore the workgroup size of 512 case, use it for >1024.

Use unrollable loops for more iteration counts.
This commit is contained in:
Jeff Bolz 2024-11-17 18:48:33 -06:00
parent c7b8ab73de
commit 85fc2974f2
2 changed files with 20 additions and 26 deletions

View file

@ -218,6 +218,7 @@ struct vk_device_struct {
vk_pipeline pipeline_tanh_f32;
vk_pipeline pipeline_diag_mask_inf_f32;
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
vk_pipeline pipeline_argsort_f32;
@ -1498,7 +1499,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@ -3933,10 +3936,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32;
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
}
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32_f16;
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
}
return nullptr;
case GGML_OP_ROPE:

View file

@ -152,30 +152,21 @@ void main() {
// instantiate the soft_max function for several different
// dimensions, to allow loop unrolling
uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE;
switch (num_blocks) {
case 1:
soft_max(1);
break;
case 2:
soft_max(2);
break;
case 3:
soft_max(3);
break;
case 4:
soft_max(4);
break;
case 5:
case 6:
case 7:
case 8:
soft_max(8);
break;
case 16:
soft_max(16);
break;
default:
if (num_blocks > 32) {
soft_max(num_blocks);
break;
} else if (num_blocks > 16) {
soft_max(32);
} else if (num_blocks > 8) {
soft_max(16);
} else if (num_blocks > 4) {
soft_max(8);
} else if (num_blocks == 4) {
soft_max(4);
} else if (num_blocks == 3) {
soft_max(3);
} else if (num_blocks == 2) {
soft_max(2);
} else if (num_blocks == 1) {
soft_max(1);
}
}