style fixes
This commit is contained in:
parent
561f1f9500
commit
0dcc1a77d7
2 changed files with 83 additions and 91 deletions
171
ggml-cuda.cu
171
ggml-cuda.cu
|
@ -163,7 +163,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
||||||
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
||||||
#if __has_builtin(__builtin_elementwise_sub_sat)
|
#if __has_builtin(__builtin_elementwise_sub_sat)
|
||||||
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
||||||
return reinterpret_cast<const int&>(c);
|
return reinterpret_cast<const int &>(c);
|
||||||
#else
|
#else
|
||||||
int8x4_t c;
|
int8x4_t c;
|
||||||
int16_t tmp;
|
int16_t tmp;
|
||||||
|
@ -174,7 +174,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
||||||
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
|
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
|
||||||
c[i] = tmp;
|
c[i] = tmp;
|
||||||
}
|
}
|
||||||
return reinterpret_cast<int&>(c);
|
return reinterpret_cast<int &>(c);
|
||||||
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,7 +257,6 @@ static void ggml_cuda_error(const char * stmt, const char * func, const char * f
|
||||||
|
|
||||||
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
|
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
|
||||||
|
|
||||||
|
|
||||||
#if !defined(GGML_USE_HIPBLAS)
|
#if !defined(GGML_USE_HIPBLAS)
|
||||||
static const char * cu_get_error_str(CUresult err) {
|
static const char * cu_get_error_str(CUresult err) {
|
||||||
const char * err_str;
|
const char * err_str;
|
||||||
|
@ -321,10 +320,10 @@ typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * s
|
||||||
typedef void (*ggml_cuda_op_mul_mat_t)(
|
typedef void (*ggml_cuda_op_mul_mat_t)(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||||
const int64_t src1_padded_row_size, const cudaStream_t & stream);
|
const int64_t src1_padded_row_size, cudaStream_t stream);
|
||||||
typedef void (*ggml_cuda_op_flatten_t)(
|
typedef void (*ggml_cuda_op_flatten_t)(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream);
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream);
|
||||||
|
|
||||||
// QK = number of values after dequantization
|
// QK = number of values after dequantization
|
||||||
// QR = QK / number of values before dequantization
|
// QR = QK / number of values before dequantization
|
||||||
|
@ -594,6 +593,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
||||||
return b;
|
return b;
|
||||||
|
GGML_UNUSED(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_add(const float a, const float b) {
|
static __device__ __forceinline__ float op_add(const float a, const float b) {
|
||||||
|
@ -715,7 +715,7 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
||||||
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
|
static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
|
||||||
const float GELU_QUICK_COEF = -1.702f;
|
const float GELU_QUICK_COEF = -1.702f;
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
|
@ -724,7 +724,7 @@ static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
|
||||||
dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
|
dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void tanh_f32(const float *x, float *dst, int k) {
|
static __global__ void tanh_f32(const float * x, float * dst, int k) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
return;
|
return;
|
||||||
|
@ -741,7 +741,7 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
||||||
dst[i] = fmaxf(x[i], 0);
|
dst[i] = fmaxf(x[i], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) {
|
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
return;
|
return;
|
||||||
|
@ -794,7 +794,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02) {
|
static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) {
|
||||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
if (nidx >= ne0) {
|
if (nidx >= ne0) {
|
||||||
return;
|
return;
|
||||||
|
@ -819,7 +819,7 @@ static __global__ void concat_f32(const float *x,const float *y, float *dst, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor) {
|
static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) {
|
||||||
int ne0 = ne00 * scale_factor;
|
int ne0 = ne00 * scale_factor;
|
||||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
if (nidx >= ne0) {
|
if (nidx >= ne0) {
|
||||||
|
@ -839,7 +839,7 @@ static __global__ void upscale_f32(const float *x, float *dst, const int ne00,
|
||||||
dst[offset_dst] = x[offset_src];
|
dst[offset_dst] = x[offset_src];
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) {
|
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) {
|
||||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
if (nidx >= ne0) {
|
if (nidx >= ne0) {
|
||||||
return;
|
return;
|
||||||
|
@ -5415,7 +5415,7 @@ struct bin_bcast_cuda {
|
||||||
cne[3] = 1;
|
cne[3] = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
|
auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
|
||||||
cnb[1] *= cne[1];
|
cnb[1] *= cne[1];
|
||||||
cnb[2] *= cne[2];
|
cnb[2] *= cne[2];
|
||||||
cnb[3] *= cne[3];
|
cnb[3] *= cne[3];
|
||||||
|
@ -6579,18 +6579,16 @@ struct scoped_spin_lock {
|
||||||
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
||||||
|
|
||||||
// #define DEBUG_CUDA_MALLOC
|
// #define DEBUG_CUDA_MALLOC
|
||||||
struct cuda_buffer {
|
struct ggml_cuda_buffer {
|
||||||
void * ptr = nullptr;
|
void * ptr = nullptr;
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
|
static ggml_cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
|
||||||
static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};
|
static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};
|
||||||
|
|
||||||
static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
|
static void * ggml_cuda_pool_malloc_leg(int device, size_t size, size_t * actual_size) {
|
||||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||||
int id;
|
|
||||||
CUDA_CHECK(cudaGetDevice(&id));
|
|
||||||
#ifdef DEBUG_CUDA_MALLOC
|
#ifdef DEBUG_CUDA_MALLOC
|
||||||
int nnz = 0;
|
int nnz = 0;
|
||||||
size_t max_size = 0;
|
size_t max_size = 0;
|
||||||
|
@ -6598,7 +6596,7 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
|
||||||
size_t best_diff = 1ull << 36;
|
size_t best_diff = 1ull << 36;
|
||||||
int ibest = -1;
|
int ibest = -1;
|
||||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||||
cuda_buffer& b = g_cuda_buffer_pool[id][i];
|
ggml_cuda_buffer& b = g_cuda_buffer_pool[device][i];
|
||||||
if (b.ptr != nullptr) {
|
if (b.ptr != nullptr) {
|
||||||
#ifdef DEBUG_CUDA_MALLOC
|
#ifdef DEBUG_CUDA_MALLOC
|
||||||
++nnz;
|
++nnz;
|
||||||
|
@ -6621,7 +6619,7 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (ibest >= 0) {
|
if (ibest >= 0) {
|
||||||
cuda_buffer& b = g_cuda_buffer_pool[id][ibest];
|
ggml_cuda_buffer& b = g_cuda_buffer_pool[device][ibest];
|
||||||
void * ptr = b.ptr;
|
void * ptr = b.ptr;
|
||||||
*actual_size = b.size;
|
*actual_size = b.size;
|
||||||
b.ptr = nullptr;
|
b.ptr = nullptr;
|
||||||
|
@ -6631,9 +6629,10 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
|
||||||
void * ptr;
|
void * ptr;
|
||||||
size_t look_ahead_size = (size_t) (1.05 * size);
|
size_t look_ahead_size = (size_t) (1.05 * size);
|
||||||
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
|
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
|
||||||
|
ggml_cuda_set_device(device);
|
||||||
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
|
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
|
||||||
*actual_size = look_ahead_size;
|
*actual_size = look_ahead_size;
|
||||||
g_cuda_pool_size[id] += look_ahead_size;
|
g_cuda_pool_size[device] += look_ahead_size;
|
||||||
#ifdef DEBUG_CUDA_MALLOC
|
#ifdef DEBUG_CUDA_MALLOC
|
||||||
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
|
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
|
||||||
(uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
|
(uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
|
||||||
|
@ -6641,11 +6640,11 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_pool_free_leg(int id, void * ptr, size_t size) {
|
static void ggml_cuda_pool_free_leg(int device, void * ptr, size_t size) {
|
||||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||||
|
|
||||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||||
cuda_buffer& b = g_cuda_buffer_pool[id][i];
|
ggml_cuda_buffer& b = g_cuda_buffer_pool[device][i];
|
||||||
if (b.ptr == nullptr) {
|
if (b.ptr == nullptr) {
|
||||||
b.ptr = ptr;
|
b.ptr = ptr;
|
||||||
b.size = size;
|
b.size = size;
|
||||||
|
@ -6653,8 +6652,9 @@ static void ggml_cuda_pool_free_leg(int id, void * ptr, size_t size) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
|
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
|
||||||
|
ggml_cuda_set_device(device);
|
||||||
CUDA_CHECK(cudaFree(ptr));
|
CUDA_CHECK(cudaFree(ptr));
|
||||||
g_cuda_pool_size[id] -= size;
|
g_cuda_pool_size[device] -= size;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(GGML_USE_HIPBLAS)
|
#if !defined(GGML_USE_HIPBLAS)
|
||||||
|
@ -6663,40 +6663,38 @@ static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0};
|
||||||
static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};
|
static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};
|
||||||
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 36; // 64 GB
|
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 36; // 64 GB
|
||||||
|
|
||||||
static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
|
static void * ggml_cuda_pool_malloc_vmm(int device, size_t size, size_t * actual_size) {
|
||||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||||
int id;
|
|
||||||
CUDA_CHECK(cudaGetDevice(&id));
|
|
||||||
|
|
||||||
// round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
|
// round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
|
||||||
const size_t alignment = 128;
|
const size_t alignment = 128;
|
||||||
size = alignment * ((size + alignment - 1) / alignment);
|
size = alignment * ((size + alignment - 1) / alignment);
|
||||||
|
|
||||||
size_t avail = g_cuda_pool_size[id] - g_cuda_pool_used[id];
|
size_t avail = g_cuda_pool_size[device] - g_cuda_pool_used[device];
|
||||||
|
|
||||||
if (size > avail) {
|
if (size > avail) {
|
||||||
// round up to the next multiple of the granularity
|
// round up to the next multiple of the granularity
|
||||||
size_t reserve_size = size - avail;
|
size_t reserve_size = size - avail;
|
||||||
const size_t granularity = g_device_caps[id].vmm_granularity;
|
const size_t granularity = g_device_caps[device].vmm_granularity;
|
||||||
reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
|
reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
|
||||||
|
|
||||||
GGML_ASSERT(g_cuda_pool_size[id] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
|
GGML_ASSERT(g_cuda_pool_size[device] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
|
||||||
|
|
||||||
// allocate more physical memory
|
// allocate more physical memory
|
||||||
CUmemAllocationProp prop = {};
|
CUmemAllocationProp prop = {};
|
||||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||||
prop.location.id = id;
|
prop.location.id = device;
|
||||||
CUmemGenericAllocationHandle handle;
|
CUmemGenericAllocationHandle handle;
|
||||||
CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
|
CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
|
||||||
|
|
||||||
// reserve virtual address space (if not already reserved)
|
// reserve virtual address space (if not already reserved)
|
||||||
if (g_cuda_pool_addr[id] == 0) {
|
if (g_cuda_pool_addr[device] == 0) {
|
||||||
CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[id], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
|
CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[device], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
// map at the end of the pool
|
// map at the end of the pool
|
||||||
CU_CHECK(cuMemMap(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, 0, handle, 0));
|
CU_CHECK(cuMemMap(g_cuda_pool_addr[device] + g_cuda_pool_size[device], reserve_size, 0, handle, 0));
|
||||||
|
|
||||||
// the memory allocation handle is no longer needed after mapping
|
// the memory allocation handle is no longer needed after mapping
|
||||||
CU_CHECK(cuMemRelease(handle));
|
CU_CHECK(cuMemRelease(handle));
|
||||||
|
@ -6704,23 +6702,23 @@ static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
|
||||||
// set access
|
// set access
|
||||||
CUmemAccessDesc access = {};
|
CUmemAccessDesc access = {};
|
||||||
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||||
access.location.id = id;
|
access.location.id = device;
|
||||||
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||||
CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, &access, 1));
|
CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[device] + g_cuda_pool_size[device], reserve_size, &access, 1));
|
||||||
|
|
||||||
// add to the pool
|
// add to the pool
|
||||||
g_cuda_pool_size[id] += reserve_size;
|
g_cuda_pool_size[device] += reserve_size;
|
||||||
|
|
||||||
//printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
|
//printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
|
||||||
// id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024),
|
// id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024),
|
||||||
// (unsigned long long) (reserve_size/1024/1024));
|
// (unsigned long long) (reserve_size/1024/1024));
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(g_cuda_pool_addr[id] != 0);
|
GGML_ASSERT(g_cuda_pool_addr[device] != 0);
|
||||||
|
|
||||||
void * ptr = (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]);
|
void * ptr = (void *) (g_cuda_pool_addr[device] + g_cuda_pool_used[device]);
|
||||||
*actual_size = size;
|
*actual_size = size;
|
||||||
g_cuda_pool_used[id] += size;
|
g_cuda_pool_used[device] += size;
|
||||||
|
|
||||||
#ifdef DEBUG_CUDA_MALLOC
|
#ifdef DEBUG_CUDA_MALLOC
|
||||||
printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr);
|
printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr);
|
||||||
|
@ -6729,34 +6727,32 @@ static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_pool_free_vmm(int id, void * ptr, size_t size) {
|
static void ggml_cuda_pool_free_vmm(int device, void * ptr, size_t size) {
|
||||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||||
|
|
||||||
#ifdef DEBUG_CUDA_MALLOC
|
#ifdef DEBUG_CUDA_MALLOC
|
||||||
printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
|
printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
g_cuda_pool_used[id] -= size;
|
g_cuda_pool_used[device] -= size;
|
||||||
|
|
||||||
// all deallocations must be in reverse order of the allocations
|
// all deallocations must be in reverse order of the allocations
|
||||||
GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]));
|
GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[device] + g_cuda_pool_used[device]));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
static void * ggml_cuda_pool_malloc(int device, size_t size, size_t * actual_size) {
|
||||||
int id;
|
if (g_device_caps[device].vmm) {
|
||||||
CUDA_CHECK(cudaGetDevice(&id));
|
return ggml_cuda_pool_malloc_vmm(device, size, actual_size);
|
||||||
if (g_device_caps[id].vmm) {
|
|
||||||
return ggml_cuda_pool_malloc_vmm(size, actual_size);
|
|
||||||
} else {
|
} else {
|
||||||
return ggml_cuda_pool_malloc_leg(size, actual_size);
|
return ggml_cuda_pool_malloc_leg(device, size, actual_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_pool_free(int id, void * ptr, size_t size) {
|
static void ggml_cuda_pool_free(int device, void * ptr, size_t size) {
|
||||||
if (g_device_caps[id].vmm) {
|
if (g_device_caps[device].vmm) {
|
||||||
ggml_cuda_pool_free_vmm(id, ptr, size);
|
ggml_cuda_pool_free_vmm(device, ptr, size);
|
||||||
} else {
|
} else {
|
||||||
ggml_cuda_pool_free_leg(id, ptr, size);
|
ggml_cuda_pool_free_leg(device, ptr, size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
@ -6774,7 +6770,7 @@ struct cuda_pool_alloc {
|
||||||
T * alloc(size_t size) {
|
T * alloc(size_t size) {
|
||||||
GGML_ASSERT(ptr == nullptr);
|
GGML_ASSERT(ptr == nullptr);
|
||||||
CUDA_CHECK(cudaGetDevice(&device));
|
CUDA_CHECK(cudaGetDevice(&device));
|
||||||
ptr = (T *) ggml_cuda_pool_malloc(size * sizeof(T), &this->actual_size);
|
ptr = (T *) ggml_cuda_pool_malloc(device, size * sizeof(T), &this->actual_size);
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6986,7 +6982,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
||||||
|
|
||||||
static void ggml_cuda_op_get_rows(
|
static void ggml_cuda_op_get_rows(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
|
const float * src0_d, const float * src1_d, float * dst_d, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7030,7 +7026,7 @@ static void ggml_cuda_op_get_rows(
|
||||||
template<class op>
|
template<class op>
|
||||||
static void ggml_cuda_op_bin_bcast(
|
static void ggml_cuda_op_bin_bcast(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
@ -7049,7 +7045,7 @@ static void ggml_cuda_op_bin_bcast(
|
||||||
|
|
||||||
static void ggml_cuda_op_repeat(
|
static void ggml_cuda_op_repeat(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) {
|
const float * src0_d, const float * src1_d, float * dst_d, cudaStream_t main_stream) {
|
||||||
|
|
||||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
|
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
|
||||||
|
|
||||||
|
@ -7059,14 +7055,14 @@ static void ggml_cuda_op_repeat(
|
||||||
|
|
||||||
static void ggml_cuda_op_add(
|
static void ggml_cuda_op_add(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_op_acc(
|
static void ggml_cuda_op_acc(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
@ -7085,21 +7081,21 @@ static void ggml_cuda_op_acc(
|
||||||
|
|
||||||
static void ggml_cuda_op_mul(
|
static void ggml_cuda_op_mul(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_op_div(
|
static void ggml_cuda_op_div(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_op_gelu(
|
static void ggml_cuda_op_gelu(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7113,7 +7109,7 @@ static void ggml_cuda_op_gelu(
|
||||||
|
|
||||||
static void ggml_cuda_op_silu(
|
static void ggml_cuda_op_silu(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7127,7 +7123,7 @@ static void ggml_cuda_op_silu(
|
||||||
|
|
||||||
static void ggml_cuda_op_gelu_quick(
|
static void ggml_cuda_op_gelu_quick(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7141,7 +7137,7 @@ static void ggml_cuda_op_gelu_quick(
|
||||||
|
|
||||||
static void ggml_cuda_op_tanh(
|
static void ggml_cuda_op_tanh(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7155,7 +7151,7 @@ static void ggml_cuda_op_tanh(
|
||||||
|
|
||||||
static void ggml_cuda_op_relu(
|
static void ggml_cuda_op_relu(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7169,7 +7165,7 @@ static void ggml_cuda_op_relu(
|
||||||
|
|
||||||
static void ggml_cuda_op_leaky_relu(
|
static void ggml_cuda_op_leaky_relu(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7186,7 +7182,7 @@ static void ggml_cuda_op_leaky_relu(
|
||||||
|
|
||||||
static void ggml_cuda_op_sqr(
|
static void ggml_cuda_op_sqr(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7200,7 +7196,7 @@ static void ggml_cuda_op_sqr(
|
||||||
|
|
||||||
static void ggml_cuda_op_norm(
|
static void ggml_cuda_op_norm(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7220,7 +7216,7 @@ static void ggml_cuda_op_norm(
|
||||||
|
|
||||||
static void ggml_cuda_op_group_norm(
|
static void ggml_cuda_op_group_norm(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7236,7 +7232,7 @@ static void ggml_cuda_op_group_norm(
|
||||||
|
|
||||||
static void ggml_cuda_op_concat(
|
static void ggml_cuda_op_concat(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
@ -7252,7 +7248,7 @@ static void ggml_cuda_op_concat(
|
||||||
|
|
||||||
static void ggml_cuda_op_upscale(
|
static void ggml_cuda_op_upscale(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7269,7 +7265,7 @@ static void ggml_cuda_op_upscale(
|
||||||
|
|
||||||
static void ggml_cuda_op_pad(
|
static void ggml_cuda_op_pad(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7286,7 +7282,7 @@ static void ggml_cuda_op_pad(
|
||||||
|
|
||||||
static void ggml_cuda_op_rms_norm(
|
static void ggml_cuda_op_rms_norm(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7307,7 +7303,7 @@ static void ggml_cuda_op_rms_norm(
|
||||||
static void ggml_cuda_op_mul_mat_q(
|
static void ggml_cuda_op_mul_mat_q(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||||
const int64_t src1_padded_row_size, const cudaStream_t & stream) {
|
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
|
||||||
|
@ -7430,7 +7426,7 @@ static int64_t get_row_rounding(ggml_type type) {
|
||||||
static void ggml_cuda_op_mul_mat_vec_q(
|
static void ggml_cuda_op_mul_mat_vec_q(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||||
const int64_t src1_padded_row_size, const cudaStream_t & stream) {
|
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ggml_nrows(src1) == 1);
|
GGML_ASSERT(ggml_nrows(src1) == 1);
|
||||||
|
|
||||||
|
@ -7483,7 +7479,7 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
||||||
static void ggml_cuda_op_dequantize_mul_mat_vec(
|
static void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||||
const int64_t src1_padded_row_size, const cudaStream_t & stream) {
|
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t row_diff = row_high - row_low;
|
const int64_t row_diff = row_high - row_low;
|
||||||
|
@ -7557,7 +7553,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
static void ggml_cuda_op_mul_mat_cublas(
|
static void ggml_cuda_op_mul_mat_cublas(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||||
const int64_t src1_padded_row_size, const cudaStream_t & stream) {
|
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0_dd_i != nullptr);
|
GGML_ASSERT(src0_dd_i != nullptr);
|
||||||
GGML_ASSERT(src1_ddf_i != nullptr);
|
GGML_ASSERT(src1_ddf_i != nullptr);
|
||||||
|
@ -7648,7 +7644,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||||
|
|
||||||
static void ggml_cuda_op_rope(
|
static void ggml_cuda_op_rope(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||||
|
@ -7728,7 +7724,7 @@ static void ggml_cuda_op_rope(
|
||||||
|
|
||||||
static void ggml_cuda_op_alibi(
|
static void ggml_cuda_op_alibi(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7759,7 +7755,7 @@ static void ggml_cuda_op_alibi(
|
||||||
|
|
||||||
static void ggml_cuda_op_im2col(
|
static void ggml_cuda_op_im2col(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
@ -7794,7 +7790,7 @@ static void ggml_cuda_op_im2col(
|
||||||
|
|
||||||
static void ggml_cuda_op_sum_rows(
|
static void ggml_cuda_op_sum_rows(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7811,7 +7807,7 @@ static void ggml_cuda_op_sum_rows(
|
||||||
|
|
||||||
static void ggml_cuda_op_argsort(
|
static void ggml_cuda_op_argsort(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
||||||
|
@ -7830,7 +7826,7 @@ static void ggml_cuda_op_argsort(
|
||||||
|
|
||||||
static void ggml_cuda_op_diag_mask_inf(
|
static void ggml_cuda_op_diag_mask_inf(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7850,7 +7846,7 @@ static void ggml_cuda_op_diag_mask_inf(
|
||||||
|
|
||||||
static void ggml_cuda_op_soft_max(
|
static void ggml_cuda_op_soft_max(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7871,7 +7867,7 @@ static void ggml_cuda_op_soft_max(
|
||||||
|
|
||||||
static void ggml_cuda_op_scale(
|
static void ggml_cuda_op_scale(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -7889,7 +7885,7 @@ static void ggml_cuda_op_scale(
|
||||||
|
|
||||||
static void ggml_cuda_op_clamp(
|
static void ggml_cuda_op_clamp(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
@ -8021,7 +8017,6 @@ static void ggml_cuda_op_mul_mat(
|
||||||
const int64_t ne01 = src0->ne[1];
|
const int64_t ne01 = src0->ne[1];
|
||||||
const int64_t ne02 = src0->ne[2];
|
const int64_t ne02 = src0->ne[2];
|
||||||
const int64_t ne03 = src0->ne[3];
|
const int64_t ne03 = src0->ne[3];
|
||||||
const int64_t nrows0 = ggml_nrows(src0);
|
|
||||||
|
|
||||||
const int64_t ne10 = src1->ne[0];
|
const int64_t ne10 = src1->ne[0];
|
||||||
const int64_t ne11 = src1->ne[1];
|
const int64_t ne11 = src1->ne[1];
|
||||||
|
@ -8921,7 +8916,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
||||||
|
|
||||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||||
|
|
||||||
const cudaStream_t stream = g_cudaStreams[g_main_device][0];
|
cudaStream_t stream = g_cudaStreams[g_main_device][0];
|
||||||
|
|
||||||
if (ids->backend == GGML_BACKEND_GPU) {
|
if (ids->backend == GGML_BACKEND_GPU) {
|
||||||
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
||||||
|
|
3
ggml.c
3
ggml.c
|
@ -4041,7 +4041,6 @@ static struct ggml_tensor * ggml_group_norm_impl(
|
||||||
result->op = GGML_OP_GROUP_NORM;
|
result->op = GGML_OP_GROUP_NORM;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = a;
|
result->src[0] = a;
|
||||||
result->src[1] = NULL; // TODO: maybe store epsilon here?
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -5541,7 +5540,6 @@ static struct ggml_tensor * ggml_upscale_impl(
|
||||||
result->op_params[0] = scale_factor;
|
result->op_params[0] = scale_factor;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = a;
|
result->src[0] = a;
|
||||||
result->src[1] = NULL;
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -5846,7 +5844,6 @@ struct ggml_tensor * ggml_get_rel_pos(
|
||||||
result->op = GGML_OP_GET_REL_POS;
|
result->op = GGML_OP_GET_REL_POS;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = a;
|
result->src[0] = a;
|
||||||
result->src[1] = NULL;
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue