[fix] Add missing parameters.

Signed-off-by: Changyeon Kim <cyzero.kim@samsung.com>
This commit is contained in:
Changyeon Kim 2024-08-15 22:10:41 +09:00
parent e6e018dafe
commit 12ab18bba0
2 changed files with 10 additions and 4 deletions

View file

@ -4489,13 +4489,18 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
d_offset,
0.0f, 0.0f, 0,
0.0f, 0.0f, offset,
});
}

View file

@ -4,12 +4,13 @@
#include "generic_binary_head.comp"
void main() {
const uint idx = get_idx();
const uint idx = gl_GlobalInvocationID.x;
if (idx >= p.ne) {
return;
}
const uint src1_i = src1_idx(idx);
const uint offset = p.param3;
const uint src1_i = idx - offset;
const uint oz = src1_i / p.nb02;
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
const uint ox = src1_i % p.nb01;