diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 9c1787609..4920d7968 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,4 +1,3 @@ -#include #include #include #include @@ -745,8 +744,6 @@ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { const half * x = (half *) vx; - // const int col_x = blockDim.x*blockIdx.x + threadIdx.x; - // const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; @@ -792,7 +789,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl } } -static __global__ void mul_mat_vec_nc_f16_f32( +static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, const int nchannels_x, const int channel_stride_x) { @@ -862,6 +859,8 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, return; } + // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor + // then combine those indices with the corresponding byte offsets to get the total offsets const int i02 = i / (ne00*ne01); const int i01 = (i - i02*ne01*ne00) / ne00; const int i00 = i - i02*ne01*ne00 - i01*ne00; @@ -875,6 +874,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } +// rope == RoPE == rotary positional embedding static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) { const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); @@ -909,6 +909,10 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU } +// the CUDA soft max implementation differs from the CPU implementation +// instead of doubles floats are used +// values are also not normalized to the maximum value by subtracting it in the exponential function +// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { const int row = blockDim.y*blockIdx.y + threadIdx.y; const int block_size = blockDim.x; @@ -1374,7 +1378,7 @@ void ggml_cuda_host_free(void * ptr) { CUDA_CHECK(cudaFreeHost(ptr)); } -static cudaError_t ggml_cuda_tensor_2d( +static cudaError_t ggml_cuda_cpy_tensor_2d( void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { cudaMemcpyKind kind; @@ -1405,7 +1409,7 @@ static cudaError_t ggml_cuda_tensor_2d( const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; if (nb0 == ts && nb1 == ts*ne0/bs) { - return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyHostToDevice, stream); + return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); } else if (nb0 == ts) { return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream); } else { @@ -1755,9 +1759,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm const size_t src0_ts = ggml_type_size(src0->type); const size_t src0_bs = ggml_blck_size(src0->type); - struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; - struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; const bool src0_is_contiguous = ggml_is_contiguous(src0); @@ -1844,12 +1848,15 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm const int64_t i03_max = flatten_rows ? 1 : ne03; const int64_t i02_max = flatten_rows ? 1 : ne02; const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01; + for (int64_t i03 = 0; i03 < i03_max; i03++) { const int64_t i13 = i03 % ne13; for (int64_t i02 = 0; i02 < i02_max; i02++) { const int64_t i12 = i02 % ne12; const int64_t i0 = i03*ne02 + i02; + + // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs const int64_t i0_offset_low = row_low/rows_per_iter; const int64_t i0_offset_high = row_high/rows_per_iter; @@ -1880,15 +1887,15 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } const int64_t i11 = i13*ne12 + i12; - cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS]; + cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS]; cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS]; - cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS]; + cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS]; // for split tensors the data begins at i0 == i0_offset_low char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs; float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride; float * src1_ddf_i = src1_ddf[id] + i11*src1_stride; - float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; + float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; // for split tensors the data pointer needs to be rounded down // to the bin edge for i03, i02 bins beyond the first @@ -1910,7 +1917,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm if (src1->backend == GGML_BACKEND_CPU) { GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1)); int64_t nrows1 = flatten_rows ? nrows0 : ne11; - CUDA_CHECK(ggml_cuda_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1)); } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { if (id != g_main_device) { GGML_ASSERT(!flatten_rows); @@ -1921,21 +1928,22 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } } else if (src1_on_device && !src1_is_contiguous) { GGML_ASSERT(!split); - CUDA_CHECK(ggml_cuda_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main)); } else { GGML_ASSERT(false); } } CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1)); + if (!src0_on_device || !src0_is_contiguous) { if (src0_is_f32) { - CUDA_CHECK(ggml_cuda_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); } else { - CUDA_CHECK(ggml_cuda_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); } } - // convert src0 to f32 if it's necessary for the ggml_cuda_op + // convert src0 to f32 if it is necessary for the ggml_cuda_op if (src0_needs_f32 && !src0_is_f32) { to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main); CUDA_CHECK(cudaGetLastError()); @@ -2267,6 +2275,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { return; } + // recursively assign CUDA buffers until a compute tensor is found if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) { const ggml_op src0_op = tensor->src0->op; if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) { @@ -2310,13 +2319,10 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { } extra->data_device[g_main_device] = data + g_scratch_offset; - // fprintf(stderr, "data=%p offset=%ld data_device=%p\n", data, g_scratch_offset, extra->data_device[0]); g_scratch_offset += size; - // fprintf(stderr, "%s: scratch %d, %p - %p\n", - // tensor->name, g_scratch_index, data + g_scratch_offset, data + g_scratch_offset + size); GGML_ASSERT(g_scratch_offset <= g_scratch_size); - } else { + } else { // allocate new buffers outside of scratch void * data; CUDA_CHECK(cudaMalloc(&data, size)); CUDA_CHECK(cudaMemset(data, 0, size)); diff --git a/llama.cpp b/llama.cpp index b2a6369e6..2027662c0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1241,10 +1241,10 @@ static void llama_model_load_internal( } else { vram_scratch = n_batch * MB; ggml_cuda_set_scratch_size(vram_scratch); - } - if (n_gpu_layers > 0) { - fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n", - __func__, vram_scratch / MB); + if (n_gpu_layers > 0) { + fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n", + __func__, vram_scratch / MB); + } } #endif // GGML_USE_CUBLAS #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) @@ -1572,7 +1572,6 @@ static bool llama_eval_internal( } lctx.use_buf(ctx0, 1); - //ggml_cuda_set_scratch(1); struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); offload_func(inpFF); @@ -1630,7 +1629,6 @@ static bool llama_eval_internal( } lctx.use_buf(ctx0, 0); - //ggml_cuda_set_scratch(0); // used at the end to optionally extract the embeddings struct ggml_tensor * embeddings = NULL;