sync : ggml
This commit is contained in:
parent
3246fe84d7
commit
231cff5f6f
21 changed files with 1422 additions and 178 deletions
|
@ -9,8 +9,10 @@
|
|||
#include "ggml-cuda/binbcast.cuh"
|
||||
#include "ggml-cuda/clamp.cuh"
|
||||
#include "ggml-cuda/concat.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
#include "ggml-cuda/convert.cuh"
|
||||
#include "ggml-cuda/cpy.cuh"
|
||||
#include "ggml-cuda/cross-entropy-loss.cuh"
|
||||
#include "ggml-cuda/diagmask.cuh"
|
||||
#include "ggml-cuda/dmmv.cuh"
|
||||
#include "ggml-cuda/fattn.cuh"
|
||||
|
@ -29,7 +31,6 @@
|
|||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
@ -2181,6 +2182,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_ADD:
|
||||
ggml_cuda_op_add(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUB:
|
||||
ggml_cuda_op_sub(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ACC:
|
||||
ggml_cuda_op_acc(ctx, dst);
|
||||
break;
|
||||
|
@ -2267,6 +2271,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_SQRT:
|
||||
ggml_cuda_op_sqrt(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SIN:
|
||||
ggml_cuda_op_sin(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_COS:
|
||||
ggml_cuda_op_cos(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
ggml_cuda_op_clamp(ctx, dst);
|
||||
break;
|
||||
|
@ -2303,6 +2313,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_cuda_flash_attn_ext(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
@ -2610,6 +2623,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
if (node->src[j] != nullptr) {
|
||||
assert(node->src[j]->buffer);
|
||||
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
|
||||
}
|
||||
}
|
||||
|
@ -2853,12 +2867,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
|
@ -2890,6 +2907,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||
}
|
||||
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
|
||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
return true;
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
default:
|
||||
return false;
|
||||
|
|
|
@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) {
|
|||
return a + b;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_sub(const float a, const float b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_mul(const float a, const float b) {
|
||||
return a * b;
|
||||
}
|
||||
|
@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
||||
}
|
||||
|
||||
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
||||
}
|
||||
|
|
|
@ -2,5 +2,6 @@
|
|||
|
||||
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
106
ggml/src/ggml-cuda/cross-entropy-loss.cu
Normal file
106
ggml/src/ggml-cuda/cross-entropy-loss.cu
Normal file
|
@ -0,0 +1,106 @@
|
|||
#include "common.cuh"
|
||||
#include "cross-entropy-loss.cuh"
|
||||
#include "sumrows.cuh"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
||||
static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
|
||||
|
||||
const int ne_tmp = WARP_SIZE*nclasses;
|
||||
|
||||
extern __shared__ float tmp_all[];
|
||||
float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
|
||||
float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
|
||||
|
||||
// Each warp first loads ne_tmp logits/labels into shared memory:
|
||||
for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
|
||||
const int ig = i0*nclasses + i; // ig == i global
|
||||
|
||||
tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
|
||||
tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
|
||||
}
|
||||
|
||||
// Each thread in the warp then calculates the cross entropy loss for a single row.
|
||||
// TODO: pad in order to avoid shared memory bank conflicts.
|
||||
|
||||
// Find maximum for softmax:
|
||||
float max = -INFINITY;
|
||||
for (int i = 0; i < nclasses; ++i) {
|
||||
max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
|
||||
}
|
||||
|
||||
// Calculate log(softmax(logits)) which is just logits - max:
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < nclasses; ++i) {
|
||||
float val = tmp_logits[lane_id*nclasses + i] - max;
|
||||
sum += expf(val);
|
||||
tmp_logits[lane_id*nclasses + i] = val;
|
||||
}
|
||||
sum = logf(sum);
|
||||
|
||||
// log(exp(logits - max) / sum) = (logits - max) - log(sum)
|
||||
float loss = 0.0f;
|
||||
for (int i = 0; i < nclasses; ++i) {
|
||||
loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
|
||||
}
|
||||
loss = -warp_reduce_sum(loss) / (float)k;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (lane_id == 0) {
|
||||
tmp_all[warp_id] = loss;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
|
||||
loss = warp_reduce_sum(loss);
|
||||
|
||||
if (lane_id != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[blockIdx.x] = loss;
|
||||
}
|
||||
|
||||
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
|
||||
const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
|
||||
const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
|
||||
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
||||
|
||||
cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||
|
||||
// Combine results from individual blocks:
|
||||
sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
|
||||
}
|
5
ggml/src/ggml-cuda/cross-entropy-loss.cuh
Normal file
5
ggml/src/ggml-cuda/cross-entropy-loss.cuh
Normal file
|
@ -0,0 +1,5 @@
|
|||
#include "common.cuh"
|
||||
|
||||
#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
|
|||
}
|
||||
}
|
||||
|
||||
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
|
@ -32,7 +32,6 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
@ -101,6 +101,24 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
|
|||
dst[i] = sqrtf(x[i]);
|
||||
}
|
||||
|
||||
static __global__ void sin_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = sinf(x[i]);
|
||||
}
|
||||
|
||||
static __global__ void cos_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = cosf(x[i]);
|
||||
}
|
||||
|
||||
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
||||
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
|
@ -156,6 +174,16 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
|||
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
|
||||
sin_f32<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
|
||||
cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
|
@ -312,3 +340,31 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
|
||||
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
|
|
@ -9,6 +9,8 @@
|
|||
#define CUDA_HARDSWISH_BLOCK_SIZE 256
|
||||
#define CUDA_SQR_BLOCK_SIZE 256
|
||||
#define CUDA_SQRT_BLOCK_SIZE 256
|
||||
#define CUDA_SIN_BLOCK_SIZE 256
|
||||
#define CUDA_COS_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
|
@ -31,3 +33,7 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
@ -31,6 +31,8 @@ struct ggml_metal_kernel {
|
|||
enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_ADD,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_SUB,
|
||||
GGML_METAL_KERNEL_TYPE_SUB_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_MUL,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_DIV,
|
||||
|
@ -207,6 +209,9 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_CONCAT,
|
||||
GGML_METAL_KERNEL_TYPE_SQR,
|
||||
GGML_METAL_KERNEL_TYPE_SQRT,
|
||||
GGML_METAL_KERNEL_TYPE_SIN,
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
|
||||
GGML_METAL_KERNEL_TYPE_COUNT
|
||||
|
@ -493,6 +498,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
|||
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
||||
|
@ -669,6 +676,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
}
|
||||
|
||||
|
@ -769,15 +779,20 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
return true;
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
|
@ -1057,6 +1072,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
{
|
||||
|
@ -1080,6 +1096,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
nb = ne00 / 4;
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
||||
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
|
@ -1089,6 +1106,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
} else {
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
||||
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
|
@ -1416,6 +1434,48 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SQRT:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SIN:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_COS:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
|
|
|
@ -17,7 +17,7 @@ enum ggml_sort_order {
|
|||
GGML_SORT_ORDER_DESC,
|
||||
};
|
||||
|
||||
// general-purpose kernel for addition, multiplication and division of two tensors
|
||||
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
|
||||
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
||||
// cons: not very efficient
|
||||
kernel void kernel_add(
|
||||
|
@ -70,6 +70,56 @@ kernel void kernel_add(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_sub(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
constant int64_t & offs,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig.z;
|
||||
const int64_t i02 = tgpig.y;
|
||||
const int64_t i01 = tgpig.x;
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
||||
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
||||
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
const int i10 = i0 % ne10;
|
||||
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
|
@ -226,6 +276,15 @@ kernel void kernel_add_row(
|
|||
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
||||
}
|
||||
|
||||
kernel void kernel_sub_row(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant uint64_t & nb [[buffer(28)]],
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] - src1[tpig % nb];
|
||||
}
|
||||
|
||||
kernel void kernel_mul_row(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
|
@ -358,6 +417,27 @@ kernel void kernel_sqr(
|
|||
dst[tpig] = src0[tpig] * src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_sqrt(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sqrt(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sin(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sin(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_cos(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = cos(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sum_rows(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
|
|
@ -3644,7 +3644,7 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
|
|||
quantize_row_q8_K_ref(x, y, k);
|
||||
}
|
||||
|
||||
//===================================== Dot ptoducts =================================
|
||||
//===================================== Dot products =================================
|
||||
|
||||
//
|
||||
// Helper functions
|
||||
|
|
|
@ -188,6 +188,8 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_upscale_f32;
|
||||
vk_pipeline pipeline_scale_f32;
|
||||
vk_pipeline pipeline_sqr_f32;
|
||||
vk_pipeline pipeline_sin_f32;
|
||||
vk_pipeline pipeline_cos_f32;
|
||||
vk_pipeline pipeline_clamp_f32;
|
||||
vk_pipeline pipeline_pad_f32;
|
||||
vk_pipeline pipeline_repeat_f32;
|
||||
|
@ -1702,6 +1704,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
|
@ -4023,6 +4027,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_sqr_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SIN:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sin_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_COS:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_cos_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CLAMP:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_clamp_f32;
|
||||
|
@ -4171,6 +4185,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_REPEAT:
|
||||
|
@ -4381,6 +4397,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
case GGML_OP_MUL:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_REPEAT:
|
||||
|
@ -4598,6 +4616,32 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
|
@ -5658,6 +5702,8 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_CPY:
|
||||
|
@ -5735,6 +5781,14 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_SQR:
|
||||
ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_SIN:
|
||||
ggml_vk_sin(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_COS:
|
||||
ggml_vk_cos(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
@ -5851,6 +5905,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_CPY:
|
||||
|
@ -6582,6 +6638,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
|||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_CONT:
|
||||
|
@ -7024,6 +7082,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
|
||||
} else if (tensor->op == GGML_OP_SQR) {
|
||||
tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
|
||||
} else if (tensor->op == GGML_OP_SIN) {
|
||||
tensor_clone = ggml_sin(ggml_ctx, src0_clone);
|
||||
} else if (tensor->op == GGML_OP_COS) {
|
||||
tensor_clone = ggml_cos(ggml_ctx, src0_clone);
|
||||
} else if (tensor->op == GGML_OP_CLAMP) {
|
||||
tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
||||
} else if (tensor->op == GGML_OP_PAD) {
|
||||
|
|
704
ggml/src/ggml.c
704
ggml/src/ggml.c
File diff suppressed because it is too large
Load diff
15
ggml/src/vulkan-shaders/cos.comp
Normal file
15
ggml/src/vulkan-shaders/cos.comp
Normal file
|
@ -0,0 +1,15 @@
|
|||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
|
||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val));
|
||||
}
|
15
ggml/src/vulkan-shaders/sin.comp
Normal file
15
ggml/src/vulkan-shaders/sin.comp
Normal file
|
@ -0,0 +1,15 @@
|
|||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
|
||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val));
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue