Merge 'origin/master' into hipblas

This commit is contained in:
Henri Vasserman 2023-06-20 01:23:12 +03:00
commit 5dd2fbe6ea
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986
12 changed files with 1281 additions and 259 deletions

View file

@ -70,6 +70,7 @@ set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
@ -239,6 +240,9 @@ if (LLAMA_CUBLAS)
add_compile_definitions(GGML_USE_CUBLAS) add_compile_definitions(GGML_USE_CUBLAS)
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y}) add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
if (LLAMA_CUDA_DMMV_F16)
add_compile_definitions(GGML_CUDA_DMMV_F16)
endif()
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
if (LLAMA_STATIC) if (LLAMA_STATIC)
@ -497,6 +501,7 @@ add_library(ggml_static STATIC $<TARGET_OBJECTS:ggml>)
if (BUILD_SHARED_LIBS) if (BUILD_SHARED_LIBS)
set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
add_library(ggml_shared SHARED $<TARGET_OBJECTS:ggml>) add_library(ggml_shared SHARED $<TARGET_OBJECTS:ggml>)
target_link_libraries(ggml_shared PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
endif() endif()
add_library(llama add_library(llama
@ -522,9 +527,18 @@ endif()
if (GGML_SOURCES_CUDA) if (GGML_SOURCES_CUDA)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture") message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF) set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES "native")
set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES OFF)
set_property(TARGET ggml_static PROPERTY CUDA_ARCHITECTURES "native")
set_property(TARGET ggml_static PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
if (BUILD_SHARED_LIBS)
set_property(TARGET ggml_shared PROPERTY CUDA_ARCHITECTURES "native")
set_property(TARGET ggml_shared PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
endif()
set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES "native")
endif() endif()

View file

@ -169,6 +169,9 @@ ifdef LLAMA_CUDA_DMMV_Y
else else
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1 NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1
endif # LLAMA_CUDA_DMMV_Y endif # LLAMA_CUDA_DMMV_Y
ifdef LLAMA_CUDA_DMMV_F16
NVCCFLAGS += -DGGML_CUDA_DMMV_F16
endif # LLAMA_CUDA_DMMV_F16
ifdef LLAMA_CUDA_KQUANTS_ITER ifdef LLAMA_CUDA_KQUANTS_ITER
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
else else
@ -270,7 +273,7 @@ $(info )
ggml.o: ggml.c ggml.h ggml-cuda.h ggml.o: ggml.c ggml.h ggml-cuda.h
$(CC) $(CFLAGS) -c $< -o $@ $(CC) $(CFLAGS) -c $< -o $@
llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama-util.h llama.o: llama.cpp ggml.h ggml-cuda.h ggml-metal.h llama.h llama-util.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
common.o: examples/common.cpp examples/common.h common.o: examples/common.cpp examples/common.h

View file

@ -337,7 +337,14 @@ Building the program with BLAS support may lead to some performance improvements
cmake --build . --config Release cmake --build . --config Release
``` ```
The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. The following compilation options are also available to tweak performance:
| Option | Legal values | Default | Description |
|-------------------------|------------------------|---------|-------------|
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
| LLAMA_CUDA_DMMV_Y | Positive integer | 1 | Block size in y direction for the CUDA dequantization + mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
| LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. |
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value 2 1 can improve performance for slow GPUs. |
- #### CLBlast - #### CLBlast
@ -617,7 +624,12 @@ And after 4.45 hours, you will have the final perplexity.
#### Building the Project using Android NDK #### Building the Project using Android NDK
You can easily run `llama.cpp` on Android device with [termux](https://termux.dev/). You can easily run `llama.cpp` on Android device with [termux](https://termux.dev/).
First, obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake:
First, install the essential packages for termux:
```
pkg install clang wget git cmake
```
Second, obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake:
``` ```
$ mkdir build-android $ mkdir build-android
$ cd build-android $ cd build-android

View file

@ -40,8 +40,10 @@ int main(int argc, char ** argv) {
// this allocates all Metal resources and memory buffers // this allocates all Metal resources and memory buffers
auto * ctx_metal = ggml_metal_init(); auto * ctx_metal = ggml_metal_init();
ggml_metal_add_buffer(ctx_metal, "data", ggml_get_mem_buffer(ctx_data), ggml_get_mem_size(ctx_data)); const size_t max_size_data = ggml_get_max_tensor_size(ctx_data);
ggml_metal_add_buffer(ctx_metal, "eval", ggml_get_mem_buffer(ctx_eval), ggml_get_mem_size(ctx_eval)); const size_t max_size_eval = ggml_get_max_tensor_size(ctx_eval);
ggml_metal_add_buffer(ctx_metal, "data", ggml_get_mem_buffer(ctx_data), ggml_get_mem_size(ctx_data), max_size_data);
ggml_metal_add_buffer(ctx_metal, "eval", ggml_get_mem_buffer(ctx_eval), ggml_get_mem_size(ctx_eval), max_size_eval);
// main // main
{ {

View file

@ -21,6 +21,7 @@ Command line options:
- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`. - `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`. - `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
- `--port`: Set the port to listen. Default: `8080`. - `--port`: Set the port to listen. Default: `8080`.
- `--embedding`: Enable embedding extraction, Default: disabled.
## Build ## Build
@ -119,14 +120,14 @@ node .
`top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9). `top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9).
`n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. (default: 128, -1 = infinity). `n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. (default: 128, -1 = infinity).
`n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context. `n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context.
By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt. By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt.
`stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`. `stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`.
`prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. `prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. A space is inserted in the front like main.cpp does.
`stop`: Specify a JSON array of stopping strings. `stop`: Specify a JSON array of stopping strings.
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []). These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []).
@ -163,6 +164,14 @@ node .
`content`: Set the text to tokenize. `content`: Set the text to tokenize.
Note that the special `BOS` token is not added in fron of the text and also a space character is not inserted automatically as it is for `/completion`.
- **POST** `/embedding`: Generate embedding of a given text just as [the embedding example](../embedding) does.
*Options:*
`content`: Set the text to process.
## More examples ## More examples
### Interactive mode ### Interactive mode

View file

@ -254,6 +254,11 @@ struct llama_server_context {
n_past += n_eval; n_past += n_eval;
} }
if (params.n_predict == 0) {
has_next_token = false;
return llama_token_eos();
}
// out of user input, sample next token // out of user input, sample next token
const float temp = params.temp; const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
@ -419,6 +424,19 @@ struct llama_server_context {
return token_text; return token_text;
} }
std::vector<float> getEmbedding() {
static const int n_embd = llama_n_embd(ctx);
if (!params.embedding) {
LOG_WARNING("embedding disabled", {
{ "params.embedding", params.embedding },
});
return std::vector<float>(n_embd, 0.0f);
}
const float * data = llama_get_embeddings(ctx);
std::vector<float> embedding(data, data + n_embd);
return embedding;
}
}; };
static void server_print_usage(const char * argv0, const gpt_params & params, static void server_print_usage(const char * argv0, const gpt_params & params,
@ -457,6 +475,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params,
fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port); fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
fprintf(stderr, " --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -603,6 +622,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
params.use_mlock = true; params.use_mlock = true;
} else if (arg == "--no-mmap") { } else if (arg == "--no-mmap") {
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--embedding") {
params.embedding = true;
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
server_print_usage(argv[0], default_params, default_sparams); server_print_usage(argv[0], default_params, default_sparams);
@ -646,6 +667,12 @@ static json format_generation_settings(llama_server_context & llama) {
}; };
} }
static json format_embedding_response(llama_server_context & llama) {
return json {
{ "embedding", llama.getEmbedding() },
};
}
static json format_final_response(llama_server_context & llama, const std::string & content) { static json format_final_response(llama_server_context & llama, const std::string & content) {
return json { return json {
{ "content", content }, { "content", content },
@ -881,12 +908,27 @@ int main(int argc, char ** argv) {
svr.Post("/tokenize", [&llama](const Request & req, Response & res) { svr.Post("/tokenize", [&llama](const Request & req, Response & res) {
const json body = json::parse(req.body); const json body = json::parse(req.body);
const std::string content = body["content"].get<std::string>(); const std::string content = body.value("content", "");
const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false); const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
const json data = format_tokenizer_response(tokens); const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json"); return res.set_content(data.dump(), "application/json");
}); });
svr.Post("/embedding", [&llama](const Request & req, Response & res) {
const json body = json::parse(req.body);
llama.rewind();
llama_reset_timings(llama.ctx);
llama.params.prompt = body.value("content", "");
llama.params.n_predict = 0;
llama.loadPrompt();
llama.beginCompletion();
llama.doCompletion();
const json data = format_embedding_response(llama);
return res.set_content(data.dump(), "application/json");
});
svr.set_logger(log_server_request); svr.set_logger(log_server_request);
svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) { svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) {

View file

@ -105,7 +105,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} while (0) } while (0)
#endif // CUDART_VERSION >= 11 #endif // CUDART_VERSION >= 11
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1); #ifdef GGML_CUDA_DMMV_F16
typedef half dfloat; // dequantize float
typedef half2 dfloat2;
#else
typedef float dfloat; // dequantize float
typedef float2 dfloat2;
#endif //GGML_CUDA_DMMV_F16
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v); typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
typedef void (*cpy_kernel_t)(const char * cx, char * cdst); typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
@ -289,82 +297,106 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
} }
} }
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx; const block_q4_0 * x = (const block_q4_0 *) vx;
const float d = x[ib].d; const dfloat d = x[ib].d;
const uint8_t vui = x[ib].qs[iqs]; const int vui = x[ib].qs[iqs];
const int8_t vi0 = vui & 0xF; v.x = vui & 0xF;
const int8_t vi1 = vui >> 4; v.y = vui >> 4;
v0 = (vi0 - 8)*d; #ifdef GGML_CUDA_DMMV_F16
v1 = (vi1 - 8)*d; v = __hsub2(v, {8.0f, 8.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 8.0f) * d;
v.y = (v.y - 8.0f) * d;
#endif // GGML_CUDA_DMMV_F16
} }
static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx; const block_q4_1 * x = (const block_q4_1 *) vx;
const float d = x[ib].d; const dfloat d = x[ib].d;
const float m = x[ib].m; const dfloat m = x[ib].m;
const uint8_t vui = x[ib].qs[iqs]; const int vui = x[ib].qs[iqs];
const int8_t vi0 = vui & 0xF; v.x = vui & 0xF;
const int8_t vi1 = vui >> 4; v.y = vui >> 4;
v0 = vi0*d + m; #ifdef GGML_CUDA_DMMV_F16
v1 = vi1*d + m; v = __hmul2(v, {d, d});
v = __hadd2(v, {m, m});
#else
v.x = (v.x * d) + m;
v.y = (v.y * d) + m;
#endif // GGML_CUDA_DMMV_F16
} }
static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_0 * x = (const block_q5_0 *) vx; const block_q5_0 * x = (const block_q5_0 *) vx;
const float d = x[ib].d; const dfloat d = x[ib].d;
uint32_t qh; uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh)); memcpy(&qh, x[ib].qh, sizeof(qh));
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16; v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16; v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
v0 = x0*d; #ifdef GGML_CUDA_DMMV_F16
v1 = x1*d; v = __hsub2(v, {16.0f, 16.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 16.0f) * d;
v.y = (v.y - 16.0f) * d;
#endif // GGML_CUDA_DMMV_F16
} }
static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx; const block_q5_1 * x = (const block_q5_1 *) vx;
const float d = x[ib].d; const dfloat d = x[ib].d;
const float m = x[ib].m; const dfloat m = x[ib].m;
uint32_t qh; uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh)); memcpy(&qh, x[ib].qh, sizeof(qh));
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0); v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1); v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
v0 = x0*d + m; #ifdef GGML_CUDA_DMMV_F16
v1 = x1*d + m; v = __hmul2(v, {d, d});
v = __hadd2(v, {m, m});
#else
v.x = (v.x * d) + m;
v.y = (v.y * d) + m;
#endif // GGML_CUDA_DMMV_F16
} }
static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q8_0 * x = (const block_q8_0 *) vx; const block_q8_0 * x = (const block_q8_0 *) vx;
const float d = x[ib].d; const dfloat d = x[ib].d;
const int8_t vi0 = x[ib].qs[iqs + 0]; v.x = x[ib].qs[iqs + 0];
const int8_t vi1 = x[ib].qs[iqs + 1]; v.y = x[ib].qs[iqs + 1];
v0 = vi0*d; #ifdef GGML_CUDA_DMMV_F16
v1 = vi1*d; v = __hmul2(v, {d, d});
#else
v.x *= d;
v.y *= d;
#endif // GGML_CUDA_DMMV_F16
} }
//================================== k-quants //================================== k-quants
@ -538,15 +570,15 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
const block_q2_K * x = (const block_q2_K *)vx + ib0; const block_q2_K * x = (const block_q2_K *)vx + ib0;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int step = 16/K_QUANTS_PER_ITERATION; const int step = 16/K_QUANTS_PER_ITERATION;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...7 const int in = tid - step*im; // 0...15 or 0...7
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...14 in steps of 4 const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0; const int q_offset = 32*im + l0;
const int s_offset = 8*im; const int s_offset = 8*im;
const int y_offset = 128*im + l0; const int y_offset = 128*im + l0;
@ -601,27 +633,30 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
} }
} }
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols) { static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
const uint16_t kmask1 = 0x0303; const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f; const uint16_t kmask2 = 0x0f0f;
const int row = blockIdx.x; const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K; const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row; const int ib0 = row*num_blocks_per_row;
const block_q3_K * x = (const block_q3_K *)vx + ib0; const block_q3_K * x = (const block_q3_K *)vx + ib0;
const int tid = threadIdx.x/2; // 0...15 const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%2; // 0, 1 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int n = 2; // iterations in the inner loop const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
const int im = tid/8; // 0 or 1. 0 computes 0..., 1 computes 128... const int step = 16/K_QUANTS_PER_ITERATION;
const int in = tid - 8*im; // 0...7 const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0....15 or 0...7
const uint8_t m = 1 << (4*im); const uint8_t m = 1 << (4*im);
const int l0 = n*in; // 0...28 in steps of 4 const int l0 = n*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0; const int q_offset = 32*im + l0;
const int y_offset = 128*im + l0; const int y_offset = 128*im + l0;
@ -632,7 +667,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 2) { for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset; const uint8_t * q = x[i].qs + q_offset;
@ -673,22 +708,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
} }
} }
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols) { static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
const uint16_t kmask1 = 0x3f3f; const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f; const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0; const uint16_t kmask3 = 0xc0c0;
const int row = blockIdx.x; const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K; const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row; const int ib0 = row*num_blocks_per_row;
const int tid = threadIdx.x/2; // 0...15 const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%2; const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int il = tid/4; // 0...3 const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
const int ir = tid - 4*il;// 0...3
const int n = 4; const int il = tid/step; // 0...3
const int ir = tid - step*il; // 0...7 or 0...3
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2; const int in = il%2;
@ -704,7 +742,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 2) { for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const uint8_t * q1 = x[i].qs + q_offset; const uint8_t * q1 = x[i].qs + q_offset;
const uint8_t * q2 = q1 + 64; const uint8_t * q2 = q1 + 64;
@ -759,7 +797,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
const int il = tid/4; // 0...3 const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3 const int ir = tid - 4*il;// 0...3
const int n = 4; const int n = 2;
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2; const int in = il%2;
@ -798,11 +836,16 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
float4 sum = {0.f, 0.f, 0.f, 0.f}; float4 sum = {0.f, 0.f, 0.f, 0.f};
float smin = 0; float smin = 0;
for (int l = 0; l < n; ++l) { for (int l = 0; l < n; ++l) {
sum.x += y1[l+ 0] * ((ql1[l] & 0xF) + (qh[l] & (hm1 << 0) ? 16 : 0)); sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
sum.y += y1[l+32] * ((ql1[l] >> 4) + (qh[l] & (hm1 << 1) ? 16 : 0)); + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
sum.z += y2[l+ 0] * ((ql2[l] & 0xF) + (qh[l] & (hm2 << 0) ? 16 : 0)); sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
sum.w += y2[l+32] * ((ql2[l] >> 4) + (qh[l] & (hm2 << 1) ? 16 : 0)); + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+ y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+ y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
} }
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
@ -898,11 +941,12 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
} }
} }
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
const half * x = (const half *) vx; const half * x = (const half *) vx;
v0 = __half2float(x[ib + iqs + 0]); // automatic half -> float type cast if dfloat == float
v1 = __half2float(x[ib + iqs + 1]); v.x = x[ib + iqs + 0];
v.y = x[ib + iqs + 1];
} }
template <int qk, int qr, dequantize_kernel_t dequantize_kernel> template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@ -919,13 +963,15 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize // dequantize
float & v0 = y[iybs + iqs + 0]; dfloat2 v;
float & v1 = y[iybs + iqs + y_offset]; dequantize_kernel(vx, ib, iqs, v);
dequantize_kernel(vx, ib, iqs, v0, v1);
y[iybs + iqs + 0] = v.x;
y[iybs + iqs + y_offset] = v.y;
} }
template <int qk, int qr, dequantize_kernel_t dequantize_kernel> template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols, const int nrows) { static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
// qk = quantized weights per x block // qk = quantized weights per x block
// qr = number of quantized weights per data value in x block // qr = number of quantized weights per data value in x block
const int row = blockIdx.y*blockDim.y + threadIdx.y; const int row = blockIdx.y*blockDim.y + threadIdx.y;
@ -940,7 +986,12 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
float tmp = 0.0f; // partial sum for thread in warp // partial sum for each thread
#ifdef GGML_CUDA_DMMV_F16
half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
#else
float tmp = 0.0f;
#endif // GGML_CUDA_DMMV_F16
for (int i = 0; i < ncols; i += iter_stride) { for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid; const int col = i + vals_per_iter*tid;
@ -954,14 +1005,21 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
// process 2 vals per j iter // process 2 vals per j iter
// dequantize // dequantize
float v0, v1;
dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
dfloat2 v;
dequantize_kernel(vx, ib, iqs + j/qr, v);
// matrix multiplication // matrix multiplication
tmp += v0 * y[iybs + iqs + j/qr + 0];
tmp += v1 * y[iybs + iqs + j/qr + y_offset];
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
#ifdef GGML_CUDA_DMMV_F16
tmp += __hmul2(v, {
y[iybs + iqs + j/qr + 0],
y[iybs + iqs + j/qr + y_offset]
});
#else
tmp += v.x * y[iybs + iqs + j/qr + 0];
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
#endif // GGML_CUDA_DMMV_F16
} }
} }
@ -973,7 +1031,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
} }
if (tid == 0) { if (tid == 0) {
#ifdef GGML_CUDA_DMMV_F16
dst[row] = tmp.x + tmp.y;
#else
dst[row] = tmp; dst[row] = tmp;
#endif // GGML_CUDA_DMMV_F16
} }
} }
@ -1268,7 +1330,7 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y); dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
} }
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
@ -1277,7 +1339,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, f
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
@ -1286,7 +1348,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, f
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
@ -1295,7 +1357,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, f
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
@ -1304,7 +1366,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, f
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
@ -1315,7 +1377,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2; const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1); const dim3 block_dims(32, ny, 1);
@ -1324,14 +1386,20 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(32, 1, 1); const int ny = 2 / K_QUANTS_PER_ITERATION;
dequantize_mul_mat_vec_q3_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols); const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(32, 1, 1); const int ny = 2 / K_QUANTS_PER_ITERATION;
dequantize_mul_mat_vec_q4_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols); const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@ -1354,7 +1422,7 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k); dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
} }
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
@ -1769,21 +1837,40 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t nrows = i01_high - i01_low; const int64_t nrows = i01_high - i01_low;
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
#ifdef GGML_CUDA_DMMV_F16
size_t ash;
dfloat * src1_dfloat = nullptr; // dfloat == half
bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
if (src1_convert_f16) {
src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
ne00, 1, sizeof(float), 0, 0,
ne00, 1, sizeof(half), 0, 0, cudaStream_main);
}
#else
dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
#endif // GGML_CUDA_DMMV_F16
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
break; break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
break; break;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
@ -1801,7 +1888,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
break; break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -1809,6 +1896,12 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
} }
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
#ifdef GGML_CUDA_DMMV_F16
if (src1_convert_f16) {
ggml_cuda_pool_free(src1_dfloat, ash);
}
#endif // GGML_CUDA_DMMV_F16
(void) src1; (void) src1;
(void) dst; (void) dst;
(void) src0_ddf_i; (void) src0_ddf_i;

View file

@ -41,12 +41,15 @@ void ggml_metal_free(struct ggml_metal_context * ctx);
// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute // - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
// - the mapping is used during computation to determine the arguments of the compute kernels // - the mapping is used during computation to determine the arguments of the compute kernels
// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal // - you don't need to keep the host memory buffer allocated as it is never accessed by Metal
// - max_size specifies the maximum size of a tensor and is used to create shared views such
// that it is guaranteed that the tensor will fit in at least one of the views
// //
bool ggml_metal_add_buffer( bool ggml_metal_add_buffer(
struct ggml_metal_context * ctx, struct ggml_metal_context * ctx,
const char * name, const char * name,
void * data, void * data,
size_t size); size_t size,
size_t max_size);
// set data from host memory into the device // set data from host memory into the device
void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t); void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);

View file

@ -183,6 +183,14 @@ struct ggml_metal_context * ggml_metal_init(void) {
#undef GGML_METAL_ADD_KERNEL #undef GGML_METAL_ADD_KERNEL
} }
fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
if (ctx->device.maxTransferRate != 0) {
fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
} else {
fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
}
return ctx; return ctx;
} }
@ -199,10 +207,13 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) { static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
//fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
const int64_t tsize = ggml_nbytes(t);
// find the view that contains the tensor fully
for (int i = 0; i < ctx->n_buffers; ++i) { for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) { if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
*offs = (size_t) ioffs; *offs = (size_t) ioffs;
//fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs); //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
@ -220,7 +231,8 @@ bool ggml_metal_add_buffer(
struct ggml_metal_context * ctx, struct ggml_metal_context * ctx,
const char * name, const char * name,
void * data, void * data,
size_t size) { size_t size,
size_t max_size) {
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) { if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
fprintf(stderr, "%s: too many buffers\n", __func__); fprintf(stderr, "%s: too many buffers\n", __func__);
return false; return false;
@ -237,30 +249,68 @@ bool ggml_metal_add_buffer(
} }
} }
size_t page_size = getpagesize(); const size_t size_page = getpagesize();
size_t aligned_size = size;
if ((aligned_size % page_size) != 0) { size_t size_aligned = size;
aligned_size += (page_size - (aligned_size % page_size)); if ((size_aligned % size_page) != 0) {
size_aligned += (size_page - (size_aligned % size_page));
} }
// the buffer fits into the max buffer size allowed by the device
if (size_aligned <= ctx->device.maxBufferLength) {
ctx->buffers[ctx->n_buffers].name = name; ctx->buffers[ctx->n_buffers].name = name;
ctx->buffers[ctx->n_buffers].data = data; ctx->buffers[ctx->n_buffers].data = data;
ctx->buffers[ctx->n_buffers].size = size; ctx->buffers[ctx->n_buffers].size = size;
if (ctx->device.maxBufferLength < aligned_size) { ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
fprintf(stderr, "%s: buffer '%s' size %zu is larger than buffer maximum of %zu\n", __func__, name, aligned_size, ctx->device.maxBufferLength);
return false;
}
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:aligned_size options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) { if (ctx->buffers[ctx->n_buffers].metal == nil) {
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0); fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
return false; return false;
} }
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0); fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
++ctx->n_buffers; ++ctx->n_buffers;
} else {
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
// one of the views
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
const size_t size_view = ctx->device.maxBufferLength;
for (size_t i = 0; i < size; i += size_step) {
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
ctx->buffers[ctx->n_buffers].name = name;
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) {
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
return false;
}
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
if (i + size_step < size) {
fprintf(stderr, "\n");
}
++ctx->n_buffers;
}
}
fprintf(stderr, ", (%8.2f / %8.2f)",
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
} else {
fprintf(stderr, "\n");
}
} }
return true; return true;
@ -909,4 +959,14 @@ void ggml_metal_graph_compute(
dispatch_barrier_sync(queue, ^{}); dispatch_barrier_sync(queue, ^{});
[command_buffers[n_cb - 1] waitUntilCompleted]; [command_buffers[n_cb - 1] waitUntilCompleted];
// check status of command buffers
// needed to detect if the device ran out-of-memory for example (#1881)
for (int i = 0; i < n_cb; i++) {
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
if (status != MTLCommandBufferStatusCompleted) {
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
GGML_ASSERT(false);
}
}
} }

820
ggml.c

File diff suppressed because it is too large Load diff

147
ggml.h
View file

@ -303,6 +303,7 @@ extern "C" {
GGML_OP_STEP, GGML_OP_STEP,
GGML_OP_RELU, GGML_OP_RELU,
GGML_OP_GELU, GGML_OP_GELU,
GGML_OP_GELU_QUICK,
GGML_OP_SILU, GGML_OP_SILU,
GGML_OP_SILU_BACK, GGML_OP_SILU_BACK,
GGML_OP_NORM, // normalize GGML_OP_NORM, // normalize
@ -331,12 +332,15 @@ extern "C" {
GGML_OP_ROPE_BACK, GGML_OP_ROPE_BACK,
GGML_OP_ALIBI, GGML_OP_ALIBI,
GGML_OP_CLAMP, GGML_OP_CLAMP,
GGML_OP_CONV_1D_1S, GGML_OP_CONV_1D_S1_PH,
GGML_OP_CONV_1D_2S, GGML_OP_CONV_1D_S2_PH,
GGML_OP_CONV_2D_SK_P0,
GGML_OP_FLASH_ATTN, GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF, GGML_OP_FLASH_FF,
GGML_OP_FLASH_ATTN_BACK, GGML_OP_FLASH_ATTN_BACK,
GGML_OP_WIN_PART,
GGML_OP_WIN_UNPART,
GGML_OP_MAP_UNARY, GGML_OP_MAP_UNARY,
GGML_OP_MAP_BINARY, GGML_OP_MAP_BINARY,
@ -500,8 +504,9 @@ extern "C" {
GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch); GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc); GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
GGML_API void * ggml_get_mem_buffer(struct ggml_context * ctx); GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
GGML_API size_t ggml_get_mem_size (struct ggml_context * ctx); GGML_API size_t ggml_get_mem_size (const struct ggml_context * ctx);
GGML_API size_t ggml_get_max_tensor_size(const struct ggml_context * ctx);
GGML_API struct ggml_tensor * ggml_new_tensor( GGML_API struct ggml_tensor * ggml_new_tensor(
struct ggml_context * ctx, struct ggml_context * ctx,
@ -557,7 +562,7 @@ extern "C" {
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor); GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor);
GGML_API void ggml_set_name(struct ggml_tensor * tensor, const char * name); GGML_API struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name);
// //
// operations on tensors with backpropagation // operations on tensors with backpropagation
@ -610,24 +615,47 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_sub_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_mul( GGML_API struct ggml_tensor * ggml_mul(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_mul_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_div( GGML_API struct ggml_tensor * ggml_div(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_div_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_sqr( GGML_API struct ggml_tensor * ggml_sqr(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sqr_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sqrt( GGML_API struct ggml_tensor * ggml_sqrt(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sqrt_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_log( GGML_API struct ggml_tensor * ggml_log(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
@ -667,31 +695,67 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_abs_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sgn( GGML_API struct ggml_tensor * ggml_sgn(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sgn_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_neg( GGML_API struct ggml_tensor * ggml_neg(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_neg_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_step( GGML_API struct ggml_tensor * ggml_step(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_step_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_relu( GGML_API struct ggml_tensor * ggml_relu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_relu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// TODO: double-check this computation is correct // TODO: double-check this computation is correct
GGML_API struct ggml_tensor * ggml_gelu( GGML_API struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_gelu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_gelu_quick(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_silu( GGML_API struct ggml_tensor * ggml_silu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_silu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// a - x // a - x
// b - dy // b - dy
GGML_API struct ggml_tensor * ggml_silu_back( GGML_API struct ggml_tensor * ggml_silu_back(
@ -705,10 +769,18 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_rms_norm( GGML_API struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// a - x // a - x
// b - dy // b - dy
GGML_API struct ggml_tensor * ggml_rms_norm_back( GGML_API struct ggml_tensor * ggml_rms_norm_back(
@ -998,16 +1070,55 @@ extern "C" {
float min, float min,
float max); float max);
// padding = 1 // TODO: implement general-purpose convolutions
// GGML_API struct ggml_tensor * ggml_conv_1d(
// struct ggml_context * ctx,
// struct ggml_tensor * a,
// struct ggml_tensor * b,
// int s0
// int p0,
// int d0);
//
// GGML_API struct ggml_tensor * ggml_conv_2d(
// struct ggml_context * ctx,
// struct ggml_tensor * a,
// struct ggml_tensor * b,
// int s0,
// int s1,
// int p0,
// int p1,
// int d0,
// int d1);
// padding = half
// TODO: we don't support extra parameters for now // TODO: we don't support extra parameters for now
// that's why we are hard-coding the stride, padding, and dilation // that's why we are hard-coding the stride, padding, and dilation
// not great .. // not great ..
GGML_API struct ggml_tensor * ggml_conv_1d_1s( // example:
// a: 3 80 768 1
// b: 3000 80 1 1
// res: 3000 768 1 1
// used in whisper
GGML_API struct ggml_tensor * ggml_conv_1d_s1_ph(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_conv_1d_2s( // used in whisper
GGML_API struct ggml_tensor * ggml_conv_1d_s2_ph(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// kernel size is a->ne[0] x a->ne[1]
// stride is equal to kernel size
// padding is zero
// example:
// a: 16 16 3 768
// b: 1024 1024 3 1
// res: 64 64 768 1
// used in sam
GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
@ -1035,6 +1146,26 @@ extern "C" {
struct ggml_tensor * c0, struct ggml_tensor * c0,
struct ggml_tensor * c1); struct ggml_tensor * c1);
// partition into non-overlapping windows with padding if needed
// example:
// a: 768 64 64 1
// w: 14
// res: 768 14 14 25
// used in sam
GGML_API struct ggml_tensor * ggml_win_part(
struct ggml_context * ctx,
struct ggml_tensor * a,
int w);
// reverse of ggml_win_part
// used in sam
GGML_API struct ggml_tensor * ggml_win_unpart(
struct ggml_context * ctx,
struct ggml_tensor * a,
int w0,
int h0,
int w);
// Mapping operations // Mapping operations
typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *); typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *); typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);

View file

@ -19,6 +19,11 @@
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
#include "ggml-metal.h" #include "ggml-metal.h"
#endif #endif
#ifdef GGML_USE_K_QUANTS
#ifndef QK_K
#define QK_K 256
#endif
#endif
#include <array> #include <array>
#include <ctime> #include <ctime>
@ -1615,7 +1620,7 @@ static bool llama_eval_internal(
model.layers[il].w1, model.layers[il].w1,
cur); cur);
offload_func(cur); offload_func(cur);
ggml_set_name(cur, "result_w2"); ggml_set_name(cur, "result_w1");
// SILU activation // SILU activation
cur = ggml_silu(ctx0, cur); cur = ggml_silu(ctx0, cur);
@ -1652,11 +1657,7 @@ static bool llama_eval_internal(
{ {
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL);
offload_func_nr(cur); offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_inpL"); ggml_set_name(cur, "rms_norm_2");
cur = ggml_rms_norm(ctx0, cur);
offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_after");
// cur = cur*norm(broadcasted) // cur = cur*norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.norm); cur = ggml_mul(ctx0, cur, model.norm);
@ -2491,8 +2492,23 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
} else { } else {
new_type = quantized_type; new_type = quantized_type;
#ifdef GGML_USE_K_QUANTS #ifdef GGML_USE_K_QUANTS
if (quantized_type == GGML_TYPE_Q2_K || quantized_type == GGML_TYPE_Q3_K || quantized_type == GGML_TYPE_Q4_K ||
quantized_type == GGML_TYPE_Q5_K || quantized_type == GGML_TYPE_Q6_K) {
int nx = tensor.ne.at(0);
int ny = tensor.ne.at(1);
if (nx % QK_K != 0 || ny % QK_K != 0) {
fprintf(stderr, "\n\n========================= Tensor sizes %d x %d are not divisible by %d\n",nx,ny,QK_K);
fprintf(stderr, "This is required to be able to use k-quants for now!\n");
fprintf(stderr, "========================================================================================\n\n");
throw std::runtime_error("Unsupported tensor size encountered\n");
}
}
if (tensor.name == "output.weight") { if (tensor.name == "output.weight") {
int nx = tensor.ne.at(0);
int ny = tensor.ne.at(1);
if (nx % QK_K == 0 && ny % QK_K == 0) {
new_type = GGML_TYPE_Q6_K; new_type = GGML_TYPE_Q6_K;
}
} else if (tensor.name.find("attention.wv.weight") != std::string::npos) { } else if (tensor.name.find("attention.wv.weight") != std::string::npos) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
@ -2696,16 +2712,21 @@ struct llama_context * llama_init_from_file(
// this allocates all Metal resources and memory buffers // this allocates all Metal resources and memory buffers
ctx->ctx_metal = ggml_metal_init(); ctx->ctx_metal = ggml_metal_init();
void *data_ptr = NULL; void * data_ptr = NULL;
size_t data_size = 0; size_t data_size = 0;
if (params.use_mmap) { if (params.use_mmap) {
data_ptr = ctx->model.mapping->addr; data_ptr = ctx->model.mapping->addr;
data_size= ctx->model.mapping->size; data_size = ctx->model.mapping->size;
} else { } else {
data_ptr = ggml_get_mem_buffer(ctx->model.ctx); data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
data_size= ggml_get_mem_size(ctx->model.ctx); data_size = ggml_get_mem_size (ctx->model.ctx);
} }
const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
printf("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
#define LLAMA_METAL_CHECK_BUF(result) \ #define LLAMA_METAL_CHECK_BUF(result) \
if (!(result)) { \ if (!(result)) { \
fprintf(stderr, "%s: failed to add buffer\n", __func__); \ fprintf(stderr, "%s: failed to add buffer\n", __func__); \
@ -2713,12 +2734,13 @@ struct llama_context * llama_init_from_file(
return NULL; \ return NULL; \
} }
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size, 0));
#undef LLAMA_METAL_CHECK_BUF #undef LLAMA_METAL_CHECK_BUF
} }
#endif #endif
@ -3104,9 +3126,7 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
if (kv_size) { if (kv_size) {
const size_t elt_size = ggml_element_size(kv_self.k); const size_t elt_size = ggml_element_size(kv_self.k);
char buffer[4096]; ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
ggml_cgraph gf{}; ggml_cgraph gf{};
gf.n_threads = 1; gf.n_threads = 1;
@ -3212,9 +3232,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
const size_t elt_size = ggml_element_size(kv_self.k); const size_t elt_size = ggml_element_size(kv_self.k);
char buffer[4096]; ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
ggml_cgraph gf{}; ggml_cgraph gf{};
gf.n_threads = 1; gf.n_threads = 1;
@ -3449,9 +3467,12 @@ void llama_print_timings(struct llama_context * ctx) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0); fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0);
fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_sample_us, n_sample, 1e-3 * ctx->t_sample_us / n_sample); fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_p_eval_us, n_p_eval, 1e-3 * ctx->t_p_eval_us / n_p_eval); __func__, 1e-3 * ctx->t_sample_us, n_sample, 1e-3 * ctx->t_sample_us / n_sample, 1e6 / ctx->t_sample_us * n_sample);
fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_eval_us, n_eval, 1e-3 * ctx->t_eval_us / n_eval); fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, 1e-3 * ctx->t_p_eval_us, n_p_eval, 1e-3 * ctx->t_p_eval_us / n_p_eval, 1e6 / ctx->t_p_eval_us * n_p_eval);
fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, 1e-3 * ctx->t_eval_us, n_eval, 1e-3 * ctx->t_eval_us / n_eval, 1e6 / ctx->t_eval_us * n_eval);
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0); fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0);
} }