Minor MUL_MAT fix and implemented DIAG_MASK_INF
This commit is contained in:
parent
964fe8c546
commit
f093bf2e5e
1 changed files with 66 additions and 18 deletions
|
@ -223,8 +223,8 @@ static const std::string program_source_head = R"(#version 450
|
|||
#define QR4_0 2
|
||||
#define QK4_1 32
|
||||
|
||||
#define GELU_COEF_A 0.044715;
|
||||
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876;
|
||||
#define GELU_COEF_A 0.044715
|
||||
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
|
||||
|
||||
#ifndef QK_K
|
||||
#define QK_K 256
|
||||
|
@ -235,6 +235,12 @@ static const std::string program_source_head = R"(#version 450
|
|||
#else
|
||||
#define K_SCALE_SIZE 4
|
||||
#endif
|
||||
|
||||
#define BM 128
|
||||
#define BN 128
|
||||
#define BK 8
|
||||
#define TM 8
|
||||
#define TN 8
|
||||
)";
|
||||
|
||||
|
||||
|
@ -651,13 +657,56 @@ void ggml_vk_soft_max(kp::Sequence& seq,
|
|||
}
|
||||
|
||||
|
||||
static const std::string program_mul_mat_f16 = R"(
|
||||
#define BM 128
|
||||
#define BN 128
|
||||
#define BK 8
|
||||
#define TM 8
|
||||
#define TN 8
|
||||
)" MULTILINE_QUOTE(
|
||||
static const std::string program_diag_mask_inf =
|
||||
MULTILINE_QUOTE(
|
||||
layout(push_constant) uniform PushConstants {
|
||||
uint64_t ne00;
|
||||
uint64_t ne01;
|
||||
uint inAOff;
|
||||
uint inBOff;
|
||||
uint outOff;
|
||||
} pcs;
|
||||
|
||||
layout(local_size_x = 1) in;
|
||||
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
|
||||
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
|
||||
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
|
||||
|
||||
void main() {
|
||||
const uint64_t i02 = uint64_t(gl_GlobalInvocationID.z);
|
||||
const uint64_t i01 = uint64_t(gl_GlobalInvocationID.y);
|
||||
const uint64_t i00 = uint64_t(gl_GlobalInvocationID.x);
|
||||
|
||||
const int n_past = inB[pcs.inBOff];
|
||||
|
||||
if (i00 > n_past + i01) {
|
||||
out_[uint(i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00 + pcs.outOff)] = uintBitsToFloat(0xFF800000);
|
||||
} else {
|
||||
out_[uint(i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00 + pcs.outOff)] = inA[uint(i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00 + pcs.inAOff)];
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
void ggml_vk_diag_mask_inf(kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
|
||||
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
||||
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||
int64_t ne00, int64_t ne01, int64_t ne02) {
|
||||
const static auto spirv = glsl_compile_source(program_source_head+program_diag_mask_inf, __func__);
|
||||
|
||||
struct PushConstants {
|
||||
int64_t ne00, ne01;
|
||||
uint32_t inAOff, inBOff, outOff;
|
||||
} pushConsts {
|
||||
ne00, ne01, inAOff, inBOff, outOff
|
||||
};
|
||||
|
||||
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({inA, inB, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts}));
|
||||
}
|
||||
|
||||
|
||||
static const std::string program_mul_mat_f16 =
|
||||
MULTILINE_QUOTE(
|
||||
layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
|
||||
|
@ -800,13 +849,8 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq,
|
|||
}
|
||||
|
||||
|
||||
static const std::string program_mul_mat_f32 = R"(
|
||||
#define BM 128
|
||||
#define BN 128
|
||||
#define BK 8
|
||||
#define TM 8
|
||||
#define TN 8
|
||||
)" MULTILINE_QUOTE(
|
||||
static const std::string program_mul_mat_f32 =
|
||||
MULTILINE_QUOTE(
|
||||
layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer tensorInA { float inA[]; };
|
||||
|
@ -1041,14 +1085,18 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||
{
|
||||
ggml_vk_soft_max(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ne02, ne03);
|
||||
} break;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
{
|
||||
ggml_vk_diag_mask_inf(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02);
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
if (src0->type == GGML_TYPE_F32
|
||||
&& src1->type == GGML_TYPE_F32) {
|
||||
ggml_vk_mul_mat_f32(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb2, nb3);
|
||||
break;
|
||||
} else if (src0->type == GGML_TYPE_F32
|
||||
&& src1->type == GGML_TYPE_F16) {
|
||||
} else if (src0->type == GGML_TYPE_F16
|
||||
&& src1->type == GGML_TYPE_F32) {
|
||||
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb10, nb11, nb12, nb13, nb2, nb3);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue