sync : ggml (ggml-backend) (#3548)
* sync : ggml (ggml-backend) ggml-ci * zig : add ggml-backend to the build
This commit is contained in:
parent
eee42c670e
commit
db3abcc114
15 changed files with 1285 additions and 268 deletions
535
ggml-cuda.cu
535
ggml-cuda.cu
|
@ -62,6 +62,7 @@
|
|||
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
||||
#define cudaMemcpyKind hipMemcpyKind
|
||||
#define cudaMemset hipMemset
|
||||
#define cudaMemsetAsync hipMemsetAsync
|
||||
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
||||
#define cudaSetDevice hipSetDevice
|
||||
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
||||
|
@ -419,6 +420,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|||
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
|
||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_GET_ROWS_BLOCK_SIZE 256
|
||||
|
||||
// dmmv = dequantize_mul_mat_vec
|
||||
#ifndef GGML_CUDA_DMMV_X
|
||||
|
@ -1574,6 +1576,34 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
|||
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
||||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
|
||||
const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
|
||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
|
||||
if (col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int r = y[row];
|
||||
|
||||
// copy x[r*ncols + col] to dst[row*ncols + col]
|
||||
const int xi = r*ncols + col;
|
||||
const int di = row*ncols + col;
|
||||
|
||||
const int ib = xi/qk; // block index
|
||||
const int iqs = (xi%qk)/qr; // quant index
|
||||
const int iybs = di - di%qk; // y block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
dfloat2 v;
|
||||
dequantize_kernel(x, ib, iqs, v);
|
||||
|
||||
dst[iybs + iqs + 0] = v.x;
|
||||
dst[iybs + iqs + y_offset] = v.y;
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
|
||||
|
@ -4555,6 +4585,15 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
|
|||
dst[i] = scale * x[i];
|
||||
}
|
||||
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dq>
|
||||
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(block_num_x, nrows, 1);
|
||||
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
|
||||
const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
|
||||
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
||||
|
@ -5703,7 +5742,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
|||
} else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) {
|
||||
GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
|
||||
kind = cudaMemcpyDeviceToDevice;
|
||||
struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
|
||||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
src_ptr = (char *) extra->data_device[id];
|
||||
|
@ -5739,6 +5778,107 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_op_repeat(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
|
||||
// guaranteed to be an integer due to the check in ggml_can_repeat
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const size_t nb0 = dst->nb[0];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
const size_t nb2 = dst->nb[2];
|
||||
const size_t nb3 = dst->nb[3];
|
||||
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
const size_t nb02 = src0->nb[2];
|
||||
const size_t nb03 = src0->nb[3];
|
||||
|
||||
const int nr0 = (int)(ne0/ne00);
|
||||
const int nr1 = (int)(ne1/ne01);
|
||||
const int nr2 = (int)(ne2/ne02);
|
||||
const int nr3 = (int)(ne3/ne03);
|
||||
|
||||
// TODO: support for transposed / permuted tensors
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
|
||||
// TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors
|
||||
for (int i3 = 0; i3 < nr3; i3++) {
|
||||
for (int k3 = 0; k3 < ne03; k3++) {
|
||||
for (int i2 = 0; i2 < nr2; i2++) {
|
||||
for (int k2 = 0; k2 < ne02; k2++) {
|
||||
for (int i1 = 0; i1 < nr1; i1++) {
|
||||
for (int k1 = 0; k1 < ne01; k1++) {
|
||||
for (int i0 = 0; i0 < nr0; i0++) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(
|
||||
(char *) dst_d + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
|
||||
(const char *) src0_d + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01,
|
||||
ne00*nb0, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(void) src1;
|
||||
(void) src1_d;
|
||||
}
|
||||
|
||||
static void ggml_cuda_op_get_rows(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
const int ncols = src0->ne[0];
|
||||
const int nrows = ggml_nelements(src1);
|
||||
|
||||
const int32_t * src1_i32 = (const int32_t *) src1_d;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
|
||||
break;
|
||||
default:
|
||||
// TODO: k-quants
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_add(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
@ -6343,7 +6483,14 @@ inline void ggml_cuda_op_scale(
|
|||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float scale = ((float *) src1->data)[0];
|
||||
float scale;
|
||||
// HACK: support for ggml backend interface
|
||||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
scale = ((float *) src1->data)[0];
|
||||
} else {
|
||||
// TODO: pass pointer to kernel instead of copying to host
|
||||
CUDA_CHECK(cudaMemcpy(&scale, src1->data, sizeof(float), cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
@ -6362,9 +6509,9 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
|
|||
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
|
||||
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
||||
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
|
||||
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
|
||||
|
@ -6505,9 +6652,9 @@ static void ggml_cuda_op_mul_mat(
|
|||
const size_t q8_1_ts = sizeof(block_q8_1);
|
||||
const size_t q8_1_bs = QK8_1;
|
||||
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
||||
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
|
||||
const bool src0_is_contiguous = ggml_is_contiguous(src0);
|
||||
|
@ -6585,7 +6732,7 @@ static void ggml_cuda_op_mul_mat(
|
|||
if (convert_src1_to_q8_1) {
|
||||
src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
|
||||
|
||||
if (split && src1_on_device && src1_is_contiguous) {
|
||||
if (src1_on_device && src1_is_contiguous) {
|
||||
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
@ -6667,7 +6814,7 @@ static void ggml_cuda_op_mul_mat(
|
|||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
if (convert_src1_to_q8_1 && src1->backend == GGML_BACKEND_CPU) {
|
||||
if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_CPU || !src1_is_contiguous)) {
|
||||
quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
@ -6758,6 +6905,14 @@ static void ggml_cuda_op_mul_mat(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_repeat);
|
||||
}
|
||||
|
||||
static void ggml_cuda_get_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_get_rows);
|
||||
}
|
||||
|
||||
static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
|
||||
}
|
||||
|
@ -6812,13 +6967,13 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
|
|||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
void * src0_ddq = src0_extra->data_device[g_main_device];
|
||||
|
||||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
|
||||
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
|
||||
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
|
||||
|
@ -6843,13 +6998,13 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
|
|||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
void * src0_ddq = src0_extra->data_device[g_main_device];
|
||||
|
||||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
|
||||
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
|
||||
const int64_t row_stride_x = nb01 / sizeof(half);
|
||||
|
@ -6870,11 +7025,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
}
|
||||
}
|
||||
|
||||
if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
|
||||
} else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
|
||||
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
||||
}else if (src0->type == GGML_TYPE_F32) {
|
||||
} else if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
|
||||
|
@ -6935,8 +7090,8 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
|||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
const ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
|
||||
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
|
||||
char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
|
||||
|
@ -6991,8 +7146,8 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
|
|||
|
||||
const size_t nb1 = tensor->nb[1];
|
||||
|
||||
ggml_backend backend = tensor->backend;
|
||||
struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
|
||||
ggml_backend_type backend = tensor->backend;
|
||||
ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
|
||||
memset(extra, 0, sizeof(*extra));
|
||||
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
|
@ -7046,7 +7201,6 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
|
|||
CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
|
||||
}
|
||||
|
||||
|
||||
CUDA_CHECK(cudaMemcpy(buf, buf_host, original_size, cudaMemcpyHostToDevice));
|
||||
|
||||
extra->data_device[id] = buf;
|
||||
|
@ -7085,17 +7239,17 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
|
|||
delete extra;
|
||||
}
|
||||
|
||||
static struct ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr;
|
||||
static ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr;
|
||||
static size_t g_temp_tensor_extra_index = 0;
|
||||
|
||||
static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
||||
static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
||||
if (g_temp_tensor_extras == nullptr) {
|
||||
g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES];
|
||||
}
|
||||
|
||||
size_t alloc_index = g_temp_tensor_extra_index;
|
||||
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_MAX_NODES;
|
||||
struct ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
|
||||
ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
|
||||
memset(extra, 0, sizeof(*extra));
|
||||
|
||||
return extra;
|
||||
|
@ -7123,7 +7277,7 @@ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scra
|
|||
return;
|
||||
}
|
||||
|
||||
struct ggml_tensor_extra_gpu * extra;
|
||||
ggml_tensor_extra_gpu * extra;
|
||||
|
||||
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
|
||||
tensor->op == GGML_OP_VIEW ||
|
||||
|
@ -7132,7 +7286,7 @@ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scra
|
|||
|
||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
|
||||
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
|
||||
size_t offset = 0;
|
||||
if (tensor->op == GGML_OP_VIEW) {
|
||||
|
@ -7141,7 +7295,7 @@ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scra
|
|||
extra = ggml_cuda_alloc_temp_tensor_extra();
|
||||
extra->data_device[g_main_device] = src0_ddc + offset;
|
||||
} else if (tensor->op == GGML_OP_CPY) {
|
||||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra;
|
||||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra;
|
||||
void * src1_ddv = src1_extra->data_device[g_main_device];
|
||||
extra = ggml_cuda_alloc_temp_tensor_extra();
|
||||
extra->data_device[g_main_device] = src1_ddv;
|
||||
|
@ -7183,13 +7337,13 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
|
|||
CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
|
||||
}
|
||||
|
||||
struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
|
||||
ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
|
||||
|
||||
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
|
||||
tensor->op == GGML_OP_VIEW;
|
||||
|
||||
if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
|
||||
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
|
||||
size_t view_offset = 0;
|
||||
if (tensor->op == GGML_OP_VIEW) {
|
||||
|
@ -7207,7 +7361,7 @@ void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) {
|
|||
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor));
|
||||
|
||||
struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
@ -7264,58 +7418,47 @@ void ggml_cuda_free_scratch() {
|
|||
g_scratch_buffer = nullptr;
|
||||
}
|
||||
|
||||
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
|
||||
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||
ggml_cuda_func_t func;
|
||||
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|
||||
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|
||||
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
|
||||
|
||||
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_REPEAT:
|
||||
func = ggml_cuda_repeat;
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
func = ggml_cuda_get_rows;
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_dup;
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_add;
|
||||
break;
|
||||
case GGML_OP_MUL:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_mul;
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(tensor)) {
|
||||
case GGML_UNARY_OP_GELU:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_gelu;
|
||||
break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_silu;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_NORM:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_norm;
|
||||
break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_rms_norm;
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
|
@ -7325,54 +7468,30 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|||
func = ggml_cuda_mul_mat;
|
||||
break;
|
||||
case GGML_OP_SCALE:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_scale;
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_cpy;
|
||||
break;
|
||||
case GGML_OP_CONT:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_dup;
|
||||
break;
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_nop;
|
||||
break;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_diag_mask_inf;
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_soft_max;
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_rope;
|
||||
break;
|
||||
case GGML_OP_ALIBI:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_alibi;
|
||||
break;
|
||||
default:
|
||||
|
@ -7400,3 +7519,263 @@ void ggml_cuda_get_device_description(int device, char * description, size_t des
|
|||
CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
|
||||
snprintf(description, description_size, "%s", prop.name);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// backend interface
|
||||
|
||||
#define UNUSED GGML_UNUSED
|
||||
|
||||
struct ggml_backend_context_cuda {
|
||||
};
|
||||
|
||||
static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
|
||||
return GGML_CUDA_NAME;
|
||||
|
||||
UNUSED(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_free(ggml_backend_t backend) {
|
||||
ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
|
||||
delete cuda_ctx;
|
||||
delete backend;
|
||||
}
|
||||
|
||||
struct ggml_backend_buffer_context_cuda {
|
||||
void * device;
|
||||
|
||||
ggml_tensor_extra_gpu * temp_tensor_extras = nullptr;
|
||||
size_t temp_tensor_extra_index = 0;
|
||||
|
||||
~ggml_backend_buffer_context_cuda() {
|
||||
delete[] temp_tensor_extras;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
||||
if (temp_tensor_extras == nullptr) {
|
||||
temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES];
|
||||
}
|
||||
|
||||
size_t alloc_index = temp_tensor_extra_index;
|
||||
temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_MAX_NODES;
|
||||
ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
|
||||
memset(extra, 0, sizeof(*extra));
|
||||
|
||||
return extra;
|
||||
}
|
||||
};
|
||||
|
||||
static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
|
||||
CUDA_CHECK(cudaFree(ctx->device));
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
|
||||
return ctx->device;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
||||
int64_t row_low = 0;
|
||||
int64_t row_high = ggml_nrows(tensor);
|
||||
int64_t nrows_split = row_high - row_low;
|
||||
|
||||
size_t size = ggml_nbytes_split(tensor, nrows_split);
|
||||
|
||||
int64_t ne0 = tensor->ne[0];
|
||||
|
||||
if (ggml_is_quantized(tensor->type)) {
|
||||
if (ne0 % MATRIX_ROW_PADDING != 0) {
|
||||
size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
|
||||
* ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
|
||||
}
|
||||
}
|
||||
|
||||
return size;
|
||||
|
||||
UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
||||
ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context;
|
||||
|
||||
if (tensor->view_src != NULL && tensor->view_offs == 0) {
|
||||
assert(tensor->view_src->buffer->backend == buffer->backend);
|
||||
tensor->backend = tensor->view_src->backend;
|
||||
tensor->extra = tensor->view_src->extra;
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_gpu * extra = ctx->ggml_cuda_alloc_temp_tensor_extra();
|
||||
|
||||
extra->data_device[g_main_device] = tensor->data;
|
||||
|
||||
tensor->backend = GGML_BACKEND_GPU;
|
||||
tensor->extra = extra;
|
||||
|
||||
if (ggml_is_quantized(tensor->type)) {
|
||||
// initialize padding to 0 to avoid possible NaN values
|
||||
int64_t row_low = 0;
|
||||
int64_t row_high = ggml_nrows(tensor);
|
||||
int64_t nrows_split = row_high - row_low;
|
||||
|
||||
size_t original_size = ggml_nbytes_split(tensor, nrows_split);
|
||||
size_t padded_size = ggml_backend_cuda_buffer_get_alloc_size(tensor->buffer, tensor);
|
||||
|
||||
if (padded_size > original_size && tensor->view_src == nullptr) {
|
||||
CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[g_main_device][0]));
|
||||
}
|
||||
}
|
||||
|
||||
UNUSED(buffer);
|
||||
}
|
||||
|
||||
static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {
|
||||
/* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
|
||||
/* .get_base = */ ggml_backend_cuda_buffer_get_base,
|
||||
/* .get_alloc_size = */ ggml_backend_cuda_buffer_get_alloc_size,
|
||||
/* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
|
||||
/* .free_tensor = */ NULL,
|
||||
};
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) {
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
|
||||
ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda;
|
||||
CUDA_CHECK(cudaMalloc(&ctx->device, size));
|
||||
return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_cuda_get_alignment(ggml_backend_t backend) {
|
||||
return 128;
|
||||
UNUSED(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[g_main_device][0]));
|
||||
|
||||
UNUSED(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
||||
|
||||
UNUSED(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
|
||||
|
||||
UNUSED(backend);
|
||||
}
|
||||
|
||||
static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
GGML_ASSERT(!"not implemented");
|
||||
|
||||
return nullptr;
|
||||
|
||||
UNUSED(backend);
|
||||
UNUSED(cgraph);
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
|
||||
GGML_ASSERT(!"not implemented");
|
||||
|
||||
UNUSED(backend);
|
||||
UNUSED(plan);
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
|
||||
GGML_ASSERT(!"not implemented");
|
||||
|
||||
UNUSED(backend);
|
||||
UNUSED(plan);
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
|
||||
ggml_compute_params params = {};
|
||||
params.type = GGML_TASK_COMPUTE;
|
||||
params.ith = 0;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
assert(node->backend == GGML_BACKEND_GPU);
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
if (node->src[j] != nullptr) {
|
||||
assert(node->src[j]->backend == GGML_BACKEND_GPU);
|
||||
}
|
||||
}
|
||||
|
||||
bool ok = ggml_cuda_compute_forward(¶ms, node);
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
|
||||
#if 0
|
||||
if (node->type == GGML_TYPE_F32) {
|
||||
cudaDeviceSynchronize();
|
||||
std::vector<float> tmp(ggml_nelements(node), 0.0f);
|
||||
cudaMemcpy(tmp.data(), node->data, ggml_nelements(node)*sizeof(float), cudaMemcpyDeviceToHost);
|
||||
printf("\n%s (%s) (%s %s) (%s %s): ", node->name, ggml_op_name(node->op),
|
||||
ggml_type_name(node->src[0]->type),
|
||||
node->src[1] ? ggml_type_name(node->src[1]->type) : "none",
|
||||
node->src[0]->name,
|
||||
node->src[1] ? node->src[1]->name : "none");
|
||||
double sum = 0.0;
|
||||
double sq_sum = 0.0;
|
||||
for (int i = 0; i < ggml_nelements(node); i++) {
|
||||
printf("%f ", tmp[i]);
|
||||
sum += tmp[i];
|
||||
sq_sum += tmp[i]*tmp[i];
|
||||
}
|
||||
printf("\n");
|
||||
printf("sum: %f, ", sum);
|
||||
printf("sq_sum: %f\n", sq_sum);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
UNUSED(backend);
|
||||
}
|
||||
|
||||
static ggml_backend_i cuda_backend_i = {
|
||||
/* .get_name = */ ggml_backend_cuda_name,
|
||||
/* .free = */ ggml_backend_cuda_free,
|
||||
/* .alloc_buffer = */ ggml_backend_cuda_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_cuda_get_alignment,
|
||||
/* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
|
||||
/* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
|
||||
/* .synchronize = */ ggml_backend_cuda_synchronize,
|
||||
/* .cpy_tensor_from = */ nullptr,
|
||||
/* .cpy_tensor_to = */ nullptr,
|
||||
/* .graph_plan_create = */ ggml_backend_cuda_graph_plan_create,
|
||||
/* .graph_plan_free = */ ggml_backend_cuda_graph_plan_free,
|
||||
/* .graph_plan_compute = */ ggml_backend_cuda_graph_plan_compute,
|
||||
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
|
||||
/* .supports_op = */ nullptr,
|
||||
};
|
||||
|
||||
ggml_backend_t ggml_backend_cuda_init() {
|
||||
ggml_init_cublas(); // TODO: remove from ggml.c
|
||||
|
||||
ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda;
|
||||
|
||||
ggml_backend_t cuda_backend = new ggml_backend {
|
||||
/* .interface = */ cuda_backend_i,
|
||||
/* .context = */ ctx
|
||||
};
|
||||
|
||||
return cuda_backend;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue