Compare commits

...
Sign in to create a new pull request.

8 commits

Author SHA1 Message Date
Georgi Gerganov
6966474928
cuda : play with faster Q4_0 dequantization 2023-10-24 10:29:40 +03:00
Georgi Gerganov
d415669087
cuda : add ROCm / hipBLAS cublasGemmBatchedEx define 2023-10-24 00:18:49 +03:00
Kerfuffle
878aa4f209
Apply suggestions from code review
These changes plus:

```c++
#define cublasGemmBatchedEx hipblasGemmBatchedEx
```

are needed to compile with ROCM. I haven't done performance testing, but it seems to work.

I couldn't figure out how to propose a change for lines outside what the pull changed, also this is the first time trying to create a multi-part review so please forgive me if I mess something up.
2023-10-23 15:09:50 -06:00
Georgi Gerganov
c13fcfbfc0
cuda : batched cuBLAS GEMMs for src0 F16 and src1 F32 (attention ops) 2023-10-23 20:37:04 +03:00
Georgi Gerganov
84d4ca0e47
cuda : minor indentation 2023-10-23 20:36:50 +03:00
Georgi Gerganov
8d8d54f834
ggml : skip nops in compute_forward 2023-10-23 20:36:32 +03:00
Georgi Gerganov
6a30bf3e51
batched : add NGL arg 2023-10-23 20:36:12 +03:00
Georgi Gerganov
8fb1be642e
cmake : add helper for faster CUDA builds 2023-10-23 20:35:19 +03:00
4 changed files with 269 additions and 13 deletions

View file

@ -331,6 +331,7 @@ if (LLAMA_CUBLAS)
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
else() else()
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
#set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work
endif() endif()
endif() endif()
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")

View file

@ -11,7 +11,7 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
if (argc == 1 || argv[1][0] == '-') { if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN]\n" , argv[0]); printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL]\n" , argv[0]);
return 1 ; return 1 ;
} }
@ -21,6 +21,9 @@ int main(int argc, char ** argv) {
// total length of the sequences including the prompt // total length of the sequences including the prompt
int n_len = 32; int n_len = 32;
// number of layers to offload to the GPU
int n_gpu_layers = 0;
if (argc >= 2) { if (argc >= 2) {
params.model = argv[1]; params.model = argv[1];
} }
@ -37,6 +40,10 @@ int main(int argc, char ** argv) {
n_len = std::atoi(argv[4]); n_len = std::atoi(argv[4]);
} }
if (argc >= 6) {
n_gpu_layers = std::atoi(argv[5]);
}
if (params.prompt.empty()) { if (params.prompt.empty()) {
params.prompt = "Hello my name is"; params.prompt = "Hello my name is";
} }
@ -49,7 +56,7 @@ int main(int argc, char ** argv) {
llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_default_params();
// model_params.n_gpu_layers = 99; // offload all layers to the GPU model_params.n_gpu_layers = n_gpu_layers;
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);

View file

