ggml : remove obsolete assert + refactor n_tasks section
This commit is contained in:
parent
9c9bdaf0b8
commit
8dc7f104f8
1 changed files with 287 additions and 290 deletions
577
ggml.c
577
ggml.c
|
@ -10717,8 +10717,6 @@ static void ggml_compute_forward_mul_mat(
|
||||||
|
|
||||||
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
|
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
|
||||||
|
|
||||||
assert(ne00 % 32 == 0);
|
|
||||||
|
|
||||||
for (int64_t ic = 0; ic < ne11; ++ic) {
|
for (int64_t ic = 0; ic < ne11; ++ic) {
|
||||||
vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
|
vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
|
||||||
}
|
}
|
||||||
|
@ -16078,328 +16076,327 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
||||||
n_threads = GGML_DEFAULT_N_THREADS;
|
n_threads = GGML_DEFAULT_N_THREADS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t work_size = 0;
|
||||||
|
|
||||||
struct ggml_cplan cplan;
|
struct ggml_cplan cplan;
|
||||||
memset(&cplan, 0, sizeof(struct ggml_cplan));
|
memset(&cplan, 0, sizeof(struct ggml_cplan));
|
||||||
|
|
||||||
int * n_tasks = cplan.n_tasks;
|
// thread scheduling for the different operations + work buffer size estimation
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
int n_tasks = 1;
|
||||||
|
|
||||||
size_t work_size = 0;
|
struct ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
||||||
// initialize tasks + work buffer
|
switch (node->op) {
|
||||||
{
|
case GGML_OP_CPY:
|
||||||
// thread scheduling for the different operations
|
case GGML_OP_DUP:
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
{
|
||||||
struct ggml_tensor * node = cgraph->nodes[i];
|
n_tasks = n_threads;
|
||||||
|
|
||||||
switch (node->op) {
|
size_t cur = 0;
|
||||||
case GGML_OP_CPY:
|
if (ggml_is_quantized(node->type)) {
|
||||||
case GGML_OP_DUP:
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
|
||||||
{
|
}
|
||||||
n_tasks[i] = n_threads;
|
|
||||||
|
|
||||||
size_t cur = 0;
|
work_size = MAX(work_size, cur);
|
||||||
if (ggml_is_quantized(node->type)) {
|
} break;
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks[i];
|
case GGML_OP_ADD:
|
||||||
}
|
case GGML_OP_ADD1:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
size_t cur = 0;
|
||||||
} break;
|
|
||||||
case GGML_OP_ADD:
|
|
||||||
case GGML_OP_ADD1:
|
|
||||||
{
|
|
||||||
n_tasks[i] = n_threads;
|
|
||||||
|
|
||||||
size_t cur = 0;
|
if (ggml_is_quantized(node->src0->type)) {
|
||||||
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_tasks;
|
||||||
|
}
|
||||||
|
|
||||||
if (ggml_is_quantized(node->src0->type)) {
|
work_size = MAX(work_size, cur);
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_tasks[i];
|
} break;
|
||||||
}
|
case GGML_OP_ACC:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
size_t cur = 0;
|
||||||
} break;
|
|
||||||
case GGML_OP_ACC:
|
|
||||||
{
|
|
||||||
n_tasks[i] = n_threads;
|
|
||||||
|
|
||||||
size_t cur = 0;
|
if (ggml_is_quantized(node->src0->type)) {
|
||||||
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_tasks;
|
||||||
|
}
|
||||||
|
|
||||||
if (ggml_is_quantized(node->src0->type)) {
|
work_size = MAX(work_size, cur);
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_tasks[i];
|
} break;
|
||||||
}
|
case GGML_OP_SUB:
|
||||||
|
case GGML_OP_DIV:
|
||||||
|
case GGML_OP_SQR:
|
||||||
|
case GGML_OP_SQRT:
|
||||||
|
case GGML_OP_LOG:
|
||||||
|
case GGML_OP_SUM:
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
|
case GGML_OP_MEAN:
|
||||||
|
case GGML_OP_ARGMAX:
|
||||||
|
case GGML_OP_REPEAT:
|
||||||
|
case GGML_OP_REPEAT_BACK:
|
||||||
|
case GGML_OP_ABS:
|
||||||
|
case GGML_OP_SGN:
|
||||||
|
case GGML_OP_NEG:
|
||||||
|
case GGML_OP_STEP:
|
||||||
|
case GGML_OP_TANH:
|
||||||
|
case GGML_OP_ELU:
|
||||||
|
case GGML_OP_RELU:
|
||||||
|
{
|
||||||
|
n_tasks = 1;
|
||||||
|
} break;
|
||||||
|
case GGML_OP_MUL:
|
||||||
|
case GGML_OP_GELU:
|
||||||
|
case GGML_OP_GELU_QUICK:
|
||||||
|
case GGML_OP_SILU:
|
||||||
|
case GGML_OP_SILU_BACK:
|
||||||
|
case GGML_OP_NORM:
|
||||||
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
} break;
|
||||||
|
case GGML_OP_MUL_MAT:
|
||||||
|
case GGML_OP_OUT_PROD:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
// TODO: use different scheduling for different matrix sizes
|
||||||
} break;
|
//const int nr0 = ggml_nrows(node->src0);
|
||||||
case GGML_OP_SUB:
|
//const int nr1 = ggml_nrows(node->src1);
|
||||||
case GGML_OP_DIV:
|
|
||||||
case GGML_OP_SQR:
|
|
||||||
case GGML_OP_SQRT:
|
|
||||||
case GGML_OP_LOG:
|
|
||||||
case GGML_OP_SUM:
|
|
||||||
case GGML_OP_SUM_ROWS:
|
|
||||||
case GGML_OP_MEAN:
|
|
||||||
case GGML_OP_ARGMAX:
|
|
||||||
case GGML_OP_REPEAT:
|
|
||||||
case GGML_OP_REPEAT_BACK:
|
|
||||||
case GGML_OP_ABS:
|
|
||||||
case GGML_OP_SGN:
|
|
||||||
case GGML_OP_NEG:
|
|
||||||
case GGML_OP_STEP:
|
|
||||||
case GGML_OP_TANH:
|
|
||||||
case GGML_OP_ELU:
|
|
||||||
case GGML_OP_RELU:
|
|
||||||
{
|
|
||||||
n_tasks[i] = 1;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_MUL:
|
|
||||||
case GGML_OP_GELU:
|
|
||||||
case GGML_OP_GELU_QUICK:
|
|
||||||
case GGML_OP_SILU:
|
|
||||||
case GGML_OP_SILU_BACK:
|
|
||||||
case GGML_OP_NORM:
|
|
||||||
case GGML_OP_RMS_NORM:
|
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
|
||||||
{
|
|
||||||
n_tasks[i] = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_MUL_MAT:
|
|
||||||
case GGML_OP_OUT_PROD:
|
|
||||||
{
|
|
||||||
n_tasks[i] = n_threads;
|
|
||||||
|
|
||||||
// TODO: use different scheduling for different matrix sizes
|
//n_tasks = MIN(n_threads, MAX(1, nr0/128));
|
||||||
//const int nr0 = ggml_nrows(node->src0);
|
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
|
||||||
//const int nr1 = ggml_nrows(node->src1);
|
|
||||||
|
|
||||||
//n_tasks[i] = MIN(n_threads, MAX(1, nr0/128));
|
size_t cur = 0;
|
||||||
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, n_tasks[i]);
|
const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
|
||||||
|
|
||||||
size_t cur = 0;
|
|
||||||
const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
|
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
|
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
|
||||||
n_tasks[i] = 1; // TODO: this actually is doing nothing
|
n_tasks = 1; // TODO: this actually is doing nothing
|
||||||
// the threads are still spinning
|
// the threads are still spinning
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
|
if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
|
||||||
n_tasks[i] = 1; // TODO: this actually is doing nothing
|
n_tasks = 1; // TODO: this actually is doing nothing
|
||||||
// the threads are still spinning
|
// the threads are still spinning
|
||||||
cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
|
cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
#endif
|
#endif
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||||
n_tasks[i] = 1; // TODO: this actually is doing nothing
|
n_tasks = 1; // TODO: this actually is doing nothing
|
||||||
// the threads are still spinning
|
// the threads are still spinning
|
||||||
if (node->src0->type != GGML_TYPE_F32) {
|
if (node->src0->type != GGML_TYPE_F32) {
|
||||||
// here we need memory just for single 2D matrix from src0
|
// here we need memory just for single 2D matrix from src0
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
#endif
|
#endif
|
||||||
if (node->src1->type != vec_dot_type) {
|
if (node->src1->type != vec_dot_type) {
|
||||||
cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
|
cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
|
||||||
} else {
|
} else {
|
||||||
cur = 0;
|
cur = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
work_size = MAX(work_size, cur);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
{
|
{
|
||||||
n_tasks[i] = 1;
|
n_tasks = 1;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SET:
|
case GGML_OP_SET:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
case GGML_OP_VIEW:
|
case GGML_OP_VIEW:
|
||||||
case GGML_OP_PERMUTE:
|
case GGML_OP_PERMUTE:
|
||||||
case GGML_OP_TRANSPOSE:
|
case GGML_OP_TRANSPOSE:
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
case GGML_OP_GET_ROWS_BACK:
|
case GGML_OP_GET_ROWS_BACK:
|
||||||
case GGML_OP_DIAG:
|
case GGML_OP_DIAG:
|
||||||
case GGML_OP_DIAG_MASK_ZERO:
|
case GGML_OP_DIAG_MASK_ZERO:
|
||||||
{
|
{
|
||||||
n_tasks[i] = 1;
|
n_tasks = 1;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
case GGML_OP_ROPE_BACK:
|
case GGML_OP_ROPE_BACK:
|
||||||
{
|
{
|
||||||
n_tasks[i] = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ALIBI:
|
case GGML_OP_ALIBI:
|
||||||
{
|
{
|
||||||
n_tasks[i] = 1; //TODO
|
n_tasks = 1; //TODO
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
{
|
{
|
||||||
n_tasks[i] = 1; //TODO
|
n_tasks = 1; //TODO
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_CONV_1D:
|
case GGML_OP_CONV_1D:
|
||||||
{
|
{
|
||||||
n_tasks[i] = n_threads;
|
n_tasks = n_threads;
|
||||||
|
|
||||||
GGML_ASSERT(node->src0->ne[3] == 1);
|
GGML_ASSERT(node->src0->ne[3] == 1);
|
||||||
GGML_ASSERT(node->src1->ne[2] == 1);
|
GGML_ASSERT(node->src1->ne[2] == 1);
|
||||||
GGML_ASSERT(node->src1->ne[3] == 1);
|
GGML_ASSERT(node->src1->ne[3] == 1);
|
||||||
|
|
||||||
size_t cur = 0;
|
size_t cur = 0;
|
||||||
const int nk = node->src0->ne[0];
|
const int nk = node->src0->ne[0];
|
||||||
|
|
||||||
if (node->src0->type == GGML_TYPE_F16 &&
|
if (node->src0->type == GGML_TYPE_F16 &&
|
||||||
node->src1->type == GGML_TYPE_F32) {
|
node->src1->type == GGML_TYPE_F32) {
|
||||||
cur = sizeof(ggml_fp16_t)*(
|
cur = sizeof(ggml_fp16_t)*(
|
||||||
nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
|
nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
|
||||||
( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
|
( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
|
||||||
);
|
);
|
||||||
} else if (node->src0->type == GGML_TYPE_F32 &&
|
} else if (node->src0->type == GGML_TYPE_F32 &&
|
||||||
node->src1->type == GGML_TYPE_F32) {
|
|
||||||
cur = sizeof(float)*(
|
|
||||||
nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
|
|
||||||
( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_CONV_2D:
|
|
||||||
{
|
|
||||||
n_tasks[i] = n_threads;
|
|
||||||
|
|
||||||
GGML_ASSERT(node->src1->ne[3] == 1);
|
|
||||||
|
|
||||||
const int64_t ne00 = node->src0->ne[0]; // W
|
|
||||||
const int64_t ne01 = node->src0->ne[1]; // H
|
|
||||||
const int64_t ne02 = node->src0->ne[2]; // C
|
|
||||||
const int64_t ne03 = node->src0->ne[3]; // N
|
|
||||||
|
|
||||||
const int64_t ne10 = node->src1->ne[0]; // W
|
|
||||||
const int64_t ne11 = node->src1->ne[1]; // H
|
|
||||||
const int64_t ne12 = node->src1->ne[2]; // C
|
|
||||||
|
|
||||||
const int64_t nk = ne00*ne01;
|
|
||||||
|
|
||||||
UNUSED(ne02);
|
|
||||||
UNUSED(ne03);
|
|
||||||
UNUSED(nk);
|
|
||||||
|
|
||||||
size_t cur = 0;
|
|
||||||
|
|
||||||
if (node->src0->type == GGML_TYPE_F16 &&
|
|
||||||
node->src1->type == GGML_TYPE_F32) {
|
node->src1->type == GGML_TYPE_F32) {
|
||||||
cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
|
cur = sizeof(float)*(
|
||||||
} else if (node->src0->type == GGML_TYPE_F32 &&
|
nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
|
||||||
node->src1->type == GGML_TYPE_F32) {
|
( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
|
||||||
cur = sizeof(float)* (ne10*ne11*ne12);
|
);
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_FLASH_ATTN:
|
|
||||||
{
|
|
||||||
n_tasks[i] = n_threads;
|
|
||||||
|
|
||||||
size_t cur = 0;
|
|
||||||
|
|
||||||
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
|
||||||
|
|
||||||
if (node->src1->type == GGML_TYPE_F32) {
|
|
||||||
cur = sizeof(float)*ne11*n_tasks[i]; // TODO: this can become (n_tasks[i]-1)
|
|
||||||
cur += sizeof(float)*ne11*n_tasks[i]; // this is overestimated by x2
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node->src1->type == GGML_TYPE_F16) {
|
|
||||||
cur = sizeof(float)*ne11*n_tasks[i]; // TODO: this can become (n_tasks[i]-1)
|
|
||||||
cur += sizeof(float)*ne11*n_tasks[i]; // this is overestimated by x2
|
|
||||||
}
|
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_FLASH_FF:
|
|
||||||
{
|
|
||||||
n_tasks[i] = n_threads;
|
|
||||||
|
|
||||||
size_t cur = 0;
|
|
||||||
|
|
||||||
if (node->src1->type == GGML_TYPE_F32) {
|
|
||||||
cur = sizeof(float)*node->src1->ne[1]*n_tasks[i]; // TODO: this can become (n_tasks[i]-1)
|
|
||||||
cur += sizeof(float)*node->src1->ne[1]*n_tasks[i]; // this is overestimated by x2
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node->src1->type == GGML_TYPE_F16) {
|
|
||||||
cur = sizeof(float)*node->src1->ne[1]*n_tasks[i]; // TODO: this can become (n_tasks[i]-1)
|
|
||||||
cur += sizeof(float)*node->src1->ne[1]*n_tasks[i]; // this is overestimated by x2
|
|
||||||
}
|
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
|
||||||
{
|
|
||||||
n_tasks[i] = n_threads;
|
|
||||||
|
|
||||||
size_t cur = 0;
|
|
||||||
|
|
||||||
const int64_t D = node->src0->ne[0];
|
|
||||||
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
|
||||||
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
|
||||||
if (node->src1->type == GGML_TYPE_F32) {
|
|
||||||
cur = sizeof(float)*mxDn*n_tasks[i]; // TODO: this can become (n_tasks[i]-1)
|
|
||||||
cur += sizeof(float)*mxDn*n_tasks[i]; // this is overestimated by x2
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node->src1->type == GGML_TYPE_F16) {
|
|
||||||
cur = sizeof(float)*mxDn*n_tasks[i]; // TODO: this can become (n_tasks[i]-1)
|
|
||||||
cur += sizeof(float)*mxDn*n_tasks[i]; // this is overestimated by x2
|
|
||||||
}
|
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_WIN_PART:
|
|
||||||
case GGML_OP_WIN_UNPART:
|
|
||||||
case GGML_OP_MAP_UNARY:
|
|
||||||
case GGML_OP_MAP_BINARY:
|
|
||||||
case GGML_OP_MAP_CUSTOM1:
|
|
||||||
case GGML_OP_MAP_CUSTOM2:
|
|
||||||
case GGML_OP_MAP_CUSTOM3:
|
|
||||||
{
|
|
||||||
n_tasks[i] = 1;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
|
||||||
{
|
|
||||||
n_tasks[i] = n_threads;
|
|
||||||
|
|
||||||
size_t cur = ggml_type_size(node->type)*(n_tasks[i] + node->src0->ne[0]*n_tasks[i]);
|
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
|
||||||
{
|
|
||||||
n_tasks[i] = n_threads;
|
|
||||||
|
|
||||||
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*n_tasks[i];
|
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_NONE:
|
|
||||||
{
|
|
||||||
n_tasks[i] = 1;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_COUNT:
|
|
||||||
{
|
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
} break;
|
}
|
||||||
}
|
|
||||||
|
work_size = MAX(work_size, cur);
|
||||||
|
} break;
|
||||||
|
case GGML_OP_CONV_2D:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
|
||||||
|
GGML_ASSERT(node->src1->ne[3] == 1);
|
||||||
|
|
||||||
|
const int64_t ne00 = node->src0->ne[0]; // W
|
||||||
|
const int64_t ne01 = node->src0->ne[1]; // H
|
||||||
|
const int64_t ne02 = node->src0->ne[2]; // C
|
||||||
|
const int64_t ne03 = node->src0->ne[3]; // N
|
||||||
|
|
||||||
|
const int64_t ne10 = node->src1->ne[0]; // W
|
||||||
|
const int64_t ne11 = node->src1->ne[1]; // H
|
||||||
|
const int64_t ne12 = node->src1->ne[2]; // C
|
||||||
|
|
||||||
|
const int64_t nk = ne00*ne01;
|
||||||
|
|
||||||
|
UNUSED(ne02);
|
||||||
|
UNUSED(ne03);
|
||||||
|
UNUSED(nk);
|
||||||
|
|
||||||
|
size_t cur = 0;
|
||||||
|
|
||||||
|
if (node->src0->type == GGML_TYPE_F16 &&
|
||||||
|
node->src1->type == GGML_TYPE_F32) {
|
||||||
|
cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
|
||||||
|
} else if (node->src0->type == GGML_TYPE_F32 &&
|
||||||
|
node->src1->type == GGML_TYPE_F32) {
|
||||||
|
cur = sizeof(float)* (ne10*ne11*ne12);
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
work_size = MAX(work_size, cur);
|
||||||
|
} break;
|
||||||
|
case GGML_OP_FLASH_ATTN:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
|
||||||
|
size_t cur = 0;
|
||||||
|
|
||||||
|
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
||||||
|
|
||||||
|
if (node->src1->type == GGML_TYPE_F32) {
|
||||||
|
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
||||||
|
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->src1->type == GGML_TYPE_F16) {
|
||||||
|
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
||||||
|
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
||||||
|
}
|
||||||
|
|
||||||
|
work_size = MAX(work_size, cur);
|
||||||
|
} break;
|
||||||
|
case GGML_OP_FLASH_FF:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
|
||||||
|
size_t cur = 0;
|
||||||
|
|
||||||
|
if (node->src1->type == GGML_TYPE_F32) {
|
||||||
|
cur = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
||||||
|
cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->src1->type == GGML_TYPE_F16) {
|
||||||
|
cur = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
||||||
|
cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
|
||||||
|
}
|
||||||
|
|
||||||
|
work_size = MAX(work_size, cur);
|
||||||
|
} break;
|
||||||
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
|
||||||
|
size_t cur = 0;
|
||||||
|
|
||||||
|
const int64_t D = node->src0->ne[0];
|
||||||
|
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
||||||
|
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
||||||
|
if (node->src1->type == GGML_TYPE_F32) {
|
||||||
|
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
||||||
|
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->src1->type == GGML_TYPE_F16) {
|
||||||
|
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
|
||||||
|
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
||||||
|
}
|
||||||
|
|
||||||
|
work_size = MAX(work_size, cur);
|
||||||
|
} break;
|
||||||
|
case GGML_OP_WIN_PART:
|
||||||
|
case GGML_OP_WIN_UNPART:
|
||||||
|
case GGML_OP_MAP_UNARY:
|
||||||
|
case GGML_OP_MAP_BINARY:
|
||||||
|
case GGML_OP_MAP_CUSTOM1:
|
||||||
|
case GGML_OP_MAP_CUSTOM2:
|
||||||
|
case GGML_OP_MAP_CUSTOM3:
|
||||||
|
{
|
||||||
|
n_tasks = 1;
|
||||||
|
} break;
|
||||||
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
|
||||||
|
size_t cur = ggml_type_size(node->type)*(n_tasks + node->src0->ne[0]*n_tasks);
|
||||||
|
|
||||||
|
work_size = MAX(work_size, cur);
|
||||||
|
} break;
|
||||||
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
|
||||||
|
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*n_tasks;
|
||||||
|
|
||||||
|
work_size = MAX(work_size, cur);
|
||||||
|
} break;
|
||||||
|
case GGML_OP_NONE:
|
||||||
|
{
|
||||||
|
n_tasks = 1;
|
||||||
|
} break;
|
||||||
|
case GGML_OP_COUNT:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
} break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cplan.n_tasks[i] = n_tasks;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (work_size > 0) {
|
if (work_size > 0) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue