Add further ops

This commit is contained in:
0cc4m 2023-08-07 06:02:57 +02:00
parent ccd2592782
commit e660943d3d
6 changed files with 249 additions and 179 deletions

View file

@ -247,6 +247,9 @@ ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_f16_f32.glsl -o vk_shaders/dequant_mul_mat_vec_f16_f32.spv & \ glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_f16_f32.glsl -o vk_shaders/dequant_mul_mat_vec_f16_f32.spv & \
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_q4_0_f32.glsl -o vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv & \ glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_q4_0_f32.glsl -o vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv & \
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/mul_f32.glsl -o vk_shaders/mul_f32.spv & \ glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/mul_f32.glsl -o vk_shaders/mul_f32.spv & \
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/add_f32.glsl -o vk_shaders/add_f32.spv & \
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/add_f16_f32_f16.glsl -o vk_shaders/add_f16_f32_f16.spv & \
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/scale_f32.glsl -o vk_shaders/scale_f32.spv & \
wait wait
endif endif

View file

@ -130,6 +130,18 @@ struct vk_device {
typedef std::vector<vk_submission> vk_sequence; typedef std::vector<vk_submission> vk_sequence;
struct vk_op_push_constants {
int M;
int N;
int stride_x;
int stride_y;
int stride_d;
int x_offset;
int y_offset;
int d_offset;
float scale;
};
vk::Instance vk_instance; vk::Instance vk_instance;
vk_device vk_device; vk_device vk_device;
vk_pipeline vk_pipeline_matmul_f32_l, vk_pipeline_matmul_f32_m, vk_pipeline_matmul_f32_s; vk_pipeline vk_pipeline_matmul_f32_l, vk_pipeline_matmul_f32_m, vk_pipeline_matmul_f32_s;
@ -142,6 +154,8 @@ vk_pipeline vk_pipeline_matmul_split_k_reduce;
vk_pipeline vk_pipeline_dequant_mul_mat_vec_f16, vk_pipeline_dequant_mul_mat_vec_q4_0; vk_pipeline vk_pipeline_dequant_mul_mat_vec_f16, vk_pipeline_dequant_mul_mat_vec_q4_0;
vk_pipeline vk_pipeline_dequant_mul_mat_vec_f16_f32, vk_pipeline_dequant_mul_mat_vec_q4_0_f32; vk_pipeline vk_pipeline_dequant_mul_mat_vec_f16_f32, vk_pipeline_dequant_mul_mat_vec_q4_0_f32;
vk_pipeline vk_pipeline_mul_f32; vk_pipeline vk_pipeline_mul_f32;
vk_pipeline vk_pipeline_add_f32, vk_pipeline_add_f16_f32_f16;
vk_pipeline vk_pipeline_scale_f32;
vk_pipeline vk_pipeline_f32_to_f16, vk_pipeline_dequant_q4_0; vk_pipeline vk_pipeline_f32_to_f16, vk_pipeline_dequant_q4_0;
static std::vector<std::tuple<void*, size_t, vk_buffer>> vk_pinned_memory; static std::vector<std::tuple<void*, size_t, vk_buffer>> vk_pinned_memory;
@ -759,7 +773,12 @@ void ggml_vk_init(void) {
vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_f16_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_f16_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
vk_pipeline_mul_f32 = ggml_vk_create_pipeline("vk_shaders/mul_f32.spv", "main", 3, 8 * sizeof(int), {32, 32, 1}, {}, 1); vk_pipeline_mul_f32 = ggml_vk_create_pipeline("vk_shaders/mul_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
vk_pipeline_add_f32 = ggml_vk_create_pipeline("vk_shaders/add_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline("vk_shaders/add_f16_f32_f16.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
vk_pipeline_scale_f32 = ggml_vk_create_pipeline("vk_shaders/scale_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
// Queues // Queues
uint32_t queue_index_offset = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; uint32_t queue_index_offset = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
@ -1713,148 +1732,6 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
ggml_vk_pool_free(d_D); ggml_vk_pool_free(d_D);
} }
static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#ifdef VK_DEBUG
std::cerr << "ggml_vk_mul_mat_f16((type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3];
std::cerr << "), (type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3];
std::cerr << "), (type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "),)" << std::endl;
#endif
GGML_ASSERT(vk_device.fp16);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int nb10 = src1->nb[0];
const int nb11 = src1->nb[1];
const int nb12 = src1->nb[2];
const int nb13 = src1->nb[3];
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(true, true, ne01, ne11, ne10 == kpad);
const uint32_t x_sz = ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
const uint32_t y_sz = ggml_vk_align_size(sizeof(ggml_fp16_t) * y_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
const uint32_t d_sz = ggml_vk_align_size(sizeof(float) * d_ne * split_k, vk_device.properties.limits.minStorageBufferOffsetAlignment);
vk_buffer d_X;
vk_buffer d_Y;
vk_buffer d_D;
if (src0->backend == GGML_BACKEND_GPU) {
d_X = *(vk_buffer*) src0->data;
} else {
ggml_vk_pool_malloc(x_sz * ne02 * ne03, &d_X, {});
}
ggml_vk_pool_malloc(y_sz * ne02 * ne03, &d_Y, {});
ggml_vk_pool_malloc(d_sz * ne02 * ne03, &d_D, {});
const bool src1_cont_rows = nb10 == sizeof(float);
const bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
std::vector<vk_sequence> compute_seqs;
std::vector<vk_sequence> transfer_0_seqs;
std::vector<vk_sequence> transfer_1_seqs;
const bool load_x = src0->backend != GGML_BACKEND_GPU;
// Allocate descriptor sets
ggml_vk_pipeline_allocate_descriptor_sets(*pipeline, ne02 * ne03);
if (split_k > 1) {
ggml_vk_pipeline_allocate_descriptor_sets(vk_pipeline_matmul_split_k_reduce, ne02 * ne03);
}
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
const uint32_t x_offset = load_x ? x_sz * (i03 * ne02 + i02) : 0;
const uint32_t y_offset = y_sz * (i03 * ne02 + i02);
const uint32_t d_offset = d_sz * (i03 * ne02 + i02);
vk::Semaphore s_x;
vk::Semaphore s_y = ggml_vk_create_semaphore(vk_device.compute_queue);
std::vector<vk::Semaphore> semaphores = { s_y };
// copy data to device
if (load_x) {
s_x = ggml_vk_create_semaphore(vk_device.compute_queue);
semaphores.push_back(s_x);
// Wait for previous matmul to be done before writing to the input buffers again
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, x_offset, src0, i03, i02, vk_device.transfer_queues[0], {}, { s_x }));
}
ggml_vk_submit(vk_device.transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE);
transfer_1_seqs.push_back(ggml_vk_h2d_tensor_2d_f32_to_f16(&d_Y, y_offset, src1, i03, i02, vk_device.transfer_queues[1], {}, { s_y }));
// compute
vk::Semaphore s_mm = ggml_vk_create_semaphore(vk_device.compute_queue);
compute_seqs.push_back(ggml_vk_matmul(*pipeline, { d_X, x_offset, x_sz }, { d_Y, y_offset, y_sz }, { d_D, d_offset, d_sz }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_device.compute_queue, std::move(semaphores), { s_mm }));
// copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
float * d_chk = (float *) ggml_vk_host_malloc(sizeof(float) * d_ne);
transfer_0_seqs.push_back(ggml_vk_buffer_read_async(&d_D, d_offset, d_chk, sizeof(float) * d_ne, vk_device.transfer_queues[0], { s_mm }, {}));
ggml_vk_submit(vk_device.transfer_queues[1], transfer_1_seqs, VK_NULL_HANDLE);
ggml_vk_submit(vk_device.compute_queue, compute_seqs, VK_NULL_HANDLE);
//DEBUG
ggml_vk_submit(vk_device.transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE);
vk_device.transfer_queues[0].queue.waitIdle();
double err = 0.0;
for (int i = 0; i < d_ne; i++) {
double abs_err = fabs(d[i] - d_chk[i]);
err += abs_err;
}
err /= d_ne;
if (err > 0.01) {
std::cerr << "ggml_vk_mul_mat_f16((type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3];
std::cerr << "), (type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3];
std::cerr << "), (type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "),)" << std::endl;
std::cerr << "MUL_MAT_F16 i02=" << i02 << " i03=" << i03 << " avg_err=" << err << std::endl;
}
ggml_vk_host_free(d_chk);
}
}
ggml_vk_submit(vk_device.transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE);
vk_device.transfer_queues[0].queue.waitIdle();
ggml_vk_queue_cleanup(vk_device.transfer_queues[0]);
ggml_vk_queue_cleanup(vk_device.transfer_queues[1]);
ggml_vk_queue_cleanup(vk_device.compute_queue);
ggml_vk_pipeline_cleanup(*pipeline);
ggml_vk_pipeline_cleanup(vk_pipeline_matmul_split_k_reduce);
if (src0->backend != GGML_BACKEND_GPU) {
ggml_vk_pool_free(d_X);
}
ggml_vk_pool_free(d_Y);
ggml_vk_pool_free(d_D);
}
static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#ifdef VK_DEBUG #ifdef VK_DEBUG
std::cerr << "ggml_vk_mul_mat_q_f16((type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3]; std::cerr << "ggml_vk_mul_mat_q_f16((type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3];
@ -2257,76 +2134,140 @@ static void ggml_vk_mul_mat(const struct ggml_tensor * src0, const struct ggml_t
} }
} }
static void ggml_vk_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static vk_pipeline* ggml_vk_op_get_pipeline(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op) {
switch (op) {
case GGML_OP_ADD:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return &vk_pipeline_add_f32;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
return &vk_pipeline_add_f16_f32_f16;
}
return nullptr;
case GGML_OP_MUL:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return &vk_pipeline_mul_f32;
}
return nullptr;
case GGML_OP_SCALE:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return &vk_pipeline_scale_f32;
}
return nullptr;
default:
return nullptr;
}
}
static void ggml_vk_op_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op, float scale=0.0f) {
#ifdef VK_DEBUG #ifdef VK_DEBUG
std::cerr << "ggml_vk_mul_f32((type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3]; std::cerr << "ggml_vk_op_f32((type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3];
std::cerr << "), (type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3]; std::cerr << "), (type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3];
std::cerr << "), (type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "),)" << std::endl; std::cerr << "), (type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "), " << ggml_op_name(op) << ")" << std::endl;
#endif #endif
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(src1->type)); // NOLINT
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2]; const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3]; const int64_t ne03 = src0->ne[3];
const int64_t ne0 = ne00 * ne01 * ne02 * ne03; const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
const int64_t ne10 = src1->ne[0]; const bool use_src1 = src1 != nullptr;
const int64_t ne11 = src1->ne[1]; const int64_t ne10 = use_src1 ? src1->ne[0] : 0;
const int64_t ne12 = src1->ne[2]; const int64_t ne11 = use_src1 ? src1->ne[1] : 0;
const int64_t ne13 = src1->ne[3]; const int64_t ne12 = use_src1 ? src1->ne[2] : 0;
const int64_t nb10 = src1->nb[0]; const int64_t ne13 = use_src1 ? src1->ne[3] : 0;
const int64_t ne1 = ne10 * ne11 * ne12 * ne13;
const int64_t nb10 = use_src1 ? src1->nb[0] : 0;
const int nb2 = dst->nb[2]; const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3]; const int nb3 = dst->nb[3];
GGML_ASSERT(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] == ne0);
GGML_ASSERT(nb10 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float));
const uint32_t buf_sz = ggml_vk_align_size(sizeof(float) * ne0, vk_device.properties.limits.minStorageBufferOffsetAlignment); vk_pipeline* pipeline = ggml_vk_op_get_pipeline(src0, src1, dst, op);
GGML_ASSERT(pipeline != nullptr);
const bool transfer_src0 = src0->backend != GGML_BACKEND_GPU;
const bool transfer_src1 = use_src1 && src1->backend != GGML_BACKEND_GPU;
const uint32_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type) * ne0, vk_device.properties.limits.minStorageBufferOffsetAlignment);
const uint32_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, vk_device.properties.limits.minStorageBufferOffsetAlignment) : 0;
const uint32_t d_sz = ggml_vk_align_size(ggml_type_size(dst->type) * ne0, vk_device.properties.limits.minStorageBufferOffsetAlignment);
vk_buffer d_X; vk_buffer d_X;
vk_buffer d_Y = *(vk_buffer*) src1->data; vk_buffer d_Y;
vk_buffer d_D; vk_buffer d_D;
ggml_vk_pool_malloc(buf_sz * ne02 * ne03, &d_X, {}); if (transfer_src0) {
ggml_vk_pool_malloc(buf_sz * ne02 * ne03, &d_D, {}); ggml_vk_pool_malloc(x_sz * ne02 * ne03, &d_X, {});
} else {
d_X = *(vk_buffer*) src0->data;
}
if (transfer_src1) {
ggml_vk_pool_malloc(y_sz * ne02 * ne03, &d_Y, {});
} else if (use_src1) {
d_Y = *(vk_buffer*) src1->data;
}
ggml_vk_pool_malloc(d_sz * ne02 * ne03, &d_D, {});
std::vector<vk_sequence> compute_seqs; std::vector<vk_sequence> compute_seqs;
std::vector<vk_sequence> transfer_0_seqs; std::vector<vk_sequence> transfer_0_seqs;
std::vector<vk_sequence> transfer_1_seqs; std::vector<vk_sequence> transfer_1_seqs;
// Allocate descriptor sets // Allocate descriptor sets
ggml_vk_pipeline_allocate_descriptor_sets(vk_pipeline_mul_f32, ne02 * ne03); ggml_vk_pipeline_allocate_descriptor_sets(*pipeline, ne02 * ne03);
int submit_counter = 0; int submit_counter = 0;
vk_op_push_constants pc = { (int)ne00, (int)ne01, (int)ne00, (int)ne00, (int)ne00, 0, 0, 0, scale };
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
const uint32_t it_idx = (i03 * ne02 + i02); const uint32_t it_idx = (i03 * ne02 + i02);
submit_counter++; submit_counter++;
if (ne03 > 1 || ne02 > 1) { if (ne03 > 1 || ne02 > 1) {
const uint32_t buf_offset = buf_sz * (i03 * ne02 + i02); const uint32_t x_offset = transfer_src0 ? x_sz * (i03 * ne02 + i02) : 0;
const uint32_t y_offset = transfer_src1 ? y_sz * (i03 * ne02 + i02) : 0;
const uint32_t d_offset = d_sz * (i03 * ne02 + i02);
vk::Semaphore s_x = ggml_vk_create_semaphore(vk_device.compute_queue); vk::Semaphore s_x;
vk::Semaphore s_y;
vk::Semaphore s_mm = ggml_vk_create_semaphore(vk_device.compute_queue); vk::Semaphore s_mm = ggml_vk_create_semaphore(vk_device.compute_queue);
std::vector<vk::Semaphore> transfer_semaphores;
// copy src0 to device // copy src0 to device
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, buf_offset, src0, i03, i02, vk_device.transfer_queues[0], {}, { s_x })); if (transfer_src0) {
s_x = ggml_vk_create_semaphore(vk_device.transfer_queues[0]);
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, x_offset, src0, i03, i02, vk_device.transfer_queues[0], {}, { s_x }));
transfer_semaphores.push_back(s_x);
}
if (transfer_src1) {
s_y = ggml_vk_create_semaphore(vk_device.transfer_queues[1]);
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Y, y_offset, src1, i03, i02, vk_device.transfer_queues[1], {}, { s_y }));
transfer_semaphores.push_back(s_y);
}
if (it_idx == 0 || submit_counter >= VK_SUBMIT_BATCH) { if (it_idx == 0 || submit_counter >= VK_SUBMIT_BATCH) {
ggml_vk_submit(vk_device.transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE); ggml_vk_submit(vk_device.transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE);
ggml_vk_submit(vk_device.transfer_queues[1], transfer_1_seqs, VK_NULL_HANDLE);
} }
const int64_t i13 = i03%ne13; const int64_t i13 = i03%ne13;
const int64_t i12 = i02%ne12; const int64_t i12 = i02%ne12;
const int i1 = i13*ne12*ne11 + i12*ne11; pc.y_offset = (i13*ne12*ne11 + i12*ne11) * ne10;
const std::vector<int> pc = { (int)ne00, (int)ne01, (int)ne00, (int)ne00, (int)ne00, 0, (int)(i1 * ne10), 0 };
vk_submission s = ggml_vk_begin_submission(vk_device.compute_queue); vk_submission s = ggml_vk_begin_submission(vk_device.compute_queue);
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false); ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false);
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_D) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferRead, vk::AccessFlagBits::eShaderWrite, false); ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_D) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferRead, vk::AccessFlagBits::eShaderWrite, false);
ggml_vk_dispatch_pipeline(s, vk_pipeline_mul_f32, { { d_X, buf_offset, buf_sz }, { d_Y, 0, (uint32_t) d_Y.size }, { d_D, buf_offset, buf_sz } }, sizeof(int) * pc.size(), pc.data(), { (uint32_t)ne00, (uint32_t)ne01, 1}); if (use_src1) {
ggml_vk_dispatch_pipeline(s, *pipeline, { { d_X, x_offset, x_sz }, { d_Y, y_offset, y_sz }, { d_D, d_offset, d_sz } }, sizeof(vk_op_push_constants), &pc, { (uint32_t)ne00, (uint32_t)ne01, 1});
} else {
ggml_vk_dispatch_pipeline(s, *pipeline, { { d_X, x_offset, x_sz }, { d_D, d_offset, d_sz } }, sizeof(vk_op_push_constants), &pc, { (uint32_t)ne00, (uint32_t)ne01, 1});
}
ggml_vk_end_submission(s, { s_x }, { s_mm }); ggml_vk_end_submission(s, { s_x }, { s_mm });
compute_seqs.push_back({ s }); compute_seqs.push_back({ s });
// copy dst to host // copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
transfer_1_seqs.push_back(ggml_vk_buffer_read_async(&d_D, buf_offset, d, sizeof(float) * ne00 * ne01, vk_device.transfer_queues[1], { s_mm }, {})); transfer_1_seqs.push_back(ggml_vk_buffer_read_async(&d_D, d_offset, d, sizeof(float) * ne00 * ne01, vk_device.transfer_queues[1], { s_mm }, {}));
if (it_idx == 0 || submit_counter >= VK_SUBMIT_BATCH) { if (it_idx == 0 || submit_counter >= VK_SUBMIT_BATCH) {
ggml_vk_submit(vk_device.compute_queue, compute_seqs, VK_NULL_HANDLE); ggml_vk_submit(vk_device.compute_queue, compute_seqs, VK_NULL_HANDLE);
@ -2334,18 +2275,29 @@ static void ggml_vk_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
submit_counter = 0; submit_counter = 0;
} }
} else { } else {
// Reduce overhead by only using one command buffer
vk_submission s = ggml_vk_begin_submission(vk_device.compute_queue); vk_submission s = ggml_vk_begin_submission(vk_device.compute_queue);
// copy src0 to device // copy src0 to device
ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_device.compute_queue, {}, {}, &s); if (transfer_src0 && transfer_src1) {
ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_device.compute_queue, {}, {}, &s);
ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_device.compute_queue, {}, {}, &s);
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, true);
} else if (transfer_src0) {
ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_device.compute_queue, {}, {}, &s);
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_X) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, true);
}
if (transfer_src1) {
ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_device.compute_queue, {}, {}, &s);
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_Y) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, true);
}
const int64_t i13 = i03%ne13;
const int64_t i12 = i02%ne12;
const int i1 = i13*ne12*ne11 + i12*ne11;
const std::vector<int> pc = { (int)ne00, (int)ne01, (int)ne00, (int)ne00, (int)ne00, 0, (int)(i1 * ne10), 0 };
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, true);
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_D) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferRead, vk::AccessFlagBits::eShaderWrite, false); ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_D) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferRead, vk::AccessFlagBits::eShaderWrite, false);
ggml_vk_dispatch_pipeline(s, vk_pipeline_mul_f32, { { d_X, 0, buf_sz }, { d_Y, 0, (uint32_t) d_Y.size }, { d_D, 0, buf_sz } }, sizeof(int) * pc.size(), pc.data(), { (uint32_t)ne00, (uint32_t)ne01, 1}); if (use_src1) {
ggml_vk_dispatch_pipeline(s, *pipeline, { { d_X, 0, x_sz }, { d_Y, 0, y_sz }, { d_D, 0, d_sz } }, sizeof(vk_op_push_constants), &pc, { (uint32_t)ne00, (uint32_t)ne01, 1});
} else {
ggml_vk_dispatch_pipeline(s, *pipeline, { { d_X, 0, x_sz }, { d_D, 0, d_sz } }, sizeof(vk_op_push_constants), &pc, { (uint32_t)ne00, (uint32_t)ne01, 1});
}
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_D) }, vk_device.compute_queue, vk::AccessFlagBits::eShaderWrite, vk::AccessFlagBits::eTransferRead, true); ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_D) }, vk_device.compute_queue, vk::AccessFlagBits::eShaderWrite, vk::AccessFlagBits::eTransferRead, true);
// copy dst to host // copy dst to host
@ -2355,10 +2307,7 @@ static void ggml_vk_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
compute_seqs.push_back({ s }); compute_seqs.push_back({ s });
if (it_idx == 0 || submit_counter >= VK_SUBMIT_BATCH) { ggml_vk_submit(vk_device.compute_queue, compute_seqs, VK_NULL_HANDLE);
ggml_vk_submit(vk_device.compute_queue, compute_seqs, VK_NULL_HANDLE);
submit_counter = 0;
}
} }
} }
} }
@ -2375,15 +2324,22 @@ static void ggml_vk_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_vk_queue_cleanup(vk_device.compute_queue); ggml_vk_queue_cleanup(vk_device.compute_queue);
} }
ggml_vk_pipeline_cleanup(vk_pipeline_mul_f32); ggml_vk_pipeline_cleanup(*pipeline);
ggml_vk_pool_free(d_X); ggml_vk_pool_free(d_X);
ggml_vk_pool_free(d_D); ggml_vk_pool_free(d_D);
} }
static void ggml_vk_add(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
ggml_vk_op_f32(src0, src1, dst, GGML_OP_ADD);
}
static void ggml_vk_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { static void ggml_vk_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); // NOLINT ggml_vk_op_f32(src0, src1, dst, GGML_OP_MUL);
ggml_vk_mul_f32(src0, src1, dst); }
static void ggml_vk_scale(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
ggml_vk_op_f32(src0, src1, dst, GGML_OP_SCALE, ((float *)src1->data)[0]);
} }
void ggml_vk_transform_tensor(void * data, ggml_tensor * tensor) { void ggml_vk_transform_tensor(void * data, ggml_tensor * tensor) {
@ -2423,6 +2379,14 @@ bool ggml_vk_compute_forward(struct ggml_compute_params * params, struct ggml_te
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU); || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
switch (tensor->op) { switch (tensor->op) {
case GGML_OP_ADD:
if (!any_on_device) {
return false;
}
func = ggml_vk_add;
break;
case GGML_OP_MUL: case GGML_OP_MUL:
if (!any_on_device) { if (!any_on_device) {
return false; return false;
@ -2430,6 +2394,14 @@ bool ggml_vk_compute_forward(struct ggml_compute_params * params, struct ggml_te
func = ggml_vk_mul; func = ggml_vk_mul;
break;
case GGML_OP_SCALE:
if (!any_on_device) {
return false;
}
func = ggml_vk_scale;
break; break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
if (!any_on_device && !ggml_vk_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) { if (!any_on_device && !ggml_vk_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {

View file

@ -0,0 +1,33 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
layout (binding = 0) buffer X { float16_t data_x[]; };
layout (binding = 1) buffer Y { float data_y[]; };
layout (binding = 2) buffer D { float16_t data_d[]; };
layout (push_constant) uniform parameter
{
int M;
int N;
int stride_x;
int stride_y;
int stride_d;
int x_offset;
int y_offset;
int d_offset;
float scale;
} p;
void main() {
const int x = int(gl_GlobalInvocationID.x);
const int y = int(gl_GlobalInvocationID.y);
if (x >= p.M || y >= p.N) {
return;
}
data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] + float16_t(data_y[p.y_offset + x]);
}

31
vk_shaders/add_f32.glsl Normal file
View file

@ -0,0 +1,31 @@
#version 450
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
layout (binding = 0) buffer X { float data_x[]; };
layout (binding = 1) buffer Y { float data_y[]; };
layout (binding = 2) buffer D { float data_d[]; };
layout (push_constant) uniform parameter
{
int M;
int N;
int stride_x;
int stride_y;
int stride_d;
int x_offset;
int y_offset;
int d_offset;
float scale;
} p;
void main() {
const int x = int(gl_GlobalInvocationID.x);
const int y = int(gl_GlobalInvocationID.y);
if (x >= p.M || y >= p.N) {
return;
}
data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] + data_y[p.y_offset + x];
}

View file

@ -16,6 +16,7 @@ layout (push_constant) uniform parameter
int x_offset; int x_offset;
int y_offset; int y_offset;
int d_offset; int d_offset;
float scale;
} p; } p;
void main() { void main() {

30
vk_shaders/scale_f32.glsl Normal file
View file

@ -0,0 +1,30 @@
#version 450
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
layout (binding = 0) buffer X { float data_x[]; };
layout (binding = 1) buffer D { float data_d[]; };
layout (push_constant) uniform parameter
{
int M;
int N;
int stride_x;
int stride_y;
int stride_d;
int x_offset;
int y_offset;
int d_offset;
float scale;
} p;
void main() {
const int x = int(gl_GlobalInvocationID.x);
const int y = int(gl_GlobalInvocationID.y);
if (x >= p.M || y >= p.N) {
return;
}
data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] * p.scale;
}