style fixes

This commit is contained in:
slaren 2023-12-25 21:23:31 +01:00
parent 561f1f9500
commit 0dcc1a77d7
2 changed files with 83 additions and 91 deletions

View file

@ -163,7 +163,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
#if __has_builtin(__builtin_elementwise_sub_sat)
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
return reinterpret_cast<const int&>(c);
return reinterpret_cast<const int &>(c);
#else
int8x4_t c;
int16_t tmp;
@ -174,7 +174,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
c[i] = tmp;
}
return reinterpret_cast<int&>(c);
return reinterpret_cast<int &>(c);
#endif // __has_builtin(__builtin_elementwise_sub_sat)
}
@ -257,7 +257,6 @@ static void ggml_cuda_error(const char * stmt, const char * func, const char * f
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
#if !defined(GGML_USE_HIPBLAS)
static const char * cu_get_error_str(CUresult err) {
const char * err_str;
@ -321,10 +320,10 @@ typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * s
typedef void (*ggml_cuda_op_mul_mat_t)(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, const cudaStream_t & stream);
const int64_t src1_padded_row_size, cudaStream_t stream);
typedef void (*ggml_cuda_op_flatten_t)(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream);
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream);
// QK = number of values after dequantization
// QR = QK / number of values before dequantization
@ -594,6 +593,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
GGML_UNUSED(a);
}
static __device__ __forceinline__ float op_add(const float a, const float b) {
@ -715,7 +715,7 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
const float GELU_QUICK_COEF = -1.702f;
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@ -724,7 +724,7 @@ static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
}
static __global__ void tanh_f32(const float *x, float *dst, int k) {
static __global__ void tanh_f32(const float * x, float * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
@ -741,7 +741,7 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
dst[i] = fmaxf(x[i], 0);
}
static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) {
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
@ -794,7 +794,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
}
}
static __global__ void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02) {
static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
@ -819,7 +819,7 @@ static __global__ void concat_f32(const float *x,const float *y, float *dst, c
}
}
static __global__ void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor) {
static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) {
int ne0 = ne00 * scale_factor;
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
@ -839,7 +839,7 @@ static __global__ void upscale_f32(const float *x, float *dst, const int ne00,
dst[offset_dst] = x[offset_src];
}
static __global__ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) {
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
@ -5415,7 +5415,7 @@ struct bin_bcast_cuda {
cne[3] = 1;
};
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
cnb[1] *= cne[1];
cnb[2] *= cne[2];
cnb[3] *= cne[3];
@ -6579,18 +6579,16 @@ struct scoped_spin_lock {
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
// #define DEBUG_CUDA_MALLOC
struct cuda_buffer {
struct ggml_cuda_buffer {
void * ptr = nullptr;
size_t size = 0;
};
static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
static ggml_cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};
static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
static void * ggml_cuda_pool_malloc_leg(int device, size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
#ifdef DEBUG_CUDA_MALLOC
int nnz = 0;
size_t max_size = 0;
@ -6598,7 +6596,7 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
size_t best_diff = 1ull << 36;
int ibest = -1;
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
cuda_buffer& b = g_cuda_buffer_pool[id][i];
ggml_cuda_buffer& b = g_cuda_buffer_pool[device][i];
if (b.ptr != nullptr) {
#ifdef DEBUG_CUDA_MALLOC
++nnz;
@ -6621,7 +6619,7 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
}
}
if (ibest >= 0) {
cuda_buffer& b = g_cuda_buffer_pool[id][ibest];
ggml_cuda_buffer& b = g_cuda_buffer_pool[device][ibest];
void * ptr = b.ptr;
*actual_size = b.size;
b.ptr = nullptr;
@ -6631,9 +6629,10 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
void * ptr;
size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
ggml_cuda_set_device(device);
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
*actual_size = look_ahead_size;
g_cuda_pool_size[id] += look_ahead_size;
g_cuda_pool_size[device] += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
@ -6641,11 +6640,11 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
return ptr;
}
static void ggml_cuda_pool_free_leg(int id, void * ptr, size_t size) {
static void ggml_cuda_pool_free_leg(int device, void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
cuda_buffer& b = g_cuda_buffer_pool[id][i];
ggml_cuda_buffer& b = g_cuda_buffer_pool[device][i];
if (b.ptr == nullptr) {
b.ptr = ptr;
b.size = size;
@ -6653,8 +6652,9 @@ static void ggml_cuda_pool_free_leg(int id, void * ptr, size_t size) {
}
}
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
ggml_cuda_set_device(device);
CUDA_CHECK(cudaFree(ptr));
g_cuda_pool_size[id] -= size;
g_cuda_pool_size[device] -= size;
}
#if !defined(GGML_USE_HIPBLAS)
@ -6663,40 +6663,38 @@ static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0};
static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 36; // 64 GB
static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
static void * ggml_cuda_pool_malloc_vmm(int device, size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
// round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
const size_t alignment = 128;
size = alignment * ((size + alignment - 1) / alignment);
size_t avail = g_cuda_pool_size[id] - g_cuda_pool_used[id];
size_t avail = g_cuda_pool_size[device] - g_cuda_pool_used[device];
if (size > avail) {
// round up to the next multiple of the granularity
size_t reserve_size = size - avail;
const size_t granularity = g_device_caps[id].vmm_granularity;
const size_t granularity = g_device_caps[device].vmm_granularity;
reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
GGML_ASSERT(g_cuda_pool_size[id] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
GGML_ASSERT(g_cuda_pool_size[device] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
// allocate more physical memory
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = id;
prop.location.id = device;
CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
// reserve virtual address space (if not already reserved)
if (g_cuda_pool_addr[id] == 0) {
CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[id], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
if (g_cuda_pool_addr[device] == 0) {
CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[device], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
}
// map at the end of the pool
CU_CHECK(cuMemMap(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, 0, handle, 0));
CU_CHECK(cuMemMap(g_cuda_pool_addr[device] + g_cuda_pool_size[device], reserve_size, 0, handle, 0));
// the memory allocation handle is no longer needed after mapping
CU_CHECK(cuMemRelease(handle));
@ -6704,23 +6702,23 @@ static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
// set access
CUmemAccessDesc access = {};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = id;
access.location.id = device;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, &access, 1));
CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[device] + g_cuda_pool_size[device], reserve_size, &access, 1));
// add to the pool
g_cuda_pool_size[id] += reserve_size;
g_cuda_pool_size[device] += reserve_size;
//printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
// id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024),
// (unsigned long long) (reserve_size/1024/1024));
}
GGML_ASSERT(g_cuda_pool_addr[id] != 0);
GGML_ASSERT(g_cuda_pool_addr[device] != 0);
void * ptr = (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]);
void * ptr = (void *) (g_cuda_pool_addr[device] + g_cuda_pool_used[device]);
*actual_size = size;
g_cuda_pool_used[id] += size;
g_cuda_pool_used[device] += size;
#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr);
@ -6729,34 +6727,32 @@ static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
return ptr;
}
static void ggml_cuda_pool_free_vmm(int id, void * ptr, size_t size) {
static void ggml_cuda_pool_free_vmm(int device, void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);
#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
#endif
g_cuda_pool_used[id] -= size;
g_cuda_pool_used[device] -= size;
// all deallocations must be in reverse order of the allocations
GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]));
GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[device] + g_cuda_pool_used[device]));
}
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
if (g_device_caps[id].vmm) {
return ggml_cuda_pool_malloc_vmm(size, actual_size);
static void * ggml_cuda_pool_malloc(int device, size_t size, size_t * actual_size) {
if (g_device_caps[device].vmm) {
return ggml_cuda_pool_malloc_vmm(device, size, actual_size);
} else {
return ggml_cuda_pool_malloc_leg(size, actual_size);
return ggml_cuda_pool_malloc_leg(device, size, actual_size);
}
}
static void ggml_cuda_pool_free(int id, void * ptr, size_t size) {
if (g_device_caps[id].vmm) {
ggml_cuda_pool_free_vmm(id, ptr, size);
static void ggml_cuda_pool_free(int device, void * ptr, size_t size) {
if (g_device_caps[device].vmm) {
ggml_cuda_pool_free_vmm(device, ptr, size);
} else {
ggml_cuda_pool_free_leg(id, ptr, size);
ggml_cuda_pool_free_leg(device, ptr, size);
}
}
#else
@ -6774,7 +6770,7 @@ struct cuda_pool_alloc {
T * alloc(size_t size) {
GGML_ASSERT(ptr == nullptr);
CUDA_CHECK(cudaGetDevice(&device));
ptr = (T *) ggml_cuda_pool_malloc(size * sizeof(T), &this->actual_size);
ptr = (T *) ggml_cuda_pool_malloc(device, size * sizeof(T), &this->actual_size);
return ptr;
}
@ -6986,7 +6982,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
static void ggml_cuda_op_get_rows(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
const float * src0_d, const float * src1_d, float * dst_d, cudaStream_t stream) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
@ -7030,7 +7026,7 @@ static void ggml_cuda_op_get_rows(
template<class op>
static void ggml_cuda_op_bin_bcast(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src1->type == GGML_TYPE_F32);
@ -7049,7 +7045,7 @@ static void ggml_cuda_op_bin_bcast(
static void ggml_cuda_op_repeat(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) {
const float * src0_d, const float * src1_d, float * dst_d, cudaStream_t main_stream) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
@ -7059,14 +7055,14 @@ static void ggml_cuda_op_repeat(
static void ggml_cuda_op_add(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
}
static void ggml_cuda_op_acc(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
@ -7085,21 +7081,21 @@ static void ggml_cuda_op_acc(
static void ggml_cuda_op_mul(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
}
static void ggml_cuda_op_div(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
}
static void ggml_cuda_op_gelu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7113,7 +7109,7 @@ static void ggml_cuda_op_gelu(
static void ggml_cuda_op_silu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7127,7 +7123,7 @@ static void ggml_cuda_op_silu(
static void ggml_cuda_op_gelu_quick(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7141,7 +7137,7 @@ static void ggml_cuda_op_gelu_quick(
static void ggml_cuda_op_tanh(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7155,7 +7151,7 @@ static void ggml_cuda_op_tanh(
static void ggml_cuda_op_relu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7169,7 +7165,7 @@ static void ggml_cuda_op_relu(
static void ggml_cuda_op_leaky_relu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7186,7 +7182,7 @@ static void ggml_cuda_op_leaky_relu(
static void ggml_cuda_op_sqr(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7200,7 +7196,7 @@ static void ggml_cuda_op_sqr(
static void ggml_cuda_op_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7220,7 +7216,7 @@ static void ggml_cuda_op_norm(
static void ggml_cuda_op_group_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7236,7 +7232,7 @@ static void ggml_cuda_op_group_norm(
static void ggml_cuda_op_concat(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
@ -7252,7 +7248,7 @@ static void ggml_cuda_op_concat(
static void ggml_cuda_op_upscale(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
@ -7269,7 +7265,7 @@ static void ggml_cuda_op_upscale(
static void ggml_cuda_op_pad(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
@ -7286,7 +7282,7 @@ static void ggml_cuda_op_pad(
static void ggml_cuda_op_rms_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7307,7 +7303,7 @@ static void ggml_cuda_op_rms_norm(
static void ggml_cuda_op_mul_mat_q(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, const cudaStream_t & stream) {
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t ne00 = src0->ne[0];
@ -7430,7 +7426,7 @@ static int64_t get_row_rounding(ggml_type type) {
static void ggml_cuda_op_mul_mat_vec_q(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, const cudaStream_t & stream) {
const int64_t src1_padded_row_size, cudaStream_t stream) {
GGML_ASSERT(ggml_nrows(src1) == 1);
@ -7483,7 +7479,7 @@ static void ggml_cuda_op_mul_mat_vec_q(
static void ggml_cuda_op_dequantize_mul_mat_vec(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, const cudaStream_t & stream) {
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t ne00 = src0->ne[0];
const int64_t row_diff = row_high - row_low;
@ -7557,7 +7553,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
static void ggml_cuda_op_mul_mat_cublas(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, const cudaStream_t & stream) {
const int64_t src1_padded_row_size, cudaStream_t stream) {
GGML_ASSERT(src0_dd_i != nullptr);
GGML_ASSERT(src1_ddf_i != nullptr);
@ -7648,7 +7644,7 @@ static void ggml_cuda_op_mul_mat_cublas(
static void ggml_cuda_op_rope(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@ -7728,7 +7724,7 @@ static void ggml_cuda_op_rope(
static void ggml_cuda_op_alibi(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7759,7 +7755,7 @@ static void ggml_cuda_op_alibi(
static void ggml_cuda_op_im2col(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
@ -7794,7 +7790,7 @@ static void ggml_cuda_op_im2col(
static void ggml_cuda_op_sum_rows(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7811,7 +7807,7 @@ static void ggml_cuda_op_sum_rows(
static void ggml_cuda_op_argsort(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
@ -7830,7 +7826,7 @@ static void ggml_cuda_op_argsort(
static void ggml_cuda_op_diag_mask_inf(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7850,7 +7846,7 @@ static void ggml_cuda_op_diag_mask_inf(
static void ggml_cuda_op_soft_max(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7871,7 +7867,7 @@ static void ggml_cuda_op_soft_max(
static void ggml_cuda_op_scale(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -7889,7 +7885,7 @@ static void ggml_cuda_op_scale(
static void ggml_cuda_op_clamp(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -8021,7 +8017,6 @@ static void ggml_cuda_op_mul_mat(
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t nrows0 = ggml_nrows(src0);
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
@ -8921,7 +8916,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
std::vector<char> ids_host(ggml_nbytes(ids));
const cudaStream_t stream = g_cudaStreams[g_main_device][0];
cudaStream_t stream = g_cudaStreams[g_main_device][0];
if (ids->backend == GGML_BACKEND_GPU) {
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];

3
ggml.c
View file

@ -4041,7 +4041,6 @@ static struct ggml_tensor * ggml_group_norm_impl(
result->op = GGML_OP_GROUP_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = NULL; // TODO: maybe store epsilon here?
return result;
}
@ -5541,7 +5540,6 @@ static struct ggml_tensor * ggml_upscale_impl(
result->op_params[0] = scale_factor;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = NULL;
return result;
}
@ -5846,7 +5844,6 @@ struct ggml_tensor * ggml_get_rel_pos(
result->op = GGML_OP_GET_REL_POS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = NULL;
return result;
}