@ -29,6 +29,7 @@
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasCreate hipblasCreate #define cublasCreate hipblasCreate
#define cublasGemmEx hipblasGemmEx #define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasHandle_t hipblasHandle_t #define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream #define cublasSetStream hipblasSetStream
@ -4326,13 +4327,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
const half * x = (const half *) vx; const half * x = (const half *) vx;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z; const int channel = blockDim.z*blockIdx.z + threadIdx.z;
const int channel_x = channel / channel_x_divisor; const int channel_x = channel / channel_x_divisor;
const int nrows_y = ncols_x; const int nrows_y = ncols_x;
const int nrows_dst = nrows_x; const int nrows_dst = nrows_x;
const int row_dst = row_x; const int row_dst = row_x;
const int idst = channel*nrows_dst + row_dst; const int idst = channel*nrows_dst + row_dst;
@ -4345,13 +4346,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
break; break;
} }
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
const float xi = __half2float(x[ix]);
const int row_y = col_x; const int row_y = col_x;
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
const int iy = channel*nrows_y + row_y; const int iy = channel*nrows_y + row_y;
const float xi = __half2float(x[ix]);
tmp += xi * y[iy]; tmp += xi * y[iy];
} }
@ -4658,12 +4659,94 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded); quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
} }
#ifdef GGML_CUDA_F16
#define make_dfloat2(x, y) __halves2half2((x), (y))
#else
#define make_dfloat2(x, y) make_float2((x), (y))
#endif
static __device__ __forceinline__ dfloat2 dfmul2(dfloat2 a, dfloat2 b) {
#ifdef GGML_CUDA_F16
return __hmul2(a, b);
#else
return make_float2(a.x * b.x, a.y * b.y);
#endif
}
static __device__ __forceinline__ float2 dfloat22float2(dfloat2 a) {
#ifdef GGML_CUDA_F16
return __half22float2(a);
#else
return a;
#endif
}
static __global__ void dequantize_block_q4_0_f32(const void * __restrict__ vx, float * __restrict__ y, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i*4 >= k) {
return;
}
const int ib = i/(QK4_0/4);
const int iqs = i%(QK4_0/4);
const block_q4_0 * x = (const block_q4_0 *) vx;
const uchar2 qs = *(const uchar2 *)(x[ib].qs + iqs*2);
const dfloat d = x[ib].d;
dfloat2 dv0 = make_dfloat2((int)(qs.x & 0xf) - 8, (int)(qs.y & 0xf) - 8);
const float2 v0 = dfloat22float2(dfmul2(dv0, {d, d}));
*(float2 *)(y + ib*QK4_0 + iqs*2) = v0;
dfloat2 dv1 = make_dfloat2((int)(qs.x >> 4) - 8, (int)(qs.y >> 4) - 8);
const float2 v1 = dfloat22float2(dfmul2(dv1, {d, d}));
*(float2 *)(y + ib*QK4_0 + QK4_0/2 + iqs*2) = v1;
}
static __global__ void dequantize_block_q4_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i*4 >= k) {
return;
}
const int ib = i/(QK4_0/4);
const int iqs = i%(QK4_0/4);
const block_q4_0 * x = (const block_q4_0 *) vx;
const uchar2 qs = *(const uchar2 *)(x[ib].qs + iqs*2);
const dfloat d = x[ib].d;
dfloat2 dv0 = make_dfloat2((int)(qs.x & 0xf) - 8, (int)(qs.y & 0xf) - 8);
const float2 v0 = dfloat22float2(dfmul2(dv0, {d, d}));
*(half2 *)(y + ib*QK4_0 + iqs*2) = __float22half2_rn(v0);
dfloat2 dv1 = make_dfloat2((int)(qs.x >> 4) - 8, (int)(qs.y >> 4) - 8);
const float2 v1 = dfloat22float2(dfmul2(dv1, {d, d}));
*(half2 *)(y + ib*QK4_0 + QK4_0/2 + iqs*2) = __float22half2_rn(v1);
}
template<typename dst_t> template<typename dst_t>
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k); dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
} }
template<>
void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
GGML_ASSERT(k % 4 == 0);
const int num_blocks = (k/4 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block_q4_0_f32<<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
template<>
void dequantize_row_q4_0_cuda(const void * vx, half * y, const int k, cudaStream_t stream) {
GGML_ASSERT(k % 4 == 0);
const int num_blocks = (k/4 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block_q4_0_f16<<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
template<typename dst_t> template<typename dst_t>
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@ -7013,7 +7096,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
} }
static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_is_permuted(src0)); GGML_ASSERT(!ggml_is_permuted(src0));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F16);
@ -7023,11 +7107,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2]; const int64_t ne02 = src0->ne[2];
const int64_t ne12 = src1->ne[2];
const int64_t nb01 = src0->nb[1]; const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2]; const int64_t nb02 = src0->nb[2];
const int64_t ne12 = src1->ne[2];
CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
@ -7046,6 +7130,154 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
} }
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
const int64_t ne1 = ggml_nelements(src1);
const int64_t ne = ggml_nelements(dst);
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
int id;
CUDA_CHECK(cudaGetDevice(&id));
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
half * src0_as_f16 = (half *) src0_ddq;
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
// convert src1 to fp16
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t src1_as = 0;
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
#if 0
// use cublasGemmEx
{
for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
int i03 = i13 / r3;
int i02 = i12 / r2;
CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
&beta_f16, (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
}
#else
// use cublasGemmBatchedEx
{
const int ne23 = ne12*ne13;
// TODO: avoid this alloc
void ** src0_ptrs = (void **) malloc(ne23*sizeof(void *));
void ** src1_ptrs = (void **) malloc(ne23*sizeof(void *));
void ** dst_ptrs = (void **) malloc(ne23*sizeof(void *));
for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
int i03 = i13 / r3;
int i02 = i12 / r2;
src0_ptrs[i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3];
src1_ptrs[i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
dst_ptrs [i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
}
}
// allocate device memory for pointers
const void ** src0_ptrs_as = nullptr;
const void ** src1_ptrs_as = nullptr;
void ** dst_ptrs_as = nullptr;
CUDA_CHECK(cudaMalloc(&src0_ptrs_as, ne23*sizeof(void *)));
CUDA_CHECK(cudaMalloc(&src1_ptrs_as, ne23*sizeof(void *)));
CUDA_CHECK(cudaMalloc(& dst_ptrs_as, ne23*sizeof(void *)));
// copy pointers to device
CUDA_CHECK(cudaMemcpy(src0_ptrs_as, src0_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(src1_ptrs_as, src1_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy( dst_ptrs_as, dst_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const void **) src0_ptrs_as, CUDA_R_16F, nb01/sizeof(half),
(const void **) src1_ptrs_as, CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( void **) dst_ptrs_as, CUDA_R_16F, ne01,
ne23,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// free device memory for pointers
CUDA_CHECK(cudaFree(src0_ptrs_as));
CUDA_CHECK(cudaFree(src1_ptrs_as));
CUDA_CHECK(cudaFree( dst_ptrs_as));
free(src0_ptrs);
free(src1_ptrs);
free( dst_ptrs);
}
#endif
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
ggml_cuda_pool_free(src1_as_f16, src1_as);
ggml_cuda_pool_free(dst_f16, dst_as);
}
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
@ -7058,10 +7290,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
} }
} }
// debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
//printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
//printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// KQ
ggml_cuda_mul_mat_vec_p021(src0, src1, dst); ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
} else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// KQV
ggml_cuda_mul_mat_vec_nc(src0, src1, dst); ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
} else if (src0->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F32) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {

4
ggml.c
View file

@ -16602,6 +16602,10 @@ static void ggml_compute_forward_cross_entropy_loss_back(
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
GGML_ASSERT(params); GGML_ASSERT(params);
if (tensor->op == GGML_OP_NONE) {
return;
}
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
bool skip_cpu = ggml_cuda_compute_forward(params, tensor); bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
if (skip_cpu) { if (skip_cpu) {