CUDA op template
This commit is contained in:
parent
827f5eda91
commit
071dcd351b
2 changed files with 266 additions and 254 deletions
517
ggml-cuda.cu
517
ggml-cuda.cu
|
@ -32,9 +32,23 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
// Q = quantized, F = float, order is src0, src1, dst
|
||||
enum ggml_cuda_op_type {
|
||||
GGML_CUDA_OP_TYPE_QQQ = 0,
|
||||
GGML_CUDA_OP_TYPE_QQF = 1,
|
||||
GGML_CUDA_OP_TYPE_QFQ = 2,
|
||||
GGML_CUDA_OP_TYPE_QFF = 3,
|
||||
GGML_CUDA_OP_TYPE_FQQ = 4,
|
||||
GGML_CUDA_OP_TYPE_FQF = 5,
|
||||
GGML_CUDA_OP_TYPE_FFQ = 6,
|
||||
GGML_CUDA_OP_TYPE_FFF = 7,
|
||||
};
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
|
||||
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||
typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
typedef void (*ggml_cuda_op_t)(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int i1, cudaStream_t & cudaStream_main);
|
||||
|
||||
// QK = number of values after dequantization
|
||||
// QR = QK / number of values before dequantization
|
||||
|
@ -360,25 +374,6 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|||
}
|
||||
}
|
||||
|
||||
static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return dequantize_mul_mat_vec_q4_0_cuda;
|
||||
case GGML_TYPE_Q4_1:
|
||||
return dequantize_mul_mat_vec_q4_1_cuda;
|
||||
case GGML_TYPE_Q5_0:
|
||||
return dequantize_mul_mat_vec_q5_0_cuda;
|
||||
case GGML_TYPE_Q5_1:
|
||||
return dequantize_mul_mat_vec_q5_1_cuda;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return dequantize_mul_mat_vec_q8_0_cuda;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_mul_mat_vec_f16_cuda;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// buffer pool for cuda
|
||||
#define MAX_CUDA_BUFFERS 256
|
||||
|
||||
|
@ -441,20 +436,24 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
|||
#define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
|
||||
#define GGML_CUDA_MAX_EVENTS 64
|
||||
static cublasHandle_t g_cublasH = nullptr;
|
||||
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr };
|
||||
static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr };
|
||||
static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr };
|
||||
static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_STREAMS] = { nullptr };
|
||||
static cudaEvent_t g_cudaEvents_main[GGML_CUDA_MAX_EVENTS] = { nullptr };
|
||||
static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_STREAMS] = { nullptr };
|
||||
static cudaStream_t g_cudaStreams_memcpy_dst[GGML_CUDA_MAX_STREAMS] = { nullptr };
|
||||
static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_EVENTS] = { nullptr };
|
||||
|
||||
void ggml_init_cublas() {
|
||||
if (g_cublasH == nullptr) {
|
||||
// create streams
|
||||
for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking));
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking));
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[i], cudaStreamNonBlocking));
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[i], cudaStreamNonBlocking));
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_dst[i], cudaStreamNonBlocking));
|
||||
}
|
||||
// create events
|
||||
for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming));
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_main[i], cudaEventDisableTiming));
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[i], cudaEventDisableTiming));
|
||||
}
|
||||
|
||||
// create cublas handle
|
||||
|
@ -514,125 +513,6 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[2];
|
||||
const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
const int nb2 = dst->nb[2];
|
||||
const int nb3 = dst->nb[3];
|
||||
size_t x_size, d_size;
|
||||
|
||||
float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
|
||||
float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
|
||||
float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
const int i0 = i03*ne02 + i02;
|
||||
float * c_X2 = d_X + i0*ne01*ne00;
|
||||
float * c_D2 = d_D + i0*ne01*ne00;
|
||||
|
||||
cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
|
||||
cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
|
||||
cudaEvent_t cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
|
||||
|
||||
// copy src0 to device
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||
|
||||
// wait for data
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
const int64_t i13 = i03%ne13;
|
||||
const int64_t i12 = i02%ne12;
|
||||
const int64_t i11 = i01%ne11;
|
||||
const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
|
||||
|
||||
float * c_X1 = c_X2 + i01*ne00;
|
||||
float * c_Y = d_Y + i1*ne10;
|
||||
float * c_D1 = c_D2 + i01*ne00;
|
||||
|
||||
// compute
|
||||
mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
// copy dst to host
|
||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
|
||||
}
|
||||
}
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
ggml_cuda_pool_free(d_D, d_size);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
|
||||
const int nb2 = dst->nb[2];
|
||||
const int nb3 = dst->nb[3];
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
const int x_ne = ne01 * ne00;
|
||||
const int y_ne = ne11 * ne10;
|
||||
const int d_ne = ne11 * ne01;
|
||||
const int n_mm = ne03 * ne02;
|
||||
|
||||
size_t x_size, y_size, d_size;
|
||||
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
|
||||
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
|
||||
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
int i = i03*ne02 + i02;
|
||||
cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
|
||||
|
||||
float * c_X = d_X + i * x_ne;
|
||||
float * c_Y = d_Y + i * y_ne;
|
||||
float * c_D = d_D + i * d_ne;
|
||||
|
||||
// copy data to device
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
|
||||
|
||||
// compute
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha, c_X, ne00,
|
||||
c_Y, ne10,
|
||||
&beta, c_D, ne01));
|
||||
|
||||
// copy dst to host
|
||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||
}
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
ggml_cuda_pool_free(d_Y, y_size);
|
||||
ggml_cuda_pool_free(d_D, d_size);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
|
@ -668,7 +548,7 @@ static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor *
|
|||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
int i = i03*ne02 + i02;
|
||||
cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
|
||||
cudaStream_t cudaStream = g_cudaStreams_main[i % GGML_CUDA_MAX_STREAMS];
|
||||
|
||||
half * c_X = d_X + i * x_ne;
|
||||
half * c_Y = d_Y + i * y_ne;
|
||||
|
@ -726,7 +606,110 @@ static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor *
|
|||
ggml_cuda_pool_free(d_D, d_size);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
inline void ggml_cuda_op_mul(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int i1, cudaStream_t & cudaStream_main){
|
||||
|
||||
GGML_ASSERT(src0_ddf_i != nullptr);
|
||||
GGML_ASSERT(src1_ddf_i != nullptr);
|
||||
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
|
||||
|
||||
float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
|
||||
float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
|
||||
float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
|
||||
|
||||
// compute
|
||||
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
(void) dst;
|
||||
(void) src0_ddq_i;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int i1, cudaStream_t & cudaStream_main){
|
||||
|
||||
GGML_ASSERT(src0_ddq_i != nullptr);
|
||||
GGML_ASSERT(src1_ddf_i != nullptr);
|
||||
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src0_ddf_i;
|
||||
(void) i1;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_mul_mat_cublas(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int i1, cudaStream_t & cudaStream_main){
|
||||
|
||||
GGML_ASSERT(src0_ddf_i != nullptr);
|
||||
GGML_ASSERT(src1_ddf_i != nullptr);
|
||||
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream_main));
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha, src0_ddf_i, ne00,
|
||||
src1_ddf_i, ne10,
|
||||
&beta, dst_ddf_i, ne01));
|
||||
|
||||
(void) dst;
|
||||
(void) src0_ddq_i;
|
||||
(void) i1;
|
||||
}
|
||||
|
||||
template<enum ggml_cuda_op_type op_type, ggml_cuda_op_t op>
|
||||
static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
|
@ -734,107 +717,154 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
|
||||
const int nb2 = dst->nb[2];
|
||||
const int nb3 = dst->nb[3];
|
||||
const ggml_type type = src0->type;
|
||||
const bool mul_mat_vec = ne11 == 1;
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
const int x_ne = ne01 * ne00;
|
||||
const int y_ne = ne11 * ne10;
|
||||
const int d_ne = ne11 * ne01;
|
||||
const int n_mm = ne03 * ne02;
|
||||
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
|
||||
const int64_t src0_stride = ne00 * ne01;
|
||||
const int64_t src1_stride = ne10 * ne11;
|
||||
const int64_t dst_stride = ne0 * ne1;
|
||||
const int64_t num_iters = ne02 * ne03;
|
||||
|
||||
size_t x_size, y_size, d_size, q_size;
|
||||
float * d_X = nullptr;
|
||||
if (!mul_mat_vec) {
|
||||
d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
|
||||
const size_t src0_ts = ggml_type_size(src0->type);
|
||||
|
||||
const bool src0_on_device = src0->backend == GGML_BACKEND_CUDA;
|
||||
const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
|
||||
const bool src0_needs_f32 = op_type & 0x4; // 3rd least significant bit = src0 needs f32
|
||||
|
||||
const bool src1_on_device = src1->backend == GGML_BACKEND_CUDA;
|
||||
|
||||
const bool dst_on_device = dst->backend == GGML_BACKEND_CUDA;
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
||||
|
||||
// dd = data device
|
||||
char * src0_ddq = nullptr; // quantized
|
||||
float * src0_ddf = nullptr; // float
|
||||
float * src1_ddf = nullptr;
|
||||
float * dst_ddf = nullptr;
|
||||
|
||||
bool src0_ddq_malloced = false;
|
||||
bool src0_ddf_malloced = false;
|
||||
bool src1_ddf_malloced = false;
|
||||
bool dst_ddf_malloced = false;
|
||||
|
||||
// asq = actual size quantized, asf = actual size float
|
||||
size_t src0_asq, src0_asf, src1_asf, dst_asf;
|
||||
|
||||
if (src0_on_device) {
|
||||
if (src0_is_f32) {
|
||||
src0_ddf = (float *) src0->data;
|
||||
} else {
|
||||
src0_ddq = (char *) src0->data;
|
||||
}
|
||||
} else {
|
||||
if (src0_is_f32) {
|
||||
src0_ddf = (float *) ggml_cuda_pool_malloc(num_iters * src0_stride * sizeof(float), &src0_asf);
|
||||
src0_ddf_malloced = true;
|
||||
} else {
|
||||
src0_ddq = (char *) ggml_cuda_pool_malloc(num_iters * src0_stride * src0_ts, &src0_asq);
|
||||
src0_ddq_malloced = true;
|
||||
}
|
||||
}
|
||||
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
|
||||
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
|
||||
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
|
||||
dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
|
||||
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||
if (src0_needs_f32 && !src0_is_f32) {
|
||||
src0_ddf = (float *) ggml_cuda_pool_malloc(num_iters * src0_stride * sizeof(float), &src0_asf);
|
||||
src0_ddf_malloced = true;
|
||||
}
|
||||
|
||||
if (src1_on_device) {
|
||||
src1_ddf = (float *) src1->data;
|
||||
} else {
|
||||
src1_ddf = (float *) ggml_cuda_pool_malloc(num_iters * src1_stride * sizeof(float), &src1_asf);
|
||||
src1_ddf_malloced = true;
|
||||
}
|
||||
if (dst_on_device) {
|
||||
dst_ddf = (float *) dst->data;
|
||||
} else {
|
||||
dst_ddf = (float *) ggml_cuda_pool_malloc(num_iters * dst_stride * sizeof(float), &dst_asf);
|
||||
dst_ddf_malloced = true;
|
||||
}
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
const int64_t i13 = i03 % ne13;
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
int i = i03*ne02 + i02;
|
||||
cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
|
||||
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
|
||||
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
|
||||
const int64_t i12 = i02 % ne12;
|
||||
|
||||
float * c_Y = d_Y + i * y_ne;
|
||||
float * c_D = d_D + i * d_ne;
|
||||
char * c_Q = d_Q + i * q_sz;
|
||||
const int64_t i0 = i03*ne02 + i02;
|
||||
const int64_t i1 = i13*ne12 + i12;
|
||||
|
||||
// copy src0 to device if necessary
|
||||
if (src0->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
|
||||
} else if (src0->backend == GGML_BACKEND_CUDA) {
|
||||
c_Q = ((char *) src0->data) + i * q_sz;
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[i0 % GGML_CUDA_MAX_STREAMS];
|
||||
cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[i0 % GGML_CUDA_MAX_STREAMS];
|
||||
cudaStream_t cudaStream_memcpy_dst = g_cudaStreams_memcpy_dst[i0 % GGML_CUDA_MAX_STREAMS];
|
||||
cudaEvent_t cudaEvent_main = g_cudaEvents_main[i0 % GGML_CUDA_MAX_EVENTS];
|
||||
cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[i0 % GGML_CUDA_MAX_EVENTS];
|
||||
|
||||
char * src0_ddq_i = src0_ddq + i0*src0_stride;
|
||||
float * src0_ddf_i = src0_ddf + i0*src0_stride;
|
||||
float * src1_ddf_i = src1_ddf + i1*src1_stride;
|
||||
float * dst_ddf_i = dst_ddf + i0*dst_stride;
|
||||
|
||||
// copy src0, src1 to device if necessary
|
||||
if (!src1_on_device) { // src1 first to avoid blocking device queues
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src1_ddf, src1, i03, i02, cudaStream_memcpy_src1));
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
|
||||
if (!src0_on_device) {
|
||||
if (src0_is_f32) {
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddf, src0, i03, i02, cudaStream_main));
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddq, src0, i03, i02, cudaStream_main));
|
||||
}
|
||||
if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||
|
||||
// copy src1 to device
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
|
||||
|
||||
// wait for data
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||
|
||||
// compute
|
||||
dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
} else { // general dequantization kernel + cuBLAS matrix matrix multiplication
|
||||
float * c_X = d_X + i * x_ne;
|
||||
|
||||
// convert src0 to fp32 on device
|
||||
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||
|
||||
// copy src1 to device
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
|
||||
|
||||
// wait for conversion
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||
|
||||
// compute
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha, c_X, ne00,
|
||||
c_Y, ne10,
|
||||
&beta, c_D, ne01));
|
||||
}
|
||||
|
||||
// copy dst to host
|
||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||
if (src0_needs_f32 && !src0_is_f32) {
|
||||
to_fp32_cuda(src0_ddq_i, src0_ddf_i, src0_stride, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
// wait with main stream until src1 memcpy is done
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
|
||||
|
||||
// do the computation
|
||||
op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i1, cudaStream_main);
|
||||
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent_main, cudaStream_main));
|
||||
|
||||
// copy dst to host if necessary
|
||||
if (!dst_on_device) {
|
||||
// wait with memcpy until main stream is done
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream_memcpy_dst, cudaEvent_main, 0));
|
||||
|
||||
float * dhf_dst_i = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), cudaMemcpyDeviceToHost, cudaStream_memcpy_dst));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
if (!mul_mat_vec) {
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
if (src0_ddf_malloced) {
|
||||
ggml_cuda_pool_free(src0_ddf, src0_asf);
|
||||
}
|
||||
if (src0_ddq_malloced) {
|
||||
ggml_cuda_pool_free(src0_ddq, src0_asq);
|
||||
}
|
||||
if (src1_ddf_malloced) {
|
||||
ggml_cuda_pool_free(src1_ddf, src1_asf);
|
||||
}
|
||||
if (dst_ddf_malloced) {
|
||||
ggml_cuda_pool_free(dst_ddf, dst_asf);
|
||||
}
|
||||
ggml_cuda_pool_free(d_Y, y_size);
|
||||
ggml_cuda_pool_free(d_D, d_size);
|
||||
ggml_cuda_pool_free(d_Q, q_size);
|
||||
}
|
||||
|
||||
void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_mul_f32(src0, src1, dst);
|
||||
ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul>(src0, src1, dst);
|
||||
}
|
||||
|
||||
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||
|
@ -873,18 +903,27 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
|
|||
GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
|
||||
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_mul_mat_f32(src0, src1, dst);
|
||||
ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul_mat_cublas>(src0, src1, dst);
|
||||
}
|
||||
else if (src0->type == GGML_TYPE_F16) {
|
||||
if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
|
||||
// ggml_cuda_op<GGML_CUDA_OP_TYPE_QQF, ggml_cuda_op_mul_mat_cublas>(src0, src1, dst);
|
||||
ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
|
||||
}
|
||||
else {
|
||||
ggml_cuda_mul_mat_q_f32(src0, src1, dst);
|
||||
if (src1->ne[1] == 1) {
|
||||
ggml_cuda_op<GGML_CUDA_OP_TYPE_QFF, ggml_cuda_op_dequantize_mul_mat_vec>(src0, src1, dst);
|
||||
} else {
|
||||
ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul_mat_cublas>(src0, src1, dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (ggml_is_quantized(src0->type)) {
|
||||
ggml_cuda_mul_mat_q_f32(src0, src1, dst);
|
||||
if (src1->ne[1] == 1) {
|
||||
ggml_cuda_op<GGML_CUDA_OP_TYPE_QFF, ggml_cuda_op_dequantize_mul_mat_vec>(src0, src1, dst);
|
||||
} else {
|
||||
ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul_mat_cublas>(src0, src1, dst);
|
||||
}
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
|
@ -900,32 +939,6 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
|
||||
const int64_t ne0 = tensor->ne[0];
|
||||
const int64_t ne1 = tensor->ne[1];
|
||||
const int64_t ne2 = tensor->ne[2];
|
||||
const int64_t ne3 = tensor->ne[3];
|
||||
|
||||
const ggml_type type = tensor->type;
|
||||
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
|
||||
|
||||
size_t q_size;
|
||||
char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
|
||||
|
||||
cudaStream_t cudaStream2 = g_cudaStreams2[0];
|
||||
|
||||
// copy tensor to device
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
int i = i3*ne2 + i2;
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
|
||||
}
|
||||
}
|
||||
|
||||
tensor->data = dst;
|
||||
tensor->backend = GGML_BACKEND_CUDA;
|
||||
}
|
||||
|
||||
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
|
||||
FILE * fp = fopen(fname, "rb");
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
|
|||
void * ggml_cuda_host_malloc(size_t size);
|
||||
void ggml_cuda_host_free(void * ptr);
|
||||
|
||||
void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue