From 84ca9c2ecf3391d911589d0fe2b483cbfb4b82a6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 29 Apr 2023 13:48:11 +0300 Subject: [PATCH 1/7] examples : fix save-load-state + rename llama-util.h --- examples/save-load-state/save-load-state.cpp | 72 +++++++++++--------- llama_util.h => llama-util.h | 1 - llama.cpp | 3 +- 3 files changed, 42 insertions(+), 34 deletions(-) rename llama_util.h => llama-util.h (99%) mode change 100755 => 100644 diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 07dfa2c74..f5f02ec1d 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -1,13 +1,10 @@ +#include "common.h" +#include "llama.h" + #include #include #include -#include "common.h" -#include "llama.h" -#include "llama.cpp" - -using namespace std; - int main(int argc, char ** argv) { gpt_params params; params.model = "models/llama-7B/ggml-model.bin"; @@ -20,21 +17,25 @@ int main(int argc, char ** argv) { return 1; } + if (params.n_predict < 0) { + params.n_predict = 16; + } + auto lparams = llama_context_default_params(); - lparams.n_ctx = params.n_ctx; - lparams.n_parts = params.n_parts; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; + lparams.n_ctx = params.n_ctx; + lparams.n_parts = params.n_parts; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.use_mmap = params.use_mmap; + lparams.use_mlock = params.use_mlock; auto n_past = 0; - auto last_n_tokens_data = vector(params.repeat_last_n, 0); + auto last_n_tokens_data = std::vector(params.repeat_last_n, 0); // init auto ctx = llama_init_from_file(params.model.c_str(), lparams); - auto tokens = vector(params.n_ctx); + auto tokens = std::vector(params.n_ctx); auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), tokens.size(), true); if (n_prompt_tokens < 1) { @@ -43,23 +44,25 @@ int main(int argc, char ** argv) { } // evaluate prompt - llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads); last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); n_past += n_prompt_tokens; + const size_t state_size = llama_get_state_size(ctx); + uint8_t * state_mem = new uint8_t[state_size]; + // Save state (rng, logits, embedding and kv_cache) to file - FILE *fp_write = fopen("dump_state.bin", "wb"); - auto state_size = llama_get_state_size(ctx); - auto state_mem = new uint8_t[state_size]; - llama_copy_state_data(ctx, state_mem); // could also copy directly to memory mapped file - fwrite(state_mem, 1, state_size, fp_write); - fclose(fp_write); + { + FILE *fp_write = fopen("dump_state.bin", "wb"); + llama_copy_state_data(ctx, state_mem); // could also copy directly to memory mapped file + fwrite(state_mem, 1, state_size, fp_write); + fclose(fp_write); + } // save state (last tokens) - auto last_n_tokens_data_saved = vector(last_n_tokens_data); - auto n_past_saved = n_past; + const auto last_n_tokens_data_saved = std::vector(last_n_tokens_data); + const auto n_past_saved = n_past; // first run printf("\n%s", params.prompt.c_str()); @@ -75,6 +78,7 @@ int main(int argc, char ** argv) { auto next_token = llama_sample_token(ctx, &candidates_p); auto next_token_str = llama_token_to_str(ctx, next_token); last_n_tokens_data.push_back(next_token); + printf("%s", next_token_str); if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); @@ -88,18 +92,21 @@ int main(int argc, char ** argv) { llama_free(ctx); // load new model - auto ctx2 = llama_init_from_file(params.model.c_str(), lparams); // Load state (rng, logits, embedding and kv_cache) from file - FILE *fp_read = fopen("dump_state.bin", "rb"); - auto state_size2 = llama_get_state_size(ctx2); - if (state_size != state_size2) { - fprintf(stderr, "\n%s : failed to validate state size\n", __func__); + { + FILE *fp_read = fopen("dump_state.bin", "rb"); + if (state_size != llama_get_state_size(ctx2)) { + fprintf(stderr, "\n%s : failed to validate state size\n", __func__); + return 1; + } + fread(state_mem, 1, state_size, fp_read); + llama_set_state_data(ctx2, state_mem); // could also read directly from memory mapped file + fclose(fp_read); } - fread(state_mem, 1, state_size, fp_read); - llama_set_state_data(ctx2, state_mem); // could also read directly from memory mapped file - fclose(fp_read); + + delete[] state_mem; // restore state (last tokens) last_n_tokens_data = last_n_tokens_data_saved; @@ -118,6 +125,7 @@ int main(int argc, char ** argv) { auto next_token = llama_sample_token(ctx2, &candidates_p); auto next_token_str = llama_token_to_str(ctx2, next_token); last_n_tokens_data.push_back(next_token); + printf("%s", next_token_str); if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); @@ -125,6 +133,8 @@ int main(int argc, char ** argv) { } n_past += 1; } + printf("\n\n"); + return 0; } diff --git a/llama_util.h b/llama-util.h old mode 100755 new mode 100644 similarity index 99% rename from llama_util.h rename to llama-util.h index 6e66d12a8..ca4dd162f --- a/llama_util.h +++ b/llama-util.h @@ -430,5 +430,4 @@ struct llama_ctx_buffer { typedef llama_buffer llama_ctx_buffer; #endif - #endif diff --git a/llama.cpp b/llama.cpp index 1032fb9fa..dc4bdc534 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5,7 +5,7 @@ #include #endif -#include "llama_util.h" +#include "llama-util.h" #include "llama.h" #include "ggml.h" @@ -33,7 +33,6 @@ #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 - // available llama models enum e_model { MODEL_UNKNOWN, From 305eb5afd51325e3142c01c17431febb7c67de87 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 29 Apr 2023 13:53:12 +0300 Subject: [PATCH 2/7] build : fix reference to old llama_util.h --- CMakeLists.txt | 2 +- Makefile | 2 +- examples/save-load-state/save-load-state.cpp | 10 +++++++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5fdbeddfc..bbf599559 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -337,7 +337,7 @@ endif() add_library(llama llama.cpp llama.h - llama_util.h) + llama-util.h) target_include_directories(llama PUBLIC .) target_compile_features(llama PUBLIC cxx_std_11) # don't bump diff --git a/Makefile b/Makefile index 5a1cb3e83..fd695d7dd 100644 --- a/Makefile +++ b/Makefile @@ -168,7 +168,7 @@ $(info ) ggml.o: ggml.c ggml.h ggml-cuda.h $(CC) $(CFLAGS) -c $< -o $@ -llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama_util.h +llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama-util.h $(CXX) $(CXXFLAGS) -c $< -o $@ common.o: examples/common.cpp examples/common.h diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index f5f02ec1d..f1531ba39 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -66,6 +66,7 @@ int main(int argc, char ** argv) { // first run printf("\n%s", params.prompt.c_str()); + for (auto i = 0; i < params.n_predict; i++) { auto logits = llama_get_logits(ctx); auto n_vocab = llama_n_vocab(ctx); @@ -86,6 +87,7 @@ int main(int argc, char ** argv) { } n_past += 1; } + printf("\n\n"); // free old model @@ -101,7 +103,13 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n%s : failed to validate state size\n", __func__); return 1; } - fread(state_mem, 1, state_size, fp_read); + + const size_t ret = fread(state_mem, 1, state_size, fp_read); + if (ret != state_size) { + fprintf(stderr, "\n%s : failed to read state\n", __func__); + return 1; + } + llama_set_state_data(ctx2, state_mem); // could also read directly from memory mapped file fclose(fp_read); } From 214b6a35702a489e3738acd81fad6d46182d3036 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 29 Apr 2023 18:43:28 +0300 Subject: [PATCH 3/7] ggml : adjust mul_mat_f16 work memory (#1226) * llama : minor - remove explicity int64_t cast * ggml : reduce memory buffer for F16 mul_mat when not using cuBLAS * ggml : add asserts to guard for incorrect wsize --- Makefile | 9 +++++++-- ggml.c | 21 +++++++++++++++------ llama.cpp | 2 +- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/Makefile b/Makefile index fd695d7dd..4516e8556 100644 --- a/Makefile +++ b/Makefile @@ -34,10 +34,15 @@ endif # # keep standard at C11 and C++11 -CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC -CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC +CFLAGS = -I. -O3 -std=c11 -fPIC +CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC LDFLAGS = +ifndef LLAMA_DEBUG + CFLAGS += -DNDEBUG + CXXFLAGS += -DNDEBUG +endif + # warnings CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar diff --git a/ggml.c b/ggml.c index 64ecd0867..0dc1939f6 100644 --- a/ggml.c +++ b/ggml.c @@ -8245,8 +8245,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); -#else - float * const wdata = params->wdata; #endif for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -8263,8 +8261,11 @@ static void ggml_compute_forward_mul_mat_f16_f32( wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10)); } } + + assert(id*sizeof(ggml_fp16_t) <= params->wsize); } #else + float * const wdata = params->wdata; { size_t id = 0; for (int64_t i01 = 0; i01 < ne01; ++i01) { @@ -8272,6 +8273,8 @@ static void ggml_compute_forward_mul_mat_f16_f32( wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); } } + + assert(id*sizeof(float) <= params->wsize); } #endif @@ -8537,7 +8540,10 @@ static void ggml_compute_forward_mul_mat_q_f32( dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); id += ne00; } + + assert(id*sizeof(float) <= params->wsize); } + const float * x = wdata; #endif @@ -11571,10 +11577,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning - cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*MAX(ggml_nelements(node->src1), ggml_nelements(node->src0)); - //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); - //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); - //printf("cur = %zu\n", cur); +#if defined(GGML_USE_CUBLAS) + // with cuBLAS, we need memory for the full 3D / 4D data of src1 + cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); +#else + // here we need memory just for single 2D matrix from src0 + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); +#endif } else { cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); } diff --git a/llama.cpp b/llama.cpp index dc4bdc534..f8b4c8e46 100644 --- a/llama.cpp +++ b/llama.cpp @@ -780,7 +780,7 @@ static bool kv_cache_init( const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; - const int64_t n_mem = (int64_t)n_layer*n_ctx; + const int64_t n_mem = n_layer*n_ctx; const int64_t n_elements = n_embd*n_mem; cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); From ec728e44d7488c2da3560970317708b2b12b9c04 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 29 Apr 2023 18:43:42 +0300 Subject: [PATCH 4/7] ggml : fix #if for f32_f32 mul_mat (CLBlast) (#1229) --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 0dc1939f6..6b9237186 100644 --- a/ggml.c +++ b/ggml.c @@ -11592,7 +11592,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) #endif } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; } From 0b5a9350993e6fc8be45dc2a3eafc1fd0812d392 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 29 Apr 2023 19:28:36 +0300 Subject: [PATCH 5/7] ggml : fix visibility and unused warnings --- ggml.c | 4 ++-- ggml.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml.c b/ggml.c index 6b9237186..ebbaf11c6 100644 --- a/ggml.c +++ b/ggml.c @@ -9124,7 +9124,7 @@ static void ggml_compute_forward_alibi_f32( //const int nb3 = src0->nb[3]; assert(nb0 == sizeof(float)); - assert(ne1+n_past == ne0); + assert(ne1 + n_past == ne0); (void) n_past; // add alibi to src0 (KQ_scaled) const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); @@ -9185,7 +9185,7 @@ static void ggml_compute_forward_alibi_f16( //const int nb3 = src0->nb[3]; assert(nb0 == sizeof(ggml_fp16_t)); - assert(ne1+n_past == ne0); + assert(ne1 + n_past == ne0); (void) n_past; // add alibi to src0 (KQ_scaled) const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); diff --git a/ggml.h b/ggml.h index 38ae9a6ee..c1c5495c6 100644 --- a/ggml.h +++ b/ggml.h @@ -701,8 +701,8 @@ extern "C" { struct ggml_tensor * c1); // Mapping operations - GGML_API typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *); - GGML_API typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *); + typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *); + typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *); GGML_API struct ggml_tensor * ggml_map_unary_f32( struct ggml_context * ctx, From e8c051611abfc9a7f37fd4bba48217180893bd68 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 29 Apr 2023 21:12:56 +0300 Subject: [PATCH 6/7] ggml : use vzip instead of vuzp for consistency --- ggml.c | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/ggml.c b/ggml.c index ebbaf11c6..c9f0f09ea 100644 --- a/ggml.c +++ b/ggml.c @@ -2658,35 +2658,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + // interleave + const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs); + const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs); + const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs); + const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs); + // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - // interleave - const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h); - const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h); - const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h); - const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h); - #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs); + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); From c3ca7a5f0546c561eb278be3f2fe335795679e01 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 29 Apr 2023 21:34:23 +0300 Subject: [PATCH 7/7] ggml : fix 32-bit ARM NEON --- ggml.c | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/ggml.c b/ggml.c index c9f0f09ea..4d53b4628 100644 --- a/ggml.c +++ b/ggml.c @@ -668,6 +668,33 @@ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { return vget_high_u8(vcombine_u8(a, b)); } +int8x16_t vzip1q_s8(int8x16_t a, int8x16_t b) { + return vcombine_s8(vget_low_s8(a), vget_low_s8(b)); +} + +int8x16_t vzip2q_s8(int8x16_t a, int8x16_t b) { + return vcombine_s8(vget_high_s8(a), vget_high_s8(b)); +} + +uint8x16_t vzip1q_u8(uint8x16_t a, uint8x16_t b) { + return vcombine_u8(vget_low_u8(a), vget_low_u8(b)); +} + +uint8x16_t vzip2q_u8(uint8x16_t a, uint8x16_t b) { + return vcombine_u8(vget_high_u8(a), vget_high_u8(b)); +} + +int32x4_t vcvtnq_s32_f32(float32x4_t v) { + int32x4_t res; + + res[0] = roundf(vgetq_lane_f32(v, 0)); + res[1] = roundf(vgetq_lane_f32(v, 1)); + res[2] = roundf(vgetq_lane_f32(v, 2)); + res[3] = roundf(vgetq_lane_f32(v, 3)); + + return res; +} + #endif #endif