diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dc0a48bfb..25cd86d23 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -151,6 +151,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_ADD_BLOCK_SIZE 256 #define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_ROPE_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 @@ -737,6 +738,73 @@ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y } } +static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int ncols_y) { + const half * x = (half *) vx; + // const int col_x = blockDim.x*blockIdx.x + threadIdx.x; + // const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + + const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + const int channel = blockDim.z*blockIdx.z + threadIdx.z; + + const int nrows_y = ncols_x; + const int ncols_dst = ncols_y; + const int nrows_dst = nrows_x; + const int row_dst = row_x; + + float tmp = 0.0f; + + for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { + const int col_x = col_x0 + threadIdx.x; + + if (col_x >= ncols_x) { + return; + } + + // x is transposed and permuted + const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x; + const float xi = __half2float(x[ix]); + + const int row_y = col_x; + + + // y is not transposed but permuted + const int iy = channel*nrows_y + row_y; + + // dst is not transposed and not permuted + + tmp += xi * y[iy]; + } + const int idst = channel*ncols_dst*nrows_dst + row_dst; + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[idst] = tmp; + } +} + +static __global__ void cpy_f32_f16(const float * x, void * vdst, const int ne0, const int ne1, const int stride_1, const int stride_2) { + const int i0 = blockDim.x*blockIdx.x + threadIdx.x; + + if (i0 >= ne0) { + return; + } + + const int i1 = blockDim.y*blockIdx.y + threadIdx.y; + const int i2 = blockDim.z*blockIdx.z + threadIdx.z; + + half * dst = (half *) vdst; + + const int ix = i0 + i1*stride_1 + i2*stride_2; + const int idst = i0 + i1*ne0 + i2*ne1*ne0; + dst[idst] = __float2half(x[ix]); +} + static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) { const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); @@ -942,6 +1010,21 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } } +static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int ncols_y, cudaStream_t stream) { + const int block_num_x = (ncols_x + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, nrows_x, nchannels_x); + const dim3 block_dims(WARP_SIZE, 1, 1); + mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, ncols_y); +} + +static void ggml_cpy_f32_f16_cuda(const float * x, void * vdst, const int ne0, const int ne1, const int ne2, + const int stride_1, const int stride_2, cudaStream_t stream) { + const int block_num_x = (ne0 + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + const dim3 block_nums(block_num_x, ne1, ne2); + const dim3 block_dims(CUDA_CPY_BLOCK_SIZE, 1, 1); + cpy_f32_f16<<>>(x, vdst, ne0, ne1, stride_1, stride_2); +} + static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(nrows % 2 == 0); const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1); @@ -1663,18 +1746,17 @@ void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml } bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - GGML_ASSERT(src0->backend != GGML_BACKEND_GPU); const int64_t ne10 = src1->ne[0]; const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; - // if (strcmp(dst->name, "KQ") == 0 || strcmp(dst->name, "KQV") == 0) { + // if (strcmp(dst->name, "KQ") == 0) { // fprintf(stderr, "(%ld, %ld, %ld, %ld) + (%ld, %ld, %ld, %ld) -> (%ld, %ld, %ld, %ld)\n", // src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], // src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], // dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); - // return false; + // return true; // } // TODO: find the optimal values for these @@ -1688,8 +1770,101 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te return false; } +void ggml_cuda_mul_mat_p021_f16_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ + GGML_ASSERT(!ggml_is_contiguous(src0) && !ggml_is_contiguous(src1)); + GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); + GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // GGML_ASSERT(src0->backend == GGML_BACKEND_GPU && src1->backend == GGML_BACKEND_GPU); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + GGML_ASSERT(ne00 % 8 == 0); + GGML_ASSERT(ne03 == 1); + const size_t src0_size = ggml_nbytes(src0); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + GGML_ASSERT(ne10 % 4 == 0); + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(ne12 == ne02); + const size_t src1_size = ggml_nbytes(src1); + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + GGML_ASSERT(ne2 == ne02); + const size_t dst_size = ggml_nbytes(dst); + + const int64_t nb1 = dst->nb[1]; + const int64_t nb2 = dst->nb[2]; + const int64_t nb3 = dst->nb[3]; + + cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0]; + + void * src0_ddq; + float * src1_ddf; + float * dst_ddf; + if (src0->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(cudaMalloc(&src0_ddq, src0_size)); + CUDA_CHECK(cudaMemcpyAsync(src0_ddq, src0->data, src0_size, cudaMemcpyHostToDevice, cudaStream_main)); + } else { + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + src0_ddq = src0_extra->data_device[g_main_device]; + // CUDA_CHECK(cudaMemset(src0_ddq, 0, ggml_nbytes(src0))); + } + if (src1->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(cudaMalloc(&src1_ddf, src1_size)); + CUDA_CHECK(cudaMemcpyAsync(src1_ddf, src1->data, src1_size, cudaMemcpyHostToDevice, cudaStream_main)); + } else { + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + src1_ddf = (float *) src1_extra->data_device[g_main_device]; + // CUDA_CHECK(cudaMemset(src1_ddf, 0, ggml_nbytes(src1))); + } + CUDA_CHECK(cudaMalloc(&dst_ddf, dst_size)); + + for (int64_t i11 = 0; i11 < ne11; ++i11) { + float * src1_ddf_i = src1_ddf + i11 * ne10*ne12; + float * dst_ddf_i = dst_ddf + i11 * ne0; + + ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf_i, dst_ddf_i, ne00, ne01, ne02, ne11, cudaStream_main); + } + + CUDA_CHECK(cudaDeviceSynchronize()); + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + for (int64_t i1 = 0; i1 < ne1; i1++) { + const int64_t i = i3*ne2*ne1 + i2*ne1 + i1; + float * dst_ddf_i = dst_ddf + i*ne0; + float * dhf_dst_i = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, ne0*sizeof(float), cudaMemcpyDeviceToHost, cudaStream_main)); + } + } + } + CUDA_CHECK(cudaDeviceSynchronize()); + if (src0->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(cudaFree(src0_ddq)); + } + if (src1->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(cudaFree(src1_ddf)); + } + CUDA_CHECK(cudaFree(dst_ddf)); +} + void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - if (src0->type == GGML_TYPE_F32) { + if (!ggml_is_contiguous(src0) && !ggml_is_contiguous(src1)) { + ggml_cuda_mul_mat_p021_f16_f32(src0, src1, dst); + // ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true); + } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { if (src1->ne[1] == 1) { @@ -1702,6 +1877,43 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ } } +void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(ggml_nelements(src0) == ggml_nelements(src1)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + GGML_ASSERT(src0->ne[3] == src1->ne[3]); + + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; + + cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0]; + + const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + + float * src0_ddf = (float *) src0_extra->data_device[g_main_device]; + void * src1_ddv = src1_extra->data_device[g_main_device]; + + const int64_t stride_1 = nb01 / sizeof(float); + GGML_ASSERT(nb01 % sizeof(float) == 0); + const int64_t stride_2 = nb02 / sizeof(float); + GGML_ASSERT(nb02 % sizeof(float) == 0); + + ggml_cpy_f32_f16_cuda(src0_ddf, src1_ddv, ne00, ne01, ne02, stride_1, stride_2, cudaStream_main); + // test<<>>(src0_ddf, src1_ddv); + + CUDA_CHECK(cudaDeviceSynchronize()); + + (void) dst; +} + void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true); @@ -1718,10 +1930,9 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { const size_t nb1 = tensor->nb[1]; ggml_backend backend = tensor->backend; struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu; + memset(extra, 0, sizeof(*extra)); for (int id = 0; id < g_device_count; ++id) { - extra->data_device[id] = nullptr; - if (backend == GGML_BACKEND_GPU && id != g_main_device) { continue; } @@ -1781,44 +1992,67 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { delete extra; } -void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { - if (tensor->src0 != nullptr && tensor->src0->op == GGML_OP_RESHAPE) { - ggml_cuda_assign_buffers(tensor->src0); - } - - const size_t size = ggml_nbytes(tensor); - GGML_ASSERT(size <= g_scratch_size); - if (g_scratch_offset + size > g_scratch_size) { - g_scratch_offset = 0; +void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { + if (tensor->src0 != nullptr) { + const ggml_op src0_op = tensor->src0->op; + if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) { + ggml_cuda_assign_buffers_impl(tensor->src0, scratch); + } } tensor->backend = GGML_BACKEND_GPU; struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu; - bool inplace = tensor->src0 != nullptr && tensor->src0->data == tensor->data; + const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) || + tensor->op == GGML_OP_VIEW; + const size_t size = ggml_nbytes(tensor); CUDA_CHECK(cudaSetDevice(g_main_device)); if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) { struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra; - extra->data_device[g_main_device] = src0_extra->data_device[g_main_device]; - } else { + char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; + size_t offset = 0; + if (tensor->op == GGML_OP_VIEW) { + memcpy(&offset, tensor->opt[0]->data, sizeof(size_t)); + } + extra->data_device[g_main_device] = src0_ddc + offset; + } else if (scratch) { + GGML_ASSERT(size <= g_scratch_size); + if (g_scratch_offset + size > g_scratch_size) { + g_scratch_offset = 0; + } + char * data = (char *) g_scratch_buffer; if (data == nullptr) { CUDA_CHECK(cudaMalloc(&data, g_scratch_size)); g_scratch_buffer = data; } extra->data_device[g_main_device] = data + g_scratch_offset; + + // fprintf(stderr, "data=%p offset=%ld data_device=%p\n", data, g_scratch_offset, extra->data_device[0]); + g_scratch_offset += size; + // fprintf(stderr, "%s: scratch %d, %p - %p\n", + // tensor->name, g_scratch_index, data + g_scratch_offset, data + g_scratch_offset + size); + + GGML_ASSERT(g_scratch_offset <= g_scratch_size); + } else { + void * data; + CUDA_CHECK(cudaMalloc(&data, size)); + CUDA_CHECK(cudaMemset(data, 0, size)); + extra->data_device[g_main_device] = data; } - // fprintf(stderr, "data=%p offset=%ld data_device=%p\n", data, g_scratch_offset, extra->data_device[0]); - g_scratch_offset += size; - // fprintf(stderr, "%s: scratch %d, %p - %p\n", - // tensor->name, g_scratch_index, data + g_scratch_offset, data + g_scratch_offset + size); - - GGML_ASSERT(g_scratch_offset <= g_scratch_size); tensor->extra = extra; } +void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, true); +} + +void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, false); +} + void ggml_cuda_set_main_device(int main_device) { if (main_device > g_device_count) { fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n", @@ -1874,7 +2108,16 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ } func = ggml_cuda_mul_mat; break; + case GGML_OP_CPY: + if (!any_on_device) { + return false; + } + func = ggml_cuda_cpy; + break; case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: if (!any_on_device) { return false; } diff --git a/ggml-cuda.h b/ggml-cuda.h index fde6d4085..e964f9a48 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -28,6 +28,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor); void ggml_cuda_free_data(struct ggml_tensor * tensor); void ggml_cuda_assign_buffers(struct ggml_tensor * tensor); +void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); void ggml_cuda_set_main_device(int main_device); void ggml_cuda_set_scratch_size(size_t scratch_size); bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor); diff --git a/llama.cpp b/llama.cpp index 9249a48d5..0f86dd09d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -893,6 +893,10 @@ static bool kv_cache_init( ggml_set_name(cache.k, "cache_k"); ggml_set_name(cache.v, "cache_v"); +#ifdef GGML_USE_CUBLAS + ggml_cuda_assign_buffers_no_scratch(cache.k); +#endif // GGML_USE_CUBLAS + return true; } @@ -1375,12 +1379,12 @@ static bool llama_eval_internal( struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0); offload_func(Kcur); - Kcur->backend = GGML_BACKEND_CPU; + // Kcur->backend = GGML_BACKEND_CPU; ggml_set_name(Kcur, "Kcur"); struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0); offload_func(Qcur); - Qcur->backend = GGML_BACKEND_CPU; + // Qcur->backend = GGML_BACKEND_CPU; ggml_set_name(Qcur, "Qcur"); // store key and value to memory @@ -1390,6 +1394,7 @@ static bool llama_eval_internal( ggml_set_name(Vcur, "Vcur"); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); + offload_func(k); ggml_set_name(k, "k"); struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, ( n_ctx)*ggml_element_size(kv_self.v), @@ -1405,6 +1410,7 @@ static bool llama_eval_internal( ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + offload_func(Q); ggml_set_name(Q, "Q"); struct ggml_tensor * K = @@ -1413,6 +1419,7 @@ static bool llama_eval_internal( ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), n_embd/n_head, n_head, n_past + N), 0, 2, 1, 3); + offload_func(K); ggml_set_name(K, "K"); // K * Q