vulkan: implement GGML_OP_SUB
This commit is contained in:
parent
deb15e3f53
commit
148f58681b
4 changed files with 68 additions and 1 deletions
|
@ -221,6 +221,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_acc_f32;
|
vk_pipeline pipeline_acc_f32;
|
||||||
vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
|
vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
|
||||||
vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
|
vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
|
||||||
|
vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
|
||||||
vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
|
vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
|
||||||
vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
|
vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
|
||||||
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
||||||
|
@ -2100,6 +2101,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
||||||
|
@ -5126,6 +5129,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
|
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_SUB:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
|
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
|
||||||
|
@ -5330,6 +5338,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
|
@ -5614,6 +5623,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
elements = { N * OC * OH * OW, 1, 1};
|
elements = { N * OC * OH * OW, 1, 1};
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
|
@ -5745,6 +5755,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||||
}, dryrun);
|
}, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||||
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||||
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||||
|
|
||||||
|
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, {
|
||||||
|
(uint32_t)ggml_nelements(src0),
|
||||||
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||||
|
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||||
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||||
|
0,
|
||||||
|
0.0f, 0.0f, 0,
|
||||||
|
}, dryrun);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||||
|
@ -7029,6 +7054,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
|
@ -7083,6 +7109,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
|
@ -7139,6 +7166,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
|
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_SUB:
|
||||||
|
ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun);
|
ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||||
|
@ -7323,6 +7354,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
|
@ -8271,6 +8303,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
|
@ -8762,6 +8795,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||||
tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
|
tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
|
||||||
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
|
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
|
||||||
tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
|
tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
|
||||||
|
} else if (tensor->op == GGML_OP_SUB) {
|
||||||
|
tensor_clone = ggml_sub(ggml_ctx, src0_clone, src1_clone);
|
||||||
} else if (tensor->op == GGML_OP_MUL) {
|
} else if (tensor->op == GGML_OP_MUL) {
|
||||||
tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
|
tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
|
||||||
} else if (tensor->op == GGML_OP_DIV) {
|
} else if (tensor->op == GGML_OP_DIV) {
|
||||||
|
|
29
ggml/src/ggml-vulkan/vulkan-shaders/sub.comp
Normal file
29
ggml/src/ggml-vulkan/vulkan-shaders/sub.comp
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
|
||||||
|
#include "types.comp"
|
||||||
|
#include "generic_binary_head.comp"
|
||||||
|
|
||||||
|
const uint num_threads = 256;
|
||||||
|
|
||||||
|
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
uint idx = get_idx();
|
||||||
|
|
||||||
|
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
||||||
|
const uint num_iter = 2;
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||||
|
if (idx >= p.ne) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
uint i00, i01, i02, i03;
|
||||||
|
get_indices(idx, i00, i01, i02, i03);
|
||||||
|
|
||||||
|
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||||
|
|
||||||
|
idx += num_threads;
|
||||||
|
}
|
||||||
|
}
|
|
@ -434,6 +434,8 @@ void process_shaders() {
|
||||||
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||||
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
|
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
|
||||||
|
|
||||||
|
string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||||
|
|
||||||
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||||
|
|
||||||
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
||||||
|
|
|
@ -1511,6 +1511,7 @@ struct test_cont : public test_case {
|
||||||
};
|
};
|
||||||
|
|
||||||
// GGML_OP_ADD
|
// GGML_OP_ADD
|
||||||
|
// GGML_OP_SUB
|
||||||
// GGML_OP_MUL
|
// GGML_OP_MUL
|
||||||
// GGML_OP_DIV
|
// GGML_OP_DIV
|
||||||
struct test_bin_bcast : public test_case {
|
struct test_bin_bcast : public test_case {
|
||||||
|
@ -3938,7 +3939,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
|
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
|
||||||
|
|
||||||
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
|
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
|
||||||
for (auto op : {ggml_add, ggml_mul, ggml_div}) {
|
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
|
||||||
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
|
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue