Fix matmul kernel, continue implementation
This commit is contained in:
parent
061246fb07
commit
4a96d0eb7f
3 changed files with 49 additions and 3 deletions
2
Makefile
2
Makefile
|
@ -216,7 +216,7 @@ endif # LLAMA_METAL
|
|||
ifdef LLAMA_VULKAN
|
||||
CFLAGS += -DGGML_USE_VULKAN
|
||||
LDFLAGS += -lvulkan
|
||||
OBJS += ggml-vulkan.o ggml-vulkan-matmul-shader
|
||||
OBJS += ggml-vulkan.o
|
||||
ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
ggml-vulkan-matmul-shader:
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Original at https://github.com/google/uVkCompute/blob/f3180c7e72ae639c0a7bc8cff7ed615b63ced27c/benchmarks/mmt/mmt_i8.glsl
|
||||
// Modified by 0cc4m for FP32
|
||||
|
||||
#version 450 core
|
||||
|
@ -22,6 +23,12 @@
|
|||
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
|
||||
#define WG_X 32
|
||||
#define WG_Y 2
|
||||
#define M0 32
|
||||
#define N0 256
|
||||
#define K0 16
|
||||
|
||||
layout(binding = 0) buffer InputA { vec4 x[]; } inputA;
|
||||
layout(binding = 1) buffer InputB { vec4 x[]; } inputB;
|
||||
layout(binding = 2) buffer Output { float x[]; } outputO;
|
||||
|
|
|
@ -19,7 +19,8 @@
|
|||
|
||||
vk::Instance instance;
|
||||
vk::PhysicalDevice physical_device;
|
||||
vk::Device device;
|
||||
vk::Device vk_device;
|
||||
vk::Pipeline vk_pipeline_matmul;
|
||||
VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc;
|
||||
vk::Buffer vk_buffer_qa, vk_buffer_a, vk_buffer_b, vk_buffer_c;
|
||||
|
||||
|
@ -48,9 +49,47 @@ void ggml_vk_init(void) {
|
|||
const float queue_priority = 1.0f;
|
||||
vk::DeviceQueueCreateInfo device_queue_create_info(vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, &queue_priority);
|
||||
vk::DeviceCreateInfo device_create_info(vk::DeviceCreateFlags(), device_queue_create_info);
|
||||
device = physical_device.createDevice(device_create_info);
|
||||
vk_device = physical_device.createDevice(device_create_info);
|
||||
|
||||
std::vector<char> matmul_shader_contents;
|
||||
if (std::ifstream shader_file{ "ggml-vulkan-matmul.spv", std::ios::binary | std::ios::ate }) {
|
||||
const size_t file_size = shader_file.tellg();
|
||||
shader_file.seekg(0);
|
||||
matmul_shader_contents.resize(file_size, '\0');
|
||||
shader_file.read(matmul_shader_contents.data(), file_size);
|
||||
}
|
||||
|
||||
vk::ShaderModuleCreateInfo shader_module_create_info(
|
||||
vk::ShaderModuleCreateFlags(),
|
||||
matmul_shader_contents.size(),
|
||||
reinterpret_cast<const uint32_t*>(matmul_shader_contents.data())
|
||||
);
|
||||
vk::ShaderModule shader_module = vk_device.createShaderModule(shader_module_create_info);
|
||||
|
||||
const std::vector<vk::DescriptorSetLayoutBinding> descriptor_set_layout_binding = {
|
||||
{0, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute},
|
||||
{1, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute},
|
||||
{2, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}
|
||||
};
|
||||
vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
|
||||
vk::DescriptorSetLayoutCreateFlags(),
|
||||
descriptor_set_layout_binding);
|
||||
vk::DescriptorSetLayout descriptor_set_layout = vk_device.createDescriptorSetLayout(descriptor_set_layout_create_info);
|
||||
|
||||
vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), descriptor_set_layout);
|
||||
vk::PipelineLayout pipeline_layout = vk_device.createPipelineLayout(pipeline_layout_create_info);
|
||||
vk::PipelineCache pipeline_cache = vk_device.createPipelineCache(vk::PipelineCacheCreateInfo());
|
||||
|
||||
vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
|
||||
vk::PipelineShaderStageCreateFlags(),
|
||||
vk::ShaderStageFlagBits::eCompute,
|
||||
shader_module,
|
||||
"main");
|
||||
vk::ComputePipelineCreateInfo compute_pipeline_create_info(
|
||||
vk::PipelineCreateFlags(), // Flags
|
||||
pipeline_shader_create_info, // Shader Create Info struct
|
||||
pipeline_layout); // Pipeline Layout
|
||||
vk_pipeline_matmul = vk_device.createComputePipeline(pipeline_cache, compute_pipeline_create_info).value;
|
||||
}
|
||||
|
||||
// static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue