From 8fb1be642e043e173985305211e8e531e777a438 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 23 Oct 2023 20:35:19 +0300 Subject: [PATCH 1/8] cmake : add helper for faster CUDA builds --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6af42a6c2..202f26049 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -331,6 +331,7 @@ if (LLAMA_CUBLAS) set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics else() set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics + #set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work endif() endif() message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") From 6a30bf3e5152d572d2f65cbea5d80c0b4e72e3fa Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 23 Oct 2023 20:36:12 +0300 Subject: [PATCH 2/8] batched : add NGL arg --- examples/batched/batched.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 2797329b4..9c9819e17 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -11,7 +11,7 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN]\n" , argv[0]); + printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL]\n" , argv[0]); return 1 ; } @@ -21,6 +21,9 @@ int main(int argc, char ** argv) { // total length of the sequences including the prompt int n_len = 32; + // number of layers to offload to the GPU + int n_gpu_layers = 0; + if (argc >= 2) { params.model = argv[1]; } @@ -37,6 +40,10 @@ int main(int argc, char ** argv) { n_len = std::atoi(argv[4]); } + if (argc >= 6) { + n_gpu_layers = std::atoi(argv[5]); + } + if (params.prompt.empty()) { params.prompt = "Hello my name is"; } @@ -49,7 +56,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = llama_model_default_params(); - // model_params.n_gpu_layers = 99; // offload all layers to the GPU + model_params.n_gpu_layers = n_gpu_layers; llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); From 8d8d54f83412d8ee2e5a54212de01956306fa07f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 23 Oct 2023 20:36:32 +0300 Subject: [PATCH 3/8] ggml : skip nops in compute_forward --- ggml.c | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml.c b/ggml.c index 49f3b7aba..17f0ce487 100644 --- a/ggml.c +++ b/ggml.c @@ -16602,6 +16602,10 @@ static void ggml_compute_forward_cross_entropy_loss_back( static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { GGML_ASSERT(params); + if (tensor->op == GGML_OP_NONE) { + return; + } + #ifdef GGML_USE_CUBLAS bool skip_cpu = ggml_cuda_compute_forward(params, tensor); if (skip_cpu) { From 84d4ca0e47785091f63245279f758eaded754d3c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 23 Oct 2023 20:36:50 +0300 Subject: [PATCH 4/8] cuda : minor indentation --- ggml-cuda.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 654d3632f..ab1d34212 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4326,13 +4326,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const half * x = (const half *) vx; - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + const int channel = blockDim.z*blockIdx.z + threadIdx.z; const int channel_x = channel / channel_x_divisor; - const int nrows_y = ncols_x; + const int nrows_y = ncols_x; const int nrows_dst = nrows_x; - const int row_dst = row_x; + const int row_dst = row_x; const int idst = channel*nrows_dst + row_dst; @@ -4345,13 +4345,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous break; } - const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; - const float xi = __half2float(x[ix]); - const int row_y = col_x; + const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; const int iy = channel*nrows_y + row_y; + const float xi = __half2float(x[ix]); + tmp += xi * y[iy]; } From c13fcfbfc09d5185977ed81c5bcfa43b5da08514 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 23 Oct 2023 20:37:04 +0300 Subject: [PATCH 5/8] cuda : batched cuBLAS GEMMs for src0 F16 and src1 F32 (attention ops) --- ggml-cuda.cu | 169 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 165 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ab1d34212..ebfd6c15e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7013,7 +7013,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens } static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ - GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src0->type == GGML_TYPE_F16); @@ -7023,11 +7024,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; - const int64_t ne12 = src1->ne[2]; - const int64_t nb01 = src0->nb[1]; const int64_t nb02 = src0->nb[2]; + const int64_t ne12 = src1->ne[2]; + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; @@ -7046,6 +7047,154 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); } +static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02); + const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); + const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); + + const int64_t ne1 = ggml_nelements(src1); + const int64_t ne = ggml_nelements(dst); + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; + half * src0_as_f16 = (half *) src0_ddq; + + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + // convert src1 to fp16 + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + + size_t src1_as = 0; + half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); + to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); + + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + +#if 0 + // use cublasGemmEx + { + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + int i03 = i13 / r3; + int i02 = i12 / r2; + + CUBLAS_CHECK( + cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), + (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), + &beta_f16, (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + } + } +#else + // use cublasGemmBatchedEx + { + const int ne23 = ne12*ne13; + + // TODO: avoid this alloc + void ** src0_ptrs = (void **) malloc(ne23*sizeof(void *)); + void ** src1_ptrs = (void **) malloc(ne23*sizeof(void *)); + void ** dst_ptrs = (void **) malloc(ne23*sizeof(void *)); + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + int i03 = i13 / r3; + int i02 = i12 / r2; + + src0_ptrs[i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3]; + src1_ptrs[i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2; + dst_ptrs [i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2; + } + } + + // allocate device memory for pointers + void ** src0_ptrs_as = nullptr; + void ** src1_ptrs_as = nullptr; + void ** dst_ptrs_as = nullptr; + + CUDA_CHECK(cudaMalloc(&src0_ptrs_as, ne23*sizeof(void *))); + CUDA_CHECK(cudaMalloc(&src1_ptrs_as, ne23*sizeof(void *))); + CUDA_CHECK(cudaMalloc(& dst_ptrs_as, ne23*sizeof(void *))); + + // copy pointers to device + CUDA_CHECK(cudaMemcpy(src0_ptrs_as, src0_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(src1_ptrs_as, src1_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( dst_ptrs_as, dst_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice)); + + CUBLAS_CHECK( + cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, (void **) src0_ptrs_as, CUDA_R_16F, nb01/sizeof(half), + (void **) src1_ptrs_as, CUDA_R_16F, nb11/sizeof(float), + &beta_f16, (void **) dst_ptrs_as, CUDA_R_16F, ne01, + ne23, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + // free device memory for pointers + CUDA_CHECK(cudaFree(src0_ptrs_as)); + CUDA_CHECK(cudaFree(src1_ptrs_as)); + CUDA_CHECK(cudaFree( dst_ptrs_as)); + + free(src0_ptrs); + free(src1_ptrs); + free( dst_ptrs); + } +#endif + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); + + ggml_cuda_pool_free(src1_as_f16, src1_as); + ggml_cuda_pool_free(dst_f16, dst_as); +} + static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; @@ -7058,10 +7207,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } } + // debug helpers + //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); + //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); + //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); + //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); + //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); + //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); + if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + // KQ ggml_cuda_mul_mat_vec_p021(src0, src1, dst); - } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { + } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + // KQV ggml_cuda_mul_mat_vec_nc(src0, src1, dst); + } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { From 878aa4f209e453254a6640afc29d41b1d35273bf Mon Sep 17 00:00:00 2001 From: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Date: Mon, 23 Oct 2023 15:09:50 -0600 Subject: [PATCH 6/8] Apply suggestions from code review These changes plus: ```c++ #define cublasGemmBatchedEx hipblasGemmBatchedEx ``` are needed to compile with ROCM. I haven't done performance testing, but it seems to work. I couldn't figure out how to propose a change for lines outside what the pull changed, also this is the first time trying to create a multi-part review so please forgive me if I mess something up. --- ggml-cuda.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ebfd6c15e..c0383d19e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7154,9 +7154,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const } // allocate device memory for pointers - void ** src0_ptrs_as = nullptr; - void ** src1_ptrs_as = nullptr; - void ** dst_ptrs_as = nullptr; + const void ** src0_ptrs_as = nullptr; + const void ** src1_ptrs_as = nullptr; + void ** dst_ptrs_as = nullptr; CUDA_CHECK(cudaMalloc(&src0_ptrs_as, ne23*sizeof(void *))); CUDA_CHECK(cudaMalloc(&src1_ptrs_as, ne23*sizeof(void *))); @@ -7170,9 +7170,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUBLAS_CHECK( cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - &alpha_f16, (void **) src0_ptrs_as, CUDA_R_16F, nb01/sizeof(half), - (void **) src1_ptrs_as, CUDA_R_16F, nb11/sizeof(float), - &beta_f16, (void **) dst_ptrs_as, CUDA_R_16F, ne01, + &alpha_f16, (const void **) src0_ptrs_as, CUDA_R_16F, nb01/sizeof(half), + (const void **) src1_ptrs_as, CUDA_R_16F, nb11/sizeof(float), + &beta_f16, ( void **) dst_ptrs_as, CUDA_R_16F, ne01, ne23, CUBLAS_COMPUTE_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); From d4156690877abfce586354fcd494c4c072968c5e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Oct 2023 00:18:49 +0300 Subject: [PATCH 7/8] cuda : add ROCm / hipBLAS cublasGemmBatchedEx define --- ggml-cuda.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c0383d19e..e2dea9eab 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -29,6 +29,7 @@ #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasCreate hipblasCreate #define cublasGemmEx hipblasGemmEx +#define cublasGemmBatchedEx hipblasGemmBatchedEx #define cublasHandle_t hipblasHandle_t #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS #define cublasSetStream hipblasSetStream From 6966474928233a0e2767a83a1f7e9a4cd314f169 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Oct 2023 10:29:40 +0300 Subject: [PATCH 8/8] cuda : play with faster Q4_0 dequantization --- ggml-cuda.cu | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e2dea9eab..d0cc6f13a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4659,12 +4659,94 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con quantize_q8_1<<>>(x, vy, kx, kx_padded); } +#ifdef GGML_CUDA_F16 +#define make_dfloat2(x, y) __halves2half2((x), (y)) +#else +#define make_dfloat2(x, y) make_float2((x), (y)) +#endif + +static __device__ __forceinline__ dfloat2 dfmul2(dfloat2 a, dfloat2 b) { +#ifdef GGML_CUDA_F16 + return __hmul2(a, b); +#else + return make_float2(a.x * b.x, a.y * b.y); +#endif +} + +static __device__ __forceinline__ float2 dfloat22float2(dfloat2 a) { +#ifdef GGML_CUDA_F16 + return __half22float2(a); +#else + return a; +#endif +} + +static __global__ void dequantize_block_q4_0_f32(const void * __restrict__ vx, float * __restrict__ y, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i*4 >= k) { + return; + } + + const int ib = i/(QK4_0/4); + const int iqs = i%(QK4_0/4); + + const block_q4_0 * x = (const block_q4_0 *) vx; + const uchar2 qs = *(const uchar2 *)(x[ib].qs + iqs*2); + const dfloat d = x[ib].d; + + dfloat2 dv0 = make_dfloat2((int)(qs.x & 0xf) - 8, (int)(qs.y & 0xf) - 8); + const float2 v0 = dfloat22float2(dfmul2(dv0, {d, d})); + *(float2 *)(y + ib*QK4_0 + iqs*2) = v0; + + dfloat2 dv1 = make_dfloat2((int)(qs.x >> 4) - 8, (int)(qs.y >> 4) - 8); + const float2 v1 = dfloat22float2(dfmul2(dv1, {d, d})); + *(float2 *)(y + ib*QK4_0 + QK4_0/2 + iqs*2) = v1; +} + +static __global__ void dequantize_block_q4_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i*4 >= k) { + return; + } + + const int ib = i/(QK4_0/4); + const int iqs = i%(QK4_0/4); + + const block_q4_0 * x = (const block_q4_0 *) vx; + const uchar2 qs = *(const uchar2 *)(x[ib].qs + iqs*2); + const dfloat d = x[ib].d; + + dfloat2 dv0 = make_dfloat2((int)(qs.x & 0xf) - 8, (int)(qs.y & 0xf) - 8); + const float2 v0 = dfloat22float2(dfmul2(dv0, {d, d})); + *(half2 *)(y + ib*QK4_0 + iqs*2) = __float22half2_rn(v0); + + dfloat2 dv1 = make_dfloat2((int)(qs.x >> 4) - 8, (int)(qs.y >> 4) - 8); + const float2 v1 = dfloat22float2(dfmul2(dv1, {d, d})); + *(half2 *)(y + ib*QK4_0 + QK4_0/2 + iqs*2) = __float22half2_rn(v1); +} + template static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); } +template<> +void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + GGML_ASSERT(k % 4 == 0); + const int num_blocks = (k/4 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block_q4_0_f32<<>>(vx, y, k); +} + +template<> +void dequantize_row_q4_0_cuda(const void * vx, half * y, const int k, cudaStream_t stream) { + GGML_ASSERT(k % 4 == 0); + const int num_blocks = (k/4 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block_q4_0_f16<<>>(vx, y, k); +} + template static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;