Add Vulkan sum_rows and div ops
This commit is contained in:
parent
579f059a15
commit
b4abdbb881
3 changed files with 45920 additions and 45948 deletions
91740
ggml-vulkan-shaders.hpp
91740
ggml-vulkan-shaders.hpp
File diff suppressed because it is too large
Load diff
|
@ -137,6 +137,7 @@ struct vk_device {
|
||||||
vk_pipeline pipeline_get_rows[VK_NUM_TYPES];
|
vk_pipeline pipeline_get_rows[VK_NUM_TYPES];
|
||||||
vk_pipeline pipeline_get_rows_f32[VK_NUM_TYPES];
|
vk_pipeline pipeline_get_rows_f32[VK_NUM_TYPES];
|
||||||
vk_pipeline pipeline_mul_f32;
|
vk_pipeline pipeline_mul_f32;
|
||||||
|
vk_pipeline pipeline_div_f32;
|
||||||
vk_pipeline pipeline_add_f32;
|
vk_pipeline pipeline_add_f32;
|
||||||
vk_pipeline pipeline_scale_f32;
|
vk_pipeline pipeline_scale_f32;
|
||||||
vk_pipeline pipeline_sqr_f32;
|
vk_pipeline pipeline_sqr_f32;
|
||||||
|
@ -152,6 +153,7 @@ struct vk_device {
|
||||||
vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
|
vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
|
||||||
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
||||||
vk_pipeline pipeline_argsort_f32;
|
vk_pipeline pipeline_argsort_f32;
|
||||||
|
vk_pipeline pipeline_sum_rows_f32;
|
||||||
|
|
||||||
std::vector<vk_pipeline_ref> pipelines;
|
std::vector<vk_pipeline_ref> pipelines;
|
||||||
|
|
||||||
|
@ -1522,6 +1524,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
@ -1544,6 +1548,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_print_gpu_info(size_t idx) {
|
static void ggml_vk_print_gpu_info(size_t idx) {
|
||||||
|
@ -3823,6 +3829,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_mul_f32;
|
return ctx->device->pipeline_mul_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_DIV:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_div_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_scale_f32;
|
return ctx->device->pipeline_scale_f32;
|
||||||
|
@ -3920,6 +3931,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_argsort_f32;
|
return ctx->device->pipeline_argsort_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_sum_rows_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -3942,6 +3958,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
|
@ -3964,7 +3981,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
|
||||||
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")" << std::endl;
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
|
GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
|
||||||
GGML_ASSERT(op == GGML_OP_CPY || ggml_vk_dim01_contiguous(src0)); // NOLINT
|
GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
|
||||||
GGML_ASSERT(dst->extra != nullptr);
|
GGML_ASSERT(dst->extra != nullptr);
|
||||||
const uint64_t ne00 = src0->ne[0];
|
const uint64_t ne00 = src0->ne[0];
|
||||||
const uint64_t ne01 = src0->ne[1];
|
const uint64_t ne01 = src0->ne[1];
|
||||||
|
@ -3987,6 +4004,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
|
||||||
const uint64_t ne23 = use_src2 ? src2->ne[3] : 0;
|
const uint64_t ne23 = use_src2 ? src2->ne[3] : 0;
|
||||||
const uint64_t ne2 = ne20 * ne21;
|
const uint64_t ne2 = ne20 * ne21;
|
||||||
|
|
||||||
|
const uint64_t ned0 = dst->ne[0];
|
||||||
|
const uint64_t ned1 = dst->ne[1];
|
||||||
|
const uint64_t ned2 = dst->ne[2];
|
||||||
|
const uint64_t ned3 = dst->ne[3];
|
||||||
|
const uint64_t ned = ned0 * ned1;
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
|
||||||
ggml_vk_func_t op_func;
|
ggml_vk_func_t op_func;
|
||||||
|
|
||||||
|
@ -4036,10 +4059,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
|
uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0;
|
||||||
uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0;
|
uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0;
|
||||||
uint64_t z_sz = use_src2 ? ggml_vk_align_size(ggml_type_size(src2->type) * ne2, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0;
|
uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0;
|
||||||
uint64_t d_sz = ggml_type_size(dst->type) * ne0;
|
uint64_t d_sz = ggml_type_size(dst->type) * ned;
|
||||||
|
|
||||||
vk_buffer d_D = extra->buffer_gpu.lock();
|
vk_buffer d_D = extra->buffer_gpu.lock();
|
||||||
|
|
||||||
|
@ -4097,6 +4120,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
elements = { (uint32_t)ggml_nrows(src0), 1, 1 };
|
elements = { (uint32_t)ggml_nrows(src0), 1, 1 };
|
||||||
break;
|
break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
|
@ -4125,7 +4149,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
|
||||||
z_sz *= ne22 * ne23;
|
z_sz *= ne22 * ne23;
|
||||||
}
|
}
|
||||||
if (d_sz != VK_WHOLE_SIZE) {
|
if (d_sz != VK_WHOLE_SIZE) {
|
||||||
d_sz *= ne02 * ne03;
|
d_sz *= ned2 * ned3;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4262,6 +4286,21 @@ static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, cons
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||||
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||||
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||||
|
|
||||||
|
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, {
|
||||||
|
(uint32_t)ggml_nelements(src0),
|
||||||
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||||
|
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||||
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||||
|
0,
|
||||||
|
0.0f, 0.0f,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
float * op_params = (float *)dst->op_params;
|
float * op_params = (float *)dst->op_params;
|
||||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||||
|
@ -4411,10 +4450,6 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx,
|
||||||
|
|
||||||
GGML_ASSERT(ncols_pad <= 1024);
|
GGML_ASSERT(ncols_pad <= 1024);
|
||||||
|
|
||||||
std::cerr << "ncols=" << ncols << " ncols_pad=" << ncols_pad << " ascending=" << op_params[0] << std::endl;
|
|
||||||
|
|
||||||
std::cerr << ((ggml_sort_order) op_params[0]) << " " << GGML_SORT_ORDER_ASC << std::endl;
|
|
||||||
|
|
||||||
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
|
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
|
||||||
ncols,
|
ncols,
|
||||||
ncols_pad,
|
ncols_pad,
|
||||||
|
@ -4422,6 +4457,10 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f });
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef GGML_VULKAN_RUN_TESTS
|
#ifdef GGML_VULKAN_RUN_TESTS
|
||||||
static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {
|
static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {
|
||||||
if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {
|
if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {
|
||||||
|
@ -5306,12 +5345,14 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
break;
|
break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(node)) {
|
switch (ggml_get_unary_op(node)) {
|
||||||
|
@ -5548,6 +5589,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
|
@ -5567,6 +5609,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
||||||
|
@ -5595,6 +5638,10 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
ggml_vk_mul(ctx, ctx->compute_ctx, src0, src1, node);
|
ggml_vk_mul(ctx, ctx->compute_ctx, src0, src1, node);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_DIV:
|
||||||
|
ggml_vk_div(ctx, ctx->compute_ctx, src0, src1, node);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
ggml_vk_scale(ctx, ctx->compute_ctx, src0, node);
|
ggml_vk_scale(ctx, ctx->compute_ctx, src0, node);
|
||||||
|
@ -5653,6 +5700,11 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
break;
|
break;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
ggml_vk_argsort(ctx, ctx->compute_ctx, src0, node);
|
ggml_vk_argsort(ctx, ctx->compute_ctx, src0, node);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
|
ggml_vk_sum_rows(ctx, ctx->compute_ctx, src0, node);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
ggml_vk_mul_mat(ctx, ctx->compute_ctx, src0, src1, node);
|
ggml_vk_mul_mat(ctx, ctx->compute_ctx, src0, src1, node);
|
||||||
|
@ -5689,6 +5741,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
|
@ -5706,6 +5759,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
|
||||||
case GGML_OP_TRANSPOSE:
|
case GGML_OP_TRANSPOSE:
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||||
|
|
||||||
break;
|
break;
|
||||||
|
@ -6442,6 +6496,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
|
@ -6450,6 +6505,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -6915,6 +6971,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
|
||||||
tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
|
tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
|
||||||
} else if (tensor->op == GGML_OP_MUL) {
|
} else if (tensor->op == GGML_OP_MUL) {
|
||||||
tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
|
tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
|
||||||
|
} else if (tensor->op == GGML_OP_DIV) {
|
||||||
|
tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone);
|
||||||
} else if (tensor->op == GGML_OP_SCALE) {
|
} else if (tensor->op == GGML_OP_SCALE) {
|
||||||
tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
|
tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
|
||||||
} else if (tensor->op == GGML_OP_SQR) {
|
} else if (tensor->op == GGML_OP_SQR) {
|
||||||
|
@ -6984,6 +7042,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
|
||||||
tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
|
tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
|
||||||
} else if (tensor->op == GGML_OP_ARGSORT) {
|
} else if (tensor->op == GGML_OP_ARGSORT) {
|
||||||
tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
|
tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
|
||||||
|
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
||||||
|
tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
|
||||||
} else {
|
} else {
|
||||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
|
|
@ -2000,12 +2000,18 @@ void main() {
|
||||||
|
|
||||||
generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_layout}\n{generic_binary_op_funcs}\n{generic_binary_op_main}"
|
generic_binary_op_combined = f"{generic_binary_op_head}\n{generic_binary_op_layout}\n{generic_binary_op_funcs}\n{generic_binary_op_main}"
|
||||||
|
|
||||||
# MUL F32
|
# MUL
|
||||||
mul_body = """
|
mul_body = """
|
||||||
data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
|
data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# DIV
|
||||||
|
div_body = """
|
||||||
|
data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) / FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
# ADD
|
# ADD
|
||||||
add_body = """
|
add_body = """
|
||||||
data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
|
data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
|
||||||
|
@ -2618,6 +2624,41 @@ void main() {
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
sum_rows_src = """
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||||
|
|
||||||
|
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint row = gl_WorkGroupID.x;
|
||||||
|
const uint col = gl_LocalInvocationID.x;
|
||||||
|
|
||||||
|
tmp[col] = FLOAT_TYPE(0.0f);
|
||||||
|
|
||||||
|
for (uint i = col; i < p.KX; i += BLOCK_SIZE) {
|
||||||
|
tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
|
||||||
|
if (col < s) {
|
||||||
|
tmp[col] += tmp[col + s];
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (col == 0) {
|
||||||
|
data_d[row] = D_TYPE(tmp[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
GLSLC = "glslc"
|
GLSLC = "glslc"
|
||||||
|
|
||||||
VK_NUM_TYPES = 16
|
VK_NUM_TYPES = 16
|
||||||
|
@ -2976,8 +3017,11 @@ async def main():
|
||||||
tasks.append(string_to_spv("add_f32", f"{generic_binary_op_combined}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
tasks.append(string_to_spv("add_f32", f"{generic_binary_op_combined}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
||||||
|
|
||||||
tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {}))
|
tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {}))
|
||||||
|
|
||||||
tasks.append(string_to_spv("mul_f32", f"{generic_binary_op_combined}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
tasks.append(string_to_spv("mul_f32", f"{generic_binary_op_combined}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
||||||
|
|
||||||
|
tasks.append(string_to_spv("div_f32", f"{generic_binary_op_combined}\n{div_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
||||||
|
|
||||||
tasks.append(string_to_spv("scale_f32", f"{generic_unary_op_combined}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
tasks.append(string_to_spv("scale_f32", f"{generic_unary_op_combined}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
||||||
|
|
||||||
tasks.append(string_to_spv("sqr_f32", f"{generic_unary_op_combined}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
tasks.append(string_to_spv("sqr_f32", f"{generic_unary_op_combined}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
||||||
|
@ -3001,6 +3045,8 @@ async def main():
|
||||||
|
|
||||||
tasks.append(string_to_spv("argsort_f32", argsort_src, {"A_TYPE": "float"}))
|
tasks.append(string_to_spv("argsort_f32", argsort_src, {"A_TYPE": "float"}))
|
||||||
|
|
||||||
|
tasks.append(string_to_spv("sum_rows_f32", f"{generic_head}\n{shader_f32}\n{sum_rows_src}", {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||||
|
|
||||||
# Helper to decorate tasks with semaphore acquisition.
|
# Helper to decorate tasks with semaphore acquisition.
|
||||||
async def withSemaphore(sem, task):
|
async def withSemaphore(sem, task):
|
||||||
async with sem:
|
async with sem:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue