First matmul success

This commit is contained in:
0cc4m 2023-06-22 09:46:00 +02:00
parent 8ce84c2747
commit a42376e7ec
3 changed files with 115 additions and 145 deletions

33
ggml-vulkan-matmul.comp Normal file
View file

@ -0,0 +1,33 @@
#version 450
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
layout (binding = 0) readonly buffer A { float A_data[]; };
layout (binding = 1) readonly buffer B { float B_data[]; };
layout (binding = 2) writeonly buffer D { float D_data[]; };
layout (push_constant) uniform parameter
{
int M;
int N;
int K;
int stride_a;
int stride_b;
int stride_d;
} p;
void main()
{
int i01 = int(gl_GlobalInvocationID.x);
int i11 = int(gl_GlobalInvocationID.y);
if (i01 < p.M && i11 < p.N) {
float sum = 0.0f;
for (int i = 0; i < p.K; i++) {
sum += A_data[i01 * p.stride_a + i] * B_data[i11 * p.stride_b + i];
}
D_data[i11 * p.stride_d + i01] = sum;
}
}

View file

@ -1,112 +0,0 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Original at https://github.com/google/uVkCompute/blob/f3180c7e72ae639c0a7bc8cff7ed615b63ced27c/benchmarks/mmt/mmt_i8.glsl
// Modified by 0cc4m for FP32
#version 450 core
#pragma use_vulkan_memory_model
#extension GL_EXT_scalar_block_layout : enable
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_basic : enable
#define WG_X 32
#define WG_Y 2
#define M0 32
#define N0 256
#define K0 16
layout(binding = 0) buffer InputA { vec4 x[]; } inputA;
layout(binding = 1) buffer InputB { vec4 x[]; } inputB;
layout(binding = 2) buffer Output { float x[]; } outputO;
layout(local_size_x = WG_X, local_size_y = WG_Y, local_size_z = 1) in;
layout(constant_id = 0) const uint M = 1;
layout(constant_id = 1) const uint N = 1;
layout(constant_id = 2) const uint K = 1;
const uint VECTORIZE_K = 4;
const uint K_VEC = K / VECTORIZE_K;
const uint K0_VEC = K0 / VECTORIZE_K;
const uint VECTORIZE_N = 4;
const uint N_VEC = N / VECTORIZE_N;
const uint N0_VEC = N0 / VECTORIZE_N;
const uint strideA = K_VEC; // Stride of the `inputA` matrix.
const uint strideB = K_VEC; // Stride of the `inputB` matrix.
const uint strideC = N; // Stride of the `outputO` matrix.
// Each workgroup processes an output tile of size [M0 x N0], therefore
// each thread processes a [M0/WG_Y x N0/WG_X] subview.
const uint C_ROWS = M0 / WG_Y;
const uint C_COLS = N0 / WG_X;
/// Returns the index of `X[i, j]`, where `X` is a 2D matrix of stride |stride|.
uint coordToOffset(uint i, uint j, uint stride) { return stride * i + j; }
float sdot(vec4 lhs, vec4 rhs) {
vec4 mul = vec4(lhs) * vec4(rhs);
return float(mul.x) + float(mul.y) + float(mul.z) + float(mul.w);
}
void main() {
const uvec2 wgID = gl_WorkGroupID.xy;
const uvec2 localID = gl_LocalInvocationID.xy;
const uint threadID = gl_SubgroupInvocationID;
const uint subgroupID = gl_SubgroupID;
const uint subgroupSz = gl_SubgroupSize;
const uint numSubgroups = gl_NumSubgroups;
// The start offsets of the tile processed by this thread in this workgroup.
const uint x_offset = wgID.x * N0 + C_COLS * localID.x;
const uint y_offset = wgID.y * M0 + C_ROWS * localID.y;
float C[C_ROWS][C_COLS]; // Local data for the output.
// Initialize result to zero.
[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
[[unroll]] for (uint j = 0; j < C_COLS; ++j) {
C[i][j] = 0;
}
}
for (uint k = 0; k < K_VEC; k += K0_VEC) {
[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
[[unroll]] for (uint kk = 0; kk < K0_VEC; ++kk) {
uint y = y_offset + i;
uint gk = k + (kk + threadID) % K0_VEC;
vec4 lhs = inputA.x[coordToOffset(y, gk, strideA)];
[[unroll]] for (uint j = 0; j < C_COLS; ++j) {
// Calculate the inner product `C[i, j] := sum(A[i, ..] * B[j, ..])`.
uint x = x_offset + j;
vec4 rhs = inputB.x[coordToOffset(x, gk, strideB)];
C[i][j] += sdot(lhs, rhs);
}
}
}
}
// Store the accumulated results in `outputO`.
[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
uint y = y_offset + i;
[[unroll]] for (uint j = 0; j < C_COLS; ++j) {
uint x = x_offset + j;
outputO.x[coordToOffset(y, x, strideC)] = C[i][j];
}
}
}

View file

@ -107,7 +107,13 @@ void ggml_vk_init(void) {
descriptor_set_layout_binding); descriptor_set_layout_binding);
vk_pipeline_matmul_dsl = vk_device.createDescriptorSetLayout(descriptor_set_layout_create_info); vk_pipeline_matmul_dsl = vk_device.createDescriptorSetLayout(descriptor_set_layout_create_info);
vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), vk_pipeline_matmul_dsl); vk::PushConstantRange push_constant_range(
vk::ShaderStageFlagBits::eCompute,
0,
6 * sizeof(int)
);
vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), vk_pipeline_matmul_dsl, push_constant_range);
vk_pipeline_matmul_layout = vk_device.createPipelineLayout(pipeline_layout_create_info); vk_pipeline_matmul_layout = vk_device.createPipelineLayout(pipeline_layout_create_info);
vk::PipelineCache pipeline_cache = vk_device.createPipelineCache(vk::PipelineCacheCreateInfo()); vk::PipelineCache pipeline_cache = vk_device.createPipelineCache(vk::PipelineCacheCreateInfo());
@ -186,7 +192,7 @@ static void ggml_vk_pool_malloc(size_t size, vk_buffer* buf) {
vk::BufferCreateInfo buffer_create_info{ vk::BufferCreateInfo buffer_create_info{
vk::BufferCreateFlags(), vk::BufferCreateFlags(),
size, size,
vk::BufferUsageFlagBits::eStorageBuffer, vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst,
vk::SharingMode::eExclusive, vk::SharingMode::eExclusive,
1, 1,
&vk_compute_queue_family_index &vk_compute_queue_family_index
@ -205,10 +211,6 @@ static void ggml_vk_pool_malloc(size_t size, vk_buffer* buf) {
VkMemoryPropertyFlags mem_prop_flags; VkMemoryPropertyFlags mem_prop_flags;
vmaGetAllocationMemoryProperties(vk_allocator, buf->allocation, &mem_prop_flags); vmaGetAllocationMemoryProperties(vk_allocator, buf->allocation, &mem_prop_flags);
if(!(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) {
printf("Nope\n");
}
} }
static void ggml_vk_pool_free(vk_buffer* buffer) { static void ggml_vk_pool_free(vk_buffer* buffer) {
@ -226,12 +228,27 @@ static void ggml_vk_pool_free(vk_buffer* buffer) {
vmaDestroyBuffer(vk_allocator, buffer->buffer, buffer->allocation); vmaDestroyBuffer(vk_allocator, buffer->buffer, buffer->allocation);
} }
static void ggml_vk_buffer_write(VkCommandBuffer cmd_buf, vk_buffer* dst, size_t offset, const void * src, size_t size) { static vk::CommandBuffer ggml_vk_cmd_buffer_create() {
vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(), vk_compute_queue_family_index);
vk::CommandPool command_pool = vk_device.createCommandPool(command_pool_create_info);
vk::CommandBufferAllocateInfo command_buffer_alloc_info(
command_pool,
vk::CommandBufferLevel::ePrimary,
1);
const std::vector<vk::CommandBuffer> cmd_buffers = vk_device.allocateCommandBuffers(command_buffer_alloc_info);
return cmd_buffers.front();
}
static void ggml_vk_buffer_write(vk_buffer* dst, size_t offset, const void * src, size_t size) {
VkMemoryPropertyFlags mem_prop_flags; VkMemoryPropertyFlags mem_prop_flags;
vmaGetAllocationMemoryProperties(vk_allocator, dst->allocation, &mem_prop_flags); vmaGetAllocationMemoryProperties(vk_allocator, dst->allocation, &mem_prop_flags);
if(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) { if(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) {
memcpy(dst->info.pMappedData, src, size); memcpy(dst->info.pMappedData, src, size);
if (!(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT)) {
vmaFlushAllocation(vk_allocator, dst->allocation, 0, VK_WHOLE_SIZE);
}
} else { } else {
// Allocation ended up in a non-mappable memory - need to transfer. // Allocation ended up in a non-mappable memory - need to transfer.
VkBufferCreateInfo staging_buf_create_info = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; VkBufferCreateInfo staging_buf_create_info = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO };
@ -261,16 +278,39 @@ static void ggml_vk_buffer_write(VkCommandBuffer cmd_buf, vk_buffer* dst, size_t
0, // srcOffset 0, // srcOffset
0, // dstOffset, 0, // dstOffset,
size}; // size size}; // size
vkCmdCopyBuffer(cmd_buf, staging_buf, dst->buffer, 1, &buf_copy);
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create();
vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit);
cmd_buffer.begin(cmd_buffer_begin_info);
vkCmdCopyBuffer(cmd_buffer, staging_buf, dst->buffer, 1, &buf_copy);
cmd_buffer.end();
vk::Queue queue = vk_device.getQueue(vk_compute_queue_family_index, 0);
vk::Fence fence = vk_device.createFence(vk::FenceCreateInfo());
vk::SubmitInfo submit_info(0,
nullptr,
nullptr,
1,
&cmd_buffer);
queue.submit({ submit_info }, fence);
vk_device.waitForFences({ fence },
true,
uint64_t(-1));
vmaDestroyBuffer(vk_allocator, staging_buf, staging_alloc); vmaDestroyBuffer(vk_allocator, staging_buf, staging_alloc);
} }
} }
static void ggml_vk_buffer_read(VkCommandBuffer cmd_buf, vk_buffer* src, size_t offset, void * dst, size_t size) { static void ggml_vk_buffer_read(vk_buffer* src, size_t offset, void * dst, size_t size) {
vk::CommandBuffer cmd_buf = ggml_vk_cmd_buffer_create();
VkMemoryPropertyFlags mem_prop_flags; VkMemoryPropertyFlags mem_prop_flags;
vmaGetAllocationMemoryProperties(vk_allocator, src->allocation, &mem_prop_flags); vmaGetAllocationMemoryProperties(vk_allocator, src->allocation, &mem_prop_flags);
if(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) { if(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) {
if (!(mem_prop_flags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT)) {
vmaInvalidateAllocation(vk_allocator, src->allocation, 0, VK_WHOLE_SIZE);
}
memcpy(dst, src->info.pMappedData, size); memcpy(dst, src->info.pMappedData, size);
} else { } else {
// Allocation ended up in a non-mappable memory - need to transfer. // Allocation ended up in a non-mappable memory - need to transfer.
@ -298,15 +338,32 @@ static void ggml_vk_buffer_read(VkCommandBuffer cmd_buf, vk_buffer* src, size_t
offset, // srcOffset offset, // srcOffset
0, // dstOffset, 0, // dstOffset,
size}; // size size}; // size
vkCmdCopyBuffer(cmd_buf, src->buffer, staging_buf, 1, &buf_copy);
vmaInvalidateAllocation(vk_allocator, staging_alloc, 0, VK_WHOLE_SIZE); vmaInvalidateAllocation(vk_allocator, staging_alloc, 0, VK_WHOLE_SIZE);
// [Executed in runtime]:
vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create();
vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit);
cmd_buffer.begin(cmd_buffer_begin_info);
vkCmdCopyBuffer(cmd_buffer, src->buffer, staging_buf, 1, &buf_copy);
cmd_buffer.end();
vk::Queue queue = vk_device.getQueue(vk_compute_queue_family_index, 0);
vk::Fence fence = vk_device.createFence(vk::FenceCreateInfo());
vk::SubmitInfo submit_info(0,
nullptr,
nullptr,
1,
&cmd_buffer);
queue.submit({ submit_info }, fence);
vk_device.waitForFences({ fence },
true,
uint64_t(-1));
memcpy(dst, staging_alloc_info.pMappedData, size); memcpy(dst, staging_alloc_info.pMappedData, size);
vmaDestroyBuffer(vk_allocator, staging_buf, staging_alloc); vmaDestroyBuffer(vk_allocator, staging_buf, staging_alloc);
} }
} }
static void ggml_vk_h2d_tensor_2d(VkCommandBuffer cmd_buf, vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2) { static void ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2) {
const uint64_t ne0 = src->ne[0]; const uint64_t ne0 = src->ne[0];
const uint64_t ne1 = src->ne[1]; const uint64_t ne1 = src->ne[1];
const uint64_t nb0 = src->nb[0]; const uint64_t nb0 = src->nb[0];
@ -319,13 +376,12 @@ static void ggml_vk_h2d_tensor_2d(VkCommandBuffer cmd_buf, vk_buffer* dst, size_
const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3); const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
if (nb0 == ts && nb1 == ts*ne0/bs) { if (nb0 == ts && nb1 == ts*ne0/bs) {
ggml_vk_buffer_write(cmd_buf, dst, offset, x, ne1*nb1); ggml_vk_buffer_write(dst, offset, x, ne1*nb1);
return; return;
} }
if (nb0 == ts) { if (nb0 == ts) {
// Might be better to use vkCmdCopyBuffer here
for (uint64_t i1 = 0; i1 < ne1; i1++) { for (uint64_t i1 = 0; i1 < ne1; i1++) {
ggml_vk_buffer_write(cmd_buf, dst, offset + ne0 * i1, x + ts*ne0/bs, ne0*nb0); ggml_vk_buffer_write(dst, offset + ne0 * i1, (uint8_t *)x + ts*ne0/bs, ne0*nb0);
} }
return; return;
} }
@ -336,7 +392,6 @@ static void ggml_vk_h2d_tensor_2d(VkCommandBuffer cmd_buf, vk_buffer* dst, size_
dst_ptr[offset + i1 * ts*ne0/bs + i0 * ts] = xc[i1 * nb1 + i0 * nb0]; dst_ptr[offset + i1 * ts*ne0/bs + i0 * ts] = xc[i1 * nb1 + i0 * nb0];
} }
} }
return;
} }
static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -386,36 +441,30 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
}; };
vk_device.updateDescriptorSets(write_descriptor_sets, {}); vk_device.updateDescriptorSets(write_descriptor_sets, {});
vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(), vk_compute_queue_family_index); std::array<int, 6> push_constants = { (int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01 };
vk::CommandPool command_pool = vk_device.createCommandPool(command_pool_create_info); assert( ( sizeof( push_constants ) <= vk_physical_device.getProperties().limits.maxPushConstantsSize ) && "Too many push constants" );
vk::CommandBufferAllocateInfo command_buffer_alloc_info( vk::CommandBuffer cmd_buffer = ggml_vk_cmd_buffer_create();
command_pool,
vk::CommandBufferLevel::ePrimary,
1);
const std::vector<vk::CommandBuffer> cmd_buffers = vk_device.allocateCommandBuffers(command_buffer_alloc_info);
vk::CommandBuffer cmd_buffer = cmd_buffers.front();
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
// copy data to device // copy data to device
if (src0->backend != GGML_BACKEND_GPU) { if (src0->backend != GGML_BACKEND_GPU) {
ggml_vk_h2d_tensor_2d(cmd_buffer, &d_X, 0, src0, i03, i02); ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02);
} }
ggml_vk_h2d_tensor_2d(cmd_buffer, &d_Y, 0, src1, i03, i02); ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02);
printf("Beginning Vulkan kernel call\n");
// compute // compute
vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit); vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit);
cmd_buffer.begin(cmd_buffer_begin_info); cmd_buffer.begin(cmd_buffer_begin_info);
cmd_buffer.pushConstants<int>(vk_pipeline_matmul_layout, vk::ShaderStageFlagBits::eCompute, 0, push_constants);
cmd_buffer.bindPipeline(vk::PipelineBindPoint::eCompute, vk_pipeline_matmul); cmd_buffer.bindPipeline(vk::PipelineBindPoint::eCompute, vk_pipeline_matmul);
cmd_buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, cmd_buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
vk_pipeline_matmul_layout, vk_pipeline_matmul_layout,
0, 0,
{ descriptor_set }, { descriptor_set },
{}); {});
cmd_buffer.dispatch(d_ne, 1, 1); cmd_buffer.dispatch((ne01 + 31) / 32, (ne11 + 31) / 32, 1);
cmd_buffer.end(); cmd_buffer.end();
vk::Queue queue = vk_device.getQueue(vk_compute_queue_family_index, 0); vk::Queue queue = vk_device.getQueue(vk_compute_queue_family_index, 0);
@ -431,11 +480,10 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
true, true,
uint64_t(-1)); uint64_t(-1));
printf("Vulkan kernel call done\n");
// copy dst to host // copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
ggml_vk_buffer_read(cmd_buffer, &d_D, 0, d, sizeof(float) * d_ne); float * d_blas = (float *) malloc(sizeof(float) * d_ne);
ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne);
} }
} }
@ -447,6 +495,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
} }
static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
assert(false);
// const int64_t ne00 = src0->ne[0]; // const int64_t ne00 = src0->ne[0];
// const int64_t ne01 = src0->ne[1]; // const int64_t ne01 = src0->ne[1];
// const int64_t ne02 = src0->ne[2]; // const int64_t ne02 = src0->ne[2];
@ -577,7 +626,7 @@ bool ggml_vk_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
const int64_t ne1 = dst->ne[1]; const int64_t ne1 = dst->ne[1];
// TODO: find the optimal values for these // TODO: find the optimal values for these
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && if ((src0->type == GGML_TYPE_F32 /*|| src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)*/) &&
src1->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU)) { ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU)) {