diff --git a/CMakeLists.txt b/CMakeLists.txt index ea9f80b80..f5a968533 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,6 +70,7 @@ set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") +set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" OFF) option(LLAMA_K_QUANTS "llama: use k-quants" ON) @@ -158,23 +159,59 @@ if (LLAMA_BLAS) if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) set(BLA_SIZEOF_INTEGER 8) endif() + set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) find_package(BLAS) + if (BLAS_FOUND) message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") - # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. - # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 - find_path(BLAS_INCLUDE_DIRS - NAMES cblas.h - HINTS - /usr/include - /usr/local/include - /usr/include/openblas - ) + if ("${BLAS_INCLUDE_DIRS}" STREQUAL "") + # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. + # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 + find_package(PkgConfig REQUIRED) + if (${LLAMA_BLAS_VENDOR} MATCHES "Generic") + pkg_check_modules(DepBLAS REQUIRED blas) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS") + pkg_check_modules(DepBLAS REQUIRED openblas) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME") + pkg_check_modules(DepBLAS REQUIRED blis) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS") + pkg_check_modules(DepBLAS REQUIRED blas-atlas) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS") + pkg_check_modules(DepBLAS REQUIRED flexiblas_api) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel") + # all Intel* libraries share the same include path + pkg_check_modules(DepBLAS REQUIRED mkl-sdl) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC") + # this doesn't provide pkg-config + # suggest to assign BLAS_INCLUDE_DIRS on your own + if ("${NVHPC_VERSION}" STREQUAL "") + message(WARNING "Better to set NVHPC_VERSION") + else() + set(DepBLAS_FOUND ON) + set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") + endif() + endif() + if (DepBLAS_FOUND) + set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) + else() + message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" + " detected by pkgconfig, trying to find cblas.h from possible paths...") + find_path(BLAS_INCLUDE_DIRS + NAMES cblas.h + HINTS + /usr/include + /usr/local/include + /usr/include/openblas + /opt/homebrew/opt/openblas/include + /usr/local/opt/openblas/include + /usr/include/x86_64-linux-gnu/openblas/include + ) + endif() + endif() message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") - add_compile_options(${BLAS_LINKER_FLAGS}) add_compile_definitions(GGML_USE_OPENBLAS) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) @@ -201,6 +238,7 @@ if (LLAMA_CUBLAS) add_compile_definitions(GGML_USE_CUBLAS) add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y}) + add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) @@ -423,8 +461,10 @@ target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES}) target_compile_features(ggml PUBLIC c_std_11) # don't bump target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) +add_library(ggml_static STATIC $) if (BUILD_SHARED_LIBS) set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) + add_library(ggml_shared SHARED $) endif() add_library(llama diff --git a/Makefile b/Makefile index 99e772623..eee9eeb53 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Define the default target now so that it is always the first target -BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch +BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple ifdef LLAMA_BUILD_SERVER BUILD_TARGETS += server @@ -173,6 +173,11 @@ ifdef LLAMA_CUDA_DMMV_Y else NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1 endif # LLAMA_CUDA_DMMV_Y +ifdef LLAMA_CUDA_KQUANTS_ITER + NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) +else + NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2 +endif ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif # LLAMA_CUBLAS @@ -273,6 +278,12 @@ main: examples/main/main.cpp build-info.h ggml. @echo '==== Run ./main -h for help. ====' @echo +simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + @echo + @echo '==== Run ./simple -h for help. ====' + @echo + quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 0add6adc0..50e14c4ac 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -4,6 +4,10 @@ #include #include +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + float frand() { return (float)rand()/(float)RAND_MAX; } @@ -1470,7 +1474,7 @@ struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_te } struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - const float eps = 1e-3; + const float eps = 1e-3f; return ggml_sum(ctx, ggml_neg(ctx, diff --git a/examples/benchmark/benchmark-matmult.cpp b/examples/benchmark/benchmark-matmult.cpp index 9f9ed9db0..39d15caeb 100644 --- a/examples/benchmark/benchmark-matmult.cpp +++ b/examples/benchmark/benchmark-matmult.cpp @@ -16,6 +16,10 @@ #include #include +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + float tensor_sum_elements(const ggml_tensor * tensor) { float sum = 0; if (tensor->type==GGML_TYPE_F32) { @@ -29,9 +33,9 @@ float tensor_sum_elements(const ggml_tensor * tensor) { } void tensor_dump(const ggml_tensor * tensor, const char * name) { - printf("%15s: type = %i (%5s) ne = %5d x %5d x %5d, nb = (%5li, %5li, %5li) - ", name, + printf("%15s: type = %i (%5s) ne = %5" PRIi64 " x %5" PRIi64 " x %5" PRIi64 ", nb = (%5zi, %5zi, %5zi) - ", name, tensor->type, ggml_type_name(tensor->type), - (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], tensor->nb[0], tensor->nb[1], tensor->nb[2]); + tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->nb[0], tensor->nb[1], tensor->nb[2]); float sum = tensor_sum_elements(tensor); printf("Sum of tensor %s is %6.2f\n", name, sum); } @@ -120,7 +124,7 @@ int main(int argc, char ** argv) { ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32); // BLAS ctx_size += 1024*1024*16; - printf("Allocating Memory of size %li bytes, %li MB\n",ctx_size, (ctx_size/1024/1024)); + printf("Allocating Memory of size %zi bytes, %zi MB\n",ctx_size, (ctx_size/1024/1024)); struct ggml_init_params params = { /*.mem_size =*/ ctx_size, diff --git a/examples/common.cpp b/examples/common.cpp index b47f06273..055383bef 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -28,6 +28,10 @@ #include #endif +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + int32_t get_num_physical_cores() { #ifdef __linux__ // enumerate the set of thread siblings, num entries is num cores @@ -373,7 +377,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else { throw std::exception(); } - } catch (const std::exception &e) { + } catch (const std::exception&) { invalid_param = true; break; } diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 03603b10f..860f99f67 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -4,6 +4,10 @@ #include +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + int main(int argc, char ** argv) { gpt_params params; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index efa913e16..a051fcbc5 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -23,11 +23,17 @@ #include #elif defined (_WIN32) #define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX #define NOMINMAX +#endif #include #include #endif +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + static console_state con_st; static llama_context ** g_ctx; @@ -348,7 +354,7 @@ int main(int argc, char ** argv) { if ((int)embd.size() > max_embd_size) { auto skipped_tokens = embd.size() - max_embd_size; console_set_color(con_st, CONSOLE_COLOR_ERROR); - printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); + printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); console_set_color(con_st, CONSOLE_COLOR_DEFAULT); fflush(stdout); embd.resize(max_embd_size); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index e19c6825f..ae8cfe0af 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -5,6 +5,10 @@ #include #include +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + std::vector softmax(const std::vector& logits) { std::vector probs(logits.size()); float max_logit = logits[0]; diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 6e4f7e1e0..6b8018ee2 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -19,6 +19,10 @@ #include #include +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + struct quantize_stats_params { std::string model = "models/7B/ggml-model-f16.bin"; bool verbose = false; diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 91f04b6c7..da4d37ad0 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -37,7 +37,7 @@ int main(int argc, char ** argv) { // init auto ctx = llama_init_from_file(params.model.c_str(), lparams); auto tokens = std::vector(params.n_ctx); - auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), tokens.size(), true); + auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true); if (n_prompt_tokens < 1) { fprintf(stderr, "%s : failed to tokenize prompt\n", __func__); diff --git a/examples/simple/CMakeLists.txt b/examples/simple/CMakeLists.txt new file mode 100644 index 000000000..1568f7364 --- /dev/null +++ b/examples/simple/CMakeLists.txt @@ -0,0 +1,7 @@ +set(TARGET simple) +add_executable(${TARGET} simple.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp new file mode 100644 index 000000000..76f991cdc --- /dev/null +++ b/examples/simple/simple.cpp @@ -0,0 +1,177 @@ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include "common.h" +#include "llama.h" +#include "build-info.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#include +#include +#elif defined (_WIN32) +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include +#include +#endif + + + +int main(int argc, char ** argv) +{ + gpt_params params; + + //--------------------------------- + // Print help : + //--------------------------------- + + if ( argc == 1 || argv[1][0] == '-' ) + { + printf( "usage: %s MODEL_PATH [PROMPT]\n" , argv[0] ); + return 1 ; + } + + //--------------------------------- + // Load parameters : + //--------------------------------- + + if ( argc >= 2 ) + { + params.model = argv[1]; + } + + if ( argc >= 3 ) + { + params.prompt = argv[2]; + } + + if ( params.prompt.empty() ) + { + params.prompt = "Hello my name is"; + } + + //--------------------------------- + // Init LLM : + //--------------------------------- + + llama_init_backend(); + + llama_context * ctx ; + + ctx = llama_init_from_gpt_params( params ); + + if ( ctx == NULL ) + { + fprintf( stderr , "%s: error: unable to load model\n" , __func__ ); + return 1; + } + + //--------------------------------- + // Tokenize the prompt : + //--------------------------------- + + std::vector tokens_list; + tokens_list = ::llama_tokenize( ctx , params.prompt , true ); + + const int max_context_size = llama_n_ctx( ctx ); + const int max_tokens_list_size = max_context_size - 4 ; + + if ( (int)tokens_list.size() > max_tokens_list_size ) + { + fprintf( stderr , "%s: error: prompt too long (%d tokens, max %d)\n" , + __func__ , (int)tokens_list.size() , max_tokens_list_size ); + return 1; + } + + fprintf( stderr, "\n\n" ); + + // Print the tokens from the prompt : + + for( auto id : tokens_list ) + { + printf( "%s" , llama_token_to_str( ctx , id ) ); + } + + fflush(stdout); + + + //--------------------------------- + // Main prediction loop : + //--------------------------------- + + // The LLM keeps a contextual cache memory of previous token evaluation. + // Usually, once this cache is full, it is required to recompute a compressed context based on previous + // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist + // example, we will just stop the loop once this cache is full or once an end of stream is detected. + + while ( llama_get_kv_cache_token_count( ctx ) < max_context_size ) + { + //--------------------------------- + // Evaluate the tokens : + //--------------------------------- + + if ( llama_eval( ctx , tokens_list.data() , tokens_list.size() , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) ) + { + fprintf( stderr, "%s : failed to eval\n" , __func__ ); + return 1; + } + + tokens_list.clear(); + + //--------------------------------- + // Select the best prediction : + //--------------------------------- + + llama_token new_token_id = 0; + + auto logits = llama_get_logits( ctx ); + auto n_vocab = llama_n_vocab( ctx ); // the size of the LLM vocabulary (in tokens) + + std::vector candidates; + candidates.reserve( n_vocab ); + + for( llama_token token_id = 0 ; token_id < n_vocab ; token_id++ ) + { + candidates.emplace_back( llama_token_data{ token_id , logits[ token_id ] , 0.0f } ); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // Select it using the "Greedy sampling" method : + new_token_id = llama_sample_token_greedy( ctx , &candidates_p ); + + + // is it an end of stream ? + if ( new_token_id == llama_token_eos() ) + { + fprintf(stderr, " [end of text]\n"); + break; + } + + // Print the new token : + printf( "%s" , llama_token_to_str( ctx , new_token_id ) ); + fflush( stdout ); + + // Push this new token for next evaluation : + tokens_list.push_back( new_token_id ); + + } // wend of main loop + + llama_free( ctx ); + + return 0; +} + +// EOF diff --git a/examples/train-text-from-scratch/README.md b/examples/train-text-from-scratch/README.md index 5344d1f52..726ec47c0 100644 --- a/examples/train-text-from-scratch/README.md +++ b/examples/train-text-from-scratch/README.md @@ -4,7 +4,7 @@ Basic usage instructions: ```bash # get training data -wget https://github.com/brunoklein99/deep-learning-notes/blob/master/shakespeare.txt +wget https://raw.githubusercontent.com/brunoklein99/deep-learning-notes/master/shakespeare.txt # train ./bin/train-text-from-scratch \ diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 51271b497..7ec85951a 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -12,6 +12,9 @@ #include #include +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif struct random_normal_distribution { std::mt19937 gen; @@ -20,7 +23,6 @@ struct random_normal_distribution { float max; }; - struct random_uniform_distribution { std::mt19937 gen; std::uniform_real_distribution rd; @@ -2366,7 +2368,7 @@ void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) { file->write_u32(0); file->write_u32(0); file->write_u32(GGML_TYPE_F32); - file->seek(-file->tell() & 31, SEEK_CUR); + file->seek(0-file->tell() & 31, SEEK_CUR); return; } const char * name = ggml_get_name(tensor); @@ -2381,7 +2383,7 @@ void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) { file->write_u32(tensor->type); file->write_raw(ne, sizeof(ne[0]) * nd); file->write_raw(name, name_len); - file->seek(-file->tell() & 31, SEEK_CUR); + file->seek(0-file->tell() & 31, SEEK_CUR); file->write_raw(tensor->data, ggml_nbytes(tensor)); } @@ -2402,7 +2404,7 @@ void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) { std::string name = file->read_string(name_len); GGML_ASSERT(strncmp(ggml_get_name(tensor), name.c_str(), sizeof(tensor->name)-1) == 0); - file->seek(-file->tell() & 31, SEEK_CUR); + file->seek(0-file->tell() & 31, SEEK_CUR); file->read_raw(tensor->data, ggml_nbytes(tensor)); } @@ -2756,8 +2758,8 @@ struct train_params get_default_train_params() { params.lbfgs_n_iter = 16; params.adam_n_iter = 16; - params.adam_alpha = 1e-3; - params.adam_decay = 1e-3; + params.adam_alpha = 1e-3f; + params.adam_decay = 1e-3f; params.mem_model_gb = 2; params.mem_compute_gb = 24; @@ -3331,8 +3333,8 @@ int main(int argc, char ** argv) { int n_gen = params.n_predict; int sample_ctx = n_tokens - n_tokens/8; - sampler.params.temp = 0.2; - sampler.params.repeat_penalty = 1.1; + sampler.params.temp = 0.2f; + sampler.params.repeat_penalty = 1.1f; sampler.params.mirostat = 2; init_sampler(&sampler, lctx); diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bd89d0a1f..7edd1a9f8 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -167,6 +167,12 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define GGML_CUDA_DMMV_Y 1 #endif +#ifndef K_QUANTS_PER_ITERATION +#define K_QUANTS_PER_ITERATION 2 +#else +static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); +#endif + static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -326,37 +332,6 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { } -static __device__ void vec_dot_q2_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { - - const block_q2_K * x = (const block_q2_K *) vx; - - // if n is 0, we want to do the lower 128, else the upper 128, - // covering y[l+0], y[l+32], y[l+64], y[l+96] and - // y[l+16], y[l+48], y[l+80], y[l+112] - int n = iqs/128; // 0 or 1 - int r = iqs - 128*n; // 0...120 in steps of 8 - int l = r/8; // 0...15 in steps of 1 - - const float * y = yy + 128*n + l; - const uint8_t * q = x[ib].qs + 32*n + l; - const uint8_t * s = x[ib].scales + 8*n; - - const float dall = x[ib].d; - const float dmin = x[ib].dmin; - - float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4)) - + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4)) - + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4)) - + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4)) - + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4)) - + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4)) - + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4)) - + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4)); - - result = sum; - -} - static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { int r = threadIdx.x/4; @@ -388,51 +363,6 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { } -static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { - - const block_q3_K * x = (const block_q3_K *) vx; - - const uint32_t kmask1 = 0x03030303; - const uint32_t kmask2 = 0x0f0f0f0f; - - uint32_t aux[3]; - uint32_t utmp[4]; - - // if n is 0, we want to do the lower 128, else the upper 128, - // covering y[l+0], y[l+32], y[l+64], y[l+96] and - // y[l+16], y[l+48], y[l+80], y[l+112] - int n = iqs/128; // 0 or 1 - int r = iqs - 128*n; // 0...120 in steps of 8 - int l = r/8; // 0...15 in steps of 1 - - const float * y = yy + 128*n + l; - const uint8_t * q = x[ib].qs + 32*n + l; - const uint8_t * hm = x[ib].hmask + l; - const int8_t * s = (const int8_t *)utmp + 8*n; - - memcpy(aux, x[ib].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - const float dall = x[ib].d; - - const uint8_t m = 1 << (4*n); - - float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4)) - + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4)) - + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4)) - + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4)) - + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4)) - + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4)) - + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4)) - + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4)); - - result = sum * dall; - -} - static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { d = q[j] & 63; m = q[j + 4] & 63; @@ -479,38 +409,6 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { } } -static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { - - const block_q4_K * x = (const block_q4_K *) vx; - - // iqs is in 0...248 in steps of 8 => - const int j = iqs / 64; // j is in 0...3 - const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4 - const int is = 2*j; // is is in 0...6 in steps of 2 - - const float * y = yy + 64*j + ir; - const uint8_t * q = x[ib].qs + 32*j + ir; - - const float dall = x[ib].d; - const float dmin = x[ib].dmin; - - uint8_t sc, m; - get_scale_min_k4(is + 0, x[ib].scales, sc, m); - const float d1 = dall * sc; - const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[ib].scales, sc, m); - const float d2 = dall * sc; - const float m2 = dmin * m; - - float sum = 0; - for (int k = 0; k < 4; ++k) { - sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1); - sum += y[k + 32] * (d2 * (q[k] >> 4) - m2); - } - result = sum; - -} - static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { const block_q5_K * x = (const block_q5_K *) vx; @@ -544,43 +442,6 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; } -static __device__ void vec_dot_q5_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { - - const block_q5_K * x = (const block_q5_K *) vx; - - // iqs is in 0...248 in steps of 8 => - const int j = iqs / 64; // j is in 0...3 - const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4 - const int is = 2*j; // is is in 0...6 in steps of 2 - - const float * y = yy + 64*j + ir; - const uint8_t * ql = x[ib].qs + 32*j + ir; - const uint8_t * qh = x[ib].qh + ir; - - const float dall = x[ib].d; - const float dmin = x[ib].dmin; - - uint8_t sc, m; - get_scale_min_k4(is + 0, x[ib].scales, sc, m); - const float d1 = dall * sc; - const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[ib].scales, sc, m); - const float d2 = dall * sc; - const float m2 = dmin * m; - - uint8_t hm = 1 << is; - float sum = 0; - for (int k = 0; k < 4; ++k) { - sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1); - } - hm <<= 1; - for (int k = 0; k < 4; ++k) { - sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2); - } - result = sum; - -} - static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { const block_q6_K * x = (const block_q6_K *) vx; @@ -606,31 +467,376 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); } -static __device__ void vec_dot_q6_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { +static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { - const block_q6_K * x = (const block_q6_K *) vx; + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int ip = iqs / 128; // 0 or 1 - const int il = (iqs - 128*ip)/8; // 0...15 - const int is = 8*ip; + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; - const float * y = yy + 128*ip + il; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; - const float d = x[ib].d; + const block_q2_K * x = (const block_q2_K *)vx + ib0; - const uint8_t * ql = x[ib].ql + 64*ip + il; - const uint8_t * qh = x[ib].qh + 32*ip + il; - const int8_t * sc = x[ib].scales + is; + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 - result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32) - + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32) - + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32) - + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32) - + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32) - + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32) - + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32) - + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32); + const int step = 16/K_QUANTS_PER_ITERATION; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...7 + + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...14 in steps of 4 + const int q_offset = 32*im + l0; + const int s_offset = 8*im; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + uint32_t aux[4]; + const uint8_t * d = (const uint8_t *)aux; + const uint8_t * m = (const uint8_t *)(aux + 2); + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + + const float dall = x[i].d; + const float dmin = x[i].dmin; + + const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); + aux[0] = a[0] & 0x0f0f0f0f; + aux[1] = a[1] & 0x0f0f0f0f; + aux[2] = (a[0] >> 4) & 0x0f0f0f0f; + aux[3] = (a[1] >> 4) & 0x0f0f0f0f; + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) + + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) + + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) + + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) + + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) + +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); + sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] + + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; + + } + tmp += dall * sum1 - dmin * sum2; + + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols) { + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const int row = blockIdx.x; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q3_K * x = (const block_q3_K *)vx + ib0; + + const int tid = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; // 0, 1 + + const int n = 2; // iterations in the inner loop + const int im = tid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - 8*im; // 0...7 + + const uint8_t m = 1 << (4*im); + + const int l0 = n*in; // 0...28 in steps of 4 + const int q_offset = 32*im + l0; + const int y_offset = 128*im + l0; + + uint16_t utmp[4]; + const int8_t * s = (const int8_t *)utmp; + + const uint16_t s_shift = 4*im; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + const uint8_t * h = x[i].hmask + l0; + + const uint16_t * a = (const uint16_t *)x[i].scales; + utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); + utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); + utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); + utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); + + const float d = x[i].d; + + float sum = 0; + for (int l = 0; l < n; ++l) { + sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) + + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) + + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) + + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); + sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) + + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) + + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) + + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); + } + tmp += d * sum; + + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols) { + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int row = blockIdx.x; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const int tid = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; + + const int il = tid/4; // 0...3 + const int ir = tid - 4*il;// 0...3 + const int n = 4; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + + const block_q4_K * x = (const block_q4_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2) { + + const uint8_t * q1 = x[i].qs + q_offset; + const uint8_t * q2 = q1 + 64; + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = x[i].d; + const float dmin = x[i].dmin; + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < n; ++l) { + s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4); + s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4); + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; + + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) { + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + //const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int row = blockIdx.x; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const int tid = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; + + const int il = tid/4; // 0...3 + const int ir = tid - 4*il;// 0...3 + const int n = 4; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + const uint8_t hm1 = 1 << (2*im); + const uint8_t hm2 = hm1 << 4; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + + const block_q5_K * x = (const block_q5_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2) { + + const uint8_t * ql1 = x[i].qs + q_offset; + const uint8_t * ql2 = ql1 + 64; + const uint8_t * qh = x[i].qh + l0; + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = x[i].d; + const float dmin = x[i].dmin; + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + + float4 sum = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < n; ++l) { + sum.x += y1[l+ 0] * ((ql1[l] & 0xF) + (qh[l] & (hm1 << 0) ? 16 : 0)); + sum.y += y1[l+32] * ((ql1[l] >> 4) + (qh[l] & (hm1 << 1) ? 16 : 0)); + sum.z += y2[l+ 0] * ((ql2[l] & 0xF) + (qh[l] & (hm2 << 0) ? 16 : 0)); + sum.w += y2[l+32] * ((ql2[l] >> 4) + (qh[l] & (hm2 << 1) ? 16 : 0)); + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; + + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q6_K * x = (const block_q6_K *)vx + ib0; + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * ql = x[i].ql + ql_offset; + const uint8_t * qh = x[i].qh + qh_offset; + const int8_t * s = x[i].scales + s_offset; + + const float d = x[i].d; + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp += sum; +#endif + + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } } static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ @@ -712,46 +918,6 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, } } -template -static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols, const int nrows) { - const int row = blockIdx.y*blockDim.y + threadIdx.y; - - if (row >= nrows) { - return; - } - - const int tid = threadIdx.x; - - const int iter_stride = QK_K; - const int vals_per_iter = iter_stride / n_thread; - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - float tmp = 0; // partial sum for thread in warp - - for (int i = 0; i < ncols; i += iter_stride) { - const int col = i + vals_per_iter*tid; - const int ib = ib0 + col/QK_K; // x block index - const int iqs = col%QK_K; // x quant index - const int iybs = col - col%QK_K; // y block start index - - float v; - dot_kernel(vx, ib, iqs, y + iybs, v); - tmp += v; - } - - // sum up partial sums and write back result - __syncthreads(); -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (tid == 0) { - dst[row] = tmp; - } -} - static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { const half * x = (half *) vx; @@ -1094,43 +1260,34 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<>>(vx, y, dst, ncols, nrows); + dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<>>(vx, y, dst, ncols, nrows); + const dim3 block_dims(32, 1, 1); + dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<>>(vx, y, dst, ncols, nrows); + const dim3 block_dims(32, 1, 1); + dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<>>(vx, y, dst, ncols, nrows); + const dim3 block_dims(32, 1, 1); + dequantize_mul_mat_vec_q5_k<<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; + const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<>>(vx, y, dst, ncols, nrows); + dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 5df922abd..1d4db96ee 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -15,7 +15,7 @@ #include "ggml.h" -#define CL_DMMV_BLOCK_SIZE 32; +#define CL_DMMV_BLOCK_SIZE 32 #define MULTILINE_QUOTE(...) #__VA_ARGS__ static std::string program_source = MULTILINE_QUOTE( @@ -59,6 +59,46 @@ struct __attribute__ ((packed)) block_q8_0 int8_t qs[QK8_0]; }; +struct __attribute__((packed)) block_q2_K +{ + uint8_t scales[16]; + uint8_t qs[64]; + half d; + half dmin; +}; + +struct __attribute__((packed)) block_q3_K +{ + uint8_t hmask[32]; + uint8_t qs[64]; + uint8_t scales[12]; + half d; +}; + +struct __attribute__((packed)) block_q4_K +{ + half d; + half dmin; + uint8_t scales[12]; + uint8_t qs[128]; +}; + +struct __attribute__((packed)) block_q5_K +{ + half d; + half dmin; + uint8_t scales[12]; + uint8_t qh[32]; + uint8_t qs[128]; +}; + +struct __attribute__((packed)) block_q6_K +{ + uint8_t ql[128]; + uint8_t qh[64]; + int8_t scales[16]; + half d; +}; __kernel void convert_fp16_to_fp32(__global half* x, __global float* y) { const uint i = get_global_id(0); @@ -131,8 +171,314 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float *v0 = vload_half(0, &x[ib + 0]); *v1 = vload_half(0, &x[ib + 1]); } + +inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m) +{ + if (j < 4) + { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } + else + { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +__kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __global float *yy) +{ + const int i = get_group_id(0); + const int tid = get_local_id(0); + const int n = tid / 32; + const int l = tid - 32 * n; + const int is = 8 * n + l / 16; + + const uint8_t q = x[i].qs[32 * n + l]; + __global float *y = yy + i * 256 + 128 * n; + + const float dall = vload_half(0, &x[i].d); + const float dmin = vload_half(0, &x[i].dmin); + + y[l + 0] = dall * (x[i].scales[is + 0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is + 0] >> 4); + y[l + 32] = dall * (x[i].scales[is + 2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is + 2] >> 4); + y[l + 64] = dall * (x[i].scales[is + 4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is + 4] >> 4); + y[l + 96] = dall * (x[i].scales[is + 6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is + 6] >> 4); +} + +__kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __global float *yy) +{ + int r = get_local_id(0) / 4; + int i = get_group_id(0); + int tid = r / 2; + int is0 = r % 2; + int l0 = 16 * is0 + 4 * (get_local_id(0) % 4); + int n = tid / 4; + int j = tid - 4 * n; + + uint8_t m = 1 << (4 * n + j); + int is = 8 * n + 2 * j + is0; + int shift = 2 * j; + + int8_t us = is < 4 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 8] >> 0) & 3) << 4) + : is < 8 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 4] >> 2) & 3) << 4) + : is < 12 ? (x[i].scales[is - 8] >> 4) | (((x[i].scales[is + 0] >> 4) & 3) << 4) + : (x[i].scales[is - 8] >> 4) | (((x[i].scales[is - 4] >> 6) & 3) << 4); + float d_all = vload_half(0, &x[i].d); + float dl = d_all * (us - 32); + + __global float *y = yy + i * 256 + 128 * n + 32 * j; + const __global uint8_t *q = x[i].qs + 32 * n; + const __global uint8_t *hm = x[i].hmask; + + for (int l = l0; l < l0 + 4; ++l) + y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); +} + +__kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __global float *yy) +{ + const int i = get_group_id(0); + const int tid = get_local_id(0); + const int il = tid / 8; + const int ir = tid % 8; + const int is = 2 * il; + const int n = 4; + + __global float *y = yy + i * 256 + 64 * il + n * ir; + + const float dall = vload_half(0, &x[i].d); + const float dmin = vload_half(0, &x[i].dmin); + + __global const uint8_t *q = x[i].qs + 32 * il + n * ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + float d1 = dall * sc; + float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + float d2 = dall * sc; + float m2 = dmin * m; + for (int l = 0; l < n; ++l) + { + y[l + 0] = d1 * (q[l] & 0xF) - m1; + y[l + 32] = d2 * (q[l] >> 4) - m2; + } +} + +__kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __global float *yy) +{ + const int i = get_group_id(0); + const int tid = get_local_id(0); + const int il = tid / 16; + const int ir = tid % 16; + const int is = 2 * il; + + __global float *y = yy + i * 256 + 64 * il + 2 * ir; + + const float dall = vload_half(0, &x[i].d); + const float dmin = vload_half(0, &x[i].dmin); + + __global const uint8_t *ql = x[i].qs + 32 * il + 2 * ir; + __global const uint8_t *qh = x[i].qh + 2 * ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float d1 = dall * sc; + const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + uint8_t hm = 1 << (2 * il); + y[0] = d1 * ((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0)) - m1; + y[1] = d1 * ((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0)) - m1; + hm <<= 1; + y[32] = d2 * ((ql[0] >> 4) + (qh[0] & hm ? 16 : 0)) - m2; + y[33] = d2 * ((ql[1] >> 4) + (qh[1] & hm ? 16 : 0)) - m2; +} + +__kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __global float *yy) +{ + const int i = get_group_id(0); + const int tid = get_local_id(0); + const int ip = tid / 32; + const int il = tid - 32 * ip; + const int is = 8 * ip + il / 16; + + __global float *y = yy + i * 256 + 128 * ip + il; + + const float d = vload_half(0, &x[i].d); + + __global const uint8_t *ql = x[i].ql + 64 * ip + il; + const uint8_t qh = x[i].qh[32 * ip + il]; + __global const int8_t *sc = x[i].scales + is; + + y[0] = d * sc[0] * ((int8_t)((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = d * sc[4] * ((int8_t)((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +} + + +void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) { + + int n = iqs / 128; + int r = iqs - 128 * n; + int l = r / 8; + + __global const float *y = yy + 128 * n + l; + __global const uint8_t *q = x[ib].qs + 32 * n + l; + __global const uint8_t *s = x[ib].scales + 8 * n; + + const float dall = vload_half(0, &x[ib].d); + const float dmin = vload_half(0, &x[ib].dmin); + + float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4)) + + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4)) + + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4)) + + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4)) + + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4)) + + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4)) + + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4)) + + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4)); + + *result = sum; +} + +void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) { + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + uint32_t aux[3]; + uint32_t utmp[4]; + + int n = iqs/128; + int r = iqs - 128*n; + int l = r/8; + + __global const float * y = yy + 128*n + l; + __global const uint8_t * q = x[ib].qs + 32*n + l; + __global const uint8_t * hm = x[ib].hmask + l; + const int8_t * s = (const int8_t *)utmp + 8*n; + + aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24; + aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24; + aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24; + + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + const float dall = vload_half(0, &x[ib].d); + const uint8_t m = 1 << (4*n); + + float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4)) + + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4)) + + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4)) + + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4)) + + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4)) + + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4)) + + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4)) + + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4)); + + *result = sum * dall; + +} + +void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) { + + const int j = iqs / 64; // j is in 0...3 + const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4 + const int is = 2*j; // is is in 0...6 in steps of 2 + + __global const float * y = yy + 64*j + ir; + __global const uint8_t * q = x[ib].qs + 32*j + ir; + + const float dall = vload_half(0, &x[ib].d); + const float dmin = vload_half(0, &x[ib].dmin); + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[ib].scales, &sc, &m); + const float d1 = dall * sc; + const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[ib].scales, &sc, &m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + float sum = 0; + for (int k = 0; k < 4; ++k) { + sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1); + sum += y[k + 32] * (d2 * (q[k] >> 4) - m2); + } + + *result = sum; +} + +void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) { + + const int j = iqs / 64; + const int ir = (iqs - 64*j)/2; + const int is = 2*j; + + __global const float * y = yy + 64*j + ir; + __global const uint8_t * ql = x[ib].qs + 32*j + ir; + __global const uint8_t * qh = x[ib].qh + ir; + + const float dall = vload_half(0, &x[ib].d); + const float dmin = vload_half(0, &x[ib].dmin); + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[ib].scales, &sc, &m); + const float d1 = dall * sc; + const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[ib].scales, &sc, &m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + uint8_t hm = 1 << is; + float sum = 0; + for (int k = 0; k < 4; ++k) { + sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1); + } + hm <<= 1; + for (int k = 0; k < 4; ++k) { + sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2); + } + *result = sum; + +} + +void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) { + + + const int ip = iqs / 128; // 0 or 1 + const int il = (iqs - 128*ip)/8; // 0...15 + const int is = 8*ip; + + __global const float * y = yy + 128*ip + il; + + const float d = vload_half(0, &x[ib].d); + + __global const uint8_t * ql = x[ib].ql + 64*ip + il; + __global const uint8_t * qh = x[ib].qh + 32*ip + il; + __global const int8_t * sc = x[ib].scales + is; + + *result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32) + + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32) + + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32) + + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32) + + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32) + + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32) + + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32) + + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32); + +} + ); + std::string dequant_template = MULTILINE_QUOTE( __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) { const int i = get_group_id(0)*get_local_size(0) + get_local_id(0)*2; @@ -160,7 +506,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) { std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE( __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { const int block_size = get_local_size(0); - const int row = get_global_id(0) / block_size; + const int row = get_group_id(0); const int tid = get_local_id(0); const uint qk = QUANT_K; @@ -199,6 +545,45 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float } ); +std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE( +__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { + const int block_size = get_local_size(0); + const int row = get_group_id(0); + const int tid = get_local_id(0); + + const int iter_stride = 256; + const int vals_per_iter = iter_stride / block_size; + const int num_blocks_per_row = ncols / 256; + const int ib0 = row*num_blocks_per_row; + + tmp[tid] = 0; + + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = ib0 + col/256; // x block index + const int iqs = col%256; // x quant index + const int iybs = col - col%256; // y block start index + + // dequantize + float v; + DOT_KERNEL(x, ib, iqs, y + iybs, &v); + tmp[tid] += v; + } + + // sum up partial sums and write back result + barrier(CLK_LOCAL_MEM_FENCE); + for (int s=block_size/2; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + dst[row] = tmp[0]; + } +} +); + std::string mul_template = MULTILINE_QUOTE( __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) { const int i = get_group_id(0)*get_local_size(0) + get_local_id(0); @@ -260,6 +645,18 @@ std::array mul_str_values = { "mul_f32", "float" }; +std::array dmmv_k_str_keys = { + "KERNEL_NAME", "X_TYPE", "DOT_KERNEL" +}; + +std::array dmmv_k_str_values = { + "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K", + "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K", + "dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K", + "dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K", + "dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K", +}; + std::string& replace(std::string& s, const std::string& from, const std::string& to) { size_t pos = 0; while ((pos = s.find(from, pos)) != std::string::npos) { @@ -289,6 +686,14 @@ std::string generate_kernels() { } src << mul_kernel << '\n'; } + for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) { + std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template; + for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) { + replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]); + } + src << dmmv_k_kernel << '\n'; + } + return src.str(); } @@ -300,6 +705,8 @@ static cl_program program; static cl_kernel convert_row_f16_cl; static cl_kernel dequantize_row_q4_0_cl, dequantize_row_q4_1_cl, dequantize_row_q5_0_cl, dequantize_row_q5_1_cl, dequantize_row_q8_0_cl; static cl_kernel dequantize_mul_mat_vec_q4_0_cl, dequantize_mul_mat_vec_q4_1_cl, dequantize_mul_mat_vec_q5_0_cl, dequantize_mul_mat_vec_q5_1_cl, dequantize_mul_mat_vec_q8_0_cl, convert_mul_mat_vec_f16_cl; +static cl_kernel dequantize_block_q2_k_cl, dequantize_block_q3_k_cl, dequantize_block_q4_k_cl, dequantize_block_q5_k_cl, dequantize_block_q6_k_cl; +static cl_kernel dequantize_mul_mat_vec_q2_K_cl, dequantize_mul_mat_vec_q3_K_cl, dequantize_mul_mat_vec_q4_K_cl, dequantize_mul_mat_vec_q5_K_cl, dequantize_mul_mat_vec_q6_K_cl; static cl_kernel mul_f32_cl; static bool fp16_support; @@ -529,6 +936,12 @@ void ggml_cl_init(void) { CL_CHECK((dequantize_row_q5_0_cl = clCreateKernel(program, "dequantize_row_q5_0", &err), err)); CL_CHECK((dequantize_row_q5_1_cl = clCreateKernel(program, "dequantize_row_q5_1", &err), err)); CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err)); + CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err)); + CL_CHECK((dequantize_block_q2_k_cl = clCreateKernel(program, "dequantize_block_q2_K", &err), err)); + CL_CHECK((dequantize_block_q3_k_cl = clCreateKernel(program, "dequantize_block_q3_K", &err), err)); + CL_CHECK((dequantize_block_q4_k_cl = clCreateKernel(program, "dequantize_block_q4_K", &err), err)); + CL_CHECK((dequantize_block_q5_k_cl = clCreateKernel(program, "dequantize_block_q5_K", &err), err)); + CL_CHECK((dequantize_block_q6_k_cl = clCreateKernel(program, "dequantize_block_q6_K", &err), err)); // dequant mul mat kernel CL_CHECK((dequantize_mul_mat_vec_q4_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_0", &err), err)); @@ -537,6 +950,11 @@ void ggml_cl_init(void) { CL_CHECK((dequantize_mul_mat_vec_q5_1_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_1", &err), err)); CL_CHECK((dequantize_mul_mat_vec_q8_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q8_0", &err), err)); CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err)); + CL_CHECK((dequantize_mul_mat_vec_q2_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q2_K", &err), err)); + CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err)); + CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K", &err), err)); + CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err)); + CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K", &err), err)); // mul kernel CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err)); @@ -554,6 +972,16 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) { return &dequantize_row_q5_1_cl; case GGML_TYPE_Q8_0: return &dequantize_row_q8_0_cl; + case GGML_TYPE_Q2_K: + return &dequantize_block_q2_k_cl; + case GGML_TYPE_Q3_K: + return &dequantize_block_q3_k_cl; + case GGML_TYPE_Q4_K: + return &dequantize_block_q4_k_cl; + case GGML_TYPE_Q5_K: + return &dequantize_block_q5_k_cl; + case GGML_TYPE_Q6_K: + return &dequantize_block_q6_k_cl; case GGML_TYPE_F16: return &convert_row_f16_cl; default: @@ -561,6 +989,50 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) { } } +static size_t ggml_cl_global_denom(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return 1; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + return 4; + case GGML_TYPE_Q4_K: + return 8; + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + return 4; + case GGML_TYPE_F16: + default: + return 1; + } +} + +static size_t ggml_cl_local_size(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return 0; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + return 64; + case GGML_TYPE_Q4_K: + return 32; + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + return 64; + case GGML_TYPE_F16: + default: + return 0; + } +} + static cl_kernel* ggml_get_dequantize_mul_mat_vec_cl(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: @@ -575,6 +1047,16 @@ static cl_kernel* ggml_get_dequantize_mul_mat_vec_cl(ggml_type type) { return &dequantize_mul_mat_vec_q8_0_cl; case GGML_TYPE_F16: return &convert_mul_mat_vec_f16_cl; + case GGML_TYPE_Q2_K: + return &dequantize_mul_mat_vec_q2_K_cl; + case GGML_TYPE_Q3_K: + return &dequantize_mul_mat_vec_q3_K_cl; + case GGML_TYPE_Q4_K: + return &dequantize_mul_mat_vec_q4_K_cl; + case GGML_TYPE_Q5_K: + return &dequantize_mul_mat_vec_q5_K_cl; + case GGML_TYPE_Q6_K: + return &dequantize_mul_mat_vec_q6_K_cl; default: return nullptr; } @@ -1017,6 +1499,9 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * cl_kernel* dmmv = ggml_get_dequantize_mul_mat_vec_cl(type); GGML_ASSERT(to_fp32_cl != nullptr); + const size_t global_denom = ggml_cl_global_denom(type); + const size_t local = ggml_cl_local_size(type); + size_t ev_idx = 0; std::vector events; @@ -1049,10 +1534,10 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++)); } else { // general dequantization kernel + CLBlast matrix matrix multiplication // convert src0 to fp32 on device - const size_t global = x_ne; + const size_t global = x_ne / global_denom; CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q)); CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X)); - CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, NULL, events.size(), !events.empty() ? events.data() : NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL)); // copy src1 to device CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL)); diff --git a/ggml.c b/ggml.c index c0efa1977..0eda7f338 100644 --- a/ggml.c +++ b/ggml.c @@ -35,6 +35,12 @@ #define static_assert(cond, msg) struct global_scope_noop_trick #endif +#if defined(_MSC_VER) +// disable "possible loss of data" to avoid hundreds of casts +// we should just be careful :) +#pragma warning(disable: 4244 4267) +#endif + #if defined(_WIN32) #include diff --git a/llama.cpp b/llama.cpp index b8bc0d821..81f047ed2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -40,6 +40,10 @@ #include #include +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -1654,7 +1658,7 @@ static bool llama_eval_internal( // cur = cur*norm(broadcasted) cur = ggml_mul(ctx0, cur, model.norm); - offload_func_nr(cur); + // offload_func_nr(cur); // TODO CPU + GPU mirrored backend ggml_set_name(cur, "result_norm"); embeddings = cur; diff --git a/pocs/vdot/vdot.cpp b/pocs/vdot/vdot.cpp index 26bf50c9a..7b18090d6 100644 --- a/pocs/vdot/vdot.cpp +++ b/pocs/vdot/vdot.cpp @@ -10,6 +10,10 @@ #include +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + constexpr int kVecSize = 1 << 18; float drawFromGaussianPdf(std::mt19937& rndm) { diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index 728460b5e..c40f1b29c 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -9,12 +9,15 @@ #include #include +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif -const float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001; -const float MAX_QUANTIZATION_TOTAL_ERROR = 0.002; -const float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075; -const float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040; -const float MAX_DOT_PRODUCT_ERROR = 0.02; +const float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f; +const float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f; +const float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f; +const float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f; +const float MAX_DOT_PRODUCT_ERROR = 0.02f; const char* RESULT_STR[] = {"ok", "FAILED"}; diff --git a/tests/test-quantize-perf.cpp b/tests/test-quantize-perf.cpp index d5514455d..600375771 100644 --- a/tests/test-quantize-perf.cpp +++ b/tests/test-quantize-perf.cpp @@ -13,6 +13,10 @@ #include #include +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + #define MAX_ALIGNMENT 64 #define QK 32 #define WARMUP 5 diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 0e675127f..5d693f7b5 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -176,27 +176,27 @@ void test_frequency_presence_penalty( int main(void) { ggml_time_init(); - test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4}, 1); - test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2}, 3); + test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1); + test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3); - test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4}, 0); - test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3}, 0.7); - test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2, 0.1}, 1); + test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0); + test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f); + test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1); - test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3}, 0.25); - test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.75); - test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.99); + test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f); + test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f); + test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f); - test_typical({0.97, 0.01, 0.01, 0.01}, {0.97}, 0.5); - test_typical({0.4, 0.2, 0.2, 0.2}, {0.2, 0.2, 0.2}, 0.5); + test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f); + test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f); - test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.25, 0.25, 0.25, 0.25, 0}, 50.0); - test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.5, 0.5, 0, 0, 0}, 50.0); - test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.5, 0.5, 0, 0, 0}, 50.0); + test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f); + test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f); + test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f); - test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.249997, 0.249997, 0.249997, 0.249997, 0.000011}, 5.0, 5.0); - test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.499966, 0.499966, 0.000023, 0.000023, 0.000023}, 5.0, 5.0); - test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.499977, 0.499977, 0.000023, 0.000023, 0.000000}, 5.0, 5.0); + test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 5.0f, 5.0f); + test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 5.0f, 5.0f); + test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 5.0f, 5.0f); printf("OK\n"); } diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index b08984571..ab1538a0c 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -53,7 +53,7 @@ int main(int argc, char **argv) { for (const auto & test_kv : k_tests()) { std::vector res(test_kv.first.size()); - const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), res.size(), true); + const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true); res.resize(n); bool correct = res.size() == test_kv.second.size();