Cleaned up code, added comments

This commit is contained in:
JohannesGaessler 2023-06-14 00:00:53 +02:00
parent 51830ee5e6
commit a47072b85d
2 changed files with 30 additions and 26 deletions

View file

@ -1,4 +1,3 @@
#include <climits>
#include <cstddef>
#include <cstdint>
#include <limits>
@ -745,8 +744,6 @@ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
const half * x = (half *) vx;
// const int col_x = blockDim.x*blockIdx.x + threadIdx.x;
// const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
@ -792,7 +789,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
}
}
static __global__ void mul_mat_vec_nc_f16_f32(
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
@ -862,6 +859,8 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
return;
}
// determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
// then combine those indices with the corresponding byte offsets to get the total offsets
const int i02 = i / (ne00*ne01);
const int i01 = (i - i02*ne01*ne00) / ne00;
const int i00 = i - i02*ne01*ne00 - i01*ne00;
@ -875,6 +874,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset);
}
// rope == RoPE == rotary positional embedding
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
@ -909,6 +909,10 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
}
// the CUDA soft max implementation differs from the CPU implementation
// instead of doubles floats are used
// values are also not normalized to the maximum value by subtracting it in the exponential function
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int block_size = blockDim.x;
@ -1374,7 +1378,7 @@ void ggml_cuda_host_free(void * ptr) {
CUDA_CHECK(cudaFreeHost(ptr));
}
static cudaError_t ggml_cuda_tensor_2d(
static cudaError_t ggml_cuda_cpy_tensor_2d(
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
cudaMemcpyKind kind;
@ -1405,7 +1409,7 @@ static cudaError_t ggml_cuda_tensor_2d(
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
if (nb0 == ts && nb1 == ts*ne0/bs) {
return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyHostToDevice, stream);
return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
} else if (nb0 == ts) {
return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
} else {
@ -1844,12 +1848,15 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
const int64_t i03_max = flatten_rows ? 1 : ne03;
const int64_t i02_max = flatten_rows ? 1 : ne02;
const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
for (int64_t i03 = 0; i03 < i03_max; i03++) {
const int64_t i13 = i03 % ne13;
for (int64_t i02 = 0; i02 < i02_max; i02++) {
const int64_t i12 = i02 % ne12;
const int64_t i0 = i03*ne02 + i02;
// i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
const int64_t i0_offset_low = row_low/rows_per_iter;
const int64_t i0_offset_high = row_high/rows_per_iter;
@ -1910,7 +1917,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
if (src1->backend == GGML_BACKEND_CPU) {
GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
int64_t nrows1 = flatten_rows ? nrows0 : ne11;
CUDA_CHECK(ggml_cuda_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
} else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
if (id != g_main_device) {
GGML_ASSERT(!flatten_rows);
@ -1921,21 +1928,22 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
}
} else if (src1_on_device && !src1_is_contiguous) {
GGML_ASSERT(!split);
CUDA_CHECK(ggml_cuda_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main));
} else {
GGML_ASSERT(false);
}
}
CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
if (!src0_on_device || !src0_is_contiguous) {
if (src0_is_f32) {
CUDA_CHECK(ggml_cuda_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
} else {
CUDA_CHECK(ggml_cuda_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
}
}
// convert src0 to f32 if it's necessary for the ggml_cuda_op
// convert src0 to f32 if it is necessary for the ggml_cuda_op
if (src0_needs_f32 && !src0_is_f32) {
to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main);
CUDA_CHECK(cudaGetLastError());
@ -2267,6 +2275,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
return;
}
// recursively assign CUDA buffers until a compute tensor is found
if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
const ggml_op src0_op = tensor->src0->op;
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
@ -2310,13 +2319,10 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
}
extra->data_device[g_main_device] = data + g_scratch_offset;
// fprintf(stderr, "data=%p offset=%ld data_device=%p\n", data, g_scratch_offset, extra->data_device[0]);
g_scratch_offset += size;
// fprintf(stderr, "%s: scratch %d, %p - %p\n",
// tensor->name, g_scratch_index, data + g_scratch_offset, data + g_scratch_offset + size);
GGML_ASSERT(g_scratch_offset <= g_scratch_size);
} else {
} else { // allocate new buffers outside of scratch
void * data;
CUDA_CHECK(cudaMalloc(&data, size));
CUDA_CHECK(cudaMemset(data, 0, size));

View file

@ -1241,11 +1241,11 @@ static void llama_model_load_internal(
} else {
vram_scratch = n_batch * MB;
ggml_cuda_set_scratch_size(vram_scratch);
}
if (n_gpu_layers > 0) {
fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n",
__func__, vram_scratch / MB);
}
}
#endif // GGML_USE_CUBLAS
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
@ -1572,7 +1572,6 @@ static bool llama_eval_internal(
}
lctx.use_buf(ctx0, 1);
//ggml_cuda_set_scratch(1);
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
offload_func(inpFF);
@ -1630,7 +1629,6 @@ static bool llama_eval_internal(
}
lctx.use_buf(ctx0, 0);
//ggml_cuda_set_scratch(0);
// used at the end to optionally extract the embeddings
struct ggml_tensor * embeddings = NULL;