Added mul_mat (needs fixes)
This commit is contained in:
parent
749d6179a8
commit
964fe8c546
1 changed files with 343 additions and 14 deletions
357
ggml-vulkan.cpp
357
ggml-vulkan.cpp
|
@ -217,6 +217,7 @@ static const std::string program_source_head = R"(#version 450
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16: enable
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16: enable
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8: enable
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8: enable
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int64: enable
|
#extension GL_EXT_shader_explicit_arithmetic_types_int64: enable
|
||||||
|
#extension GL_EXT_control_flow_attributes: enable
|
||||||
|
|
||||||
#define QK4_0 32
|
#define QK4_0 32
|
||||||
#define QR4_0 2
|
#define QR4_0 2
|
||||||
|
@ -336,6 +337,44 @@ void ggml_vk_dequantize_row_q4_1(const void *x_, float *y, int k) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static const std::string program_fpx_to_fpx =
|
||||||
|
MULTILINE_QUOTE(
|
||||||
|
layout(push_constant) uniform PushConstants {
|
||||||
|
uint inOff;
|
||||||
|
uint outOff;
|
||||||
|
uint row;
|
||||||
|
} pcs;
|
||||||
|
|
||||||
|
layout(local_size_x = 1) in;
|
||||||
|
layout(binding = 0) buffer restrict readonly tensorIn { IN_TYPE in_[]; };
|
||||||
|
layout(binding = 1) buffer restrict writeonly tensorOut { OUT_TYPE out_[]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
out_[pcs.outOff + i] = OUT_TYPE(in_[pcs.inOff + i]);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
void ggml_vk_fp32_to_fp16_row(kp::Sequence& seq,
|
||||||
|
const std::shared_ptr<kp::Tensor>& in, uint32_t inOff,
|
||||||
|
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||||
|
uint32_t size) {
|
||||||
|
const static auto spirv = glsl_compile_source(program_source_head+
|
||||||
|
"#define IN_TYPE float\n"
|
||||||
|
"#define OUT_TYPE float16_t\n"+
|
||||||
|
program_fpx_to_fpx, __func__);
|
||||||
|
|
||||||
|
struct PushConstants {
|
||||||
|
uint32_t inOff, outOff;
|
||||||
|
} const pushConsts {
|
||||||
|
inOff, outOff
|
||||||
|
};
|
||||||
|
|
||||||
|
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({in, out}, spirv, {size}, {}, {pushConsts}));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static const std::string program_abmath =
|
static const std::string program_abmath =
|
||||||
MULTILINE_QUOTE(
|
MULTILINE_QUOTE(
|
||||||
layout(push_constant) uniform PushConstants {
|
layout(push_constant) uniform PushConstants {
|
||||||
|
@ -535,24 +574,24 @@ void main() {
|
||||||
const uint out_off = pcs.outOff + extra_off;
|
const uint out_off = pcs.outOff + extra_off;
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
buf[gl_LocalInvocationID.x] = uintBitsToFloat(0xFF800000);
|
buf[gl_WorkGroupID.x] = uintBitsToFloat(0xFF800000);
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) {
|
for (uint i00 = gl_WorkGroupID.x; i00 < pcs.ne00; i00 += nth) {
|
||||||
buf[gl_LocalInvocationID.x] = max(buf[gl_LocalInvocationID.x], in_[in_off + i00]);
|
buf[gl_WorkGroupID.x] = max(buf[gl_WorkGroupID.x], in_[in_off + i00]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce
|
// reduce
|
||||||
barrier();
|
barrier();
|
||||||
memoryBarrierShared();
|
memoryBarrierShared();
|
||||||
for (uint i = nth/2; i > 0; i /= 2) {
|
[[unroll]] for (uint i = nth/2; i > 0; i /= 2) {
|
||||||
if (gl_LocalInvocationID.x < i) {
|
if (gl_WorkGroupID.x < i) {
|
||||||
buf[gl_LocalInvocationID.x] = max(buf[gl_LocalInvocationID.x], buf[gl_LocalInvocationID.x + i]);
|
buf[gl_WorkGroupID.x] = max(buf[gl_WorkGroupID.x], buf[gl_WorkGroupID.x + i]);
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
memoryBarrierShared();
|
memoryBarrierShared();
|
||||||
}
|
}
|
||||||
|
|
||||||
// broadcast (no effect?)
|
// broadcast (no effect?)
|
||||||
if (gl_LocalInvocationID.x == 0) {
|
if (gl_WorkGroupID.x == 0) {
|
||||||
buf[0] = buf[0]; // ???
|
buf[0] = buf[0]; // ???
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -562,24 +601,24 @@ void main() {
|
||||||
const float max_ = buf[0];
|
const float max_ = buf[0];
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
buf[gl_LocalInvocationID.x] = 0.0;
|
buf[gl_WorkGroupID.x] = 0.0;
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) {
|
for (uint i00 = gl_WorkGroupID.x; i00 < pcs.ne00; i00 += nth) {
|
||||||
buf[gl_LocalInvocationID.x] += exp(in_[in_off + i00] - max_);
|
buf[gl_WorkGroupID.x] += exp(in_[in_off + i00] - max_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce
|
// reduce
|
||||||
barrier();
|
barrier();
|
||||||
memoryBarrierShared();
|
memoryBarrierShared();
|
||||||
for (uint i = nth/2; i > 0; i /= 2) {
|
for (uint i = nth/2; i > 0; i /= 2) {
|
||||||
if (gl_LocalInvocationID.x < i) {
|
if (gl_WorkGroupID.x < i) {
|
||||||
buf[gl_LocalInvocationID.x] += buf[gl_LocalInvocationID.x + i];
|
buf[gl_WorkGroupID.x] += buf[gl_WorkGroupID.x + i];
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
memoryBarrierShared();
|
memoryBarrierShared();
|
||||||
}
|
}
|
||||||
|
|
||||||
// broadcast (no effect?)
|
// broadcast (no effect?)
|
||||||
if (gl_LocalInvocationID.x == 0) {
|
if (gl_WorkGroupID.x == 0) {
|
||||||
buf[0] = buf[0]; // ???
|
buf[0] = buf[0]; // ???
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -588,7 +627,7 @@ void main() {
|
||||||
|
|
||||||
const float sum = buf[0];
|
const float sum = buf[0];
|
||||||
|
|
||||||
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) {
|
for (uint i00 = gl_WorkGroupID.x; i00 < pcs.ne00; i00 += nth) {
|
||||||
out_[out_off + i00] = exp(in_[in_off + i00] - max_) / sum;
|
out_[out_off + i00] = exp(in_[in_off + i00] - max_) / sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -612,6 +651,285 @@ void ggml_vk_soft_max(kp::Sequence& seq,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static const std::string program_mul_mat_f16 = R"(
|
||||||
|
#define BM 128
|
||||||
|
#define BN 128
|
||||||
|
#define BK 8
|
||||||
|
#define TM 8
|
||||||
|
#define TN 8
|
||||||
|
)" MULTILINE_QUOTE(
|
||||||
|
layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
|
||||||
|
layout (binding = 1) readonly buffer tensorInB { float16_t inB[]; };
|
||||||
|
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
int M;
|
||||||
|
int N;
|
||||||
|
int K;
|
||||||
|
int inAStride;
|
||||||
|
int inBStride;
|
||||||
|
int outStride;
|
||||||
|
uint inAOff;
|
||||||
|
uint inBOff;
|
||||||
|
uint outOff;
|
||||||
|
} pcs;
|
||||||
|
|
||||||
|
shared float16_t bufA[BM * (BK+1)];
|
||||||
|
shared float16_t bufB[BN * (BK+1)];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const int ir = int(gl_WorkGroupID.x);
|
||||||
|
const int ic = int(gl_WorkGroupID.y);
|
||||||
|
|
||||||
|
const int rstride = BM / TM;
|
||||||
|
|
||||||
|
const int lr = int(gl_LocalInvocationID.x % rstride);
|
||||||
|
const int lc = int(gl_LocalInvocationID.x / rstride);
|
||||||
|
|
||||||
|
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||||
|
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||||
|
|
||||||
|
const int loadstride = int(gl_WorkGroupSize.x);
|
||||||
|
|
||||||
|
int posA = ir * BM * pcs.inAStride;
|
||||||
|
int posB = ic * BN * pcs.inBStride;
|
||||||
|
|
||||||
|
float sums[TM * TN];
|
||||||
|
float16_t cacheA[TM];
|
||||||
|
float16_t cacheB[TN];
|
||||||
|
|
||||||
|
[[unroll]] for (int i = 0; i < TM*TN; i++) {
|
||||||
|
sums[i] = 0.0hf;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int block = 0; block < pcs.K; block += BK) {
|
||||||
|
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||||
|
const int lr = l % BK;
|
||||||
|
const int lc = l / BK;
|
||||||
|
bufA[(loadc + lc) * (BK+1) + loadr + lr] = inA[posA + (loadc + lc) * pcs.inAStride + loadr + lr];
|
||||||
|
}
|
||||||
|
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||||
|
const int lr = l % BK;
|
||||||
|
const int lc = l / BK;
|
||||||
|
bufB[(loadc + lc) * (BK+1) + loadr + lr] = inB[posB + (loadc + lc) * pcs.inBStride + loadr + lr];
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
posA += BK;
|
||||||
|
posB += BK;
|
||||||
|
|
||||||
|
[[unroll]] for (int i = 0; i < BK; i++) {
|
||||||
|
// Load from shared into cache
|
||||||
|
[[unroll]] for (int j = 0; j < BM; j++) {
|
||||||
|
cacheA[j] = bufA[(lr + j*rstride) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||||
|
cacheB[j] = bufB[(lc * TN + j) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
sums[cc * TM + cr] += float(cacheA[cr]) * float(cacheB[cc]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int dr = ir * BM + lr;
|
||||||
|
const int dc = ic * BN + lc * TN;
|
||||||
|
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
out_[(dc + cc) * pcs.outStride + dr + cr*rstride] = sums[cc * TM + cr];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
void ggml_vk_mul_mat_f16(kp::Sequence& seq,
|
||||||
|
const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
|
||||||
|
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
||||||
|
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||||
|
int64_t ne00, int64_t ne01, int64_t ne02, uint64_t ne03,
|
||||||
|
int64_t ne10, int64_t ne11,
|
||||||
|
int nb10, int nb11, int nb12, int nb13,
|
||||||
|
int nb2, int nb3) {
|
||||||
|
const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f16, __func__);
|
||||||
|
|
||||||
|
const bool inB_cont_rows = nb10 == sizeof(float);
|
||||||
|
const bool inB_cont_cols = (size_t)nb11 == ne11 * sizeof(float);
|
||||||
|
|
||||||
|
struct PushConstants {
|
||||||
|
int32_t M, N, K, inAStride, inBStride, outStride;
|
||||||
|
uint32_t inAOff, inBOff, outOff;
|
||||||
|
} pushConsts {
|
||||||
|
(int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01,
|
||||||
|
inAOff, inBOff, outOff
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
auto tmp = mgr.tensorT<half>(std::vector<half>(ne10*ne11));
|
||||||
|
|
||||||
|
if (inB_cont_rows) {
|
||||||
|
if (inB_cont_cols) {
|
||||||
|
ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12)/sizeof(float), tmp, 0, ne10*ne11);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (int64_t i01 = 0; i01 < ne11; i01++) {
|
||||||
|
ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12 + i01*nb11)/sizeof(float), tmp, i01*ne10, ne10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int64_t i01 = 0; i01 < ne11; i01++) {
|
||||||
|
for (int64_t i00 = 0; i00 < ne10; i00++) {
|
||||||
|
// Extremely slow because of single shader invocation
|
||||||
|
ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10)/sizeof(float), tmp, i01*ne10 + i00, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({inA, tmp, out}, spirv, {(uint32_t)ne01, (uint32_t)ne11}, {}, {pushConsts}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static const std::string program_mul_mat_f32 = R"(
|
||||||
|
#define BM 128
|
||||||
|
#define BN 128
|
||||||
|
#define BK 8
|
||||||
|
#define TM 8
|
||||||
|
#define TN 8
|
||||||
|
)" MULTILINE_QUOTE(
|
||||||
|
layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer tensorInA { float inA[]; };
|
||||||
|
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
|
||||||
|
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
int M;
|
||||||
|
int N;
|
||||||
|
int K;
|
||||||
|
int inAStride;
|
||||||
|
int inBStride;
|
||||||
|
int outStride;
|
||||||
|
uint inAOff;
|
||||||
|
uint inBOff;
|
||||||
|
uint outOff;
|
||||||
|
} pcs;
|
||||||
|
|
||||||
|
shared float bufA[BM * (BK+1)];
|
||||||
|
shared float bufB[BN * (BK+1)];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const int ir = int(gl_WorkGroupID.x);
|
||||||
|
const int ic = int(gl_WorkGroupID.y);
|
||||||
|
|
||||||
|
const int rstride = BM / TM;
|
||||||
|
|
||||||
|
const int lr = int(gl_WorkGroupID.x % rstride);
|
||||||
|
const int lc = int(gl_WorkGroupID.x / rstride);
|
||||||
|
|
||||||
|
const int loadr = int(gl_WorkGroupID.x % BK);
|
||||||
|
const int loadc = int(gl_WorkGroupID.x / BK);
|
||||||
|
|
||||||
|
const int loadstride = int(gl_WorkGroupSize.x);
|
||||||
|
|
||||||
|
int posA = ir * BM * pcs.inAStride;
|
||||||
|
int posB = ic * BN * pcs.inBStride;
|
||||||
|
|
||||||
|
float sums[TM * TN];
|
||||||
|
float cacheA[TM];
|
||||||
|
float cacheB[TN];
|
||||||
|
|
||||||
|
[[unroll]] for (int i = 0; i < TM*TN; i++) {
|
||||||
|
sums[i] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int block = 0; block < pcs.K; block += BK) {
|
||||||
|
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||||
|
const int lr = l % BK;
|
||||||
|
const int lc = l / BK;
|
||||||
|
bufA[(loadc + lc) * (BK+1) + loadr + lr] = inA[posA + (loadc + lc) * pcs.inAStride + loadr + lr + pcs.inAOff];
|
||||||
|
}
|
||||||
|
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||||
|
const int lr = l % BK;
|
||||||
|
const int lc = l / BK;
|
||||||
|
bufB[(loadc + lc) * (BK+1) + loadr + lr] = inB[posB + (loadc + lc) * pcs.inBStride + loadr + lr + pcs.inBOff];
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
memoryBarrierShared();
|
||||||
|
|
||||||
|
posA += BK;
|
||||||
|
posB += BK;
|
||||||
|
|
||||||
|
[[unroll]] for (int i = 0; i < BK; i++) {
|
||||||
|
// Load from shared into cache
|
||||||
|
[[unroll]] for (int j = 0; j < BM; j++) {
|
||||||
|
cacheA[j] = bufA[(lr + j*rstride) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||||
|
cacheB[j] = bufB[(lc * TN + j) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
sums[cc * TM + cr] += cacheA[cr] * cacheB[cc];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int dr = ir * BM + lr;
|
||||||
|
const int dc = ic * BN + lc * TN;
|
||||||
|
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
out_[(dc + cc) * pcs.outStride + dr + cr*rstride + pcs.outOff] = sums[cc * TM + cr];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
void ggml_vk_mul_mat_f32(kp::Sequence& seq,
|
||||||
|
const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
|
||||||
|
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
||||||
|
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||||
|
int64_t ne00, int64_t ne01, int64_t ne02, uint64_t ne03,
|
||||||
|
int64_t ne10, int64_t ne11,
|
||||||
|
int nb2, int nb3) {
|
||||||
|
const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f32, __func__);
|
||||||
|
|
||||||
|
struct PushConstants {
|
||||||
|
int32_t M, N, K, inAStride, inBStride, outStride;
|
||||||
|
uint32_t inAOff, inBOff, outOff;
|
||||||
|
} pushConsts {
|
||||||
|
(int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01,
|
||||||
|
inAOff, inBOff, outOff
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
auto off = i02*nb2 + i03*nb3;
|
||||||
|
pushConsts.inAOff = inAOff + off;
|
||||||
|
pushConsts.inBOff = inBOff + off;
|
||||||
|
pushConsts.outOff = outOff + off;
|
||||||
|
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({inA, inB, out}, spirv, {(uint32_t)ne01, (uint32_t)ne11}, {}, {pushConsts}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
|
void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
|
||||||
printf("%s: evaluating graph\n", __func__);
|
printf("%s: evaluating graph\n", __func__);
|
||||||
|
|
||||||
|
@ -723,6 +1041,17 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||||
{
|
{
|
||||||
ggml_vk_soft_max(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ne02, ne03);
|
ggml_vk_soft_max(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ne02, ne03);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_MUL_MAT:
|
||||||
|
{
|
||||||
|
if (src0->type == GGML_TYPE_F32
|
||||||
|
&& src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_vk_mul_mat_f32(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb2, nb3);
|
||||||
|
break;
|
||||||
|
} else if (src0->type == GGML_TYPE_F32
|
||||||
|
&& src1->type == GGML_TYPE_F16) {
|
||||||
|
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb10, nb11, nb12, nb13, nb2, nb3);
|
||||||
|
}
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||||
//GGML_ASSERT(false);
|
//GGML_ASSERT(false);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue