vulkan: enable the use of simpler softmax shaders
Even though the regular softmax shaders successfully pass test-backend-ops with Apple GPUs, running long inference tests has shown the models end derailing with softmax OPs being the root cause. With this commit, we use simpler softmax shaders borrowed from the Kompute backend (which are basically reimplementations of the Metal shaders) on certain GPUs know to have problem with the regular ones. Signed-off-by: Sergio Lopez <slp@redhat.com>
This commit is contained in:
parent
8b47fce07a
commit
bae9a58f5d
3 changed files with 123 additions and 9 deletions
|
@ -228,6 +228,9 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_simpler_mul_mat_q6_k;
|
||||
vk_pipeline pipeline_simpler_mul_mat_q8_0;
|
||||
|
||||
vk_pipeline pipeline_simpler_soft_max_f16;
|
||||
vk_pipeline pipeline_simpler_soft_max_f32;
|
||||
|
||||
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_acc_f32;
|
||||
|
@ -518,6 +521,18 @@ struct vk_op_soft_max_push_constants {
|
|||
uint32_t nrows_x;
|
||||
};
|
||||
|
||||
struct vk_op_simpler_soft_max_push_constants {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
uint32_t n_head_log2;
|
||||
int32_t mask;
|
||||
};
|
||||
|
||||
struct vk_op_argsort_push_constants {
|
||||
uint32_t ncols;
|
||||
uint32_t ncols_pad;
|
||||
|
@ -2088,6 +2103,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q6_k, "simpler_mul_mat_q6_k", simpler_mul_mat_q6_k_len, simpler_mul_mat_q6_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {2, device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q8_0, "simpler_mul_mat_q8_0", simpler_mul_mat_q8_0_len, simpler_mul_mat_q8_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f16, "simpler_soft_max_f16", simpler_soft_max_f16_len, simpler_soft_max_f16_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f32, "simpler_soft_max_f32", simpler_soft_max_f32_len, simpler_soft_max_f32_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
@ -5440,6 +5458,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
case GGML_OP_SOFT_MAX:
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||
|
||||
if (ctx->device->simpler_shaders) {
|
||||
if (src1 && src1->type == GGML_TYPE_F16) {
|
||||
return ctx->device->pipeline_simpler_soft_max_f16;
|
||||
}
|
||||
return ctx->device->pipeline_simpler_soft_max_f32;
|
||||
}
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
||||
}
|
||||
|
@ -5738,9 +5763,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
|
||||
|
||||
switch (op) {
|
||||
case GGML_OP_SOFT_MAX:
|
||||
if (ctx->device->simpler_shaders) {
|
||||
elements = { (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3] };
|
||||
break;
|
||||
}
|
||||
// fall-through
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
{
|
||||
const uint32_t nr = ggml_nrows(src0);
|
||||
|
@ -6281,14 +6311,26 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
|
||||
ncols,
|
||||
src1 != nullptr ? nrows_y : (uint32_t)0,
|
||||
scale, max_bias,
|
||||
m0, m1,
|
||||
n_head_log2,
|
||||
nrows_x,
|
||||
}, dryrun);
|
||||
if (ctx->device->simpler_shaders) {
|
||||
ggml_vk_op_f32<vk_op_simpler_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
|
||||
(int32_t) src0->ne[0],
|
||||
(int32_t) src0->ne[1],
|
||||
(int32_t) src0->ne[2],
|
||||
scale, max_bias,
|
||||
m0, m1,
|
||||
n_head_log2,
|
||||
src1 == nullptr ? 0 : 1,
|
||||
}, dryrun);
|
||||
} else {
|
||||
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
|
||||
ncols,
|
||||
src1 != nullptr ? nrows_y : (uint32_t)0,
|
||||
scale, max_bias,
|
||||
m0, m1,
|
||||
n_head_log2,
|
||||
nrows_x,
|
||||
}, dryrun);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
||||
|
|
69
ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp
Normal file
69
ggml/src/ggml-vulkan/vulkan-shaders/simpler_soft_max.comp
Normal file
|
@ -0,0 +1,69 @@
|
|||
// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
|
||||
|
||||
#version 450
|
||||
|
||||
#include "simpler_common.comp"
|
||||
|
||||
layout(local_size_x = 32) in;
|
||||
|
||||
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
|
||||
layout(binding = 1) buffer restrict readonly tensorInB { A_TYPE inB[]; };
|
||||
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
|
||||
|
||||
layout(push_constant) uniform PushConstants {
|
||||
int ne00;
|
||||
int ne01;
|
||||
int ne02;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
uint n_head_log2;
|
||||
int mask;
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
if (gl_SubgroupInvocationID > 31)
|
||||
return;
|
||||
|
||||
const uint i03 = gl_WorkGroupID.z;
|
||||
const uint i02 = gl_WorkGroupID.y;
|
||||
const uint i01 = gl_WorkGroupID.x;
|
||||
|
||||
const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
|
||||
const uint psrc0 = extra_off;
|
||||
const uint pmask = i01*pcs.ne00;
|
||||
const uint pdst = extra_off;
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (pcs.max_bias > 0.0f) {
|
||||
int64_t h = i02;
|
||||
|
||||
float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
|
||||
int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, float(exp));
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float localMax = uintBitsToFloat(0xFF800000);
|
||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
||||
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
|
||||
}
|
||||
float max_ = subgroupMax(localMax);
|
||||
|
||||
// parallel sum
|
||||
float localSum = 0.0f;
|
||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
||||
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
|
||||
localSum += exp_psrc0;
|
||||
out_[pdst + i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
const float sum = subgroupAdd(localSum);
|
||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
||||
out_[pdst + i00] /= sum;
|
||||
}
|
||||
}
|
|
@ -504,6 +504,9 @@ void process_shaders() {
|
|||
string_to_spv("simpler_mul_mat_q6_k", "simpler_mul_mat_q6_k.comp", {});
|
||||
string_to_spv("simpler_mul_mat_q8_0", "simpler_mul_mat_q8_0.comp", {});
|
||||
|
||||
string_to_spv("simpler_soft_max_f16", "simpler_soft_max.comp", {{"A_TYPE", "float16_t"}});
|
||||
string_to_spv("simpler_soft_max_f32", "simpler_soft_max.comp", {{"A_TYPE", "float"}});
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue