Vulkan loader code

This commit is contained in:
0cc4m 2023-05-07 07:22:12 +02:00
parent b8c8dda75f
commit 061246fb07
5 changed files with 346 additions and 0 deletions

View file

@ -213,6 +213,16 @@ ggml-metal.o: ggml-metal.m ggml-metal.h
$(CC) $(CFLAGS) -c $< -o $@ $(CC) $(CFLAGS) -c $< -o $@
endif # LLAMA_METAL endif # LLAMA_METAL
ifdef LLAMA_VULKAN
CFLAGS += -DGGML_USE_VULKAN
LDFLAGS += -lvulkan
OBJS += ggml-vulkan.o ggml-vulkan-matmul-shader
ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
$(CXX) $(CXXFLAGS) -c $< -o $@
ggml-vulkan-matmul-shader:
glslc -fshader-stage=compute --target-env=vulkan1.2 -O ggml-vulkan-matmul.glsl -o ggml-vulkan-matmul.spv
endif
ifneq ($(filter aarch64%,$(UNAME_M)),) ifneq ($(filter aarch64%,$(UNAME_M)),)
# Apple M1, M2, etc. # Apple M1, M2, etc.
# Raspberry Pi 3, 4, Zero 2 (64-bit) # Raspberry Pi 3, 4, Zero 2 (64-bit)

105
ggml-vulkan-matmul.glsl Normal file
View file

@ -0,0 +1,105 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Modified by 0cc4m for FP32
#version 450 core
#pragma use_vulkan_memory_model
#extension GL_EXT_scalar_block_layout : enable
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_basic : enable
layout(binding = 0) buffer InputA { vec4 x[]; } inputA;
layout(binding = 1) buffer InputB { vec4 x[]; } inputB;
layout(binding = 2) buffer Output { float x[]; } outputO;
layout(local_size_x = WG_X, local_size_y = WG_Y, local_size_z = 1) in;
layout(constant_id = 0) const uint M = 1;
layout(constant_id = 1) const uint N = 1;
layout(constant_id = 2) const uint K = 1;
const uint VECTORIZE_K = 4;
const uint K_VEC = K / VECTORIZE_K;
const uint K0_VEC = K0 / VECTORIZE_K;
const uint VECTORIZE_N = 4;
const uint N_VEC = N / VECTORIZE_N;
const uint N0_VEC = N0 / VECTORIZE_N;
const uint strideA = K_VEC; // Stride of the `inputA` matrix.
const uint strideB = K_VEC; // Stride of the `inputB` matrix.
const uint strideC = N; // Stride of the `outputO` matrix.
// Each workgroup processes an output tile of size [M0 x N0], therefore
// each thread processes a [M0/WG_Y x N0/WG_X] subview.
const uint C_ROWS = M0 / WG_Y;
const uint C_COLS = N0 / WG_X;
/// Returns the index of `X[i, j]`, where `X` is a 2D matrix of stride |stride|.
uint coordToOffset(uint i, uint j, uint stride) { return stride * i + j; }
float sdot(vec4 lhs, vec4 rhs) {
vec4 mul = vec4(lhs) * vec4(rhs);
return float(mul.x) + float(mul.y) + float(mul.z) + float(mul.w);
}
void main() {
const uvec2 wgID = gl_WorkGroupID.xy;
const uvec2 localID = gl_LocalInvocationID.xy;
const uint threadID = gl_SubgroupInvocationID;
const uint subgroupID = gl_SubgroupID;
const uint subgroupSz = gl_SubgroupSize;
const uint numSubgroups = gl_NumSubgroups;
// The start offsets of the tile processed by this thread in this workgroup.
const uint x_offset = wgID.x * N0 + C_COLS * localID.x;
const uint y_offset = wgID.y * M0 + C_ROWS * localID.y;
float C[C_ROWS][C_COLS]; // Local data for the output.
// Initialize result to zero.
[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
[[unroll]] for (uint j = 0; j < C_COLS; ++j) {
C[i][j] = 0;
}
}
for (uint k = 0; k < K_VEC; k += K0_VEC) {
[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
[[unroll]] for (uint kk = 0; kk < K0_VEC; ++kk) {
uint y = y_offset + i;
uint gk = k + (kk + threadID) % K0_VEC;
vec4 lhs = inputA.x[coordToOffset(y, gk, strideA)];
[[unroll]] for (uint j = 0; j < C_COLS; ++j) {
// Calculate the inner product `C[i, j] := sum(A[i, ..] * B[j, ..])`.
uint x = x_offset + j;
vec4 rhs = inputB.x[coordToOffset(x, gk, strideB)];
C[i][j] += sdot(lhs, rhs);
}
}
}
}
// Store the accumulated results in `outputO`.
[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
uint y = y_offset + i;
[[unroll]] for (uint j = 0; j < C_COLS; ++j) {
uint x = x_offset + j;
outputO.x[coordToOffset(y, x, strideC)] = C[i][j];
}
}
}

203
ggml-vulkan.cpp Normal file
View file

@ -0,0 +1,203 @@
#include "ggml-vulkan.h"
#include <vulkan/vulkan.hpp>
#include "external/vk_mem_alloc.h"
#include <iostream>
#include <fstream>
#include "ggml.h"
// static cl_platform_id platform;
// static cl_device_id device;
// static cl_context context;
// static cl_command_queue queue;
// static cl_program program;
// static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q5_0, kernel_q5_1, kernel_q8_0;
// static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
// static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
vk::Instance instance;
vk::PhysicalDevice physical_device;
vk::Device device;
VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc;
vk::Buffer vk_buffer_qa, vk_buffer_a, vk_buffer_b, vk_buffer_c;
void ggml_vk_init(void) {
char* GGML_VULKAN_DEVICE = getenv("GGML_VULKAN_DEVICE");
int dev_num = (GGML_VULKAN_DEVICE == NULL ? 0 : atoi(GGML_VULKAN_DEVICE));
printf("\nInitializing Vulkan...");
printf("\nAttempting to use: Device=%d\n", dev_num);
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION_1_2 };
const std::vector<const char*> layers = { "VK_LAYER_KHRONOS_validation" };
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags(), &app_info, layers.size(), layers.data());
instance = vk::createInstance(instance_create_info);
physical_device = instance.enumeratePhysicalDevices()[dev_num];
vk::PhysicalDeviceProperties device_props = physical_device.getProperties();
std::cout << "Picked: " << device_props.deviceName << std::endl;
std::vector<vk::QueueFamilyProperties> queue_family_props = physical_device.getQueueFamilyProperties();
auto prop_it = std::find_if(queue_family_props.begin(), queue_family_props.end(), [](const vk::QueueFamilyProperties& prop)
{
return prop.queueFlags & vk::QueueFlagBits::eCompute;
});
const uint32_t compute_queue_family_index = std::distance(queue_family_props.begin(), prop_it);
const float queue_priority = 1.0f;
vk::DeviceQueueCreateInfo device_queue_create_info(vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, &queue_priority);
vk::DeviceCreateInfo device_create_info(vk::DeviceCreateFlags(), device_queue_create_info);
device = physical_device.createDevice(device_create_info);
}
// static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
// if (req_size <= *cur_size) {
// return;
// }
//
// // Reallocate buffer with enough space
// if (*cur_size > 0) {
// clReleaseMemObject(*buf);
// }
// cl_int err;
// *buf = clCreateBuffer(context, flags, req_size, NULL, &err);
// *cur_size = req_size;
// CL_CHECK(err, "clCreateBuffer");
// }
//
// void ggml_cl_sgemm_wrapper(
// const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b,
// const int m, const int n, const int k,
// const float alpha, const void *host_a, const int lda,
// const float *host_b, const int ldb, const float beta,
// float *host_c, const int ldc, const int btype) {
// cl_int err = 0;
//
// cl_kernel kernel;
// size_t global = n * k, local, size_qb;
// bool dequant;
// cl_block_q5_0* cl_host_b;
//
// switch (btype) {
// case GGML_TYPE_F32:
// dequant = false;
// break;
// case GGML_TYPE_Q4_0:
// dequant = true;
// kernel = kernel_q4_0;
// local = 16;
// size_qb = global * (sizeof(float) + local) / 32;
// break;
// case GGML_TYPE_Q4_1:
// dequant = true;
// kernel = kernel_q4_1;
// local = 16;
// size_qb = global * (sizeof(float) * 2 + local) / 32;
// break;
// case GGML_TYPE_Q4_2:
// dequant = true;
// kernel = kernel_q4_2;
// local = 8;
// size_qb = global * (sizeof(ggml_fp16_t) + local) / 16;
// break;
// case GGML_TYPE_Q5_0:
// dequant = true;
// kernel = kernel_q5_0;
// local = 16;
// // For some reason OpenCL seems to be incapable of working with structs of size 22.
// // 20 and 24 bytes are fine. Workaround to do the fp16 to fp32 step on CPU...
// // TODO Find the reason, fix and remove workaround.
// const block_q5_0* b = (const block_q5_0*) host_b;
// cl_host_b = (cl_block_q5_0*) malloc(sizeof(cl_block_q5_0) * global / 32);
// for (size_t i = 0; i < global / 32; i++) {
// cl_host_b[i].d = ggml_fp16_to_fp32(b[i].d);
// memcpy(&cl_host_b[i].qh, b[i].qh, sizeof(uint32_t));
// memcpy(&cl_host_b[i].qs, b[i].qs, QK5_0 / 2);
// }
// host_b = (const float*) cl_host_b;
// size_qb = global * (sizeof(float) + sizeof(uint32_t) + local) / 32;
// break;
// case GGML_TYPE_Q5_1:
// dequant = true;
// kernel = kernel_q5_1;
// local = 16;
// size_qb = global * (sizeof(ggml_fp16_t) * 2 + sizeof(uint32_t) + local) / 32;
// break;
// case GGML_TYPE_Q8_0:
// dequant = true;
// kernel = kernel_q8_0;
// local = 32;
// size_qb = global * (sizeof(float) + local) / 32;
// break;
// default:
// fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
// abort();
// }
//
// const size_t size_a = m * k * sizeof(float);
// const size_t size_b = n * k * sizeof(float);
// const size_t size_c = m * n * sizeof(float);
//
// // Prepare buffers
// ggml_cl_malloc(size_a, &cl_size_a, CL_MEM_READ_ONLY, &cl_buffer_a);
// if (dequant) {
// ggml_cl_malloc(size_qb, &cl_size_qb, CL_MEM_READ_ONLY, &cl_buffer_qb);
// }
// ggml_cl_malloc(size_b, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b);
// ggml_cl_malloc(size_c, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c);
//
// cl_event ev_a, ev_qb, ev_b;
//
// if (dequant) {
// err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb);
// err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b);
// CL_CHECK(err, "clSetKernelArg");
// err = clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb);
// CL_CHECK(err, "clEnqueueWriteBuffer qb");
// } else {
// err = clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b);
// CL_CHECK(err, "clEnqueueWriteBuffer b");
// }
//
// err = clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a);
// CL_CHECK(err, "clEnqueueWriteBuffer a");
// if (dequant) {
// err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b);
// CL_CHECK(err, "clEnqueueNDRangeKernel");
// clReleaseEvent(ev_qb);
// }
// clWaitForEvents(1, &ev_a);
// clWaitForEvents(1, &ev_b);
// clReleaseEvent(ev_a);
// clReleaseEvent(ev_b);
//
// cl_event ev_sgemm;
// CLBlastStatusCode status = CLBlastSgemm((CLBlastLayout)order,
// (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
// m, n, k,
// alpha,
// cl_buffer_a, 0, lda,
// cl_buffer_b, 0, ldb,
// beta,
// cl_buffer_c, 0, ldc,
// &queue, &ev_sgemm);
//
// if (status != CLBlastSuccess) {
// fprintf(stderr, "Error: CLBlast SGEMM %d\n", status);
// abort();
// }
//
// cl_event ev_c;
// clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c);
//
// // Wait for completion
// clWaitForEvents(1, &ev_c);
// clReleaseEvent(ev_sgemm);
// clReleaseEvent(ev_c);
// if (btype == GGML_TYPE_Q5_0) {
// free((void*) cl_host_b);
// }
// }

