Fix typos, use GGML_TYPE defines, improve code

This commit is contained in:
0cc4m 2023-04-25 18:43:31 +02:00
parent daa5df51f7
commit 36bfb3c158
3 changed files with 9 additions and 13 deletions

View file

@ -34,8 +34,8 @@ endif
#
# keep standard at C11 and C++11
CFLAGS = -I. -O3 -DNODEBUG -std=c11 -fPIC
CXXFLAGS = -I. -I./examples -O3 -DNODEBUG -std=c++11 -fPIC
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
LDFLAGS =
# warnings

View file

@ -111,11 +111,7 @@ void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_me
void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) {
cl_int err = 0;
cl_event events[4];
events[0] = NULL;
events[1] = NULL;
events[2] = NULL;
events[3] = NULL;
cl_event events[4] = { NULL };
cl_kernel kernel;
size_t global, local, size_qb;
@ -124,22 +120,22 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra
global = n * k;
switch (btype) {
case 2:
case GGML_TYPE_Q4_0:
kernel = kernel_q4_0;
local = 16;
size_qb = global * (sizeof(float) + local) / 32;
break;
case 3:
case GGML_TYPE_Q4_1:
kernel = kernel_q4_1;
local = 16;
size_qb = global * (sizeof(float) * 2 + local) / 32;
break;
case 4:
case GGML_TYPE_Q4_2:
kernel = kernel_q4_2;
local = 8;
size_qb = global * (sizeof(short) + local) / 16;
break;
case 5:
case GGML_TYPE_Q4_3:
kernel = kernel_q4_3;
local = 8;
size_qb = global * (sizeof(short) * 2 + local) / 16;

4
ggml.c
View file

@ -7580,7 +7580,7 @@ static void ggml_compute_forward_mul_mat_f32(
1.0f, y, ne10,
x, ne10,
0.0f, d, ne01,
params->type);
GGML_TYPE_F32);
#else
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne11, ne01, ne10,
@ -7814,7 +7814,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
1.0f, y, ne10,
x, ne10,
0.0f, d, ne01,
params->type);
GGML_TYPE_F32);
#else
const float * x = wdata;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);