OpenCL Token Generation Acceleration (#1459)
* Move back to C++ for OpenCL * Refactor OpenCL code to work more like the CUDA code, add missing functions * Deduplicate dequant kernels * Add OpenCL compile options * Use compile args for preprocessing constants * Restore default platform + device selection by id behavior --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> Co-authored-by: Henri Vasserman <henv@hot.ee>
This commit is contained in:
parent
7e4ea5beff
commit
2e6cd4b025
8 changed files with 1113 additions and 536 deletions
83
ggml.c
83
ggml.c
|
@ -9431,7 +9431,7 @@ static void ggml_compute_forward_rms_norm_back(
|
|||
|
||||
// ggml_compute_forward_mul_mat
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
// helper function to determine if it is better to use BLAS or not
|
||||
// for large matrices, BLAS is faster
|
||||
static bool ggml_compute_forward_mul_mat_use_blas(
|
||||
|
@ -9472,7 +9472,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
#endif
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
|
@ -9536,9 +9536,16 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||
}
|
||||
return;
|
||||
}
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
|
@ -9558,21 +9565,11 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
|
||||
#if defined(GGML_USE_CLBLAST)
|
||||
// zT = y * xT
|
||||
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
||||
ne11, ne01, ne10,
|
||||
1.0f, y, ne10,
|
||||
x, ne10,
|
||||
0.0f, d, ne01,
|
||||
GGML_TYPE_F32);
|
||||
#else
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
ne11, ne01, ne10,
|
||||
1.0f, y, ne10,
|
||||
x, ne00,
|
||||
0.0f, d, ne01);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
||||
|
@ -9711,9 +9708,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
}
|
||||
return;
|
||||
}
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
|
@ -9743,20 +9747,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
assert(id*sizeof(float) <= params->wsize);
|
||||
}
|
||||
|
||||
#if defined(GGML_USE_CLBLAST)
|
||||
const float * x = wdata;
|
||||
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
||||
|
||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
|
||||
// zT = y * xT
|
||||
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
||||
ne11, ne01, ne10,
|
||||
1.0f, y, ne10,
|
||||
x, ne10,
|
||||
0.0f, d, ne01,
|
||||
GGML_TYPE_F32);
|
||||
#else
|
||||
const float * x = wdata;
|
||||
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
||||
|
||||
|
@ -9768,7 +9758,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
1.0f, y, ne10,
|
||||
x, ne00,
|
||||
0.0f, d, ne01);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9931,9 +9920,16 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
}
|
||||
return;
|
||||
}
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
|
@ -9956,9 +9952,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
|
||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
|
||||
#if defined(GGML_USE_CLBLAST)
|
||||
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
|
||||
#else
|
||||
{
|
||||
size_t id = 0;
|
||||
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
||||
|
@ -9970,23 +9963,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
}
|
||||
|
||||
const float * x = wdata;
|
||||
#endif
|
||||
|
||||
#if defined(GGML_USE_CLBLAST)
|
||||
// zT = y * xT
|
||||
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
||||
ne11, ne01, ne10,
|
||||
1.0f, y, ne10,
|
||||
x, ne10,
|
||||
0.0f, d, ne01,
|
||||
type);
|
||||
#else
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
ne11, ne01, ne10,
|
||||
1.0f, y, ne10,
|
||||
x, ne00,
|
||||
0.0f, d, ne01);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -14165,9 +14147,16 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
|
||||
}
|
||||
else
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
|
||||
node->n_tasks = 1; // TODO: this actually is doing nothing
|
||||
// the threads are still spinning
|
||||
cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||
node->n_tasks = 1; // TODO: this actually is doing nothing
|
||||
// the threads are still spinning
|
||||
|
@ -14181,13 +14170,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
#endif
|
||||
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
||||
cur = 0;
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||
node->n_tasks = 1;
|
||||
}
|
||||
#endif
|
||||
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||
node->n_tasks = 1;
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue