vulkan: enable the use of simpler softmax shaders

Even though the regular softmax shaders successfully pass
test-backend-ops with Apple GPUs, running long inference tests has shown
the models end derailing with softmax OPs being the root cause.

With this commit, we use simpler softmax shaders borrowed from the
Kompute backend (which are basically reimplementations of the Metal
shaders) on certain GPUs know to have problem with the regular ones.

Signed-off-by: Sergio Lopez <slp@redhat.com>
This commit is contained in:
Sergio Lopez 2025-02-07 14:57:03 +01:00
parent 8b47fce07a
commit bae9a58f5d
3 changed files with 123 additions and 9 deletions

View file

@ -228,6 +228,9 @@ struct vk_device_struct {
vk_pipeline pipeline_simpler_mul_mat_q6_k;
vk_pipeline pipeline_simpler_mul_mat_q8_0;
vk_pipeline pipeline_simpler_soft_max_f16;
vk_pipeline pipeline_simpler_soft_max_f32;
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_acc_f32;
@ -518,6 +521,18 @@ struct vk_op_soft_max_push_constants {
uint32_t nrows_x;
};
struct vk_op_simpler_soft_max_push_constants {
int32_t ne00;
int32_t ne01;
int32_t ne02;
float scale;
float max_bias;
float m0;
float m1;
uint32_t n_head_log2;
int32_t mask;
};
struct vk_op_argsort_push_constants {
uint32_t ncols;
uint32_t ncols_pad;
@ -2088,6 +2103,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q6_k, "simpler_mul_mat_q6_k", simpler_mul_mat_q6_k_len, simpler_mul_mat_q6_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {2, device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q8_0, "simpler_mul_mat_q8_0", simpler_mul_mat_q8_0_len, simpler_mul_mat_q8_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1);
ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f16, "simpler_soft_max_f16", simpler_soft_max_f16_len, simpler_soft_max_f16_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f32, "simpler_soft_max_f32", simpler_soft_max_f32_len, simpler_soft_max_f32_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@ -5440,6 +5458,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
case GGML_OP_SOFT_MAX:
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
if (ctx->device->simpler_shaders) {
if (src1 && src1->type == GGML_TYPE_F16) {
return ctx->device->pipeline_simpler_soft_max_f16;
}
return ctx->device->pipeline_simpler_soft_max_f32;
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
}
@ -5738,9 +5763,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
switch (op) {
case GGML_OP_SOFT_MAX:
if (ctx->device->simpler_shaders) {
elements = { (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3] };
break;
}
// fall-through
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_SOFT_MAX:
case GGML_OP_SUM_ROWS:
{
const uint32_t nr = ggml_nrows(src0);
@ -6281,14 +6311,26 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
scale, max_bias,
m0, m1,
n_head_log2,
nrows_x,
}, dryrun);
if (ctx->device->simpler_shaders) {
ggml_vk_op_f32<vk_op_simpler_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
(int32_t) src0->ne[0],
(int32_t) src0->ne[1],
(int32_t) src0->ne[2],
scale, max_bias,
m0, m1,
n_head_log2,
src1 == nullptr ? 0 : 1,
}, dryrun);
} else {
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
scale, max_bias,
m0, m1,
n_head_log2,
nrows_x,
}, dryrun);
}
}
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {

View file

@ -0,0 +1,69 @@
// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
#version 450
#include "simpler_common.comp"
layout(local_size_x = 32) in;
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { A_TYPE inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
int ne00;
int ne01;
int ne02;
float scale;
float max_bias;
float m0;
float m1;
uint n_head_log2;
int mask;
} pcs;
void main() {
if (gl_SubgroupInvocationID > 31)
return;
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
const uint psrc0 = extra_off;
const uint pmask = i01*pcs.ne00;
const uint pdst = extra_off;
float slope = 1.0f;
// ALiBi
if (pcs.max_bias > 0.0f) {
int64_t h = i02;
float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
slope = pow(base, float(exp));
}
// parallel max
float localMax = uintBitsToFloat(0xFF800000);
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
}
float max_ = subgroupMax(localMax);
// parallel sum
float localSum = 0.0f;
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
localSum += exp_psrc0;
out_[pdst + i00] = exp_psrc0;
}
const float sum = subgroupAdd(localSum);
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
out_[pdst + i00] /= sum;
}
}

View file

@ -504,6 +504,9 @@ void process_shaders() {
string_to_spv("simpler_mul_mat_q6_k", "simpler_mul_mat_q6_k.comp", {});
string_to_spv("simpler_mul_mat_q8_0", "simpler_mul_mat_q8_0.comp", {});
string_to_spv("simpler_soft_max_f16", "simpler_soft_max.comp", {{"A_TYPE", "float16_t"}});
string_to_spv("simpler_soft_max_f32", "simpler_soft_max.comp", {{"A_TYPE", "float"}});
for (auto &c : compiles) {
c.wait();
}