GEMM Kernel optimization
This commit is contained in:
parent
a42376e7ec
commit
baf9ff536b
3 changed files with 73 additions and 17 deletions
6
Makefile
6
Makefile
|
@ -215,12 +215,11 @@ endif # LLAMA_METAL
|
|||
|
||||
ifdef LLAMA_VULKAN
|
||||
CFLAGS += -DGGML_USE_VULKAN
|
||||
LDFLAGS += -lvulkan
|
||||
LDFLAGS += -lvulkan -lopenblas -lcblas
|
||||
OBJS += ggml-vulkan.o
|
||||
ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
ggml-vulkan-matmul-shader:
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 -O ggml-vulkan-matmul.glsl -o ggml-vulkan-matmul.spv
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 -O ggml-vulkan-matmul.comp -o ggml-vulkan-matmul.spv
|
||||
endif
|
||||
|
||||
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
||||
|
@ -287,7 +286,6 @@ clean:
|
|||
#
|
||||
# Examples
|
||||
#
|
||||
|
||||
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||
@echo
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
#version 450
|
||||
|
||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
||||
#define BLOCKSIZE 32
|
||||
|
||||
layout (binding = 0) readonly buffer A { float A_data[]; };
|
||||
layout (binding = 1) readonly buffer B { float B_data[]; };
|
||||
layout (binding = 2) writeonly buffer D { float D_data[]; };
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = BLOCKSIZE * BLOCKSIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { float data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
|
@ -16,18 +20,42 @@ layout (push_constant) uniform parameter
|
|||
int stride_d;
|
||||
} p;
|
||||
|
||||
shared float buf_a[(BLOCKSIZE+1) * BLOCKSIZE];
|
||||
shared float buf_b[(BLOCKSIZE+1) * BLOCKSIZE];
|
||||
|
||||
void main()
|
||||
{
|
||||
int i01 = int(gl_GlobalInvocationID.x);
|
||||
int i11 = int(gl_GlobalInvocationID.y);
|
||||
const int lr = int(gl_LocalInvocationID.x % BLOCKSIZE);
|
||||
const int lc = int(gl_LocalInvocationID.x / BLOCKSIZE);
|
||||
|
||||
if (i01 < p.M && i11 < p.N) {
|
||||
float sum = 0.0f;
|
||||
const int ir = int(gl_WorkGroupID.x);
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
for (int i = 0; i < p.K; i++) {
|
||||
sum += A_data[i01 * p.stride_a + i] * B_data[i11 * p.stride_b + i];
|
||||
int pos_a = ir * BLOCKSIZE * p.stride_a;
|
||||
int pos_b = ic * BLOCKSIZE * p.stride_b;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
[[unroll]] for (int i = 0; i < p.K; i += BLOCKSIZE) {
|
||||
buf_a[lc * (BLOCKSIZE+1) + lr] = data_a[pos_a + lc * p.stride_a + lr];
|
||||
buf_b[lc * (BLOCKSIZE+1) + lr] = data_b[pos_b + lc * p.stride_b + lr];
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BLOCKSIZE;
|
||||
pos_b += BLOCKSIZE;
|
||||
|
||||
[[unroll]] for (int j = 0; j < BLOCKSIZE; j++) {
|
||||
sum += buf_a[lr * (BLOCKSIZE+1) + j] * buf_b[lc * (BLOCKSIZE+1) + j];
|
||||
}
|
||||
|
||||
D_data[i11 * p.stride_d + i01] = sum;
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BLOCKSIZE + lr;
|
||||
const int dc = ic * BLOCKSIZE + lc;
|
||||
|
||||
if (dr < p.M && dc < p.N) {
|
||||
data_d[dc * p.stride_d + dr] = sum;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
#include "ggml-vulkan.h"
|
||||
|
||||
#include <cblas.h>
|
||||
#include <cmath>
|
||||
|
||||
#include <vulkan/vulkan.hpp>
|
||||
#define VMA_IMPLEMENTATION
|
||||
#if UINTPTR_MAX == 0xFFFFFFFF
|
||||
|
@ -29,6 +32,7 @@ inline static void* ggml_aligned_malloc(size_t size, size_t alignment) {
|
|||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <chrono>
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
|
@ -199,7 +203,7 @@ static void ggml_vk_pool_malloc(size_t size, vk_buffer* buf) {
|
|||
};
|
||||
|
||||
VmaAllocationCreateInfo allocation_info = {};
|
||||
allocation_info.usage = VMA_MEMORY_USAGE_AUTO;
|
||||
allocation_info.usage = VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE;
|
||||
allocation_info.flags = VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT | VMA_ALLOCATION_CREATE_HOST_ACCESS_ALLOW_TRANSFER_INSTEAD_BIT | VMA_ALLOCATION_CREATE_MAPPED_BIT;
|
||||
|
||||
vmaCreateBuffer(vk_allocator,
|
||||
|
@ -455,6 +459,8 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02);
|
||||
|
||||
// compute
|
||||
auto begin = std::chrono::high_resolution_clock::now();
|
||||
|
||||
vk::CommandBufferBeginInfo cmd_buffer_begin_info(vk::CommandBufferUsageFlagBits::eOneTimeSubmit);
|
||||
cmd_buffer.begin(cmd_buffer_begin_info);
|
||||
cmd_buffer.pushConstants<int>(vk_pipeline_matmul_layout, vk::ShaderStageFlagBits::eCompute, 0, push_constants);
|
||||
|
@ -480,10 +486,34 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
true,
|
||||
uint64_t(-1));
|
||||
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
std::cout << "m=" << ne01 << " n=" << ne11 << " k=" << ne10 << " matmul " << std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0 << "ms" << std::endl;
|
||||
|
||||
// copy dst to host
|
||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
float * d_blas = (float *) malloc(sizeof(float) * d_ne);
|
||||
ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne);
|
||||
|
||||
#ifdef false
|
||||
const float * x = (float *) ((char *) src0->data);
|
||||
const float * y = (float *) ((char *) src1->data);
|
||||
float * d_chk = (float *) malloc(sizeof(float) * d_ne);
|
||||
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
|
||||
ne01, ne11, ne10,
|
||||
1.0f, x, ne00,
|
||||
y, ne10,
|
||||
0.0f, d_chk, ne01);
|
||||
|
||||
for (size_t i = 0; i < d_ne; i++) {
|
||||
if (std::fabs(d[i] - d_chk[i]) > 0.01f) {
|
||||
printf("d[%ld] = %f d_chk[%ld] = %f\n", i, d[i], i, d_chk[i]);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
free(d_chk);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue