From 061246fb07c1acdc299c9acb1f0692e27fe0f469 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sun, 7 May 2023 07:22:12 +0200 Subject: [PATCH] Vulkan loader code --- Makefile | 10 ++ ggml-vulkan-matmul.glsl | 105 +++++++++++++++++++++ ggml-vulkan.cpp | 203 ++++++++++++++++++++++++++++++++++++++++ ggml-vulkan.h | 24 +++++ ggml.c | 4 + 5 files changed, 346 insertions(+) create mode 100644 ggml-vulkan-matmul.glsl create mode 100644 ggml-vulkan.cpp create mode 100644 ggml-vulkan.h diff --git a/Makefile b/Makefile index 03f38bdba..2af16554e 100644 --- a/Makefile +++ b/Makefile @@ -213,6 +213,16 @@ ggml-metal.o: ggml-metal.m ggml-metal.h $(CC) $(CFLAGS) -c $< -o $@ endif # LLAMA_METAL +ifdef LLAMA_VULKAN + CFLAGS += -DGGML_USE_VULKAN + LDFLAGS += -lvulkan + OBJS += ggml-vulkan.o ggml-vulkan-matmul-shader +ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h + $(CXX) $(CXXFLAGS) -c $< -o $@ +ggml-vulkan-matmul-shader: + glslc -fshader-stage=compute --target-env=vulkan1.2 -O ggml-vulkan-matmul.glsl -o ggml-vulkan-matmul.spv +endif + ifneq ($(filter aarch64%,$(UNAME_M)),) # Apple M1, M2, etc. # Raspberry Pi 3, 4, Zero 2 (64-bit) diff --git a/ggml-vulkan-matmul.glsl b/ggml-vulkan-matmul.glsl new file mode 100644 index 000000000..7570e75e1 --- /dev/null +++ b/ggml-vulkan-matmul.glsl @@ -0,0 +1,105 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Modified by 0cc4m for FP32 + +#version 450 core +#pragma use_vulkan_memory_model + +#extension GL_EXT_scalar_block_layout : enable +#extension GL_EXT_control_flow_attributes : enable + +#extension GL_KHR_shader_subgroup_basic : enable + +layout(binding = 0) buffer InputA { vec4 x[]; } inputA; +layout(binding = 1) buffer InputB { vec4 x[]; } inputB; +layout(binding = 2) buffer Output { float x[]; } outputO; + +layout(local_size_x = WG_X, local_size_y = WG_Y, local_size_z = 1) in; + +layout(constant_id = 0) const uint M = 1; +layout(constant_id = 1) const uint N = 1; +layout(constant_id = 2) const uint K = 1; + +const uint VECTORIZE_K = 4; +const uint K_VEC = K / VECTORIZE_K; +const uint K0_VEC = K0 / VECTORIZE_K; + +const uint VECTORIZE_N = 4; +const uint N_VEC = N / VECTORIZE_N; +const uint N0_VEC = N0 / VECTORIZE_N; + +const uint strideA = K_VEC; // Stride of the `inputA` matrix. +const uint strideB = K_VEC; // Stride of the `inputB` matrix. +const uint strideC = N; // Stride of the `outputO` matrix. + +// Each workgroup processes an output tile of size [M0 x N0], therefore +// each thread processes a [M0/WG_Y x N0/WG_X] subview. +const uint C_ROWS = M0 / WG_Y; +const uint C_COLS = N0 / WG_X; + +/// Returns the index of `X[i, j]`, where `X` is a 2D matrix of stride |stride|. +uint coordToOffset(uint i, uint j, uint stride) { return stride * i + j; } + +float sdot(vec4 lhs, vec4 rhs) { + vec4 mul = vec4(lhs) * vec4(rhs); + return float(mul.x) + float(mul.y) + float(mul.z) + float(mul.w); +} + +void main() { + const uvec2 wgID = gl_WorkGroupID.xy; + const uvec2 localID = gl_LocalInvocationID.xy; + const uint threadID = gl_SubgroupInvocationID; + const uint subgroupID = gl_SubgroupID; + const uint subgroupSz = gl_SubgroupSize; + const uint numSubgroups = gl_NumSubgroups; + + // The start offsets of the tile processed by this thread in this workgroup. + const uint x_offset = wgID.x * N0 + C_COLS * localID.x; + const uint y_offset = wgID.y * M0 + C_ROWS * localID.y; + + float C[C_ROWS][C_COLS]; // Local data for the output. + + // Initialize result to zero. + [[unroll]] for (uint i = 0; i < C_ROWS; ++i) { + [[unroll]] for (uint j = 0; j < C_COLS; ++j) { + C[i][j] = 0; + } + } + + for (uint k = 0; k < K_VEC; k += K0_VEC) { + [[unroll]] for (uint i = 0; i < C_ROWS; ++i) { + [[unroll]] for (uint kk = 0; kk < K0_VEC; ++kk) { + uint y = y_offset + i; + uint gk = k + (kk + threadID) % K0_VEC; + vec4 lhs = inputA.x[coordToOffset(y, gk, strideA)]; + [[unroll]] for (uint j = 0; j < C_COLS; ++j) { + // Calculate the inner product `C[i, j] := sum(A[i, ..] * B[j, ..])`. + uint x = x_offset + j; + vec4 rhs = inputB.x[coordToOffset(x, gk, strideB)]; + C[i][j] += sdot(lhs, rhs); + } + } + } + } + + // Store the accumulated results in `outputO`. + [[unroll]] for (uint i = 0; i < C_ROWS; ++i) { + uint y = y_offset + i; + [[unroll]] for (uint j = 0; j < C_COLS; ++j) { + uint x = x_offset + j; + outputO.x[coordToOffset(y, x, strideC)] = C[i][j]; + } + } +} diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp new file mode 100644 index 000000000..7777154f2 --- /dev/null +++ b/ggml-vulkan.cpp @@ -0,0 +1,203 @@ +#include "ggml-vulkan.h" + +#include +#include "external/vk_mem_alloc.h" + +#include +#include + +#include "ggml.h" + +// static cl_platform_id platform; +// static cl_device_id device; +// static cl_context context; +// static cl_command_queue queue; +// static cl_program program; +// static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q5_0, kernel_q5_1, kernel_q8_0; +// static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; +// static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0; + +vk::Instance instance; +vk::PhysicalDevice physical_device; +vk::Device device; +VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc; +vk::Buffer vk_buffer_qa, vk_buffer_a, vk_buffer_b, vk_buffer_c; + +void ggml_vk_init(void) { + char* GGML_VULKAN_DEVICE = getenv("GGML_VULKAN_DEVICE"); + int dev_num = (GGML_VULKAN_DEVICE == NULL ? 0 : atoi(GGML_VULKAN_DEVICE)); + printf("\nInitializing Vulkan..."); + printf("\nAttempting to use: Device=%d\n", dev_num); + + vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION_1_2 }; + const std::vector layers = { "VK_LAYER_KHRONOS_validation" }; + vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags(), &app_info, layers.size(), layers.data()); + instance = vk::createInstance(instance_create_info); + + physical_device = instance.enumeratePhysicalDevices()[dev_num]; + vk::PhysicalDeviceProperties device_props = physical_device.getProperties(); + std::cout << "Picked: " << device_props.deviceName << std::endl; + + std::vector queue_family_props = physical_device.getQueueFamilyProperties(); + auto prop_it = std::find_if(queue_family_props.begin(), queue_family_props.end(), [](const vk::QueueFamilyProperties& prop) + { + return prop.queueFlags & vk::QueueFlagBits::eCompute; + }); + const uint32_t compute_queue_family_index = std::distance(queue_family_props.begin(), prop_it); + + const float queue_priority = 1.0f; + vk::DeviceQueueCreateInfo device_queue_create_info(vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, &queue_priority); + vk::DeviceCreateInfo device_create_info(vk::DeviceCreateFlags(), device_queue_create_info); + device = physical_device.createDevice(device_create_info); + + +} + +// static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) { +// if (req_size <= *cur_size) { +// return; +// } +// +// // Reallocate buffer with enough space +// if (*cur_size > 0) { +// clReleaseMemObject(*buf); +// } +// cl_int err; +// *buf = clCreateBuffer(context, flags, req_size, NULL, &err); +// *cur_size = req_size; +// CL_CHECK(err, "clCreateBuffer"); +// } +// +// void ggml_cl_sgemm_wrapper( +// const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, +// const int m, const int n, const int k, +// const float alpha, const void *host_a, const int lda, +// const float *host_b, const int ldb, const float beta, +// float *host_c, const int ldc, const int btype) { +// cl_int err = 0; +// +// cl_kernel kernel; +// size_t global = n * k, local, size_qb; +// bool dequant; +// cl_block_q5_0* cl_host_b; +// +// switch (btype) { +// case GGML_TYPE_F32: +// dequant = false; +// break; +// case GGML_TYPE_Q4_0: +// dequant = true; +// kernel = kernel_q4_0; +// local = 16; +// size_qb = global * (sizeof(float) + local) / 32; +// break; +// case GGML_TYPE_Q4_1: +// dequant = true; +// kernel = kernel_q4_1; +// local = 16; +// size_qb = global * (sizeof(float) * 2 + local) / 32; +// break; +// case GGML_TYPE_Q4_2: +// dequant = true; +// kernel = kernel_q4_2; +// local = 8; +// size_qb = global * (sizeof(ggml_fp16_t) + local) / 16; +// break; +// case GGML_TYPE_Q5_0: +// dequant = true; +// kernel = kernel_q5_0; +// local = 16; +// // For some reason OpenCL seems to be incapable of working with structs of size 22. +// // 20 and 24 bytes are fine. Workaround to do the fp16 to fp32 step on CPU... +// // TODO Find the reason, fix and remove workaround. +// const block_q5_0* b = (const block_q5_0*) host_b; +// cl_host_b = (cl_block_q5_0*) malloc(sizeof(cl_block_q5_0) * global / 32); +// for (size_t i = 0; i < global / 32; i++) { +// cl_host_b[i].d = ggml_fp16_to_fp32(b[i].d); +// memcpy(&cl_host_b[i].qh, b[i].qh, sizeof(uint32_t)); +// memcpy(&cl_host_b[i].qs, b[i].qs, QK5_0 / 2); +// } +// host_b = (const float*) cl_host_b; +// size_qb = global * (sizeof(float) + sizeof(uint32_t) + local) / 32; +// break; +// case GGML_TYPE_Q5_1: +// dequant = true; +// kernel = kernel_q5_1; +// local = 16; +// size_qb = global * (sizeof(ggml_fp16_t) * 2 + sizeof(uint32_t) + local) / 32; +// break; +// case GGML_TYPE_Q8_0: +// dequant = true; +// kernel = kernel_q8_0; +// local = 32; +// size_qb = global * (sizeof(float) + local) / 32; +// break; +// default: +// fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype); +// abort(); +// } +// +// const size_t size_a = m * k * sizeof(float); +// const size_t size_b = n * k * sizeof(float); +// const size_t size_c = m * n * sizeof(float); +// +// // Prepare buffers +// ggml_cl_malloc(size_a, &cl_size_a, CL_MEM_READ_ONLY, &cl_buffer_a); +// if (dequant) { +// ggml_cl_malloc(size_qb, &cl_size_qb, CL_MEM_READ_ONLY, &cl_buffer_qb); +// } +// ggml_cl_malloc(size_b, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b); +// ggml_cl_malloc(size_c, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c); +// +// cl_event ev_a, ev_qb, ev_b; +// +// if (dequant) { +// err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb); +// err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b); +// CL_CHECK(err, "clSetKernelArg"); +// err = clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb); +// CL_CHECK(err, "clEnqueueWriteBuffer qb"); +// } else { +// err = clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b); +// CL_CHECK(err, "clEnqueueWriteBuffer b"); +// } +// +// err = clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a); +// CL_CHECK(err, "clEnqueueWriteBuffer a"); +// if (dequant) { +// err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b); +// CL_CHECK(err, "clEnqueueNDRangeKernel"); +// clReleaseEvent(ev_qb); +// } +// clWaitForEvents(1, &ev_a); +// clWaitForEvents(1, &ev_b); +// clReleaseEvent(ev_a); +// clReleaseEvent(ev_b); +// +// cl_event ev_sgemm; +// CLBlastStatusCode status = CLBlastSgemm((CLBlastLayout)order, +// (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b, +// m, n, k, +// alpha, +// cl_buffer_a, 0, lda, +// cl_buffer_b, 0, ldb, +// beta, +// cl_buffer_c, 0, ldc, +// &queue, &ev_sgemm); +// +// if (status != CLBlastSuccess) { +// fprintf(stderr, "Error: CLBlast SGEMM %d\n", status); +// abort(); +// } +// +// cl_event ev_c; +// clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c); +// +// // Wait for completion +// clWaitForEvents(1, &ev_c); +// clReleaseEvent(ev_sgemm); +// clReleaseEvent(ev_c); +// if (btype == GGML_TYPE_Q5_0) { +// free((void*) cl_host_b); +// } +// } diff --git a/ggml-vulkan.h b/ggml-vulkan.h new file mode 100644 index 000000000..8dfda90a2 --- /dev/null +++ b/ggml-vulkan.h @@ -0,0 +1,24 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +void ggml_vk_init(void); + +// enum ggml_blas_order { +// GGML_BLAS_ORDER_ROW_MAJOR = 101, +// GGML_BLAS_ORDER_COLUMN_MAJOR = 102, +// }; +// +// enum ggml_blas_op { +// GGML_BLAS_OP_N = 111, +// GGML_BLAS_OP_T = 112, +// GGML_BLAS_OP_C = 113, +// }; +// +// void ggml_cl_sgemm_wrapper(const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype); + +#ifdef __cplusplus +} +#endif diff --git a/ggml.c b/ggml.c index 684caaa37..6071eabd2 100644 --- a/ggml.c +++ b/ggml.c @@ -234,6 +234,8 @@ inline static void* ggml_aligned_malloc(size_t size) { #include "ggml-cuda.h" #elif defined(GGML_USE_CLBLAST) #include "ggml-opencl.h" +#elif defined(GGML_USE_VULKAN) +#include "ggml-vulkan.h" #endif #undef MIN @@ -4265,6 +4267,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { ggml_init_cublas(); #elif defined(GGML_USE_CLBLAST) ggml_cl_init(); +#elif defined(GGML_USE_VULKAN) + ggml_vk_init(); #endif is_first_call = false;