diff --git a/ggml.c b/ggml.c index 64ecd0867..6d2fb8d61 100644 --- a/ggml.c +++ b/ggml.c @@ -11571,10 +11571,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning - cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*MAX(ggml_nelements(node->src1), ggml_nelements(node->src0)); - //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); - //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); - //printf("cur = %zu\n", cur); +#if defined(GGML_USE_CUBLAS) + // with cuBLAS, we need memory for the full 3D / 4D data of src1 + cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); +#else + // here we need memory just for single 2D matrix from src0 + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); +#endif } else { cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); }