24
ggml-vulkan.h Normal file
View file

@ -0,0 +1,24 @@
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
void ggml_vk_init(void);
// enum ggml_blas_order {
// GGML_BLAS_ORDER_ROW_MAJOR = 101,
// GGML_BLAS_ORDER_COLUMN_MAJOR = 102,
// };
//
// enum ggml_blas_op {
// GGML_BLAS_OP_N = 111,
// GGML_BLAS_OP_T = 112,
// GGML_BLAS_OP_C = 113,
// };
//
// void ggml_cl_sgemm_wrapper(const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype);
#ifdef __cplusplus
}
#endif

4
ggml.c
View file

@ -234,6 +234,8 @@ inline static void* ggml_aligned_malloc(size_t size) {
#include "ggml-cuda.h" #include "ggml-cuda.h"
#elif defined(GGML_USE_CLBLAST) #elif defined(GGML_USE_CLBLAST)
#include "ggml-opencl.h" #include "ggml-opencl.h"
#elif defined(GGML_USE_VULKAN)
#include "ggml-vulkan.h"
#endif #endif
#undef MIN #undef MIN
@ -4265,6 +4267,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
ggml_init_cublas(); ggml_init_cublas();
#elif defined(GGML_USE_CLBLAST) #elif defined(GGML_USE_CLBLAST)
ggml_cl_init(); ggml_cl_init();
#elif defined(GGML_USE_VULKAN)
ggml_vk_init();
#endif #endif
is_first_call = false; is_first_call = false;