Fixed single GPU performance regression
This commit is contained in:
parent
4f9640b8fe
commit
11af67866e
2 changed files with 134 additions and 0 deletions
133
ggml-cuda.cu
133
ggml-cuda.cu
|
@ -934,6 +934,30 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
|
||||||
|
size_t src0_sz = ggml_nbytes(src0);
|
||||||
|
size_t src1_sz = ggml_nbytes(src1);
|
||||||
|
|
||||||
|
// mul_mat_q: src0 is converted to fp32 on device
|
||||||
|
size_t mul_mat_q_transfer = src0_sz + src1_sz;
|
||||||
|
|
||||||
|
// mul_mat_f16: src1 is converted to fp16 on cpu
|
||||||
|
size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
|
||||||
|
|
||||||
|
// choose the smaller one to transfer to the device
|
||||||
|
// TODO: this is not always the best choice due to the overhead of converting to fp16
|
||||||
|
return mul_mat_f16_transfer < mul_mat_q_transfer;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||||
|
if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
|
||||||
|
return ggml_nelements(src1) * sizeof(ggml_fp16_t);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
|
GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
|
||||||
|
|
||||||
|
@ -950,6 +974,99 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
|
||||||
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
const int64_t ne01 = src0->ne[1];
|
||||||
|
const int64_t ne02 = src0->ne[2];
|
||||||
|
const int64_t ne03 = src0->ne[3];
|
||||||
|
|
||||||
|
const int64_t ne10 = src1->ne[0];
|
||||||
|
const int64_t ne11 = src1->ne[1];
|
||||||
|
|
||||||
|
const int nb10 = src1->nb[0];
|
||||||
|
const int nb11 = src1->nb[1];
|
||||||
|
const int nb12 = src1->nb[2];
|
||||||
|
const int nb13 = src1->nb[3];
|
||||||
|
|
||||||
|
const int nb2 = dst->nb[2];
|
||||||
|
const int nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
const float alpha = 1.0f;
|
||||||
|
const float beta = 0.0f;
|
||||||
|
const int x_ne = ne01 * ne00;
|
||||||
|
const int y_ne = ne11 * ne10;
|
||||||
|
const int d_ne = ne11 * ne01;
|
||||||
|
const int n_mm = ne03 * ne02;
|
||||||
|
|
||||||
|
size_t x_size, y_size, d_size;
|
||||||
|
half * d_X = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
|
||||||
|
half * d_Y = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
|
||||||
|
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
|
||||||
|
|
||||||
|
bool src1_cont_rows = nb10 == sizeof(float);
|
||||||
|
bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
|
||||||
|
|
||||||
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
int i = i03*ne02 + i02;
|
||||||
|
cudaStream_t cudaStream = g_cudaStreams_main[0][i % GGML_CUDA_MAX_STREAMS];
|
||||||
|
|
||||||
|
half * c_X = d_X + i * x_ne;
|
||||||
|
half * c_Y = d_Y + i * y_ne;
|
||||||
|
float * c_D = d_D + i * d_ne;
|
||||||
|
|
||||||
|
// copy src0 to device
|
||||||
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, 0, ne01, cudaStream));
|
||||||
|
|
||||||
|
// convert src1 to fp16
|
||||||
|
// TODO: use multiple threads
|
||||||
|
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
|
||||||
|
char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
|
||||||
|
if (src1_cont_rows) {
|
||||||
|
if (src1_cont_cols) {
|
||||||
|
ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (int64_t i01 = 0; i01 < ne11; i01++) {
|
||||||
|
ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (int64_t i01 = 0; i01 < ne11; i01++) {
|
||||||
|
for (int64_t i00 = 0; i00 < ne10; i00++) {
|
||||||
|
// very slow due to no inlining
|
||||||
|
tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy src1 to device
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||||
|
|
||||||
|
// compute
|
||||||
|
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[0], cudaStream));
|
||||||
|
CUBLAS_CHECK(
|
||||||
|
cublasGemmEx(g_cublas_handles[0], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
|
ne01, ne11, ne10,
|
||||||
|
&alpha, c_X, CUDA_R_16F, ne00,
|
||||||
|
c_Y, CUDA_R_16F, ne10,
|
||||||
|
&beta, c_D, CUDA_R_32F, ne01,
|
||||||
|
CUBLAS_COMPUTE_32F_FAST_16F,
|
||||||
|
CUBLAS_GEMM_DEFAULT));
|
||||||
|
|
||||||
|
// copy dst to host
|
||||||
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
ggml_cuda_pool_free(d_X, x_size);
|
||||||
|
ggml_cuda_pool_free(d_Y, y_size);
|
||||||
|
ggml_cuda_pool_free(d_D, d_size);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset, int n_layer) {
|
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset, int n_layer) {
|
||||||
FILE * fp = fopen(fname, "rb");
|
FILE * fp = fopen(fname, "rb");
|
||||||
int nrows = ggml_nrows(tensor);
|
int nrows = ggml_nrows(tensor);
|
||||||
|
@ -1054,6 +1171,22 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||||
if (!ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
|
if (!ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For prompt processing the multi GPU code is currently slower than the single GPU code that existed before.
|
||||||
|
// To avoid a performance regression the old code is kept for now:
|
||||||
|
if (g_device_count == 1 && tensor->src0->type == GGML_TYPE_F16 &&
|
||||||
|
ggml_cuda_mul_mat_use_f16(tensor->src0, tensor->src1, tensor)) {
|
||||||
|
|
||||||
|
if (params->ith != 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (params->type == GGML_TASK_COMPUTE) {
|
||||||
|
ggml_cuda_mul_mat_f16(tensor->src0, tensor->src1, tensor, params->wdata, params->wsize);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
func = ggml_cuda_mul_mat;
|
func = ggml_cuda_mul_mat;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
1
ggml.c
1
ggml.c
|
@ -14203,6 +14203,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||||
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
|
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
|
||||||
node->n_tasks = 1; // TODO: this actually is doing nothing
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
||||||
// the threads are still spinning
|
// the threads are still spinning
|
||||||
|
cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue