vulkan: implement GGML_OP_COUNT_EQUAL
This commit is contained in:
parent
148f58681b
commit
095f8d17ac
4 changed files with 61 additions and 2 deletions
|
@ -254,6 +254,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_argsort_f32;
|
vk_pipeline pipeline_argsort_f32;
|
||||||
vk_pipeline pipeline_sum_rows_f32;
|
vk_pipeline pipeline_sum_rows_f32;
|
||||||
vk_pipeline pipeline_argmax_f32;
|
vk_pipeline pipeline_argmax_f32;
|
||||||
|
vk_pipeline pipeline_count_equal_i32;
|
||||||
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
||||||
vk_pipeline pipeline_timestep_embedding_f32;
|
vk_pipeline pipeline_timestep_embedding_f32;
|
||||||
vk_pipeline pipeline_pool2d_f32;
|
vk_pipeline pipeline_pool2d_f32;
|
||||||
|
@ -2157,6 +2158,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||||
if (device->float_controls_rte_fp16) {
|
if (device->float_controls_rte_fp16) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||||
|
@ -5298,6 +5301,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_argmax_f32;
|
return ctx->device->pipeline_argmax_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_COUNT_EQUAL:
|
||||||
|
if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
|
||||||
|
return ctx->device->pipeline_count_equal_i32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_im2col_f32;
|
return ctx->device->pipeline_im2col_f32;
|
||||||
|
@ -6187,6 +6195,11 @@ static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
ggml_backend_tensor_memset(dst, 0, 0, ggml_nbytes(dst));
|
||||||
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||||
const int32_t s0 = dst->op_params[0];
|
const int32_t s0 = dst->op_params[0];
|
||||||
const int32_t s1 = dst->op_params[1];
|
const int32_t s1 = dst->op_params[1];
|
||||||
|
@ -7080,6 +7093,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
|
case GGML_OP_COUNT_EQUAL:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
|
@ -7134,6 +7148,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
|
case GGML_OP_COUNT_EQUAL:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
|
@ -7269,6 +7284,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
|
ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_COUNT_EQUAL:
|
||||||
|
ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
|
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||||
|
@ -7383,6 +7402,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
|
case GGML_OP_COUNT_EQUAL:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
|
@ -8320,6 +8340,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
|
case GGML_OP_COUNT_EQUAL:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
|
@ -8898,6 +8919,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||||
tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
|
tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
|
||||||
} else if (tensor->op == GGML_OP_ARGMAX) {
|
} else if (tensor->op == GGML_OP_ARGMAX) {
|
||||||
tensor_clone = ggml_argmax(ggml_ctx, src0_clone);
|
tensor_clone = ggml_argmax(ggml_ctx, src0_clone);
|
||||||
|
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
|
||||||
|
tensor_clone = ggml_count_equal(ggml_ctx, src0_clone, src1_clone);
|
||||||
} else if (tensor->op == GGML_OP_IM2COL) {
|
} else if (tensor->op == GGML_OP_IM2COL) {
|
||||||
const int32_t s0 = tensor->op_params[0];
|
const int32_t s0 = tensor->op_params[0];
|
||||||
const int32_t s1 = tensor->op_params[1];
|
const int32_t s1 = tensor->op_params[1];
|
||||||
|
@ -9017,6 +9040,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
||||||
} else if (tensor->type == GGML_TYPE_I32) {
|
} else if (tensor->type == GGML_TYPE_I32) {
|
||||||
correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
|
correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
|
||||||
result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
|
result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
|
||||||
|
} else if (tensor->type == GGML_TYPE_I64) {
|
||||||
|
correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
|
||||||
|
result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
|
||||||
} else {
|
} else {
|
||||||
std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
|
std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
|
||||||
}
|
}
|
||||||
|
|
31
ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp
Normal file
31
ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#include "types.comp"
|
||||||
|
#include "generic_head.comp"
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
||||||
|
layout (binding = 2) buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
const uint CHUNK_SIZE = 512;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
|
||||||
|
const uint col = gl_LocalInvocationID.x;
|
||||||
|
|
||||||
|
uint count = 0;
|
||||||
|
[[unroll]]
|
||||||
|
for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
|
||||||
|
const uint idx = base + i + col;
|
||||||
|
if (idx >= p.KX) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
count += uint(data_a[idx] == data_b[idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
atomicAdd(data_d[0], D_TYPE(count));
|
||||||
|
}
|
|
@ -488,6 +488,7 @@ void process_shaders() {
|
||||||
|
|
||||||
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
||||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||||
|
|
||||||
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
|
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
|
||||||
|
|
|
@ -1254,7 +1254,7 @@ struct test_count_equal : public test_case {
|
||||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
ggml_set_name(b, "b");
|
ggml_set_name(b, "b");
|
||||||
|
|
||||||
ggml_tensor * b_argmax = ggml_argmax(ctx, a);
|
ggml_tensor * b_argmax = ggml_argmax(ctx, b);
|
||||||
ggml_set_name(b_argmax, "b_argmax");
|
ggml_set_name(b_argmax, "b_argmax");
|
||||||
|
|
||||||
ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
|
ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
|
||||||
|
@ -3861,7 +3861,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
|
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
|
||||||
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
|
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_count_equal());
|
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
|
||||||
|
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
|
||||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue