Merge branch 'master' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # Makefile # README.md
This commit is contained in:
commit
0142760fc3
8 changed files with 129 additions and 62 deletions
|
@ -50,6 +50,8 @@ set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA
|
|||
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
|
||||
option(LLAMA_CUDA_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
|
||||
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
|
||||
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
||||
"llama: max. batch size for using peer access")
|
||||
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
|
||||
option(LLAMA_K_QUANTS "llama: use k-quants" ON)
|
||||
|
||||
|
@ -93,6 +95,7 @@ if (LLAMA_CUBLAS)
|
|||
add_compile_definitions(GGML_CUDA_F16)
|
||||
endif()
|
||||
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
|
||||
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE})
|
||||
|
||||
if (LLAMA_STATIC)
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
|
|
|
@ -801,10 +801,10 @@ std::vector<llama_token> llama_tokenize(
|
|||
// upper limit for the number of tokens
|
||||
int n_tokens = text.length() + add_bos;
|
||||
std::vector<llama_token> result(n_tokens);
|
||||
n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
||||
n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
||||
int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
|
|
|
@ -965,10 +965,10 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto
|
|||
|
||||
buf[size] = '\0';
|
||||
|
||||
int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
|
||||
int n_tokens = llama_tokenize(lctx, buf.data(), buf.size(), out.data(), out.size(), false);
|
||||
if (n_tokens < 0) {
|
||||
out.resize(-n_tokens);
|
||||
n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
|
||||
n_tokens = llama_tokenize(lctx, buf.data(), buf.size(), out.data(), out.size(), false);
|
||||
}
|
||||
GGML_ASSERT(n_tokens >= 0);
|
||||
out.resize(n_tokens);
|
||||
|
|
145
ggml-cuda.cu
145
ggml-cuda.cu
|
@ -31,6 +31,9 @@
|
|||
#define cublasSetStream hipblasSetStream
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
|
@ -61,7 +64,7 @@
|
|||
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
||||
#define cudaStreamNonBlocking hipStreamNonBlocking
|
||||
#define cudaStreamSynchronize hipStreamSynchronize
|
||||
#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
|
||||
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
||||
#define cudaStream_t hipStream_t
|
||||
#define cudaSuccess hipSuccess
|
||||
#else
|
||||
|
@ -190,6 +193,12 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
|||
} while (0)
|
||||
#endif // CUDART_VERSION >= 11
|
||||
|
||||
#if CUDART_VERSION >= 11100
|
||||
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
||||
#else
|
||||
#define GGML_CUDA_ASSUME(x)
|
||||
#endif // CUDART_VERSION >= 11100
|
||||
|
||||
#ifdef GGML_CUDA_F16
|
||||
typedef half dfloat; // dequantize float
|
||||
typedef half2 dfloat2;
|
||||
|
@ -418,6 +427,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|||
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
||||
#endif
|
||||
|
||||
#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
|
||||
#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
|
||||
#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
|
||||
|
||||
#define MUL_MAT_SRC1_COL_STRIDE 128
|
||||
|
||||
#define MAX_STREAMS 8
|
||||
|
@ -2145,10 +2158,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
|
||||
|
||||
__builtin_assume(i_offset >= 0);
|
||||
__builtin_assume(i_offset < nwarps);
|
||||
__builtin_assume(k >= 0);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
GGML_CUDA_ASSUME(i_offset >= 0);
|
||||
GGML_CUDA_ASSUME(i_offset < nwarps);
|
||||
GGML_CUDA_ASSUME(k >= 0);
|
||||
GGML_CUDA_ASSUME(k < WARP_SIZE);
|
||||
|
||||
const int kbx = k / QI4_0;
|
||||
const int kqsx = k % QI4_0;
|
||||
|
@ -2239,10 +2252,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
|
||||
|
||||
__builtin_assume(i_offset >= 0);
|
||||
__builtin_assume(i_offset < nwarps);
|
||||
__builtin_assume(k >= 0);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
GGML_CUDA_ASSUME(i_offset >= 0);
|
||||
GGML_CUDA_ASSUME(i_offset < nwarps);
|
||||
GGML_CUDA_ASSUME(k >= 0);
|
||||
GGML_CUDA_ASSUME(k < WARP_SIZE);
|
||||
|
||||
const int kbx = k / QI4_1;
|
||||
const int kqsx = k % QI4_1;
|
||||
|
@ -2331,10 +2344,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
|
||||
|
||||
__builtin_assume(i_offset >= 0);
|
||||
__builtin_assume(i_offset < nwarps);
|
||||
__builtin_assume(k >= 0);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
GGML_CUDA_ASSUME(i_offset >= 0);
|
||||
GGML_CUDA_ASSUME(i_offset < nwarps);
|
||||
GGML_CUDA_ASSUME(k >= 0);
|
||||
GGML_CUDA_ASSUME(k < WARP_SIZE);
|
||||
|
||||
const int kbx = k / QI5_0;
|
||||
const int kqsx = k % QI5_0;
|
||||
|
@ -2445,10 +2458,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
|
||||
|
||||
__builtin_assume(i_offset >= 0);
|
||||
__builtin_assume(i_offset < nwarps);
|
||||
__builtin_assume(k >= 0);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
GGML_CUDA_ASSUME(i_offset >= 0);
|
||||
GGML_CUDA_ASSUME(i_offset < nwarps);
|
||||
GGML_CUDA_ASSUME(k >= 0);
|
||||
GGML_CUDA_ASSUME(k < WARP_SIZE);
|
||||
|
||||
const int kbx = k / QI5_1;
|
||||
const int kqsx = k % QI5_1;
|
||||
|
@ -2551,10 +2564,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
|
||||
|
||||
__builtin_assume(i_offset >= 0);
|
||||
__builtin_assume(i_offset < nwarps);
|
||||
__builtin_assume(k >= 0);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
GGML_CUDA_ASSUME(i_offset >= 0);
|
||||
GGML_CUDA_ASSUME(i_offset < nwarps);
|
||||
GGML_CUDA_ASSUME(k >= 0);
|
||||
GGML_CUDA_ASSUME(k < WARP_SIZE);
|
||||
|
||||
const int kbx = k / QI8_0;
|
||||
const int kqsx = k % QI8_0;
|
||||
|
@ -2642,10 +2655,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
|
||||
|
||||
__builtin_assume(i_offset >= 0);
|
||||
__builtin_assume(i_offset < nwarps);
|
||||
__builtin_assume(k >= 0);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
GGML_CUDA_ASSUME(i_offset >= 0);
|
||||
GGML_CUDA_ASSUME(i_offset < nwarps);
|
||||
GGML_CUDA_ASSUME(k >= 0);
|
||||
GGML_CUDA_ASSUME(k < WARP_SIZE);
|
||||
|
||||
const int kbx = k / QI2_K;
|
||||
const int kqsx = k % QI2_K;
|
||||
|
@ -2763,10 +2776,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
|
||||
|
||||
__builtin_assume(i_offset >= 0);
|
||||
__builtin_assume(i_offset < nwarps);
|
||||
__builtin_assume(k >= 0);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
GGML_CUDA_ASSUME(i_offset >= 0);
|
||||
GGML_CUDA_ASSUME(i_offset < nwarps);
|
||||
GGML_CUDA_ASSUME(k >= 0);
|
||||
GGML_CUDA_ASSUME(k < WARP_SIZE);
|
||||
|
||||
const int kbx = k / QI3_K;
|
||||
const int kqsx = k % QI3_K;
|
||||
|
@ -2981,10 +2994,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
|
||||
|
||||
__builtin_assume(i_offset >= 0);
|
||||
__builtin_assume(i_offset < nwarps);
|
||||
__builtin_assume(k >= 0);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
GGML_CUDA_ASSUME(i_offset >= 0);
|
||||
GGML_CUDA_ASSUME(i_offset < nwarps);
|
||||
GGML_CUDA_ASSUME(k >= 0);
|
||||
GGML_CUDA_ASSUME(k < WARP_SIZE);
|
||||
|
||||
const int kbx = k / QI4_K; // == 0 if QK_K == 256
|
||||
const int kqsx = k % QI4_K; // == k if QK_K == 256
|
||||
|
@ -3162,10 +3175,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
|
||||
|
||||
__builtin_assume(i_offset >= 0);
|
||||
__builtin_assume(i_offset < nwarps);
|
||||
__builtin_assume(k >= 0);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
GGML_CUDA_ASSUME(i_offset >= 0);
|
||||
GGML_CUDA_ASSUME(i_offset < nwarps);
|
||||
GGML_CUDA_ASSUME(k >= 0);
|
||||
GGML_CUDA_ASSUME(k < WARP_SIZE);
|
||||
|
||||
const int kbx = k / QI5_K; // == 0 if QK_K == 256
|
||||
const int kqsx = k % QI5_K; // == k if QK_K == 256
|
||||
|
@ -3291,10 +3304,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
|
||||
|
||||
__builtin_assume(i_offset >= 0);
|
||||
__builtin_assume(i_offset < nwarps);
|
||||
__builtin_assume(k >= 0);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
GGML_CUDA_ASSUME(i_offset >= 0);
|
||||
GGML_CUDA_ASSUME(i_offset < nwarps);
|
||||
GGML_CUDA_ASSUME(k >= 0);
|
||||
GGML_CUDA_ASSUME(k < WARP_SIZE);
|
||||
|
||||
const int kbx = k / QI6_K; // == 0 if QK_K == 256
|
||||
const int kqsx = k % QI6_K; // == k if QK_K == 256
|
||||
|
@ -6248,6 +6261,43 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_set_peer_access(const int n_tokens) {
|
||||
static bool peer_access_enabled = false;
|
||||
|
||||
const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
|
||||
|
||||
if (peer_access_enabled == enable_peer_access) {
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef NDEBUG
|
||||
for (int id = 0; id < g_device_count; ++id) {
|
||||
CUDA_CHECK(ggml_cuda_set_device(id));
|
||||
|
||||
for (int id_other = 0; id_other < g_device_count; ++id_other) {
|
||||
if (id == id_other) {
|
||||
continue;
|
||||
}
|
||||
if (id != g_main_device && id_other != g_main_device) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int can_access_peer;
|
||||
CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
|
||||
if (can_access_peer) {
|
||||
if (enable_peer_access) {
|
||||
CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
|
||||
} else {
|
||||
CUDA_CHECK(cudaDeviceDisablePeerAccess(id_other));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // NDEBUG
|
||||
|
||||
peer_access_enabled = enable_peer_access;
|
||||
}
|
||||
|
||||
static void ggml_cuda_op_mul_mat(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
|
||||
const bool convert_src1_to_q8_1) {
|
||||
|
@ -6272,6 +6322,8 @@ static void ggml_cuda_op_mul_mat(
|
|||
const int nb2 = dst->nb[2];
|
||||
const int nb3 = dst->nb[3];
|
||||
|
||||
ggml_cuda_set_peer_access(ne11);
|
||||
|
||||
GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
|
||||
|
@ -6404,7 +6456,7 @@ static void ggml_cuda_op_mul_mat(
|
|||
|
||||
// wait for main GPU data if necessary
|
||||
if (split && (id != g_main_device || is != 0)) {
|
||||
CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0]));
|
||||
CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0], 0));
|
||||
}
|
||||
|
||||
for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
|
||||
|
@ -6526,7 +6578,7 @@ static void ggml_cuda_op_mul_mat(
|
|||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
for (int64_t is = 0; is < is_max; ++is) {
|
||||
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is]));
|
||||
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6960,6 +7012,7 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
|
|||
return;
|
||||
}
|
||||
if (g_scratch_buffer == nullptr) {
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
|
||||
}
|
||||
|
||||
|
@ -6999,7 +7052,7 @@ void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
|
|||
ggml_cuda_assign_buffers_impl(tensor, false, true, false);
|
||||
}
|
||||
|
||||
void ggml_cuda_set_main_device(int main_device) {
|
||||
void ggml_cuda_set_main_device(const int main_device) {
|
||||
if (main_device >= g_device_count) {
|
||||
fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
|
||||
main_device, g_device_count, g_main_device);
|
||||
|
@ -7013,11 +7066,11 @@ void ggml_cuda_set_main_device(int main_device) {
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_set_mul_mat_q(bool mul_mat_q) {
|
||||
void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
|
||||
g_mul_mat_q = mul_mat_q;
|
||||
}
|
||||
|
||||
void ggml_cuda_set_scratch_size(size_t scratch_size) {
|
||||
void ggml_cuda_set_scratch_size(const size_t scratch_size) {
|
||||
g_scratch_size = scratch_size;
|
||||
}
|
||||
|
||||
|
|
18
llama.cpp
18
llama.cpp
|
@ -932,6 +932,7 @@ enum e_model {
|
|||
|
||||
static const size_t kB = 1024;
|
||||
static const size_t MB = kB*kB;
|
||||
static const size_t GB = kB*kB*kB;
|
||||
|
||||
// default hparams (LLaMA 7B)
|
||||
struct llama_hparams {
|
||||
|
@ -1285,6 +1286,7 @@ struct llama_model_loader {
|
|||
int n_created = 0;
|
||||
|
||||
int64_t n_elements = 0;
|
||||
size_t n_bytes = 0;
|
||||
|
||||
bool use_mmap = false;
|
||||
|
||||
|
@ -1317,6 +1319,7 @@ struct llama_model_loader {
|
|||
const char * name = gguf_get_tensor_name(ctx_gguf, i);
|
||||
struct ggml_tensor * t = ggml_get_tensor(ctx_meta, name);
|
||||
n_elements += ggml_nelements(t);
|
||||
n_bytes += ggml_nbytes(t);
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n",
|
||||
|
@ -1915,7 +1918,12 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
|
||||
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
|
||||
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
|
||||
LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9);
|
||||
LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9);
|
||||
if (ml.n_bytes < GB) {
|
||||
LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
|
||||
}
|
||||
|
||||
// general kv
|
||||
LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str());
|
||||
|
@ -3505,7 +3513,7 @@ static struct ggml_cgraph * llm_build_starcoder(
|
|||
|
||||
ggml_allocr_alloc(lctx.alloc, token);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
memcpy(token->data, embd, N * n_embd * ggml_element_size(inpL));
|
||||
memcpy(token->data, embd, N * n_embd * ggml_element_size(token));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7045,19 +7053,21 @@ llama_token llama_token_nl(const struct llama_context * ctx) {
|
|||
int llama_tokenize(
|
||||
struct llama_context * ctx,
|
||||
const char * text,
|
||||
int text_len,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos) {
|
||||
return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos);
|
||||
return llama_tokenize_with_model(&ctx->model, text, text_len, tokens, n_max_tokens, add_bos);
|
||||
}
|
||||
|
||||
int llama_tokenize_with_model(
|
||||
const struct llama_model * model,
|
||||
const char * text,
|
||||
int text_len,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos) {
|
||||
auto res = llama_tokenize_internal(model->vocab, text, add_bos);
|
||||
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos);
|
||||
|
||||
if (n_max_tokens < (int) res.size()) {
|
||||
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
|
||||
|
|
2
llama.h
2
llama.h
|
@ -374,6 +374,7 @@ extern "C" {
|
|||
LLAMA_API int llama_tokenize(
|
||||
struct llama_context * ctx,
|
||||
const char * text,
|
||||
int text_len,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos);
|
||||
|
@ -381,6 +382,7 @@ extern "C" {
|
|||
LLAMA_API int llama_tokenize_with_model(
|
||||
const struct llama_model * model,
|
||||
const char * text,
|
||||
int text_len,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos);
|
||||
|
|
|
@ -36,6 +36,7 @@ static const std::map<std::string, std::vector<llama_token>> & k_tests() {
|
|||
{ " Hello" , { 1678, 15043, }, },
|
||||
{ " Hello" , { 268, 15043, }, },
|
||||
{ " Hello\n Hello" , { 268, 15043, 13, 1678, 15043, }, },
|
||||
{ " (" , { 29871, 313, }, },
|
||||
};
|
||||
|
||||
return _k_tests;
|
||||
|
|
|
@ -87,10 +87,9 @@ int main(int argc, char **argv) {
|
|||
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
|
||||
std::string check = llama_detokenize_spm(ctx, tokens);
|
||||
if (check != str) {
|
||||
fprintf(stderr, "%s : error: token %d detokenizes to >%s<(%llu) but tokenization of this detokenizes to >%s<(%llu)\n",
|
||||
fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
|
||||
__func__, i, str.c_str(), str.length(), check.c_str(), check.length());
|
||||
if(i != 3)
|
||||
return 2;
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,11 +98,10 @@ int main(int argc, char **argv) {
|
|||
std::string str = codepoint_to_utf8(cp);
|
||||
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
|
||||
std::string check = llama_detokenize_spm(ctx, tokens);
|
||||
if (str != check) {
|
||||
fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n",
|
||||
if (cp != 9601 && str != check) {
|
||||
fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
|
||||
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
|
||||
if(cp != 0 && cp != 9601)
|
||||
return 3;
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -112,7 +110,7 @@ int main(int argc, char **argv) {
|
|||
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
|
||||
std::string check = llama_detokenize_spm(ctx, tokens);
|
||||
if (str != check) {
|
||||
fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n",
|
||||
fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
|
||||
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
|
||||
return 4;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue