Add NVIDIA cuBLAS support
This commit is contained in:
parent
42747220b4
commit
4440d198c0
4 changed files with 203 additions and 13 deletions
4
Makefile
4
Makefile
|
@ -97,6 +97,10 @@ ifdef LLAMA_OPENBLAS
|
||||||
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
|
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
|
||||||
LDFLAGS += -lopenblas
|
LDFLAGS += -lopenblas
|
||||||
endif
|
endif
|
||||||
|
ifdef LLAMA_CUBLAS
|
||||||
|
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
|
||||||
|
LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64
|
||||||
|
endif
|
||||||
ifdef LLAMA_GPROF
|
ifdef LLAMA_GPROF
|
||||||
CFLAGS += -pg
|
CFLAGS += -pg
|
||||||
CXXFLAGS += -pg
|
CXXFLAGS += -pg
|
||||||
|
|
209
ggml.c
209
ggml.c
|
@ -142,10 +142,46 @@ inline static void* ggml_aligned_malloc(size_t size) {
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#ifdef GGML_USE_ACCELERATE
|
#if defined(GGML_USE_ACCELERATE)
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
#elif GGML_USE_OPENBLAS
|
#elif defined(GGML_USE_OPENBLAS)
|
||||||
#include <cblas.h>
|
#include <cblas.h>
|
||||||
|
#elif defined(GGML_USE_CUBLAS)
|
||||||
|
#include <cublas_v2.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#define CUDA_CHECK(err) \
|
||||||
|
do { \
|
||||||
|
cudaError_t err_ = (err); \
|
||||||
|
if (err_ != cudaSuccess) { \
|
||||||
|
printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
||||||
|
cudaGetErrorString(err_)); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define CUBLAS_CHECK(err) \
|
||||||
|
do { \
|
||||||
|
cublasStatus_t err_ = (err); \
|
||||||
|
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
||||||
|
printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
static cublasHandle_t cublasH = NULL;
|
||||||
|
static cudaStream_t cudaStream = NULL;
|
||||||
|
static void init_cublas(void) {
|
||||||
|
if (cublasH == NULL) {
|
||||||
|
/* step 1: create cublas handle, bind a stream */
|
||||||
|
CUBLAS_CHECK(cublasCreate(&cublasH));
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
|
||||||
|
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
|
||||||
|
|
||||||
|
// configure logging to stdout
|
||||||
|
//CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
||||||
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#undef MIN
|
#undef MIN
|
||||||
|
@ -3605,6 +3641,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
||||||
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initialize cuBLAS
|
||||||
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
init_cublas();
|
||||||
|
#endif
|
||||||
|
|
||||||
is_first_call = false;
|
is_first_call = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7161,7 +7202,7 @@ static void ggml_compute_forward_rms_norm(
|
||||||
|
|
||||||
// ggml_compute_forward_mul_mat
|
// ggml_compute_forward_mul_mat
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
||||||
// helper function to determine if it is better to use BLAS or not
|
// helper function to determine if it is better to use BLAS or not
|
||||||
// for large matrices, BLAS is faster
|
// for large matrices, BLAS is faster
|
||||||
static bool ggml_compute_forward_mul_mat_use_blas(
|
static bool ggml_compute_forward_mul_mat_use_blas(
|
||||||
|
@ -7201,7 +7242,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
||||||
const int64_t ne02 = src0->ne[2];
|
const int64_t ne02 = src0->ne[2];
|
||||||
const int64_t ne03 = src0->ne[3];
|
const int64_t ne03 = src0->ne[3];
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
||||||
const int64_t ne10 = src1->ne[0];
|
const int64_t ne10 = src1->ne[0];
|
||||||
#endif
|
#endif
|
||||||
const int64_t ne11 = src1->ne[1];
|
const int64_t ne11 = src1->ne[1];
|
||||||
|
@ -7258,7 +7299,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
||||||
// nb01 >= nb00 - src0 is not transposed
|
// nb01 >= nb00 - src0 is not transposed
|
||||||
// compute by src0 rows
|
// compute by src0 rows
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||||
if (params->ith != 0) {
|
if (params->ith != 0) {
|
||||||
return;
|
return;
|
||||||
|
@ -7272,6 +7313,21 @@ static void ggml_compute_forward_mul_mat_f32(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
float *d_X = NULL;
|
||||||
|
float *d_Y = NULL;
|
||||||
|
float *d_D = NULL;
|
||||||
|
const float alpha = 1.0f;
|
||||||
|
const float beta = 0.0f;
|
||||||
|
const int x_ne = ne01 * ne10;
|
||||||
|
const int y_ne = ne11 * ne10;
|
||||||
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
||||||
|
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
||||||
|
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
||||||
|
#endif
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
||||||
|
@ -7279,15 +7335,38 @@ static void ggml_compute_forward_mul_mat_f32(
|
||||||
|
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
/* step 2: copy data to device */
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||||
|
|
||||||
|
/* step 3: compute */
|
||||||
|
CUBLAS_CHECK(
|
||||||
|
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
|
ne01, ne11, ne10,
|
||||||
|
&alpha, d_X, ne00,
|
||||||
|
d_Y, ne10,
|
||||||
|
&beta, d_D, ne01));
|
||||||
|
|
||||||
|
/* step 4: copy data to host */
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||||
|
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||||
|
#else
|
||||||
// zT = y * xT
|
// zT = y * xT
|
||||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||||
ne11, ne01, ne10,
|
ne11, ne01, ne10,
|
||||||
1.0f, y, ne10,
|
1.0f, y, ne10,
|
||||||
x, ne00,
|
x, ne00,
|
||||||
0.0f, d, ne01);
|
0.0f, d, ne01);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
/* free resources */
|
||||||
|
CUDA_CHECK(cudaFree(d_X));
|
||||||
|
CUDA_CHECK(cudaFree(d_Y));
|
||||||
|
CUDA_CHECK(cudaFree(d_D));
|
||||||
|
#endif
|
||||||
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
||||||
|
|
||||||
return;
|
return;
|
||||||
|
@ -7417,7 +7496,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||||
// nb01 >= nb00 - src0 is not transposed
|
// nb01 >= nb00 - src0 is not transposed
|
||||||
// compute by src0 rows
|
// compute by src0 rows
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||||
GGML_ASSERT(nb10 == sizeof(float));
|
GGML_ASSERT(nb10 == sizeof(float));
|
||||||
|
|
||||||
|
@ -7433,10 +7512,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * const wdata = params->wdata;
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
ggml_fp16_t * const wdata = params->wdata;
|
||||||
|
|
||||||
|
float *d_X = NULL;
|
||||||
|
float *d_Y = NULL;
|
||||||
|
float *d_D = NULL;
|
||||||
|
const float alpha = 1.0f;
|
||||||
|
const float beta = 0.0f;
|
||||||
|
const int x_ne = ne01 * ne10;
|
||||||
|
const int y_ne = ne11 * ne10;
|
||||||
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
|
||||||
|
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
||||||
|
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
||||||
|
#else
|
||||||
|
float * const wdata = params->wdata;
|
||||||
|
#endif
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
|
||||||
|
{
|
||||||
|
size_t id = 0;
|
||||||
|
for (int64_t i01 = 0; i01 < ne11; ++i01) {
|
||||||
|
for (int64_t i00 = 0; i00 < ne10; ++i00) {
|
||||||
|
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
{
|
{
|
||||||
size_t id = 0;
|
size_t id = 0;
|
||||||
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
||||||
|
@ -7445,7 +7551,32 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
||||||
|
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
|
||||||
|
|
||||||
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
|
/* step 2: copy data to device */
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||||
|
|
||||||
|
/* step 3: compute */
|
||||||
|
CUBLAS_CHECK(
|
||||||
|
cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
|
ne01, ne11, ne10,
|
||||||
|
&alpha, d_X, CUDA_R_16F, ne00,
|
||||||
|
d_Y, CUDA_R_16F, ne10,
|
||||||
|
&beta, d_D, CUDA_R_32F, ne01,
|
||||||
|
CUBLAS_COMPUTE_32F,
|
||||||
|
CUBLAS_GEMM_DEFAULT));
|
||||||
|
|
||||||
|
/* step 4: copy data to host */
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||||
|
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||||
|
#else
|
||||||
const float * x = wdata;
|
const float * x = wdata;
|
||||||
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
||||||
|
|
||||||
|
@ -7457,9 +7588,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||||
1.0f, y, ne10,
|
1.0f, y, ne10,
|
||||||
x, ne00,
|
x, ne00,
|
||||||
0.0f, d, ne01);
|
0.0f, d, ne01);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
/* free resources */
|
||||||
|
CUDA_CHECK(cudaFree(d_X));
|
||||||
|
CUDA_CHECK(cudaFree(d_Y));
|
||||||
|
CUDA_CHECK(cudaFree(d_D));
|
||||||
|
#endif
|
||||||
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
||||||
|
|
||||||
return;
|
return;
|
||||||
|
@ -7611,7 +7749,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
// nb01 >= nb00 - src0 is not transposed
|
// nb01 >= nb00 - src0 is not transposed
|
||||||
// compute by src0 rows
|
// compute by src0 rows
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||||
if (params->ith != 0) {
|
if (params->ith != 0) {
|
||||||
return;
|
return;
|
||||||
|
@ -7628,6 +7766,21 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
float * const wdata = params->wdata;
|
float * const wdata = params->wdata;
|
||||||
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
||||||
|
|
||||||
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
float *d_X = NULL;
|
||||||
|
float *d_Y = NULL;
|
||||||
|
float *d_D = NULL;
|
||||||
|
const float alpha = 1.0f;
|
||||||
|
const float beta = 0.0f;
|
||||||
|
const int x_ne = ne01 * ne10;
|
||||||
|
const int y_ne = ne11 * ne10;
|
||||||
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
||||||
|
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
||||||
|
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
||||||
|
#endif
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
{
|
{
|
||||||
|
@ -7643,15 +7796,39 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
|
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
/* step 2: copy data to device */
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||||
|
|
||||||
|
/* step 3: compute */
|
||||||
|
CUBLAS_CHECK(
|
||||||
|
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
|
ne01, ne11, ne10,
|
||||||
|
&alpha, d_X, ne00,
|
||||||
|
d_Y, ne10,
|
||||||
|
&beta, d_D, ne01));
|
||||||
|
|
||||||
|
/* step 4: copy data to host */
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||||
|
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||||
|
#else
|
||||||
// zT = y * xT
|
// zT = y * xT
|
||||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||||
ne11, ne01, ne10,
|
ne11, ne01, ne10,
|
||||||
1.0f, y, ne10,
|
1.0f, y, ne10,
|
||||||
x, ne00,
|
x, ne00,
|
||||||
0.0f, d, ne01);
|
0.0f, d, ne01);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
/* free resources */
|
||||||
|
CUDA_CHECK(cudaFree(d_X));
|
||||||
|
CUDA_CHECK(cudaFree(d_Y));
|
||||||
|
CUDA_CHECK(cudaFree(d_D));
|
||||||
|
#endif
|
||||||
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
||||||
|
|
||||||
return;
|
return;
|
||||||
|
@ -10466,7 +10643,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||||
size_t cur = 0;
|
size_t cur = 0;
|
||||||
|
|
||||||
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||||
node->n_tasks = 1; // TODO: this actually is doing nothing
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
||||||
// the threads are still spinning
|
// the threads are still spinning
|
||||||
|
@ -10483,7 +10660,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||||
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
||||||
cur = 0;
|
cur = 0;
|
||||||
} else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
|
} else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||||
node->n_tasks = 1;
|
node->n_tasks = 1;
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
||||||
|
@ -11800,7 +11977,15 @@ int ggml_cpu_has_wasm_simd(void) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int ggml_cpu_has_blas(void) {
|
int ggml_cpu_has_blas(void) {
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
||||||
|
return 1;
|
||||||
|
#else
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
int ggml_cpu_has_cublas(void) {
|
||||||
|
#if defined(GGML_USE_CUBLAS)
|
||||||
return 1;
|
return 1;
|
||||||
#else
|
#else
|
||||||
return 0;
|
return 0;
|
||||||
|
|
1
ggml.h
1
ggml.h
|
@ -823,6 +823,7 @@ int ggml_cpu_has_f16c(void);
|
||||||
int ggml_cpu_has_fp16_va(void);
|
int ggml_cpu_has_fp16_va(void);
|
||||||
int ggml_cpu_has_wasm_simd(void);
|
int ggml_cpu_has_wasm_simd(void);
|
||||||
int ggml_cpu_has_blas(void);
|
int ggml_cpu_has_blas(void);
|
||||||
|
int ggml_cpu_has_cublas(void);
|
||||||
int ggml_cpu_has_sse3(void);
|
int ggml_cpu_has_sse3(void);
|
||||||
int ggml_cpu_has_vsx(void);
|
int ggml_cpu_has_vsx(void);
|
||||||
|
|
||||||
|
|
|
@ -1066,7 +1066,7 @@ static bool llama_eval_internal(
|
||||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||||
ggml_cgraph gf = {};
|
ggml_cgraph gf = {};
|
||||||
gf.n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads;
|
gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_cublas() ? 1 : n_threads;
|
||||||
|
|
||||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue