diff --git a/ggml-cuda.cu b/ggml-cuda.cu index fc051b314..921b5a03b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -897,7 +897,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c dst[i + 1] = x0*sin_theta + x1*cos_theta; } -static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int n_past) { +static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { const int col = blockDim.x*blockIdx.x + threadIdx.x; const int row = blockDim.y*blockIdx.y + threadIdx.y; @@ -907,7 +907,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int const int i = row*ncols + col; // dst[i] = col > n_past + row ? -INFINITY : x[i]; - dst[i] = x[i] - (col > n_past + row) * INT_MAX; // equivalent within rounding error but slightly faster on GPU + dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU } static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { @@ -1192,11 +1192,11 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i rope_f32<<>>(x, dst, ncols, p, theta_scale); } -static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int n_past, cudaStream_t stream) { +static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1); const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; const dim3 block_nums(block_num_x, nrows_x, 1); - diag_mask_inf_f32<<>>(x, dst, ncols_x, n_past); + diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { @@ -1659,12 +1659,13 @@ inline void ggml_cuda_op_diag_mask_inf( GGML_ASSERT(dst_ddf_i != nullptr); const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; const int64_t i01_diff = i01_high - i01_low; const int n_past = ((int32_t *) src1->data)[0]; // compute - diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, n_past, cudaStream_main); + diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main); CUDA_CHECK(cudaGetLastError()); (void) dst; @@ -1723,7 +1724,7 @@ inline void ggml_cuda_op_scale( } static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - ggml_cuda_op_t op, bool src0_needs_f32) { + ggml_cuda_op_t op, bool src0_needs_f32, bool flatten_rows) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; @@ -1746,10 +1747,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); // strides for iteration over dims 3 and 2 - const int64_t src0_stride = ne00 * ne01; - const int64_t src1_stride = ne10 * ne11; - const int64_t dst_stride = ne0 * ne1; - const int64_t num_iters = ne02 * ne03; + const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03; + const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1; + const int64_t src0_stride = ne00 * ne01 * stride_mod; + const int64_t src1_stride = ne10 * ne11 * stride_mod; + const int64_t dst_stride = ne0 * ne1 * stride_mod; const size_t src0_ts = ggml_type_size(src0->type); const size_t src0_bs = ggml_blck_size(src0->type); @@ -1763,6 +1765,8 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm const bool src0_is_f32 = src0->type == GGML_TYPE_F32; const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1); + const bool src1_stays_on_host = use_src1 && ( + dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE); const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; @@ -1772,13 +1776,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm char * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; - float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; + float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // asq = actual size quantized, asf = actual size float size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0}; size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0}; size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; - size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0}; + size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0}; for (int id = 0; id < g_device_count; ++id) { if (!split && id != g_main_device) { @@ -1824,7 +1828,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]); } - if (use_src1) { + if (use_src1 && !src1_stays_on_host) { if (src1_on_device && src1_is_contiguous) { src1_ddf[id] = (float *) src1_extra->data_device[id]; } else { @@ -1838,26 +1842,29 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]); } - for (int64_t i03 = 0; i03 < ne03; i03++) { + const int64_t i03_max = flatten_rows ? 1 : ne03; + const int64_t i02_max = flatten_rows ? 1 : ne02; + const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01; + for (int64_t i03 = 0; i03 < i03_max; i03++) { const int64_t i13 = i03 % ne13; - for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i02 = 0; i02 < i02_max; i02++) { const int64_t i12 = i02 % ne12; const int64_t i0 = i03*ne02 + i02; - const int64_t i0_offset_low = row_low/ne01; - const int64_t i0_offset_high = row_high/ne01; + const int64_t i0_offset_low = row_low/rows_per_iter; + const int64_t i0_offset_high = row_high/rows_per_iter; int64_t i01_low = 0; - int64_t i01_high = ne01; + int64_t i01_high = rows_per_iter; if (split) { if (i0 < i0_offset_low || i0 > i0_offset_high) { continue; } if (i0 == i0_offset_low) { - i01_low = row_low % ne01; + i01_low = row_low % rows_per_iter; } if (i0 == i0_offset_high) { - i01_high = row_high % ne01; + i01_high = row_high % rows_per_iter; } } @@ -1866,7 +1873,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm // Removing both asserts results in i01_high becoming 0 which in turn results in garbage output. // The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU). GGML_ASSERT(i01_low == 0 || g_device_count > 1); - GGML_ASSERT(i01_high == ne01 || g_device_count > 1); + GGML_ASSERT(i01_high == rows_per_iter || g_device_count > 1); const int64_t i01_diff = i01_high - i01_low; if (i01_diff == 0) { @@ -1887,11 +1894,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm // for split tensors the data pointer needs to be rounded down // to the bin edge for i03, i02 bins beyond the first if (i0 - i0_offset_low > 0) { + GGML_ASSERT(!flatten_rows); src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs; src0_ddf_i -= (row_low % ne01)*ne00; - } - if (i0 - i0_offset_low > 0) { - dst_ddf_i -= (row_low % ne0)*ne1; + dst_ddf_i -= (row_low % ne0)*ne1; } // the main device memory buffer can be on VRAM scratch, with space for all partial results @@ -1901,11 +1907,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } // copy src0, src1 to device if necessary - if (use_src1) { + if (use_src1 && !src1_stays_on_host) { if (src1->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(ggml_cuda_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_memcpy_src1)); + GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1)); + int64_t nrows1 = flatten_rows ? nrows0 : ne11; + CUDA_CHECK(ggml_cuda_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1)); } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { if (id != g_main_device) { + GGML_ASSERT(!flatten_rows); float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device]; src1_ddf_i_source += i11*src1_stride; CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float), @@ -1991,22 +2000,22 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true); } void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten } void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true); } void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true); } bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { @@ -2236,12 +2245,12 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ } else if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { ggml_cuda_mul_mat_vec_nc_f16_f32(src0, src1, dst); }else if (src0->type == GGML_TYPE_F32) { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false); } else { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); } } else { GGML_ASSERT(false); @@ -2250,7 +2259,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true, true); } void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -2304,17 +2313,17 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true); } void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true, true); } void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, false); // FIXME flatten changes results } void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {