sync : ggml

This commit is contained in:
Georgi Gerganov 2024-08-27 22:01:45 +03:00
parent 3246fe84d7
commit 231cff5f6f
21 changed files with 1422 additions and 178 deletions

View file

@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) {
return a + b;
}
static __device__ __forceinline__ float op_sub(const float a, const float b) {
return a - b;
}
static __device__ __forceinline__ float op_mul(const float a, const float b) {
return a * b;
}
@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}

View file

@ -2,5 +2,6 @@
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -0,0 +1,106 @@
#include "common.cuh"
#include "cross-entropy-loss.cuh"
#include "sumrows.cuh"
#include <cmath>
#include <cstdint>
static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
const int ne_tmp = WARP_SIZE*nclasses;
extern __shared__ float tmp_all[];
float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
// Each warp first loads ne_tmp logits/labels into shared memory:
for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
const int ig = i0*nclasses + i; // ig == i global
tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
}
// Each thread in the warp then calculates the cross entropy loss for a single row.
// TODO: pad in order to avoid shared memory bank conflicts.
// Find maximum for softmax:
float max = -INFINITY;
for (int i = 0; i < nclasses; ++i) {
max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
}
// Calculate log(softmax(logits)) which is just logits - max:
float sum = 0.0f;
for (int i = 0; i < nclasses; ++i) {
float val = tmp_logits[lane_id*nclasses + i] - max;
sum += expf(val);
tmp_logits[lane_id*nclasses + i] = val;
}
sum = logf(sum);
// log(exp(logits - max) / sum) = (logits - max) - log(sum)
float loss = 0.0f;
for (int i = 0; i < nclasses; ++i) {
loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
}
loss = -warp_reduce_sum(loss) / (float)k;
__syncthreads();
if (lane_id == 0) {
tmp_all[warp_id] = loss;
}
__syncthreads();
if (warp_id != 0) {
return;
}
loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
loss = warp_reduce_sum(loss);
if (lane_id != 0) {
return;
}
dst[blockIdx.x] = loss;
}
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t stream = ctx.stream();
const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
// Combine results from individual blocks:
sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
}

View file

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
}
}
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
@ -32,7 +32,6 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);

View file

@ -1,3 +1,5 @@
#include "common.cuh"
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -101,6 +101,24 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
dst[i] = sqrtf(x[i]);
}
static __global__ void sin_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = sinf(x[i]);
}
static __global__ void cos_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = cosf(x[i]);
}
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@ -156,6 +174,16 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
sin_f32<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@ -312,3 +340,31 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}

View file

@ -9,6 +9,8 @@
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256
#define CUDA_SIN_BLOCK_SIZE 256
#define CUDA_COS_BLOCK_SIZE 256
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@ -31,3 +33,7 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);