Fix possible synchronization issue
This commit is contained in:
parent
891af05e7d
commit
95cf9597aa
3 changed files with 19 additions and 18 deletions
20
ggml-cuda.cu
20
ggml-cuda.cu
|
@ -1,6 +1,6 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include "ggml-cuda.h"
|
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
#include "ggml-cuda.h"
|
||||||
|
|
||||||
typedef uint16_t ggml_fp16_t;
|
typedef uint16_t ggml_fp16_t;
|
||||||
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||||
|
@ -31,7 +31,7 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
|
||||||
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
|
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
|
||||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||||
|
|
||||||
int i = blockIdx.x;
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
const float d = x[i].d;
|
const float d = x[i].d;
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
|
||||||
static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
|
static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
|
||||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||||
|
|
||||||
int i = blockIdx.x;
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
const float d = x[i].d;
|
const float d = x[i].d;
|
||||||
const float m = x[i].m;
|
const float m = x[i].m;
|
||||||
|
@ -78,7 +78,7 @@ static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
|
||||||
static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
|
static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
|
||||||
const block_q4_2 * x = (const block_q4_2 *) vx;
|
const block_q4_2 * x = (const block_q4_2 *) vx;
|
||||||
|
|
||||||
int i = blockIdx.x;
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
const float d = x[i].d;
|
const float d = x[i].d;
|
||||||
|
|
||||||
|
@ -99,18 +99,18 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
__host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k) {
|
__host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK4_0;
|
const int nb = k / QK4_0;
|
||||||
dequantize_block_q4_0<<<nb, 1>>>(vx, y);
|
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
__host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k) {
|
__host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK4_1;
|
const int nb = k / QK4_1;
|
||||||
dequantize_block_q4_1<<<nb, 1>>>(vx, y);
|
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
__host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k) {
|
__host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK4_2;
|
const int nb = k / QK4_2;
|
||||||
dequantize_block_q4_2<<<nb, 1>>>(vx, y);
|
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,9 +2,9 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k);
|
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
|
||||||
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k);
|
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
|
||||||
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k);
|
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
11
ggml.c
11
ggml.c
|
@ -178,6 +178,7 @@ static void init_cublas(void) {
|
||||||
CUBLAS_CHECK(cublasCreate(&cublasH));
|
CUBLAS_CHECK(cublasCreate(&cublasH));
|
||||||
|
|
||||||
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
|
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
|
||||||
|
|
||||||
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
|
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
|
||||||
|
|
||||||
// configure logging to stdout
|
// configure logging to stdout
|
||||||
|
@ -7758,7 +7759,6 @@ static void ggml_compute_forward_mul_mat_f32(
|
||||||
|
|
||||||
// copy data to host
|
// copy data to host
|
||||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||||
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
|
||||||
#else
|
#else
|
||||||
// zT = y * xT
|
// zT = y * xT
|
||||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||||
|
@ -7770,6 +7770,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||||
CUDA_CHECK(cudaFree(d_X));
|
CUDA_CHECK(cudaFree(d_X));
|
||||||
CUDA_CHECK(cudaFree(d_Y));
|
CUDA_CHECK(cudaFree(d_Y));
|
||||||
CUDA_CHECK(cudaFree(d_D));
|
CUDA_CHECK(cudaFree(d_D));
|
||||||
|
@ -7982,7 +7983,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||||
|
|
||||||
// copy data to host
|
// copy data to host
|
||||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||||
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
|
||||||
#else
|
#else
|
||||||
const float * x = wdata;
|
const float * x = wdata;
|
||||||
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
||||||
|
@ -8000,6 +8000,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||||
CUDA_CHECK(cudaFree(d_X));
|
CUDA_CHECK(cudaFree(d_X));
|
||||||
CUDA_CHECK(cudaFree(d_Y));
|
CUDA_CHECK(cudaFree(d_Y));
|
||||||
CUDA_CHECK(cudaFree(d_D));
|
CUDA_CHECK(cudaFree(d_D));
|
||||||
|
@ -8185,7 +8186,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
||||||
CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
|
CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
|
||||||
|
|
||||||
dequantize_row_q_t dequantize_row_q_cuda = NULL;
|
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
|
||||||
if (type == GGML_TYPE_Q4_0) {
|
if (type == GGML_TYPE_Q4_0) {
|
||||||
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
|
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
|
||||||
}
|
}
|
||||||
|
@ -8215,7 +8216,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
|
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
|
||||||
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
|
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
|
||||||
|
|
||||||
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00);
|
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
#else
|
#else
|
||||||
{
|
{
|
||||||
|
@ -8243,7 +8244,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
|
|
||||||
// copy data to host
|
// copy data to host
|
||||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||||
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
|
||||||
#else
|
#else
|
||||||
// zT = y * xT
|
// zT = y * xT
|
||||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||||
|
@ -8256,6 +8256,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||||
CUDA_CHECK(cudaFree(d_X));
|
CUDA_CHECK(cudaFree(d_X));
|
||||||
CUDA_CHECK(cudaFree(d_Y));
|
CUDA_CHECK(cudaFree(d_Y));
|
||||||
CUDA_CHECK(cudaFree(d_D));
|
CUDA_CHECK(cudaFree(d_D));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue