diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c8ab76132..962b5bdc2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -488,6 +488,34 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor } } +static cudaError_t ggml_cuda_h2d_tensor_2d_hack(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream, void * wdata) { + const uint64_t ne0 = src->ne[0]; + const uint64_t ne1 = src->ne[1]; + const uint64_t nb0 = src->nb[0]; + const uint64_t nb1 = src->nb[1]; + const uint64_t nb2 = src->nb[2]; + const uint64_t nb3 = src->nb[3]; + const enum ggml_type type = src->type; + const size_t ts = ggml_type_size(type); + const size_t bs = ggml_blck_size(type); + + const void * x = (const void *) ((const char *) wdata + i2*nb2 + i3*nb3); + if (nb0 == ts && nb1 == ts*ne0/bs) { + return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream); + } else if (nb0 == ts) { + return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + const void * rx = (const void *) ((const char *) x + i1*nb1); + void * rd = (void *) ((char *) dst + i1*ts*ne0/bs); + // pretend the row is a matrix with cols=1 + cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream); + if (r != cudaSuccess) return r; + } + return cudaSuccess; + } +} + static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -695,13 +723,13 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); // copy src1 to device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); + CUDA_CHECK(ggml_cuda_h2d_tensor_2d_hack(c_Y, src1, i03, i02, cudaStream, wdata)); // wait for data CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); // compute - dequantize_mul_mat_q4_0_cuda(c_Q, wdata + i * QK8_0, c_D, ne00, ne01, cudaStream); + dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream); CUDA_CHECK(cudaGetLastError()); } else {