add missing kernels

This commit is contained in:
FSSRepo 2023-11-24 11:55:30 -05:00
parent 08e7afacf7
commit f4f0b06a9c
2 changed files with 86 additions and 3 deletions

View file

@ -330,7 +330,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
if (!ggml_allocr_is_measure(ctx->compute_alloc)) { if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
float scale = 1.0f / sqrt((float)d_head); float scale = 1.0f / sqrt((float)d_head);
ggml_backend_tensor_set(KQ_scale, &scale, 0, ggml_nbytes(KQ_scale)); ggml_backend_tensor_set(KQ_scale, &scale, 0, ggml_nbytes(KQ_scale));
printf("alloc scale\n");
} }
// loop over layers // loop over layers
@ -428,7 +427,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
} }
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
free(patches_data); free(patches_data);
printf("patches");
} }
embeddings = ggml_get_rows(ctx0, embeddings, patches); embeddings = ggml_get_rows(ctx0, embeddings, patches);

View file

@ -510,6 +510,24 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
dst[i] = x[i] + y[i%ky]; dst[i] = x[i] + y[i%ky];
} }
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
int oz = i / nb2;
int oy = (i - (oz * nb2)) / nb1;
int ox = i % nb1;
if(ox < ne10 && oy < ne11 && oz < ne12) {
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
} else {
dst[i] = x[i];
}
}
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -568,6 +586,15 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
dst[i] = fmaxf(x[i], 0); dst[i] = fmaxf(x[i], 0);
} }
static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
const float GELU_QUICK_COEF = -1.702f;
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if(i >= k) {
return;
}
dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
}
static __global__ void sqr_f32(const float * x, float * dst, const int k) { static __global__ void sqr_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -4810,6 +4837,13 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky); add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
} }
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2, cudaStream_t stream) {
int num_blocks = (n_elements + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
acc_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2);
}
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k); add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
@ -4840,6 +4874,11 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k); relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
} }
static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE; const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k); sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@ -6205,6 +6244,23 @@ inline void ggml_cuda_op_mul(
(void) dst; (void) dst;
} }
inline void ggml_cuda_op_acc(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
int nb1 = dst->nb[1] / 4; // 4 bytes of float32
int nb2 = dst->nb[2] / 4; // 4 bytes of float32
acc_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, main_stream);
(void) dst;
}
inline void ggml_cuda_op_gelu( inline void ggml_cuda_op_gelu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@ -6247,6 +6303,21 @@ inline void ggml_cuda_op_relu(
(void) src1_dd; (void) src1_dd;
} }
inline void ggml_cuda_op_gelu_quick(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
gelu_quick_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_cuda_op_sqr( inline void ggml_cuda_op_sqr(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@ -7325,6 +7396,10 @@ static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, gg
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul); ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
} }
static void ggml_cuda_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_acc);
}
static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu); ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
} }
@ -7337,6 +7412,10 @@ static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, g
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu); ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
} }
static void ggml_cuda_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu_quick);
}
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr); ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
} }
@ -8080,6 +8159,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_MUL: case GGML_OP_MUL:
func = ggml_cuda_mul; func = ggml_cuda_mul;
break; break;
case GGML_OP_ACC:
func = ggml_cuda_acc;
break;
case GGML_OP_UNARY: case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) { switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
@ -8091,6 +8173,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
func = ggml_cuda_relu; func = ggml_cuda_relu;
break; break;
case GGML_UNARY_OP_GELU_QUICK:
func = ggml_cuda_gelu_quick;
break;
default: default:
return false; return false;
} break; } break;