diff --git a/CMakeLists.txt b/CMakeLists.txt index 6206737c2..6efba527d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,8 @@ set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") option(LLAMA_CUDA_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF) set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") +set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING + "llama: max. batch size for using peer access") option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_K_QUANTS "llama: use k-quants" ON) @@ -93,6 +95,7 @@ if (LLAMA_CUBLAS) add_compile_definitions(GGML_CUDA_F16) endif() add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) + add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE}) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) diff --git a/common/common.cpp b/common/common.cpp index 02ec0f8d0..6d655fd55 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -801,10 +801,10 @@ std::vector llama_tokenize( // upper limit for the number of tokens int n_tokens = text.length() + add_bos; std::vector result(n_tokens); - n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 947aa7ed3..59c90c7ba 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -965,10 +965,10 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto buf[size] = '\0'; - int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); + int n_tokens = llama_tokenize(lctx, buf.data(), buf.size(), out.data(), out.size(), false); if (n_tokens < 0) { out.resize(-n_tokens); - n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); + n_tokens = llama_tokenize(lctx, buf.data(), buf.size(), out.data(), out.size(), false); } GGML_ASSERT(n_tokens >= 0); out.resize(n_tokens); diff --git a/ggml-cuda.cu b/ggml-cuda.cu index aa179487e..78ba94f9a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -31,6 +31,9 @@ #define cublasSetStream hipblasSetStream #define cublasSgemm hipblasSgemm #define cublasStatus_t hipblasStatus_t +#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t @@ -61,7 +64,7 @@ #define cudaStreamCreateWithFlags hipStreamCreateWithFlags #define cudaStreamNonBlocking hipStreamNonBlocking #define cudaStreamSynchronize hipStreamSynchronize -#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0) +#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #else @@ -190,6 +193,12 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } while (0) #endif // CUDART_VERSION >= 11 +#if CUDART_VERSION >= 11100 +#define GGML_CUDA_ASSUME(x) __builtin_assume(x) +#else +#define GGML_CUDA_ASSUME(x) +#endif // CUDART_VERSION >= 11100 + #ifdef GGML_CUDA_F16 typedef half dfloat; // dequantize float typedef half2 dfloat2; @@ -418,6 +427,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); #endif +#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE +#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128 +#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE + #define MUL_MAT_SRC1_COL_STRIDE 128 #define MAX_STREAMS 8 @@ -2145,10 +2158,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); const int kbx = k / QI4_0; const int kqsx = k % QI4_0; @@ -2239,10 +2252,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); const int kbx = k / QI4_1; const int kqsx = k % QI4_1; @@ -2331,10 +2344,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); const int kbx = k / QI5_0; const int kqsx = k % QI5_0; @@ -2445,10 +2458,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); const int kbx = k / QI5_1; const int kqsx = k % QI5_1; @@ -2551,10 +2564,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); const int kbx = k / QI8_0; const int kqsx = k % QI8_0; @@ -2642,10 +2655,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); const int kbx = k / QI2_K; const int kqsx = k % QI2_K; @@ -2763,10 +2776,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); const int kbx = k / QI3_K; const int kqsx = k % QI3_K; @@ -2981,10 +2994,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); const int kbx = k / QI4_K; // == 0 if QK_K == 256 const int kqsx = k % QI4_K; // == k if QK_K == 256 @@ -3162,10 +3175,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); const int kbx = k / QI5_K; // == 0 if QK_K == 256 const int kqsx = k % QI5_K; // == k if QK_K == 256 @@ -3291,10 +3304,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); const int kbx = k / QI6_K; // == 0 if QK_K == 256 const int kqsx = k % QI6_K; // == k if QK_K == 256 @@ -6248,6 +6261,43 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s } } +void ggml_cuda_set_peer_access(const int n_tokens) { + static bool peer_access_enabled = false; + + const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; + + if (peer_access_enabled == enable_peer_access) { + return; + } + +#ifdef NDEBUG + for (int id = 0; id < g_device_count; ++id) { + CUDA_CHECK(ggml_cuda_set_device(id)); + + for (int id_other = 0; id_other < g_device_count; ++id_other) { + if (id == id_other) { + continue; + } + if (id != g_main_device && id_other != g_main_device) { + continue; + } + + int can_access_peer; + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); + if (can_access_peer) { + if (enable_peer_access) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); + } else { + CUDA_CHECK(cudaDeviceDisablePeerAccess(id_other)); + } + } + } + } +#endif // NDEBUG + + peer_access_enabled = enable_peer_access; +} + static void ggml_cuda_op_mul_mat( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, const bool convert_src1_to_q8_1) { @@ -6272,6 +6322,8 @@ static void ggml_cuda_op_mul_mat( const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; + ggml_cuda_set_peer_access(ne11); + GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT); @@ -6404,7 +6456,7 @@ static void ggml_cuda_op_mul_mat( // wait for main GPU data if necessary if (split && (id != g_main_device || is != 0)) { - CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0])); + CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0], 0)); } for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) { @@ -6526,7 +6578,7 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(ggml_cuda_set_device(g_main_device)); for (int64_t id = 0; id < g_device_count; ++id) { for (int64_t is = 0; is < is_max; ++is) { - CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is])); + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0)); } } } @@ -6960,6 +7012,7 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) return; } if (g_scratch_buffer == nullptr) { + ggml_cuda_set_device(g_main_device); CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size)); } @@ -6999,7 +7052,7 @@ void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) { ggml_cuda_assign_buffers_impl(tensor, false, true, false); } -void ggml_cuda_set_main_device(int main_device) { +void ggml_cuda_set_main_device(const int main_device) { if (main_device >= g_device_count) { fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n", main_device, g_device_count, g_main_device); @@ -7013,11 +7066,11 @@ void ggml_cuda_set_main_device(int main_device) { } } -void ggml_cuda_set_mul_mat_q(bool mul_mat_q) { +void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) { g_mul_mat_q = mul_mat_q; } -void ggml_cuda_set_scratch_size(size_t scratch_size) { +void ggml_cuda_set_scratch_size(const size_t scratch_size) { g_scratch_size = scratch_size; } diff --git a/llama.cpp b/llama.cpp index 621404d6a..c07c35466 100644 --- a/llama.cpp +++ b/llama.cpp @@ -932,6 +932,7 @@ enum e_model { static const size_t kB = 1024; static const size_t MB = kB*kB; +static const size_t GB = kB*kB*kB; // default hparams (LLaMA 7B) struct llama_hparams { @@ -1285,6 +1286,7 @@ struct llama_model_loader { int n_created = 0; int64_t n_elements = 0; + size_t n_bytes = 0; bool use_mmap = false; @@ -1317,6 +1319,7 @@ struct llama_model_loader { const char * name = gguf_get_tensor_name(ctx_gguf, i); struct ggml_tensor * t = ggml_get_tensor(ctx_meta, name); n_elements += ggml_nelements(t); + n_bytes += ggml_nbytes(t); } LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", @@ -1915,7 +1918,12 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); - LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9); + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); + if (ml.n_bytes < GB) { + LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); + } else { + LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); + } // general kv LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); @@ -3505,7 +3513,7 @@ static struct ggml_cgraph * llm_build_starcoder( ggml_allocr_alloc(lctx.alloc, token); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(token->data, embd, N * n_embd * ggml_element_size(inpL)); + memcpy(token->data, embd, N * n_embd * ggml_element_size(token)); } } @@ -7045,19 +7053,21 @@ llama_token llama_token_nl(const struct llama_context * ctx) { int llama_tokenize( struct llama_context * ctx, const char * text, + int text_len, llama_token * tokens, int n_max_tokens, bool add_bos) { - return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos); + return llama_tokenize_with_model(&ctx->model, text, text_len, tokens, n_max_tokens, add_bos); } int llama_tokenize_with_model( const struct llama_model * model, const char * text, + int text_len, llama_token * tokens, int n_max_tokens, bool add_bos) { - auto res = llama_tokenize_internal(model->vocab, text, add_bos); + auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos); if (n_max_tokens < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); diff --git a/llama.h b/llama.h index c6ee038c7..369be048c 100644 --- a/llama.h +++ b/llama.h @@ -374,6 +374,7 @@ extern "C" { LLAMA_API int llama_tokenize( struct llama_context * ctx, const char * text, + int text_len, llama_token * tokens, int n_max_tokens, bool add_bos); @@ -381,6 +382,7 @@ extern "C" { LLAMA_API int llama_tokenize_with_model( const struct llama_model * model, const char * text, + int text_len, llama_token * tokens, int n_max_tokens, bool add_bos); diff --git a/tests/test-tokenizer-0-llama.cpp b/tests/test-tokenizer-0-llama.cpp index edbd86f85..dfb2e81a9 100644 --- a/tests/test-tokenizer-0-llama.cpp +++ b/tests/test-tokenizer-0-llama.cpp @@ -36,6 +36,7 @@ static const std::map> & k_tests() { { " Hello" , { 1678, 15043, }, }, { " Hello" , { 268, 15043, }, }, { " Hello\n Hello" , { 268, 15043, 13, 1678, 15043, }, }, + { " (" , { 29871, 313, }, }, }; return _k_tests; diff --git a/tests/test-tokenizer-1-llama.cpp b/tests/test-tokenizer-1-llama.cpp index 804ea2486..a95d462cf 100644 --- a/tests/test-tokenizer-1-llama.cpp +++ b/tests/test-tokenizer-1-llama.cpp @@ -87,10 +87,9 @@ int main(int argc, char **argv) { std::vector tokens = llama_tokenize(ctx, str, false); std::string check = llama_detokenize_spm(ctx, tokens); if (check != str) { - fprintf(stderr, "%s : error: token %d detokenizes to >%s<(%llu) but tokenization of this detokenizes to >%s<(%llu)\n", + fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n", __func__, i, str.c_str(), str.length(), check.c_str(), check.length()); - if(i != 3) - return 2; + return 2; } } @@ -99,11 +98,10 @@ int main(int argc, char **argv) { std::string str = codepoint_to_utf8(cp); std::vector tokens = llama_tokenize(ctx, str, false); std::string check = llama_detokenize_spm(ctx, tokens); - if (str != check) { - fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n", + if (cp != 9601 && str != check) { + fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n", __func__, cp, check.c_str(), check.length(), str.c_str(), str.length()); - if(cp != 0 && cp != 9601) - return 3; + return 3; } } } @@ -112,7 +110,7 @@ int main(int argc, char **argv) { std::vector tokens = llama_tokenize(ctx, str, false); std::string check = llama_detokenize_spm(ctx, tokens); if (str != check) { - fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n", + fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n", __func__, cp, check.c_str(), check.length(), str.c_str(), str.length()); return 4; }