Fix possible synchronization issue

This commit is contained in:
Slaren 2023-04-19 23:01:53 +02:00
parent 891af05e7d
commit 95cf9597aa
3 changed files with 19 additions and 18 deletions

View file

@ -1,6 +1,6 @@
#include <stdint.h> #include <stdint.h>
#include "ggml-cuda.h"
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include "ggml-cuda.h"
typedef uint16_t ggml_fp16_t; typedef uint16_t ggml_fp16_t;
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size"); static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
@ -31,7 +31,7 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
static __global__ void dequantize_block_q4_0(const void * vx, float * y) { static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
const block_q4_0 * x = (const block_q4_0 *) vx; const block_q4_0 * x = (const block_q4_0 *) vx;
int i = blockIdx.x; const int i = blockIdx.x;
const float d = x[i].d; const float d = x[i].d;
@ -54,7 +54,7 @@ static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
static __global__ void dequantize_block_q4_1(const void * vx, float * y) { static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
const block_q4_1 * x = (const block_q4_1 *) vx; const block_q4_1 * x = (const block_q4_1 *) vx;
int i = blockIdx.x; const int i = blockIdx.x;
const float d = x[i].d; const float d = x[i].d;
const float m = x[i].m; const float m = x[i].m;
@ -78,7 +78,7 @@ static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
static __global__ void dequantize_block_q4_2(const void * vx, float * y) { static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
const block_q4_2 * x = (const block_q4_2 *) vx; const block_q4_2 * x = (const block_q4_2 *) vx;
int i = blockIdx.x; const int i = blockIdx.x;
const float d = x[i].d; const float d = x[i].d;
@ -99,18 +99,18 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
} }
extern "C" { extern "C" {
__host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k) { __host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0; const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1>>>(vx, y); dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
} }
__host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k) { __host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_1; const int nb = k / QK4_1;
dequantize_block_q4_1<<<nb, 1>>>(vx, y); dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
} }
__host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k) { __host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_2; const int nb = k / QK4_2;
dequantize_block_q4_2<<<nb, 1>>>(vx, y); dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
} }
} }

View file

@ -2,9 +2,9 @@
extern "C" { extern "C" {
#endif #endif
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k); void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k); void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k); void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} }

11
ggml.c
View file

@ -178,6 +178,7 @@ static void init_cublas(void) {
CUBLAS_CHECK(cublasCreate(&cublasH)); CUBLAS_CHECK(cublasCreate(&cublasH));
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
// configure logging to stdout // configure logging to stdout
@ -7758,7 +7759,6 @@ static void ggml_compute_forward_mul_mat_f32(
// copy data to host // copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else #else
// zT = y * xT // zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@ -7770,6 +7770,7 @@ static void ggml_compute_forward_mul_mat_f32(
} }
} }
#if defined(GGML_USE_CUBLAS) #if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X)); CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y)); CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D)); CUDA_CHECK(cudaFree(d_D));
@ -7982,7 +7983,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
// copy data to host // copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else #else
const float * x = wdata; const float * x = wdata;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@ -8000,6 +8000,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
} }
#if defined(GGML_USE_CUBLAS) #if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X)); CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y)); CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D)); CUDA_CHECK(cudaFree(d_D));
@ -8185,7 +8186,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type])); CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
dequantize_row_q_t dequantize_row_q_cuda = NULL; void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
if (type == GGML_TYPE_Q4_0) { if (type == GGML_TYPE_Q4_0) {
dequantize_row_q_cuda = dequantize_row_q4_0_cuda; dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
} }
@ -8215,7 +8216,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02, cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream)); GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00); dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
#else #else
{ {
@ -8243,7 +8244,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
// copy data to host // copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else #else
// zT = y * xT // zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@ -8256,6 +8256,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
} }
#if defined(GGML_USE_CUBLAS) #if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X)); CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y)); CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D)); CUDA_CHECK(cudaFree(d_D));