Implemented ggml_vk_soft_max
This commit is contained in:
parent
e2b721db65
commit
de7d1823ed
1 changed files with 127 additions and 17 deletions
144
ggml-vulkan.cpp
144
ggml-vulkan.cpp
|
@ -212,15 +212,28 @@ std::vector<uint8_t> getVecBlockQ4_0QS(T *x, unsigned nb, unsigned qk) {
|
|||
};
|
||||
|
||||
|
||||
static const std::string program_source_head = R"(
|
||||
#version 450
|
||||
static const std::string program_source_head = R"(#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16: enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8: enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int64: enable
|
||||
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
#define QK4_1 32
|
||||
|
||||
#define GELU_COEF_A 0.044715;
|
||||
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876;
|
||||
|
||||
#ifndef QK_K
|
||||
#define QK_K 256
|
||||
#endif
|
||||
|
||||
#if QK_K == 256
|
||||
#define K_SCALE_SIZE 12
|
||||
#else
|
||||
#define K_SCALE_SIZE 4
|
||||
#endif
|
||||
)";
|
||||
|
||||
|
||||
|
@ -366,16 +379,6 @@ void ggml_vk_abmath(kp::Sequence& seq,
|
|||
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({inA, inB, out}, spirv, {size}, {}, {pushConsts}));
|
||||
}
|
||||
|
||||
template <bool with_row = false, typename... Args>
|
||||
void ggml_vk_add(Args&&... args) {
|
||||
return ggml_vk_abmath<'+', with_row>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <bool with_row = false, typename... Args>
|
||||
void ggml_vk_mul(Args&&... args) {
|
||||
return ggml_vk_abmath<'*', with_row>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
|
||||
static const std::string program_scale =
|
||||
MULTILINE_QUOTE(
|
||||
|
@ -456,8 +459,8 @@ void ggml_vk_silu(Args&&... args) {
|
|||
static const std::string program_relu =
|
||||
MULTILINE_QUOTE(
|
||||
layout(push_constant) uniform PushConstants {
|
||||
uint inAOff;
|
||||
uint inOff;
|
||||
uint outOff;
|
||||
} pcs;
|
||||
|
||||
layout(local_size_x = 1) in;
|
||||
|
@ -482,8 +485,8 @@ void ggml_vk_relu(Args&&... args) {
|
|||
static const std::string program_gelu =
|
||||
MULTILINE_QUOTE(
|
||||
layout(push_constant) uniform PushConstants {
|
||||
uint inAOff;
|
||||
uint inOff;
|
||||
uint outOff;
|
||||
} pcs;
|
||||
|
||||
layout(local_size_x = 1) in;
|
||||
|
@ -506,6 +509,109 @@ void ggml_vk_gelu(Args&&... args) {
|
|||
}
|
||||
|
||||
|
||||
static const std::string program_soft_max =
|
||||
MULTILINE_QUOTE(
|
||||
layout(push_constant) uniform PushConstants {
|
||||
uint64_t ne00;
|
||||
uint64_t ne01;
|
||||
uint64_t ne02;
|
||||
uint inOff;
|
||||
uint outOff;
|
||||
} pcs;
|
||||
|
||||
layout(local_size_x = nth) in;
|
||||
layout(binding = 0) buffer restrict readonly tensorInA { float in_[]; };
|
||||
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
|
||||
|
||||
shared float buf[nth];
|
||||
|
||||
void main() {
|
||||
const uint64_t i03 = uint64_t(gl_GlobalInvocationID.z);
|
||||
const uint64_t i02 = uint64_t(gl_GlobalInvocationID.y);
|
||||
const uint64_t i01 = uint64_t(gl_GlobalInvocationID.x);
|
||||
|
||||
const uint extra_off = uint(i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00);
|
||||
const uint in_off = pcs.inOff + extra_off;
|
||||
const uint out_off = pcs.outOff + extra_off;
|
||||
|
||||
// parallel max
|
||||
buf[gl_LocalInvocationID.x] = uintBitsToFloat(0xFF800000);
|
||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) {
|
||||
buf[gl_LocalInvocationID.x] = max(buf[gl_LocalInvocationID.x], in_[in_off + i00]);
|
||||
}
|
||||
|
||||
// reduce
|
||||
barrier();
|
||||
memoryBarrierShared();
|
||||
for (uint i = nth/2; i > 0; i /= 2) {
|
||||
if (gl_LocalInvocationID.x < i) {
|
||||
buf[gl_LocalInvocationID.x] = max(buf[gl_LocalInvocationID.x], buf[gl_LocalInvocationID.x + i]);
|
||||
}
|
||||
barrier();
|
||||
memoryBarrierShared();
|
||||
}
|
||||
|
||||
// broadcast (no effect?)
|
||||
if (gl_LocalInvocationID.x == 0) {
|
||||
buf[0] = buf[0]; // ???
|
||||
}
|
||||
|
||||
barrier();
|
||||
memoryBarrierShared();
|
||||
|
||||
const float max_ = buf[0];
|
||||
|
||||
// parallel sum
|
||||
buf[gl_LocalInvocationID.x] = 0.0;
|
||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) {
|
||||
buf[gl_LocalInvocationID.x] += exp(in_[in_off + i00] - max_);
|
||||
}
|
||||
|
||||
// reduce
|
||||
barrier();
|
||||
memoryBarrierShared();
|
||||
for (uint i = nth/2; i > 0; i /= 2) {
|
||||
if (gl_LocalInvocationID.x < i) {
|
||||
buf[gl_LocalInvocationID.x] += buf[gl_LocalInvocationID.x + i];
|
||||
}
|
||||
barrier();
|
||||
memoryBarrierShared();
|
||||
}
|
||||
|
||||
// broadcast (no effect?)
|
||||
if (gl_LocalInvocationID.x == 0) {
|
||||
buf[0] = buf[0]; // ???
|
||||
}
|
||||
|
||||
barrier();
|
||||
memoryBarrierShared();
|
||||
|
||||
const float sum = buf[0];
|
||||
|
||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) {
|
||||
out_[out_off + i00] = exp(in_[in_off + i00] - max_) / sum;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
void ggml_vk_soft_max(kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& in, uint32_t inOff,
|
||||
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||
int64_t ne00, int64_t ne01, int64_t ne02, uint64_t ne03) {
|
||||
const static unsigned nth = 32;
|
||||
const static auto spirv = compileSource(program_source_head+"#define nth "+std::to_string(nth)+"\n"+program_soft_max, __func__);
|
||||
|
||||
struct PushConstants {
|
||||
int64_t ne00, ne01, ne02;
|
||||
uint32_t inOff, outOff;
|
||||
} pushConsts {
|
||||
ne00, ne01, ne02, inOff, outOff
|
||||
};
|
||||
|
||||
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}));
|
||||
}
|
||||
|
||||
|
||||
void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
|
||||
printf("%s: evaluating graph\n", __func__);
|
||||
|
||||
|
@ -585,15 +691,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||
} break;
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
ggml_vk_add(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst));
|
||||
ggml_vk_abmath<'+'>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst));
|
||||
} break;
|
||||
case GGML_OP_MUL:
|
||||
{
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
ggml_vk_mul<true>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst), ne00);
|
||||
ggml_vk_abmath<'*', true>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst), ne00);
|
||||
} else {
|
||||
ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst));
|
||||
ggml_vk_abmath<'*'>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst));
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
|
@ -613,6 +719,10 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||
{
|
||||
ggml_vk_gelu(seq, id_src0, offs_src0, id_dst, offs_dst, ggml_nelements(dst));
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
ggml_vk_soft_max(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ne02, ne03);
|
||||
} break;
|
||||
default:
|
||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||
//GGML_ASSERT(false);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue