tests: add gradient tests for all backends (ggml/932)
* tests: add gradient checking to test-backend-ops * remove old comment * reorder includes * adjust SIN/COS parameters * add documentation, use supports_op if possible
This commit is contained in:
parent
dbbebcab33
commit
202084d31d
10 changed files with 1080 additions and 92 deletions
|
@ -1,6 +1,6 @@
|
|||
#include "common.cuh"
|
||||
#include "cross-entropy-loss.cuh"
|
||||
#include "sumrows.cuh"
|
||||
#include "sum.cuh"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
@ -102,5 +102,5 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||
|
||||
// Combine results from individual blocks:
|
||||
sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
|
||||
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
|
||||
}
|
||||
|
|
41
ggml/src/ggml-cuda/sum.cu
Normal file
41
ggml/src/ggml-cuda/sum.cu
Normal file
|
@ -0,0 +1,41 @@
|
|||
#include "sumrows.cuh"
|
||||
#include "sum.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
#include <cub/cub.cuh>
|
||||
using namespace cub;
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
|
||||
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
size_t tmp_size = 0;
|
||||
DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
|
||||
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
|
||||
DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);
|
||||
#else
|
||||
// Use (inefficient) sum_rows implementation as a fallback.
|
||||
// For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
|
||||
sum_rows_f32_cuda(x, dst, ne, 1, stream);
|
||||
GGML_UNUSED(pool);
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
sum_f32_cuda(pool, src0_d, dst_d, ne, stream);
|
||||
}
|
5
ggml/src/ggml-cuda/sum.cuh
Normal file
5
ggml/src/ggml-cuda/sum.cuh
Normal file
|
@ -0,0 +1,5 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -1,5 +1,15 @@
|
|||
#include "unary.cuh"
|
||||
|
||||
static __global__ void neg_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[i] = -x[i];
|
||||
}
|
||||
|
||||
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
@ -119,6 +129,11 @@ static __global__ void cos_f32(const float * x, float * dst, const int k) {
|
|||
dst[i] = cosf(x[i]);
|
||||
}
|
||||
|
||||
static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
|
||||
neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
||||
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
|
@ -184,6 +199,20 @@ static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
|
|||
cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "common.cuh"
|
||||
|
||||
#define CUDA_NEG_BLOCK_SIZE 256
|
||||
#define CUDA_GELU_BLOCK_SIZE 256
|
||||
#define CUDA_SILU_BLOCK_SIZE 256
|
||||
#define CUDA_TANH_BLOCK_SIZE 256
|
||||
|
@ -12,6 +13,8 @@
|
|||
#define CUDA_SIN_BLOCK_SIZE 256
|
||||
#define CUDA_COS_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue