diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 11925d59e..4bea0c52c 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -330,7 +330,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima if (!ggml_allocr_is_measure(ctx->compute_alloc)) { float scale = 1.0f / sqrt((float)d_head); ggml_backend_tensor_set(KQ_scale, &scale, 0, ggml_nbytes(KQ_scale)); - printf("alloc scale\n"); } // loop over layers @@ -424,11 +423,10 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima if (!ggml_allocr_is_measure(ctx->compute_alloc)) { int* patches_data = (int*)malloc(ggml_nbytes(patches)); for (int i = 0; i < num_positions; i++) { - patches_data[i] = i+1; + patches_data[i] = i + 1; } ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); free(patches_data); - printf("patches"); } embeddings = ggml_get_rows(ctx0, embeddings, patches); diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f0db7ae35..8ac3e6a81 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -510,6 +510,24 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co dst[i] = x[i] + y[i%ky]; } + +static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne, + const int ne10, const int ne11, const int ne12, + const int nb1, const int nb2) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + if (i >= ne) { + return; + } + int oz = i / nb2; + int oy = (i - (oz * nb2)) / nb1; + int ox = i % nb1; + if(ox < ne10 && oy < ne11 && oz < ne12) { + dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; + } else { + dst[i] = x[i]; + } +} + static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -568,6 +586,15 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) { dst[i] = fmaxf(x[i], 0); } +static __global__ void gelu_quick_f32(const float *x, float *dst, int k) { + const float GELU_QUICK_COEF = -1.702f; + const int i = blockDim.x*blockIdx.x + threadIdx.x; + if(i >= k) { + return; + } + dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i]))); +} + static __global__ void sqr_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -4810,6 +4837,13 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in add_f32<<>>(x, y, dst, kx, ky); } +static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements, + const int ne10, const int ne11, const int ne12, + const int nb1, const int nb2, cudaStream_t stream) { + int num_blocks = (n_elements + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; + acc_f32<<>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2); +} + static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; add_f16_f32_f16<<>>(x, y, dst, k); @@ -4840,6 +4874,11 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ relu_f32<<>>(x, dst, k); } +static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; + gelu_quick_f32<<>>(x, dst, k); +} + static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE; sqr_f32<<>>(x, dst, k); @@ -6205,6 +6244,23 @@ inline void ggml_cuda_op_mul( (void) dst; } +inline void ggml_cuda_op_acc( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported + + int nb1 = dst->nb[1] / 4; // 4 bytes of float32 + int nb2 = dst->nb[2] / 4; // 4 bytes of float32 + + acc_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, main_stream); + + (void) dst; +} + inline void ggml_cuda_op_gelu( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -6247,6 +6303,21 @@ inline void ggml_cuda_op_relu( (void) src1_dd; } + +inline void ggml_cuda_op_gelu_quick( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + gelu_quick_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + inline void ggml_cuda_op_sqr( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -7325,6 +7396,10 @@ static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, gg ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul); } +static void ggml_cuda_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_acc); +} + static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu); } @@ -7337,6 +7412,10 @@ static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, g ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu); } +static void ggml_cuda_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu_quick); +} + static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr); } @@ -8080,6 +8159,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_MUL: func = ggml_cuda_mul; break; + case GGML_OP_ACC: + func = ggml_cuda_acc; + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_GELU: @@ -8091,6 +8173,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_UNARY_OP_RELU: func = ggml_cuda_relu; break; + case GGML_UNARY_OP_GELU_QUICK: + func = ggml_cuda_gelu_quick; + break; default: return false; } break;