Finish merge of ClBlast support
This commit is contained in:
parent
b7143c1a2e
commit
6f66870726
3 changed files with 272 additions and 260 deletions
333
ggml.c
333
ggml.c
|
@ -143,28 +143,213 @@ inline static void* ggml_aligned_malloc(size_t size) {
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#if GGML_USE_CLBLAST
|
|
||||||
#ifndef GGML_USE_OPENBLAS
|
|
||||||
#define GGML_USE_OPENBLAS
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define CL_TARGET_OPENCL_VERSION 110
|
|
||||||
#include <clblast_c.h>
|
|
||||||
|
|
||||||
cl_platform_id platform;
|
|
||||||
cl_device_id device;
|
|
||||||
cl_context context;
|
|
||||||
cl_command_queue queue;
|
|
||||||
cl_event event;
|
|
||||||
bool cl_initialized = false;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE)
|
#if defined(GGML_USE_ACCELERATE)
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
#elif defined(GGML_USE_OPENBLAS)
|
#elif defined(GGML_USE_OPENBLAS)
|
||||||
#include <cblas.h>
|
#include <cblas.h>
|
||||||
#elif defined(GGML_USE_CUBLAS)
|
#elif defined(GGML_USE_CUBLAS)
|
||||||
#include "ggml-cuda.h"
|
#include "ggml-cuda.h"
|
||||||
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
|
#define CL_TARGET_OPENCL_VERSION 110
|
||||||
|
#include <clblast_c.h>
|
||||||
|
#include <ggml_clblast_dequant.cl>
|
||||||
|
#include <cblas.h>
|
||||||
|
|
||||||
|
cl_platform_id platform;
|
||||||
|
cl_device_id device;
|
||||||
|
cl_context context;
|
||||||
|
cl_command_queue queue;
|
||||||
|
cl_program program;
|
||||||
|
cl_kernel kernel_q4_0, kernel_q4_1;
|
||||||
|
bool cl_initialized = false;
|
||||||
|
size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0;
|
||||||
|
|
||||||
|
cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) {
|
||||||
|
cl_program program;
|
||||||
|
char *program_log;
|
||||||
|
size_t program_size, log_size;
|
||||||
|
int err;
|
||||||
|
|
||||||
|
program_size = strlen(program_buffer);
|
||||||
|
|
||||||
|
program = clCreateProgramWithSource(ctx, 1,
|
||||||
|
(const char**)&program_buffer, &program_size, &err);
|
||||||
|
if(err < 0) {
|
||||||
|
perror("OpenCL error creating program");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
err = clBuildProgram(program, 0, NULL, NULL, NULL, NULL);
|
||||||
|
if(err < 0) {
|
||||||
|
|
||||||
|
clGetProgramBuildInfo(program, dev, CL_PROGRAM_BUILD_LOG,
|
||||||
|
0, NULL, &log_size);
|
||||||
|
program_log = (char*) malloc(log_size + 1);
|
||||||
|
program_log[log_size] = '\0';
|
||||||
|
clGetProgramBuildInfo(program, dev, CL_PROGRAM_BUILD_LOG,
|
||||||
|
log_size + 1, program_log, NULL);
|
||||||
|
printf("%s\n", program_log);
|
||||||
|
free(program_log);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return program;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cl_init() {
|
||||||
|
cl_int err = 0;
|
||||||
|
char * KCPP_CLBLAST_PLATFORM = getenv("KCPP_CLBLAST_PLATFORM");
|
||||||
|
char * KCPP_CLBLAST_DEVICES = getenv("KCPP_CLBLAST_DEVICES");
|
||||||
|
int plat_num = (KCPP_CLBLAST_PLATFORM == NULL ? 0 : atoi(KCPP_CLBLAST_PLATFORM));
|
||||||
|
int dev_num = (KCPP_CLBLAST_DEVICES == NULL ? 0 : atoi(KCPP_CLBLAST_DEVICES));
|
||||||
|
printf("\nInitializing CLBlast (First Run)...");
|
||||||
|
printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
|
||||||
|
cl_uint num_platforms;
|
||||||
|
clGetPlatformIDs(0, NULL, &num_platforms);
|
||||||
|
cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
|
||||||
|
clGetPlatformIDs(num_platforms, platforms, NULL);
|
||||||
|
platform = platforms[plat_num];
|
||||||
|
char platform_buffer[1024];
|
||||||
|
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
|
||||||
|
cl_uint num_devices;
|
||||||
|
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
|
||||||
|
cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
|
||||||
|
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
|
||||||
|
device = devices[dev_num];
|
||||||
|
char device_buffer[1024];
|
||||||
|
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
|
||||||
|
printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
|
||||||
|
context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
|
||||||
|
if (err != CL_SUCCESS) {
|
||||||
|
printf("Error creating OpenCL context: %d\n", err);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
queue = clCreateCommandQueue(context, device, 0, &err);
|
||||||
|
|
||||||
|
if (err != CL_SUCCESS) {
|
||||||
|
printf("Error creating OpenCL Command Queue: %d\n", err);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
free(platforms);
|
||||||
|
free(devices);
|
||||||
|
|
||||||
|
program = build_program_from_source(context, device, clblast_dequant);
|
||||||
|
|
||||||
|
// Prepare dequantize kernels
|
||||||
|
kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err);
|
||||||
|
if(err < 0) {
|
||||||
|
printf("Error creating OpenCL dequantize q4_0 kernel: %d\n", err);
|
||||||
|
fflush(stdout);
|
||||||
|
};
|
||||||
|
kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err);
|
||||||
|
if(err < 0) {
|
||||||
|
printf("Error creating OpenCL dequantize q4_1 kernel: %d\n", err);
|
||||||
|
fflush(stdout);
|
||||||
|
};
|
||||||
|
cl_initialized = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cl_sgemm_wrapper(const CBLAS_ORDER order, const CBLAS_TRANSPOSE trans_a, const CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) {
|
||||||
|
cl_int err = 0;
|
||||||
|
|
||||||
|
cl_event events[4];
|
||||||
|
events[0] = NULL;
|
||||||
|
events[1] = NULL;
|
||||||
|
events[2] = NULL;
|
||||||
|
events[3] = NULL;
|
||||||
|
|
||||||
|
if (!cl_initialized) {
|
||||||
|
ggml_cl_init();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool dequant = (btype == 2 || btype == 3);
|
||||||
|
cl_kernel kernel = btype == 2 ? kernel_q4_0 : kernel_q4_1;
|
||||||
|
|
||||||
|
size_t global = n * k, local = 16, qb_size;
|
||||||
|
cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
|
||||||
|
|
||||||
|
// Prepare buffers
|
||||||
|
cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, m*k*sizeof(float), NULL, &err);
|
||||||
|
if (err != CL_SUCCESS) {
|
||||||
|
printf("Error creating OpenCL Buffer A: %d\n", err);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
if (dequant) {
|
||||||
|
qb_size = global * (sizeof(float) * (btype == 2 ? 1 : 2) + 16) / 32;
|
||||||
|
cl_buffer_qb = clCreateBuffer(context, CL_MEM_READ_ONLY, qb_size, NULL, &err);
|
||||||
|
if (err != CL_SUCCESS) {
|
||||||
|
printf("Error creating OpenCL Buffer QB: %d\n", err);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_WRITE, n*k*sizeof(float), NULL, &err);
|
||||||
|
if (err != CL_SUCCESS) {
|
||||||
|
printf("Error creating OpenCL Buffer B: %d\n", err);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, m*n*sizeof(float), NULL, &err);
|
||||||
|
if (err != CL_SUCCESS) {
|
||||||
|
printf("Error creating OpenCL Buffer C: %d\n", err);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dequant) {
|
||||||
|
err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb);
|
||||||
|
err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b);
|
||||||
|
if(err < 0) {
|
||||||
|
printf("Error setting OpenCL kernel args: %d\n", err);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, qb_size, host_b, 0, NULL, events + 1);
|
||||||
|
} else {
|
||||||
|
clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, n*k*sizeof(float), host_b, 0, NULL, events + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, m*k*sizeof(float), host_a, 0, NULL, events);
|
||||||
|
clEnqueueWriteBuffer(queue, cl_buffer_c, CL_FALSE, 0, m*n*sizeof(float), host_c, 0, NULL, events + 2);
|
||||||
|
if (dequant) {
|
||||||
|
err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, events + 1, events + 3);
|
||||||
|
if(err < 0) {
|
||||||
|
printf("Error enqueueing OpenCL dequantize kernel: %d\n", err);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
clWaitForEvents(dequant ? 4 : 3, events);
|
||||||
|
clReleaseEvent(events[0]);
|
||||||
|
clReleaseEvent(events[1]);
|
||||||
|
clReleaseEvent(events[2]);
|
||||||
|
if (dequant) {
|
||||||
|
clReleaseEvent(events[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the SGEMM routine.
|
||||||
|
CLBlastStatusCode status = CLBlastSgemm(order,
|
||||||
|
trans_a, trans_b,
|
||||||
|
m, n, k,
|
||||||
|
alpha,
|
||||||
|
cl_buffer_a, 0, lda,
|
||||||
|
cl_buffer_b, 0, ldb,
|
||||||
|
beta,
|
||||||
|
cl_buffer_c, 0, ldc,
|
||||||
|
&queue, events);
|
||||||
|
|
||||||
|
clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 1, events, events + 1);
|
||||||
|
|
||||||
|
// Wait for completion
|
||||||
|
if (status == CLBlastSuccess) {
|
||||||
|
clWaitForEvents(2, events);
|
||||||
|
clReleaseEvent(events[0]);
|
||||||
|
clReleaseEvent(events[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
clReleaseMemObject(cl_buffer_a);
|
||||||
|
if (dequant) {
|
||||||
|
clReleaseMemObject(cl_buffer_qb);
|
||||||
|
}
|
||||||
|
clReleaseMemObject(cl_buffer_b);
|
||||||
|
clReleaseMemObject(cl_buffer_c);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#undef MIN
|
#undef MIN
|
||||||
|
@ -3705,6 +3890,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
||||||
// initialize cuBLAS
|
// initialize cuBLAS
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
ggml_init_cublas();
|
ggml_init_cublas();
|
||||||
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
|
ggml_cl_init();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
is_first_call = false;
|
is_first_call = false;
|
||||||
|
@ -7464,84 +7651,6 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_CLBLAST
|
|
||||||
static bool ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const float *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc) {
|
|
||||||
cl_int err = 0;
|
|
||||||
|
|
||||||
if (!cl_initialized) {
|
|
||||||
cl_uint num_platforms;
|
|
||||||
clGetPlatformIDs(0, NULL, &num_platforms);
|
|
||||||
cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
|
|
||||||
clGetPlatformIDs(num_platforms, platforms, NULL);
|
|
||||||
platform = platforms[0];
|
|
||||||
cl_uint num_devices;
|
|
||||||
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
|
|
||||||
cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
|
|
||||||
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
|
|
||||||
device = devices[0];
|
|
||||||
context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
|
|
||||||
if (err != CL_SUCCESS) {
|
|
||||||
printf("Error creating OpenCL context: %d\n", err);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
queue = clCreateCommandQueue(context, device, 0, &err);
|
|
||||||
event = NULL;
|
|
||||||
|
|
||||||
if (err != CL_SUCCESS) {
|
|
||||||
printf("Error creating OpenCL Command Queue: %d\n", err);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
|
|
||||||
free(platforms);
|
|
||||||
free(devices);
|
|
||||||
cl_initialized = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare buffers
|
|
||||||
cl_mem cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_WRITE, m*k*sizeof(float), NULL, &err);
|
|
||||||
if (err != CL_SUCCESS) {
|
|
||||||
printf("Error creating OpenCL Buffer A: %d\n", err);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
cl_mem cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_WRITE, n*k*sizeof(float), NULL, &err);
|
|
||||||
if (err != CL_SUCCESS) {
|
|
||||||
printf("Error creating OpenCL Buffer B: %d\n", err);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
cl_mem cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, m*n*sizeof(float), NULL, &err);
|
|
||||||
if (err != CL_SUCCESS) {
|
|
||||||
printf("Error creating OpenCL Buffer C: %d\n", err);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
|
|
||||||
clEnqueueWriteBuffer(queue, cl_buffer_a, CL_TRUE, 0, m*k*sizeof(float), host_a, 0, NULL, NULL);
|
|
||||||
clEnqueueWriteBuffer(queue, cl_buffer_b, CL_TRUE, 0, n*k*sizeof(float), host_b, 0, NULL, NULL);
|
|
||||||
clEnqueueWriteBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 0, NULL, NULL);
|
|
||||||
|
|
||||||
// Call the SGEMM routine.
|
|
||||||
CLBlastStatusCode status = CLBlastSgemm(order,
|
|
||||||
trans_a, trans_b,
|
|
||||||
m, n, k,
|
|
||||||
alpha,
|
|
||||||
cl_buffer_a, 0, lda,
|
|
||||||
cl_buffer_b, 0, ldb,
|
|
||||||
beta,
|
|
||||||
cl_buffer_c, 0, ldc,
|
|
||||||
&queue, &event);
|
|
||||||
|
|
||||||
// Wait for completion
|
|
||||||
if (status == CLBlastSuccess) {
|
|
||||||
clWaitForEvents(1, &event);
|
|
||||||
clReleaseEvent(event);
|
|
||||||
}
|
|
||||||
|
|
||||||
clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 0, NULL, NULL);
|
|
||||||
|
|
||||||
clReleaseMemObject(cl_buffer_a);
|
|
||||||
clReleaseMemObject(cl_buffer_b);
|
|
||||||
clReleaseMemObject(cl_buffer_c);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static void ggml_compute_forward_mul_mat_f32(
|
static void ggml_compute_forward_mul_mat_f32(
|
||||||
|
@ -7663,14 +7772,14 @@ static void ggml_compute_forward_mul_mat_f32(
|
||||||
|
|
||||||
// copy data to host
|
// copy data to host
|
||||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
||||||
#else
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
// zT = y * xT
|
// zT = y * xT
|
||||||
#ifdef GGML_USE_CLBLAST
|
|
||||||
ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans,
|
ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||||
ne11, ne01, ne10,
|
ne11, ne01, ne10,
|
||||||
1.0f, y, ne10,
|
1.0f, y, ne10,
|
||||||
x, ne10,
|
x, ne10,
|
||||||
0.0f, d, ne01);
|
0.0f, d, ne01,
|
||||||
|
params->type);
|
||||||
#else
|
#else
|
||||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||||
ne11, ne01, ne10,
|
ne11, ne01, ne10,
|
||||||
|
@ -7892,6 +8001,19 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||||
|
|
||||||
// copy data to host
|
// copy data to host
|
||||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
||||||
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
|
const float * x = wdata;
|
||||||
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
||||||
|
|
||||||
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
|
// zT = y * xT
|
||||||
|
ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||||
|
ne11, ne01, ne10,
|
||||||
|
1.0f, y, ne10,
|
||||||
|
x, ne10,
|
||||||
|
0.0f, d, ne01,
|
||||||
|
params->type);
|
||||||
#else
|
#else
|
||||||
const float * x = wdata;
|
const float * x = wdata;
|
||||||
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
||||||
|
@ -7899,13 +8021,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
// zT = y * xT
|
// zT = y * xT
|
||||||
#ifdef GGML_USE_CLBLAST
|
|
||||||
ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans,
|
|
||||||
ne11, ne01, ne10,
|
|
||||||
1.0f, y, ne10,
|
|
||||||
x, ne10,
|
|
||||||
0.0f, d, ne01);
|
|
||||||
#else
|
|
||||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||||
ne11, ne01, ne10,
|
ne11, ne01, ne10,
|
||||||
1.0f, y, ne10,
|
1.0f, y, ne10,
|
||||||
|
@ -8135,6 +8250,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
|
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
#else
|
#else
|
||||||
|
#ifndef GGML_USE_CLBLAST
|
||||||
{
|
{
|
||||||
size_t id = 0;
|
size_t id = 0;
|
||||||
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
||||||
|
@ -8143,6 +8259,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const float * x = wdata;
|
const float * x = wdata;
|
||||||
|
#else
|
||||||
|
const void* x = src0->data + i03*nb03 + i02*nb02;
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
@ -8160,14 +8279,14 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
|
|
||||||
// copy data to host
|
// copy data to host
|
||||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
||||||
#else
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
// zT = y * xT
|
// zT = y * xT
|
||||||
#ifdef GGML_USE_CLBLAST
|
|
||||||
ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans,
|
ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||||
ne11, ne01, ne10,
|
ne11, ne01, ne10,
|
||||||
1.0f, y, ne10,
|
1.0f, y, ne10,
|
||||||
x, ne10,
|
x, ne10,
|
||||||
0.0f, d, ne01);
|
0.0f, d, ne01,
|
||||||
|
type);
|
||||||
#else
|
#else
|
||||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||||
ne11, ne01, ne10,
|
ne11, ne01, ne10,
|
||||||
|
|
|
@ -1,153 +0,0 @@
|
||||||
//this is a drop-in for all CLBlast related code, to keep the main ggml.c unmodified
|
|
||||||
// we will imitate the function definition from OpenBLAS instead, replaced as necessary.
|
|
||||||
|
|
||||||
//windows binaries for clblast obtained from https://github.com/CNugteren/CLBlast (apache license)
|
|
||||||
//windows binaries for opencl obtained from https://github.com/KhronosGroup/OpenCL-SDK (apache license)
|
|
||||||
|
|
||||||
#if GGML_USE_OPENBLAS
|
|
||||||
#include <cblas.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
|
|
||||||
#if GGML_USE_CLBLAST
|
|
||||||
|
|
||||||
#define CL_TARGET_OPENCL_VERSION 110
|
|
||||||
#include <clblast_c.h>
|
|
||||||
|
|
||||||
cl_platform_id platform;
|
|
||||||
cl_device_id device;
|
|
||||||
cl_context context;
|
|
||||||
cl_command_queue queue;
|
|
||||||
bool cl_initialized = false;
|
|
||||||
size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0;
|
|
||||||
cl_mem cl_buffer_a, cl_buffer_b, cl_buffer_c;
|
|
||||||
|
|
||||||
static void ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const float *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc) {
|
|
||||||
cl_int err = 0;
|
|
||||||
|
|
||||||
cl_event events[2];
|
|
||||||
events[0] = NULL;
|
|
||||||
events[1] = NULL;
|
|
||||||
|
|
||||||
if (!cl_initialized) {
|
|
||||||
char * KCPP_CLBLAST_PLATFORM = getenv("KCPP_CLBLAST_PLATFORM");
|
|
||||||
char * KCPP_CLBLAST_DEVICES = getenv("KCPP_CLBLAST_DEVICES");
|
|
||||||
int plat_num = (KCPP_CLBLAST_PLATFORM == NULL ? 0 : atoi(KCPP_CLBLAST_PLATFORM));
|
|
||||||
int dev_num = (KCPP_CLBLAST_DEVICES == NULL ? 0 : atoi(KCPP_CLBLAST_DEVICES));
|
|
||||||
printf("\nInitializing CLBlast (First Run)...");
|
|
||||||
printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
|
|
||||||
cl_uint num_platforms;
|
|
||||||
clGetPlatformIDs(0, NULL, &num_platforms);
|
|
||||||
cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
|
|
||||||
clGetPlatformIDs(num_platforms, platforms, NULL);
|
|
||||||
platform = platforms[plat_num];
|
|
||||||
char platform_buffer[1024];
|
|
||||||
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
|
|
||||||
cl_uint num_devices;
|
|
||||||
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
|
|
||||||
cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
|
|
||||||
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
|
|
||||||
device = devices[dev_num];
|
|
||||||
char device_buffer[1024];
|
|
||||||
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
|
|
||||||
printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
|
|
||||||
context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
|
|
||||||
if (err != CL_SUCCESS) {
|
|
||||||
printf("Error creating OpenCL context: %d\n", err);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
queue = clCreateCommandQueue(context, device, 0, &err);
|
|
||||||
|
|
||||||
if (err != CL_SUCCESS) {
|
|
||||||
printf("Error creating OpenCL Command Queue: %d\n", err);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
|
|
||||||
free(platforms);
|
|
||||||
free(devices);
|
|
||||||
|
|
||||||
cl_size_a = m * k * sizeof(float);
|
|
||||||
cl_size_b = n * k * sizeof(float);
|
|
||||||
cl_size_c = m * n * sizeof(float);
|
|
||||||
|
|
||||||
// Prepare buffers
|
|
||||||
cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_a, NULL, &err);
|
|
||||||
if (err != CL_SUCCESS) {
|
|
||||||
printf("Error creating OpenCL Buffer A: %d\n", err);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_b, NULL, &err);
|
|
||||||
if (err != CL_SUCCESS) {
|
|
||||||
printf("Error creating OpenCL Buffer B: %d\n", err);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, cl_size_c, NULL, &err);
|
|
||||||
if (err != CL_SUCCESS) {
|
|
||||||
printf("Error creating OpenCL Buffer C: %d\n", err);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
|
|
||||||
cl_initialized = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resize buffers if too small
|
|
||||||
if (m * k * sizeof(float) > cl_size_a) {
|
|
||||||
clReleaseMemObject(cl_buffer_a);
|
|
||||||
cl_size_a = m * k * sizeof(float);
|
|
||||||
cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_a, NULL, NULL);
|
|
||||||
}
|
|
||||||
if (n * k * sizeof(float) > cl_size_b) {
|
|
||||||
clReleaseMemObject(cl_buffer_b);
|
|
||||||
cl_size_b = n * k * sizeof(float);
|
|
||||||
cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_b, NULL, NULL);
|
|
||||||
}
|
|
||||||
if (m * n * sizeof(float) > cl_size_c) {
|
|
||||||
clReleaseMemObject(cl_buffer_c);
|
|
||||||
cl_size_c = m * n * sizeof(float);
|
|
||||||
cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, cl_size_c, NULL, NULL);
|
|
||||||
}
|
|
||||||
|
|
||||||
clEnqueueWriteBuffer(queue, cl_buffer_a, CL_TRUE, 0, cl_size_a, host_a, 0, NULL, events);
|
|
||||||
clEnqueueWriteBuffer(queue, cl_buffer_b, CL_TRUE, 0, cl_size_b, host_b, 0, NULL, events + 1);
|
|
||||||
// buffer c is not required for this use case
|
|
||||||
// clEnqueueWriteBuffer(queue, cl_buffer_c, CL_TRUE, 0, cl_size_c, host_c, 0, NULL, NULL);
|
|
||||||
|
|
||||||
clWaitForEvents(2, events);
|
|
||||||
clReleaseEvent(events[0]);
|
|
||||||
clReleaseEvent(events[1]);
|
|
||||||
|
|
||||||
// Call the SGEMM routine.
|
|
||||||
CLBlastStatusCode status = CLBlastSgemm(order,
|
|
||||||
trans_a, trans_b,
|
|
||||||
m, n, k,
|
|
||||||
alpha,
|
|
||||||
cl_buffer_a, 0, lda,
|
|
||||||
cl_buffer_b, 0, ldb,
|
|
||||||
beta,
|
|
||||||
cl_buffer_c, 0, ldc,
|
|
||||||
&queue, events);
|
|
||||||
|
|
||||||
clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 1, events, events + 1);
|
|
||||||
|
|
||||||
// Wait for completion
|
|
||||||
if (status == CLBlastSuccess) {
|
|
||||||
clWaitForEvents(2, events);
|
|
||||||
clReleaseEvent(events[0]);
|
|
||||||
clReleaseEvent(events[1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
|
||||||
#if GGML_USE_CLBLAST
|
|
||||||
#define do_blas_sgemm(Order, TransA, TransB,M, N, K,alpha, A, lda, B, ldb, beta, C, ldc) ({\
|
|
||||||
ggml_cl_sgemm_wrapper(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);\
|
|
||||||
})
|
|
||||||
#else
|
|
||||||
#define do_blas_sgemm(Order, TransA, TransB,M, N, K,alpha, A, lda, B, ldb, beta, C, ldc) ({\
|
|
||||||
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);\
|
|
||||||
})
|
|
||||||
#endif
|
|
||||||
#endif
|
|
46
ggml_clblast_dequant.cl
Normal file
46
ggml_clblast_dequant.cl
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
||||||
|
const char * clblast_dequant = MULTILINE_QUOTE(
|
||||||
|
|
||||||
|
struct __attribute__ ((packed)) block_q4_0
|
||||||
|
{
|
||||||
|
float d;
|
||||||
|
uchar qs[16];
|
||||||
|
};
|
||||||
|
|
||||||
|
__kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) {
|
||||||
|
uint i, l;
|
||||||
|
i = get_global_id(0) / 32;
|
||||||
|
l = get_local_id(0);
|
||||||
|
|
||||||
|
float d = blocks[i].d;
|
||||||
|
|
||||||
|
uchar vi = blocks[i].qs[l];
|
||||||
|
|
||||||
|
uint index = i*32 + l*2;
|
||||||
|
result[index + 0] = ((vi & 0xf) - 8)*d;
|
||||||
|
result[index + 1] = ((vi >> 4) - 8)*d;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct __attribute__ ((packed)) block_q4_1
|
||||||
|
{
|
||||||
|
float d;
|
||||||
|
float m;
|
||||||
|
uchar qs[16];
|
||||||
|
};
|
||||||
|
|
||||||
|
__kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) {
|
||||||
|
uint i, l;
|
||||||
|
i = get_global_id(0) / 32;
|
||||||
|
l = get_local_id(0);
|
||||||
|
|
||||||
|
float d = blocks[i].d;
|
||||||
|
float m = blocks[i].m;
|
||||||
|
|
||||||
|
uchar vi = blocks[i].qs[l];
|
||||||
|
|
||||||
|
uint index = i*32 + l*2;
|
||||||
|
result[index + 0] = (vi & 0xf) * d + m;
|
||||||
|
result[index + 1] = (vi >> 4) * d + m;
|
||||||
|
}
|
||||||
|
|
||||||
|
);
|
Loading…
Add table
Add a link
Reference in a new issue