add missing kernels
This commit is contained in:
parent
08e7afacf7
commit
f4f0b06a9c
2 changed files with 86 additions and 3 deletions
|
@ -330,7 +330,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
||||||
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
||||||
float scale = 1.0f / sqrt((float)d_head);
|
float scale = 1.0f / sqrt((float)d_head);
|
||||||
ggml_backend_tensor_set(KQ_scale, &scale, 0, ggml_nbytes(KQ_scale));
|
ggml_backend_tensor_set(KQ_scale, &scale, 0, ggml_nbytes(KQ_scale));
|
||||||
printf("alloc scale\n");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop over layers
|
// loop over layers
|
||||||
|
@ -424,11 +423,10 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
|
||||||
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
|
||||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||||
for (int i = 0; i < num_positions; i++) {
|
for (int i = 0; i < num_positions; i++) {
|
||||||
patches_data[i] = i+1;
|
patches_data[i] = i + 1;
|
||||||
}
|
}
|
||||||
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||||
free(patches_data);
|
free(patches_data);
|
||||||
printf("patches");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings = ggml_get_rows(ctx0, embeddings, patches);
|
embeddings = ggml_get_rows(ctx0, embeddings, patches);
|
||||||
|
|
85
ggml-cuda.cu
85
ggml-cuda.cu
|
@ -510,6 +510,24 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
|
||||||
dst[i] = x[i] + y[i%ky];
|
dst[i] = x[i] + y[i%ky];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||||
|
const int ne10, const int ne11, const int ne12,
|
||||||
|
const int nb1, const int nb2) {
|
||||||
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
if (i >= ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int oz = i / nb2;
|
||||||
|
int oy = (i - (oz * nb2)) / nb1;
|
||||||
|
int ox = i % nb1;
|
||||||
|
if(ox < ne10 && oy < ne11 && oz < ne12) {
|
||||||
|
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
|
||||||
|
} else {
|
||||||
|
dst[i] = x[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
|
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
@ -568,6 +586,15 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
||||||
dst[i] = fmaxf(x[i], 0);
|
dst[i] = fmaxf(x[i], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
|
||||||
|
const float GELU_QUICK_COEF = -1.702f;
|
||||||
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
if(i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
@ -4810,6 +4837,13 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
|
||||||
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
|
||||||
|
const int ne10, const int ne11, const int ne12,
|
||||||
|
const int nb1, const int nb2, cudaStream_t stream) {
|
||||||
|
int num_blocks = (n_elements + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
|
||||||
|
acc_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2);
|
||||||
|
}
|
||||||
|
|
||||||
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
|
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
|
||||||
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
||||||
|
@ -4840,6 +4874,11 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
||||||
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||||
|
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
||||||
|
gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
|
}
|
||||||
|
|
||||||
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
|
||||||
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
|
@ -6205,6 +6244,23 @@ inline void ggml_cuda_op_mul(
|
||||||
(void) dst;
|
(void) dst;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void ggml_cuda_op_acc(
|
||||||
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
|
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
|
||||||
|
|
||||||
|
int nb1 = dst->nb[1] / 4; // 4 bytes of float32
|
||||||
|
int nb2 = dst->nb[2] / 4; // 4 bytes of float32
|
||||||
|
|
||||||
|
acc_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, main_stream);
|
||||||
|
|
||||||
|
(void) dst;
|
||||||
|
}
|
||||||
|
|
||||||
inline void ggml_cuda_op_gelu(
|
inline void ggml_cuda_op_gelu(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||||
|
@ -6247,6 +6303,21 @@ inline void ggml_cuda_op_relu(
|
||||||
(void) src1_dd;
|
(void) src1_dd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline void ggml_cuda_op_gelu_quick(
|
||||||
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
|
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
gelu_quick_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||||
|
|
||||||
|
(void) src1;
|
||||||
|
(void) dst;
|
||||||
|
(void) src1_dd;
|
||||||
|
}
|
||||||
|
|
||||||
inline void ggml_cuda_op_sqr(
|
inline void ggml_cuda_op_sqr(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||||
|
@ -7325,6 +7396,10 @@ static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_acc);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
|
||||||
}
|
}
|
||||||
|
@ -7337,6 +7412,10 @@ static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, g
|
||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu_quick);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
|
||||||
}
|
}
|
||||||
|
@ -8080,6 +8159,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
func = ggml_cuda_mul;
|
func = ggml_cuda_mul;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_ACC:
|
||||||
|
func = ggml_cuda_acc;
|
||||||
|
break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(tensor)) {
|
switch (ggml_get_unary_op(tensor)) {
|
||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
|
@ -8091,6 +8173,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
func = ggml_cuda_relu;
|
func = ggml_cuda_relu;
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
|
func = ggml_cuda_gelu_quick;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
} break;
|
} break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue