Cleaned up code, added comments
This commit is contained in:
parent
51830ee5e6
commit
a47072b85d
2 changed files with 30 additions and 26 deletions
46
ggml-cuda.cu
46
ggml-cuda.cu
|
@ -1,4 +1,3 @@
|
|||
#include <climits>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
@ -745,8 +744,6 @@ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y
|
|||
|
||||
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
|
||||
const half * x = (half *) vx;
|
||||
// const int col_x = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
// const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
|
||||
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
|
@ -792,7 +789,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
|
|||
}
|
||||
}
|
||||
|
||||
static __global__ void mul_mat_vec_nc_f16_f32(
|
||||
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
||||
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
|
||||
|
||||
|
@ -862,6 +859,8 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
|||
return;
|
||||
}
|
||||
|
||||
// determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
||||
// then combine those indices with the corresponding byte offsets to get the total offsets
|
||||
const int i02 = i / (ne00*ne01);
|
||||
const int i01 = (i - i02*ne01*ne00) / ne00;
|
||||
const int i00 = i - i02*ne01*ne00 - i01*ne00;
|
||||
|
@ -875,6 +874,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
|||
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
// rope == RoPE == rotary positional embedding
|
||||
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
|
||||
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
||||
|
||||
|
@ -909,6 +909,10 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|||
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
|
||||
}
|
||||
|
||||
// the CUDA soft max implementation differs from the CPU implementation
|
||||
// instead of doubles floats are used
|
||||
// values are also not normalized to the maximum value by subtracting it in the exponential function
|
||||
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
|
||||
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int block_size = blockDim.x;
|
||||
|
@ -1374,7 +1378,7 @@ void ggml_cuda_host_free(void * ptr) {
|
|||
CUDA_CHECK(cudaFreeHost(ptr));
|
||||
}
|
||||
|
||||
static cudaError_t ggml_cuda_tensor_2d(
|
||||
static cudaError_t ggml_cuda_cpy_tensor_2d(
|
||||
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
|
||||
|
||||
cudaMemcpyKind kind;
|
||||
|
@ -1405,7 +1409,7 @@ static cudaError_t ggml_cuda_tensor_2d(
|
|||
|
||||
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
|
||||
if (nb0 == ts && nb1 == ts*ne0/bs) {
|
||||
return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyHostToDevice, stream);
|
||||
return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
|
||||
} else if (nb0 == ts) {
|
||||
return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
|
||||
} else {
|
||||
|
@ -1755,9 +1759,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
const size_t src0_ts = ggml_type_size(src0->type);
|
||||
const size_t src0_bs = ggml_blck_size(src0->type);
|
||||
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
|
||||
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
||||
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
|
||||
const bool src0_is_contiguous = ggml_is_contiguous(src0);
|
||||
|
@ -1844,12 +1848,15 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
const int64_t i03_max = flatten_rows ? 1 : ne03;
|
||||
const int64_t i02_max = flatten_rows ? 1 : ne02;
|
||||
const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
|
||||
|
||||
for (int64_t i03 = 0; i03 < i03_max; i03++) {
|
||||
const int64_t i13 = i03 % ne13;
|
||||
for (int64_t i02 = 0; i02 < i02_max; i02++) {
|
||||
const int64_t i12 = i02 % ne12;
|
||||
|
||||
const int64_t i0 = i03*ne02 + i02;
|
||||
|
||||
// i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
|
||||
const int64_t i0_offset_low = row_low/rows_per_iter;
|
||||
const int64_t i0_offset_high = row_high/rows_per_iter;
|
||||
|
||||
|
@ -1880,15 +1887,15 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
}
|
||||
const int64_t i11 = i13*ne12 + i12;
|
||||
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
|
||||
cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
|
||||
cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
|
||||
cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
|
||||
|
||||
// for split tensors the data begins at i0 == i0_offset_low
|
||||
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
|
||||
float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
|
||||
float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
|
||||
float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
|
||||
float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
|
||||
|
||||
// for split tensors the data pointer needs to be rounded down
|
||||
// to the bin edge for i03, i02 bins beyond the first
|
||||
|
@ -1910,7 +1917,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
|
||||
int64_t nrows1 = flatten_rows ? nrows0 : ne11;
|
||||
CUDA_CHECK(ggml_cuda_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
|
||||
} else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
|
||||
if (id != g_main_device) {
|
||||
GGML_ASSERT(!flatten_rows);
|
||||
|
@ -1921,21 +1928,22 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
}
|
||||
} else if (src1_on_device && !src1_is_contiguous) {
|
||||
GGML_ASSERT(!split);
|
||||
CUDA_CHECK(ggml_cuda_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main));
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main));
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
|
||||
|
||||
if (!src0_on_device || !src0_is_contiguous) {
|
||||
if (src0_is_f32) {
|
||||
CUDA_CHECK(ggml_cuda_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
|
||||
} else {
|
||||
CUDA_CHECK(ggml_cuda_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
|
||||
}
|
||||
}
|
||||
|
||||
// convert src0 to f32 if it's necessary for the ggml_cuda_op
|
||||
// convert src0 to f32 if it is necessary for the ggml_cuda_op
|
||||
if (src0_needs_f32 && !src0_is_f32) {
|
||||
to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
@ -2267,6 +2275,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
|||
return;
|
||||
}
|
||||
|
||||
// recursively assign CUDA buffers until a compute tensor is found
|
||||
if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
|
||||
const ggml_op src0_op = tensor->src0->op;
|
||||
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
|
||||
|
@ -2310,13 +2319,10 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
|||
}
|
||||
extra->data_device[g_main_device] = data + g_scratch_offset;
|
||||
|
||||
// fprintf(stderr, "data=%p offset=%ld data_device=%p\n", data, g_scratch_offset, extra->data_device[0]);
|
||||
g_scratch_offset += size;
|
||||
// fprintf(stderr, "%s: scratch %d, %p - %p\n",
|
||||
// tensor->name, g_scratch_index, data + g_scratch_offset, data + g_scratch_offset + size);
|
||||
|
||||
GGML_ASSERT(g_scratch_offset <= g_scratch_size);
|
||||
} else {
|
||||
} else { // allocate new buffers outside of scratch
|
||||
void * data;
|
||||
CUDA_CHECK(cudaMalloc(&data, size));
|
||||
CUDA_CHECK(cudaMemset(data, 0, size));
|
||||
|
|
10
llama.cpp
10
llama.cpp
|
@ -1241,10 +1241,10 @@ static void llama_model_load_internal(
|
|||
} else {
|
||||
vram_scratch = n_batch * MB;
|
||||
ggml_cuda_set_scratch_size(vram_scratch);
|
||||
}
|
||||
if (n_gpu_layers > 0) {
|
||||
fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n",
|
||||
__func__, vram_scratch / MB);
|
||||
if (n_gpu_layers > 0) {
|
||||
fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n",
|
||||
__func__, vram_scratch / MB);
|
||||
}
|
||||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
||||
|
@ -1572,7 +1572,6 @@ static bool llama_eval_internal(
|
|||
}
|
||||
|
||||
lctx.use_buf(ctx0, 1);
|
||||
//ggml_cuda_set_scratch(1);
|
||||
|
||||
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
||||
offload_func(inpFF);
|
||||
|
@ -1630,7 +1629,6 @@ static bool llama_eval_internal(
|
|||
}
|
||||
|
||||
lctx.use_buf(ctx0, 0);
|
||||
//ggml_cuda_set_scratch(0);
|
||||
|
||||
// used at the end to optionally extract the embeddings
|
||||
struct ggml_tensor * embeddings = NULL;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue