multithreaded dequantize in mul_mat when using blas library
This commit is contained in:
parent
20fefdfe2b
commit
2c0ed7a638
1 changed files with 33 additions and 16 deletions
49
ggml.c
49
ggml.c
|
@ -9828,11 +9828,30 @@ static void ggml_compute_forward_mul_mat(
|
|||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(dst)) {
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
const int64_t ne_plane = ne01*ne00;
|
||||
const int64_t desired_wsize = ne13*ne12*ne_plane*sizeof(float);
|
||||
UNUSED(desired_wsize);
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
if (type != GGML_TYPE_F32) {
|
||||
assert(params->wsize >= desired_wsize);
|
||||
// parallelize by src0 rows
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
// broadcast src0 into src1 across 2nd,3rd dimension
|
||||
const int64_t i03 = i13/r3;
|
||||
const int64_t i02 = i12/r2;
|
||||
|
||||
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
|
||||
float * const wdata = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane;
|
||||
ggml_to_float_t const to_float = type_traits[type].to_float;
|
||||
|
||||
for (int64_t i01 = ith; i01 < ne01; i01+=nth) {
|
||||
to_float((const char *) x + i01*nb01, wdata + i01*ne00, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -9840,9 +9859,14 @@ static void ggml_compute_forward_mul_mat(
|
|||
return;
|
||||
}
|
||||
|
||||
// perform sgemm, parallelization controlled by blas lib
|
||||
if (ith != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t tgemm0 = ggml_perf_time_us();
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
// broadcast src0 into src1 across 2nd,3rd dimension
|
||||
const int64_t i03 = i13/r3;
|
||||
const int64_t i02 = i12/r2;
|
||||
|
||||
|
@ -9851,17 +9875,7 @@ static void ggml_compute_forward_mul_mat(
|
|||
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
||||
|
||||
if (type != GGML_TYPE_F32) {
|
||||
float * const wdata = params->wdata;
|
||||
ggml_to_float_t const to_float = type_traits[type].to_float;
|
||||
|
||||
size_t id = 0;
|
||||
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
||||
to_float((const char *) x + i01*nb01, wdata + id, ne00);
|
||||
id += ne00;
|
||||
}
|
||||
|
||||
assert(id*sizeof(float) <= params->wsize);
|
||||
x = wdata;
|
||||
x = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane;
|
||||
}
|
||||
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
|
@ -9871,6 +9885,7 @@ static void ggml_compute_forward_mul_mat(
|
|||
0.0f, d, ne01);
|
||||
}
|
||||
}
|
||||
//printf("cblas_sgemm = %.3f ms, %lld flops\n", (ggml_perf_time_us() - tgemm0)/1000.0, ne13*ne12*ne1*ne01*ne10*2);
|
||||
|
||||
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
||||
|
||||
|
@ -16782,7 +16797,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|||
if (ggml_compute_forward_mul_mat_use_blas(node)) {
|
||||
if (node->src[0]->type != GGML_TYPE_F32) {
|
||||
// here we need memory just for single 2D matrix from src0
|
||||
cur = ggml_type_size(GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]);
|
||||
cur = ggml_type_size(GGML_TYPE_F32)
|
||||
* node->src[0]->ne[0]*node->src[0]->ne[1]
|
||||
* node->src[1]->ne[2]*node->src[1]->ne[3];
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue