Merge branch 'master' of github.com:ggerganov/llama.cpp into phillip-kravtsov/support-adept-persimmon-8b

This commit is contained in:
Phillip Kravtsov 2023-10-02 14:00:14 -07:00
commit 5a0990c1c3
24 changed files with 1565 additions and 212 deletions

View file

@ -1,6 +1,9 @@
*.o *.o
*.a *.a
.cache/ .cache/
.git/
.github/
.gitignore
.vs/ .vs/
.vscode/ .vscode/
.DS_Store .DS_Store

1
.gitignore vendored
View file

@ -40,6 +40,7 @@ models-mnt
/embedding /embedding
/gguf /gguf
/gguf-llama-simple /gguf-llama-simple
/infill
/libllama.so /libllama.so
/llama-bench /llama-bench
/main /main

View file

@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason cmake_minimum_required(VERSION 3.13) # for add_link_options
project("llama.cpp" C CXX) project("llama.cpp" C CXX)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
@ -343,8 +343,9 @@ if (LLAMA_MPI)
set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h) set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h)
add_compile_definitions(GGML_USE_MPI) add_compile_definitions(GGML_USE_MPI)
add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS})
set(cxx_flags ${cxx_flags} -Wno-cast-qual) if (NOT MSVC)
set(c_flags ${c_flags} -Wno-cast-qual) add_compile_options(-Wno-cast-qual)
endif()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES}) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES})
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS}) set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS})
# Even if you're only using the C header, C++ programs may bring in MPI # Even if you're only using the C header, C++ programs may bring in MPI
@ -418,10 +419,11 @@ if (LLAMA_ALL_WARNINGS)
set(c_flags -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int set(c_flags -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int
-Werror=implicit-function-declaration) -Werror=implicit-function-declaration)
set(cxx_flags -Wmissing-declarations -Wmissing-noreturn) set(cxx_flags -Wmissing-declarations -Wmissing-noreturn)
set(host_cxx_flags "")
if (CMAKE_C_COMPILER_ID MATCHES "Clang") if (CMAKE_C_COMPILER_ID MATCHES "Clang")
set(warning_flags ${warning_flags} -Wunreachable-code-break -Wunreachable-code-return) set(warning_flags ${warning_flags} -Wunreachable-code-break -Wunreachable-code-return)
set(cxx_flags ${cxx_flags} -Wmissing-prototypes -Wextra-semi) set(host_cxx_flags ${host_cxx_flags} -Wmissing-prototypes -Wextra-semi)
if ( if (
(CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 3.8.0) OR (CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 3.8.0) OR
@ -431,27 +433,38 @@ if (LLAMA_ALL_WARNINGS)
endif() endif()
elseif (CMAKE_C_COMPILER_ID STREQUAL "GNU") elseif (CMAKE_C_COMPILER_ID STREQUAL "GNU")
set(c_flags ${c_flags} -Wdouble-promotion) set(c_flags ${c_flags} -Wdouble-promotion)
set(cxx_flags ${cxx_flags} -Wno-array-bounds) set(host_cxx_flags ${host_cxx_flags} -Wno-array-bounds)
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 7.1.0) if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 7.1.0)
set(cxx_flags ${cxx_flags} -Wno-format-truncation) set(host_cxx_flags ${host_cxx_flags} -Wno-format-truncation)
endif() endif()
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1.0) if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1.0)
set(cxx_flags ${cxx_flags} -Wextra-semi) set(host_cxx_flags ${host_cxx_flags} -Wextra-semi)
endif() endif()
endif() endif()
else() else()
# todo : msvc # todo : msvc
endif() endif()
add_compile_options( set(c_flags ${c_flags} ${warning_flags})
${warning_flags} set(cxx_flags ${cxx_flags} ${warning_flags})
"$<$<COMPILE_LANGUAGE:C>:${c_flags}>" add_compile_options("$<$<COMPILE_LANGUAGE:C>:${c_flags}>"
"$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags}>" "$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags} ${host_cxx_flags}>")
)
endif() endif()
if (NOT MSVC)
set(cuda_flags -Wno-pedantic)
endif()
set(cuda_flags ${cxx_flags} -use_fast_math ${cuda_flags})
list(JOIN host_cxx_flags " " cuda_host_flags) # pass host compiler flags as a single argument
if (NOT cuda_host_flags STREQUAL "")
set(cuda_flags ${cuda_flags} -Xcompiler ${cuda_host_flags})
endif()
add_compile_options("$<$<COMPILE_LANGUAGE:CUDA>:${cuda_flags}>")
if (WIN32) if (WIN32)
add_compile_definitions(_CRT_SECURE_NO_WARNINGS) add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
@ -705,6 +718,7 @@ set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR}
set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER}) set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER})
set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT}) set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT})
set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER}) set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER})
get_directory_property(LLAMA_TRANSIENT_DEFINES COMPILE_DEFINITIONS)
configure_package_config_file( configure_package_config_file(
${CMAKE_CURRENT_SOURCE_DIR}/scripts/LlamaConfig.cmake.in ${CMAKE_CURRENT_SOURCE_DIR}/scripts/LlamaConfig.cmake.in

View file

@ -1,5 +1,5 @@
# Define the default target now so that it is always the first target # Define the default target now so that it is always the first target
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative benchmark-matmult parallel finetune export-lora tests/test-c.o BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o
# Binaries only useful for tests # Binaries only useful for tests
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama
@ -543,6 +543,9 @@ main: examples/main/main.cpp build-info.h ggml.
@echo '==== Run ./main -h for help. ====' @echo '==== Run ./main -h for help. ===='
@echo @echo
infill: examples/infill/infill.cpp build-info.h ggml.o llama.o common.o console.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS) simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

View file

@ -389,6 +389,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.interactive_first = true; params.interactive_first = true;
} else if (arg == "-ins" || arg == "--instruct") { } else if (arg == "-ins" || arg == "--instruct") {
params.instruct = true; params.instruct = true;
} else if (arg == "--infill") {
params.infill = true;
} else if (arg == "--multiline-input") { } else if (arg == "--multiline-input") {
params.multiline_input = true; params.multiline_input = true;
} else if (arg == "--simple-io") { } else if (arg == "--simple-io") {

View file

@ -120,6 +120,7 @@ struct gpt_params {
bool use_mlock = false; // use mlock to keep model in memory bool use_mlock = false; // use mlock to keep model in memory
bool numa = false; // attempt optimizations that help on some NUMA systems bool numa = false; // attempt optimizations that help on some NUMA systems
bool verbose_prompt = false; // print prompt tokens before generation bool verbose_prompt = false; // print prompt tokens before generation
bool infill = false; // use infill mode
}; };
bool gpt_params_parse(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params);

View file

@ -0,0 +1,135 @@
from convert import lazy_load_safetensors_file
import sys
import torch
from safetensors import safe_open
from pathlib import Path
from pprint import pprint
from sentencepiece import SentencePieceProcessor
import argparse
import gguf
import json
import struct
def file_is_safetensors(path: Path) -> bool:
fp = open(path, 'rb')
first8 = fp.read(8)
fp.seek(0)
if first8[:2] == b'PK':
# A zip file, i.e. PyTorch format
return False
return struct.unpack('<Q', first8)[0] < 16 * 1024 * 1024
def get_tokenizer_info(dir_model: Path):
tokenizer_path = dir_model / 'adept_vocab.model'
print('gguf: getting sentencepiece tokenizer from', tokenizer_path)
tokenizer = SentencePieceProcessor(str(tokenizer_path))
tokens: list[bytes] = []
scores: list[float] = []
toktypes: list[int] = []
for i in range(tokenizer.vocab_size()):
text: bytes
score: float
piece = tokenizer.id_to_piece(i)
text = piece.encode("utf-8")
score = tokenizer.get_score(i)
toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3
# toktype = 4 is user-defined = tokens from added_tokens.json
if tokenizer.is_unused(i):
toktype = 5
if tokenizer.is_byte(i):
toktype = 6
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
pass
return tokens, scores, toktypes
def get_args():
parser = argparse.ArgumentParser(description="Convert a Persimmon model from Adept (e.g. Persimmon 8b chat) to a GGML compatible file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.safetensors)")
args = parser.parse_args()
return args
def main() -> None:
args = get_args()
assert file_is_safetensors(args.model), 'Error: model file is not a SafeTensors file'
dir_model = args.model.parent
with open(dir_model / 'config.json', 'r') as f:
hparams = json.load(f)
arch = gguf.MODEL_ARCH.PERSIMMON
gguf_writer = gguf.GGUFWriter(args.outfile, gguf.MODEL_ARCH_NAMES[arch])
block_count = hparams['num_layers']
head_count = hparams['num_attention_heads']
head_count_kv = head_count
ctx_length = hparams['seq_length']
hidden_size = hparams['hidden_size']
gguf_writer.add_name('persimmon-8b-chat')
gguf_writer.add_context_length(ctx_length)
gguf_writer.add_embedding_length(hidden_size)
gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(hparams['ffn_hidden_size'])
gguf_writer.add_rope_dimension_count(hidden_size // head_count)
gguf_writer.add_head_count(head_count)
gguf_writer.add_head_count_kv(head_count_kv)
gguf_writer.add_rope_freq_base(hparams['rotary_emb_base'])
gguf_writer.add_layer_norm_eps(hparams['layernorm_epsilon'])
tokens, scores, toktypes = get_tokenizer_info(dir_model)
gguf_writer.add_tokenizer_model('llama')
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
gguf_writer.add_bos_token_id(71013)
gguf_writer.add_eos_token_id(71013)
tensor_map = gguf.get_tensor_name_map(arch, block_count)
print(tensor_map)
tensors = {}
with safe_open(args.model, framework="pt") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
for name in tensors.keys():
data = tensors[name]
if name.endswith(".self_attention.rotary_emb.inv_freq"):
continue
old_dtype = data.dtype
# TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
data = data.to(torch.float32).squeeze().numpy()
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Can not map tensor '" + name + "'")
sys.exit()
n_dims = len(data.shape)
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.add_tensor(new_name, data)
print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
gguf_writer.close()
print(f"gguf: model successfully exported to '{args.outfile}'")
print("")
if __name__ == '__main__':
main()

View file

@ -41,8 +41,7 @@ if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
NDArray: TypeAlias = 'np.ndarray[Any, Any]' NDArray: TypeAlias = 'np.ndarray[Any, Any]'
ARCH=gguf.MODEL_ARCH.LLAMA ARCH = gguf.MODEL_ARCH.LLAMA
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
DEFAULT_CONCURRENCY = 8 DEFAULT_CONCURRENCY = 8
# #
@ -953,7 +952,7 @@ class OutputFile:
of.close() of.close()
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType: def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
wq_type = model[NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32): if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
return GGMLFileType.AllF32 return GGMLFileType.AllF32

View file

@ -313,7 +313,7 @@ class ModelParams:
gguf_writer.add_feed_forward_length(self.get_n_ff()) gguf_writer.add_feed_forward_length(self.get_n_ff())
def tensor_name(key, bid=None, suffix=".weight"): def tensor_name(key, bid=None, suffix=".weight"):
return gguf.MODEL_TENSOR_NAMES[gguf.MODEL_ARCH.LLAMA][key].format(bid=bid) + suffix return gguf.TENSOR_NAMES[key].format(bid=bid) + suffix
class Layer: class Layer:
def __init__(self, params, lora_params, bid): def __init__(self, params, lora_params, bid):

View file

@ -332,8 +332,8 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
assert_shape_1d(layer.attention_norm, hparams.n_embd); assert_shape_1d(layer.attention_norm, hparams.n_embd);
assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd); assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd);
assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd); assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd_gqa());
assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd); assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd_gqa());
assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd); assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd);
assert_shape_1d(layer.ffn_norm, hparams.n_embd); assert_shape_1d(layer.ffn_norm, hparams.n_embd);
assert_shape_2d(layer.w1, hparams.n_embd, hparams.n_ff); assert_shape_2d(layer.w1, hparams.n_embd, hparams.n_ff);

View file

@ -0,0 +1,8 @@
set(TARGET infill)
add_executable(${TARGET} infill.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()

41
examples/infill/README.md Normal file
View file

@ -0,0 +1,41 @@
# llama.cpp/example/infill
This example shows how to use the infill mode with Code Llama models supporting infill mode.
Currently the 7B and 13B models support infill mode.
Infill supports most of the options available in the main example.
For further information have a look at the main README.md in llama.cpp/example/main/README.md
## Common Options
In this section, we cover the most commonly used options for running the `infill` program with the LLaMA models:
- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
- `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text.
- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference.
## Input Prompts
The `infill` program provides several ways to interact with the LLaMA models using input prompts:
- `--in-prefix PROMPT_BEFORE_CURSOR`: Provide the prefix directly as a command-line option.
- `--in-suffix PROMPT_AFTER_CURSOR`: Provide the suffix directly as a command-line option.
- `--interactive-first`: Run the program in interactive mode and wait for input right away. (More on this below.)
## Interaction
The `infill` program offers a seamless way to interact with LLaMA models, allowing users to receive real-time infill suggestions. The interactive mode can be triggered using `--interactive`, and `--interactive-first`
### Interaction Options
- `-i, --interactive`: Run the program in interactive mode, allowing users to get real time code suggestions from model.
- `--interactive-first`: Run the program in interactive mode and immediately wait for user input before starting the text generation.
- `--color`: Enable colorized output to differentiate visually distinguishing between prompts, user input, and generated text.
### Example
```bash
./infill -t 10 -ngl 0 -m models/codellama-13b.Q5_K_S.gguf -c 4096 --temp 0.7 --repeat_penalty 1.1 -n 20 --in-prefix "def helloworld():\n print(\"hell" --in-suffix "\n print(\"goodbye world\")\n "
```

769
examples/infill/infill.cpp Normal file
View file

@ -0,0 +1,769 @@
#include "common.h"
#include "console.h"
#include "llama.h"
#include "build-info.h"
#include "grammar-parser.h"
#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <signal.h>
#endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
static llama_context ** g_ctx;
static llama_model ** g_model;
static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
static void write_logfile(
const llama_context * ctx, const gpt_params & params, const llama_model * model,
const std::vector<llama_token> & input_tokens, const std::string & output,
const std::vector<llama_token> & output_tokens
) {
if (params.logdir.empty()) {
return;
}
const std::string timestamp = get_sortable_timestamp();
const bool success = create_directory_with_parents(params.logdir);
if (!success) {
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
__func__, params.logdir.c_str());
return;
}
const std::string logfile_path = params.logdir + timestamp + ".yml";
FILE * logfile = fopen(logfile_path.c_str(), "w");
if (logfile == NULL) {
fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
return;
}
fprintf(logfile, "binary: infill\n");
char model_desc[128];
llama_model_desc(model, model_desc, sizeof(model_desc));
dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc);
fprintf(logfile, "\n");
fprintf(logfile, "######################\n");
fprintf(logfile, "# Generation Results #\n");
fprintf(logfile, "######################\n");
fprintf(logfile, "\n");
dump_string_yaml_multiline(logfile, "output", output.c_str());
dump_vector_int_yaml(logfile, "output_tokens", output_tokens);
llama_dump_timing_info_yaml(logfile, ctx);
fclose(logfile);
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void sigint_handler(int signo) {
if (signo == SIGINT) {
if (!is_interacting) {
is_interacting = true;
} else {
console::cleanup();
printf("\n");
llama_print_timings(*g_ctx);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
_exit(130);
}
}
}
#endif
int main(int argc, char ** argv) {
gpt_params params;
g_params = &params;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("infill", "log"));
LOG_TEE("Log start\n");
log_dump_cmdline(argc, argv);
#endif // LOG_DISABLE_LOGS
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });
if (params.logits_all) {
printf("\n************\n");
printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
printf("************\n\n");
return 0;
}
if (params.embedding) {
printf("\n************\n");
printf("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
printf("************\n\n");
return 0;
}
if (params.n_ctx != 0 && params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
}
if (params.instruct) {
printf("\n************\n");
printf("%s: please use the 'main' tool for instruct mode\n", __func__);
printf("************\n\n");
return 0;
}
if (!params.antiprompt.empty()) {
printf("\n************\n");
printf("%s: please use the 'main' tool for antiprompt mode\n", __func__);
printf("************\n\n");
return 0;
}
if (!params.interactive_first && (params.input_prefix.empty() && params.input_suffix.empty())) {
printf("\n************\n");
printf("%s: please use '--interactive_first' or specify '--in_prefix' and/or '--in_suffix'\n", __func__);
printf("************\n\n");
return 0;
}
if (params.random_prompt) {
printf("\n************\n");
printf("%s: please use the 'main' tool for random prompt mode\n", __func__);
printf("************\n\n");
return 0;
}
if (!params.path_prompt_cache.empty()) {
printf("\n************\n");
printf("%s: infill does not support prompt caching\n", __func__);
printf("************\n\n");
return 0;
}
if (params.rope_freq_base != 0.0) {
LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base);
}
if (params.rope_freq_scale != 0.0) {
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
}
LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
LOG_TEE("%s: built with %s for %s\n", __func__, BUILD_COMPILER, BUILD_TARGET);
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
std::mt19937 rng(params.seed);
LOG("%s: llama backend init\n", __func__);
llama_backend_init(params.numa);
llama_model * model;
llama_context * ctx;
llama_context * ctx_guidance = NULL;
g_model = &model;
g_ctx = &ctx;
// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (params.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams);
}
if (model == NULL) {
LOG_TEE("%s: error: unable to load model\n", __func__);
return 1;
}
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);
if (n_ctx > n_ctx_train) {
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, n_ctx);
}
// print system information
{
LOG_TEE("\n");
LOG_TEE("%s\n", get_system_info(params).c_str());
}
const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
LOG("add_bos: %d\n", add_bos);
std::vector<llama_token> embd_inp;
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos);
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos);
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
embd_inp.push_back(llama_token_middle(ctx));
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
}
// Tokenize negative prompt
std::vector<llama_token> guidance_inp;
int guidance_offset = 0;
int original_prompt_len = 0;
if (ctx_guidance) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
LOG("guidance_offset: %s", log_tostr(guidance_offset));
}
if ((int) embd_inp.size() > n_ctx - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
return 1;
}
// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size()) {
params.n_keep = (int)embd_inp.size();
}
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
// enable interactive mode if interactive start is specified
if (params.interactive_first) {
params.interactive = true;
}
if (params.verbose_prompt) {
LOG_TEE("\n");
LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
for (int i = 0; i < (int) embd_inp.size(); i++) {
LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
}
if (ctx_guidance) {
LOG_TEE("\n");
LOG_TEE("%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
for (int i = 0; i < (int) guidance_inp.size(); i++) {
LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
}
}
if (params.n_keep > 0) {
LOG_TEE("%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
LOG_TEE("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
}
LOG_TEE("'\n");
}
LOG_TEE("\n");
}
if (params.interactive) {
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = sigint_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
LOG_TEE("%s: interactive mode on.\n", __func__);
if (params.input_prefix_bos) {
LOG_TEE("Input prefix with BOS\n");
}
if (!params.input_prefix.empty()) {
LOG_TEE("Input prefix: '%s'\n", params.input_prefix.c_str());
}
if (!params.input_suffix.empty()) {
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
}
}
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
struct llama_grammar * grammar = NULL;
grammar_parser::parse_state parsed_grammar;
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
return 1;
}
LOG_TEE("%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, parsed_grammar);
LOG_TEE("\n");
{
auto it = params.logit_bias.find(llama_token_eos(ctx));
if (it != params.logit_bias.end() && it->second == -INFINITY) {
LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
}
}
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
// TODO: replace with ring-buffer
std::vector<llama_token> last_tokens(n_ctx);
std::fill(last_tokens.begin(), last_tokens.end(), 0);
LOG_TEE("\n##### Infill mode #####\n\n");
if (params.infill) {
printf("\n************\n");
printf("no need to specify '--infill', always running infill\n");
printf("************\n\n");
}
if (params.interactive) {
const char *control_message;
if (params.multiline_input) {
control_message = " - To return control to LLaMa, end your input with '\\'.\n"
" - To return control without starting a new line, end your input with '/'.\n";
} else {
control_message = " - Press Return to return control to LLaMa.\n"
" - To return control without starting a new line, end your input with '/'.\n"
" - If you want to submit another line, end your input with '\\'.\n";
}
LOG_TEE("== Running in interactive mode. ==\n");
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
LOG_TEE( " - Press Ctrl+C to interject at any time.\n");
#endif
LOG_TEE( "%s\n", control_message);
is_interacting = params.interactive_first;
}
bool input_echo = true;
int n_past = 0;
int n_remain = params.n_predict;
int n_consumed = 0;
int n_past_guidance = 0;
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
std::ostringstream output_ss; g_output_ss = &output_ss;
// the first thing we will do is to output the prompt, so set color accordingly
console::set_display(console::prompt);
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
const int n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
while (n_remain != 0 || params.interactive) {
// predict
if (!embd.empty()) {
// Note: n_ctx - 4 here is to match the logic for commandline prompt handling via
// --prompt or --file which uses the same value.
int max_embd_size = n_ctx - 4;
// Ensure the input doesn't exceed the context size by truncating embd if necessary.
if ((int) embd.size() > max_embd_size) {
const int skipped_tokens = (int) embd.size() - max_embd_size;
embd.resize(max_embd_size);
console::set_display(console::error);
printf("<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
console::set_display(console::reset);
fflush(stdout);
}
// infinite text generation via context swapping
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
if (params.n_predict == -2) {
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
}
const int n_left = n_past - params.n_keep - 1;
const int n_discard = n_left/2;
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard;
if (ctx_guidance) {
n_past_guidance -= n_discard;
}
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
}
// evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always
if (ctx_guidance) {
int input_size = 0;
llama_token * input_buf = NULL;
if (n_past_guidance < (int) guidance_inp.size()) {
// Guidance context should have the same data with these modifications:
//
// * Replace the initial prompt
// * Shift everything by guidance_offset
embd_guidance = guidance_inp;
if (embd.begin() + original_prompt_len < embd.end()) {
embd_guidance.insert(
embd_guidance.end(),
embd.begin() + original_prompt_len,
embd.end()
);
}
input_buf = embd_guidance.data();
input_size = embd_guidance.size();
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance));
} else {
input_buf = embd.data();
input_size = embd.size();
}
for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch);
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
n_past_guidance += n_eval;
}
}
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
int n_eval = (int) embd.size() - i;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
n_past += n_eval;
LOG("n_past = %d\n", n_past);
}
}
embd.clear();
embd_guidance.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
embd.push_back(id);
// echo this to console
input_echo = true;
// decrement remaining sampling budget
--n_remain;
LOG("n_remain: %d\n", n_remain);
} else {
// some user input remains from prompt or interaction, forward it to processing
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
}
}
}
// display text
if (input_echo) {
for (auto id : embd) {
const std::string token_str = llama_token_to_piece(ctx, id);
printf("%s", token_str.c_str());
if (embd.size() > 1) {
input_tokens.push_back(id);
} else {
output_tokens.push_back(id);
output_ss << token_str;
}
}
fflush(stdout);
}
// reset color to default if we there is no pending user input
if (input_echo && (int) embd_inp.size() == n_consumed) {
console::set_display(console::reset);
}
// if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode
if ((last_tokens.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
}
fflush(stdout);
printf("\n");
console::set_display(console::user_input);
std::string buffer;
std::string line;
bool another_line=true;
// set a new prefix via stdin
do {
another_line = console::readline(line, params.multiline_input);
buffer += line;
} while (another_line);
// check if we got an empty line, if so we use the old input
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
params.input_prefix = buffer;
}
buffer.clear();
// set a new suffix via stdin
do {
another_line = console::readline(line, params.multiline_input);
buffer += line;
} while (another_line);
// check if we got an empty line
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
params.input_suffix = buffer;
}
buffer.clear();
// done taking input, reset color
console::set_display(console::reset);
// tokenize new prefix and suffix
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos);
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos);
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
embd_inp.push_back(llama_token_middle(ctx));
embd.clear();
embd_guidance.clear();
n_remain = params.n_predict;
n_past = 0;
n_consumed = 0;
// LOG_TEE("took new input\n");
is_interacting = false;
}
// deal with end of text token in interactive mode
else if (last_tokens.back() == llama_token_eos(ctx)) {
LOG("found EOS token\n");
if (params.interactive) {
is_interacting = true;
printf("\n");
console::set_display(console::user_input);
fflush(stdout);
}
}
if (n_past > 0 && is_interacting && !params.interactive) {
LOG("waiting for user input\n");
if (params.input_prefix_bos) {
LOG("adding input prefix BOS token\n");
embd_inp.push_back(llama_token_bos(ctx));
}
std::string buffer;
if (!params.input_prefix.empty()) {
LOG("appending input prefix: '%s'\n", params.input_prefix.c_str());
buffer += params.input_prefix;
printf("%s", buffer.c_str());
}
std::string line;
bool another_line = true;
do {
another_line = console::readline(line, params.multiline_input);
buffer += line;
} while (another_line);
// done taking input, reset color
console::set_display(console::reset);
// Add tokens to embd only if the input buffer is non-empty
// Entering a empty line lets the user pass control back
if (buffer.length() > 1) {
// append input suffix if any
if (!params.input_suffix.empty()) {
LOG("appending input suffix: '%s'\n", params.input_suffix.c_str());
buffer += params.input_suffix;
printf("%s", params.input_suffix.c_str());
}
LOG("buffer: '%s'\n", buffer.c_str());
const size_t original_size = embd_inp.size();
const auto line_inp = ::llama_tokenize(ctx, buffer, false);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
for (size_t i = original_size; i < embd_inp.size(); ++i) {
const llama_token token = embd_inp[i];
output_tokens.push_back(token);
output_ss << llama_token_to_piece(ctx, token);
}
n_remain -= line_inp.size();
LOG("n_remain: %d\n", n_remain);
} else {
LOG("empty line, passing control back\n");
}
input_echo = false; // do not echo this again
}
if (n_past > 0) {
if (is_interacting) {
// reset grammar state if we're restarting generation
if (grammar != NULL) {
llama_grammar_free(grammar);
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(),
parsed_grammar.symbol_ids.at("root"));
}
}
is_interacting = false;
}
}
// end of text token
if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !params.interactive) {
break;
}
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
// We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size).
if (params.interactive && n_remain <= 0 && params.n_predict >= 0) {
n_remain = params.n_predict;
is_interacting = true;
}
}
if (!params.interactive && n_remain <= 0) {
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
fflush(stdout);
}
llama_print_timings(ctx);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
if (ctx_guidance) { llama_free(ctx_guidance); }
llama_free(ctx);
llama_free_model(model);
if (grammar != NULL) {
llama_grammar_free(grammar);
}
llama_backend_free();
#ifndef LOG_DISABLE_LOGS
LOG_TEE("Log end\n");
#endif // LOG_DISABLE_LOGS
return 0;
}

View file

@ -28,6 +28,16 @@ configure_file(${_common_path}/../build-info.h
target_include_directories(common PUBLIC ${LLAMA_INCLUDE_DIR} target_include_directories(common PUBLIC ${LLAMA_INCLUDE_DIR}
${CMAKE_CURRENT_BINARY_DIR}) ${CMAKE_CURRENT_BINARY_DIR})
# If the common project was part of "main-cmake-pkg" the transient
# defines would automatically be attached. Because the common func-
# tionality is separate, but dependent upon the defines, it must be
# explicitly extracted from the "llama" target.
#
get_target_property(_llama_transient_defines llama
INTERFACE_COMPILE_DEFINITIONS)
target_compile_definitions(common PRIVATE "${_llama_transient_defines}")
add_executable(${TARGET} ${CMAKE_CURRENT_LIST_DIR}/../main/main.cpp) add_executable(${TARGET} ${CMAKE_CURRENT_LIST_DIR}/../main/main.cpp)
target_include_directories(${TARGET} PRIVATE ${_common_path}) target_include_directories(${TARGET} PRIVATE ${_common_path})
install(TARGETS ${TARGET} RUNTIME) install(TARGETS ${TARGET} RUNTIME)

View file

@ -176,6 +176,16 @@ node index.js
`content`: Set the text to process. `content`: Set the text to process.
**POST** `/infill`: For code infilling. Takes a prefix and a suffix and returns the predicted completion as stream.
*Options:*
`input_prefix`: Set the prefix of the code to infill.
`input_suffix`: Set the suffix of the code to infill.
It also accepts all the options of `/completion` except `stream` and `prompt`.
## More examples ## More examples
### Interactive mode ### Interactive mode

View file

@ -342,6 +342,70 @@ struct llama_server_context
return true; return true;
} }
void loadInfill()
{
auto prefix_tokens = tokenize(params.input_prefix, true); // always add BOS
auto suffix_tokens = tokenize(params.input_suffix, true); // always add BOS
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(ctx));
auto prompt_tokens = prefix_tokens;
num_prompt_tokens = prompt_tokens.size();
if (params.n_keep < 0)
{
params.n_keep = (int)num_prompt_tokens;
}
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)params.n_ctx)
{
printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens);
// todo we probably want to cut from both sides
const int n_left = (params.n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", params.n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
truncated = true;
prompt_tokens = new_tokens;
}
else
{
const size_t ps = num_prompt_tokens;
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
}
// compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens);
embd = prompt_tokens;
if (n_past == num_prompt_tokens)
{
// we have to evaluate at least 1 token to generate logits.
printf("we have to evaluate at least 1 token to generate logits\n");
n_past--;
}
LOG_VERBOSE("prompt ingested", {
{"n_past", n_past},
{"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)},
{"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
});
has_next_token = true;
}
void loadPrompt() void loadPrompt()
{ {
auto prompt_tokens = tokenize(prompt, true); // always add BOS auto prompt_tokens = tokenize(prompt, true); // always add BOS
@ -1219,6 +1283,27 @@ static void parse_options_completion(const json &body, llama_server_context &lla
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
} }
static void parse_options_infill(const json &body, llama_server_context &llama)
{
if (body.count("input_prefix") != 0)
{
llama.params.input_prefix = body["input_prefix"];
}
else
{
llama.params.input_prefix = "";
}
if (body.count("input_suffix") != 0)
{
llama.params.input_suffix = body["input_suffix"];
}
else
{
llama.params.input_suffix = "";
}
parse_options_completion(body, llama);
}
static void log_server_request(const Request &req, const Response &res) static void log_server_request(const Request &req, const Response &res)
{ {
LOG_INFO("request", { LOG_INFO("request", {
@ -1519,6 +1604,127 @@ int main(int argc, char **argv)
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
} }); } });
svr.Post("/infill", [&llama](const Request &req, Response &res)
{
auto lock = llama.lock();
llama.rewind();
llama_reset_timings(llama.ctx);
parse_options_infill(json::parse(req.body), llama);
if (!llama.loadGrammar())
{
res.status = 400;
return;
}
llama.loadInfill();
llama.beginCompletion();
const auto chunked_content_provider = [&](size_t, DataSink & sink) {
size_t sent_count = 0;
size_t sent_token_probs_index = 0;
while (llama.has_next_token) {
const completion_token_output token_with_probs = llama.doCompletion();
if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) {
continue;
}
const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok);
size_t pos = std::min(sent_count, llama.generated_text.size());
const std::string str_test = llama.generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos =
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
if (stop_pos != std::string::npos) {
is_stop_full = true;
llama.generated_text.erase(
llama.generated_text.begin() + pos + stop_pos,
llama.generated_text.end());
pos = std::min(sent_count, llama.generated_text.size());
} else {
is_stop_full = false;
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
STOP_PARTIAL);
}
if (
stop_pos == std::string::npos ||
// Send rest of the text if we are at the end of the generation
(!llama.has_next_token && !is_stop_full && stop_pos > 0)
) {
const std::string to_send = llama.generated_text.substr(pos, std::string::npos);
sent_count += to_send.size();
std::vector<completion_token_output> probs_output = {};
if (llama.params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;
}
const json data = format_partial_response(llama, to_send, probs_output);
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.data(), str.size())) {
LOG_VERBOSE("stream closed", {});
llama_print_timings(llama.ctx);
return false;
}
}
if (!llama.has_next_token) {
// Generation is done, send extra information.
const json data = format_final_response(
llama,
"",
std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index)
);
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.data(), str.size())) {
LOG_VERBOSE("stream closed", {});
llama_print_timings(llama.ctx);
return false;
}
}
}
llama_print_timings(llama.ctx);
sink.done();
return true;
};
const auto on_complete = [&](bool) {
llama.mutex.unlock();
};
lock.release();
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
});
svr.Get("/model.json", [&llama](const Request &, Response &res) svr.Get("/model.json", [&llama](const Request &, Response &res)
{ {
const json data = format_generation_settings(llama); const json data = format_generation_settings(llama);

View file

@ -364,7 +364,7 @@ class ModelParams:
gguf_writer.add_feed_forward_length(self.get_n_ff()) gguf_writer.add_feed_forward_length(self.get_n_ff())
def tensor_name(key, bid=None): def tensor_name(key, bid=None):
return gguf.MODEL_TENSOR_NAMES[gguf.MODEL_ARCH.LLAMA][key].format(bid=bid) + ".weight" return gguf.TENSOR_NAMES[key].format(bid=bid) + ".weight"
class Layer: class Layer:
def __init__(self, params, bid): def __init__(self, params, bid):

View file

@ -1476,10 +1476,15 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
const int64_t ne10 = src1->ne[0]; const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1]; const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
const int nb2 = dst->nb[2]; const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3]; const int nb3 = dst->nb[3];
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
const int x_ne = ne01 * ne00; const int x_ne = ne01 * ne00;
@ -1498,13 +1503,22 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size); cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size); cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
for (int64_t i03 = 0; i03 < ne03; i03++) { int64_t pi02 = -1;
for (int64_t i02 = 0; i02 < ne02; i02++) { int64_t pi03 = -1;
for (int64_t i13 = 0; i13 < ne13; i13++) {
int64_t i03 = i13 / r3;
for (int64_t i12 = 0; i12 < ne12; i12++) {
int64_t i02 = i12 / r2;
// copy data to device // copy data to device
if (src0->backend != GGML_BACKEND_GPU) { if (src0->backend != GGML_BACKEND_GPU && (i02 != pi02 || i03 != pi03)) {
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL)); CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
pi02 = i02;
pi03 = i03;
} }
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL)); CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
CL_CHECK(clFinish(queue)); CL_CHECK(clFinish(queue));
@ -1525,7 +1539,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
} }
// copy dst to host // copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL)); CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
} }
} }
@ -1547,6 +1561,8 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
const int64_t ne10 = src1->ne[0]; const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1]; const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
const int nb10 = src1->nb[0]; const int nb10 = src1->nb[0];
const int nb11 = src1->nb[1]; const int nb11 = src1->nb[1];
@ -1556,6 +1572,9 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
const int nb2 = dst->nb[2]; const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3]; const int nb3 = dst->nb[3];
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
const ggml_fp16_t alpha = ggml_fp32_to_fp16(1.0f); const ggml_fp16_t alpha = ggml_fp32_to_fp16(1.0f);
const ggml_fp16_t beta = ggml_fp32_to_fp16(0.0f); const ggml_fp16_t beta = ggml_fp32_to_fp16(0.0f);
const int x_ne = ne01 * ne00; const int x_ne = ne01 * ne00;
@ -1577,32 +1596,41 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
bool src1_cont_rows = nb10 == sizeof(float); bool src1_cont_rows = nb10 == sizeof(float);
bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float); bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
for (int64_t i03 = 0; i03 < ne03; i03++) { int64_t pi02 = -1;
for (int64_t i02 = 0; i02 < ne02; i02++) { int64_t pi03 = -1;
for (int64_t i13 = 0; i13 < ne13; i13++) {
int64_t i03 = i13 / r3;
for (int64_t i12 = 0; i12 < ne12; i12++) {
int64_t i02 = i12 / r2;
// copy src0 to device // copy src0 to device
if (src0->backend != GGML_BACKEND_GPU) { if (src0->backend != GGML_BACKEND_GPU && (i02 != pi02 || i03 != pi03)) {
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL)); CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
pi02 = i02;
pi03 = i03;
} }
// convert src1 to fp16 // convert src1 to fp16
// TODO: use multiple threads // TODO: use multiple threads
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02); ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12);
char * src1i = (char *) src1->data + i03*nb13 + i02*nb12; char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
if (src1_cont_rows) { if (src1_cont_rows) {
if (src1_cont_cols) { if (src1_cont_cols) {
ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11); ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
} }
else { else {
for (int64_t i01 = 0; i01 < ne11; i01++) { for (int64_t i11 = 0; i11 < ne11; i11++) {
ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10); ggml_fp32_to_fp16_row((float *) (src1i + i11*nb11), tmp + i11*ne10, ne10);
} }
} }
} }
else { else {
for (int64_t i01 = 0; i01 < ne11; i01++) { for (int64_t i11 = 0; i11 < ne11; i11++) {
for (int64_t i00 = 0; i00 < ne10; i00++) { for (int64_t i10 = 0; i10 < ne10; i10++) {
// very slow due to no inlining // very slow due to no inlining
tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10)); tmp[i11*ne10 + i10] = ggml_fp32_to_fp16(*(float *) (src1i + i11*nb11 + i10*nb10));
} }
} }
} }
@ -1631,7 +1659,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
// copy dst to host, then convert to float // copy dst to host, then convert to float
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL)); CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
ggml_fp16_to_fp32_row(tmp, d, d_ne); ggml_fp16_to_fp32_row(tmp, d, d_ne);
} }
@ -1652,12 +1680,17 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
const int64_t ne10 = src1->ne[0]; const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1]; const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
const int nb2 = dst->nb[2]; const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3]; const int nb3 = dst->nb[3];
const ggml_type type = src0->type; const ggml_type type = src0->type;
const bool mul_mat_vec = ne11 == 1; const bool mul_mat_vec = ne11 == 1;
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
const int x_ne = ne01 * ne00; const int x_ne = ne01 * ne00;
@ -1690,12 +1723,23 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
size_t ev_idx = 0; size_t ev_idx = 0;
std::vector<cl_event> events; std::vector<cl_event> events;
for (int64_t i03 = 0; i03 < ne03; i03++) { int64_t pi02 = -1;
for (int64_t i02 = 0; i02 < ne02; i02++) { int64_t pi03 = -1;
for (int64_t i13 = 0; i13 < ne13; i13++) {
int64_t i03 = i13 / r3;
for (int64_t i12 = 0; i12 < ne12; i12++) {
int64_t i02 = i12 / r2;
// copy src0 to device if necessary // copy src0 to device if necessary
if (src0->backend == GGML_BACKEND_CPU) { if (src0->backend == GGML_BACKEND_CPU) {
events.emplace_back(); if (i02 != pi02 || i03 != pi03) {
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++)); events.emplace_back();
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
pi02 = i02;
pi03 = i03;
}
} else if (src0->backend == GGML_BACKEND_GPU) { } else if (src0->backend == GGML_BACKEND_GPU) {
d_Q = (cl_mem) src0->extra; d_Q = (cl_mem) src0->extra;
} else { } else {
@ -1704,7 +1748,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
// copy src1 to device // copy src1 to device
events.emplace_back(); events.emplace_back();
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, events.data() + ev_idx++)); CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, events.data() + ev_idx++));
// compute // compute
const size_t global = ne01 * CL_DMMV_BLOCK_SIZE; const size_t global = ne01 * CL_DMMV_BLOCK_SIZE;
@ -1725,7 +1769,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL)); CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
// copy src1 to device // copy src1 to device
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL)); CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
events.emplace_back(); events.emplace_back();
@ -1749,7 +1793,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
} }
// copy dst to host // copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &events[events.size() - 1], NULL)); CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &events[events.size() - 1], NULL));
for (auto *event : events) { for (auto *event : events) {
clReleaseEvent(event); clReleaseEvent(event);

5
ggml.c
View file

@ -11621,11 +11621,6 @@ static void ggml_compute_forward_mul_mat(
#if defined(GGML_USE_CLBLAST) #if defined(GGML_USE_CLBLAST)
if (ggml_cl_can_mul_mat(src0, src1, dst)) { if (ggml_cl_can_mul_mat(src0, src1, dst)) {
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
// ref: https://github.com/ggerganov/ggml/pull/224
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
} }

View file

@ -86,10 +86,12 @@ class MODEL_ARCH(IntEnum):
MPT : int = auto() MPT : int = auto()
STARCODER : int = auto() STARCODER : int = auto()
PERSIMMON : int = auto() PERSIMMON : int = auto()
BERT : int = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
TOKEN_EMBD : int = auto() TOKEN_EMBD : int = auto()
TOKEN_TYPES : int = auto()
POS_EMBD : int = auto() POS_EMBD : int = auto()
OUTPUT : int = auto() OUTPUT : int = auto()
OUTPUT_NORM : int = auto() OUTPUT_NORM : int = auto()
@ -122,90 +124,150 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.PERSIMMON: "persimmon", MODEL_ARCH.PERSIMMON: "persimmon",
} }
MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_ARCH.LLAMA: { MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.TOKEN_TYPES: "token_types",
MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.POS_EMBD: "position_embd",
MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.ROPE_FREQS: "rope_freqs", MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
}, MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_ARCH.GPTNEOX: { MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", }
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_ARCH.LLAMA: [
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.OUTPUT_NORM,
}, MODEL_TENSOR.OUTPUT,
MODEL_ARCH.FALCON: { MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.FFN_DOWN,
}, MODEL_TENSOR.FFN_UP,
MODEL_ARCH.BAICHUAN: { ],
MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_ARCH.GPTNEOX: [
MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ROPE_FREQS: "rope_freqs", MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", ],
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_ARCH.FALCON: [
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.OUTPUT_NORM,
}, MODEL_TENSOR.OUTPUT,
MODEL_ARCH.STARCODER: { MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.ATTN_NORM_2,
MODEL_TENSOR.POS_EMBD: "position_embd", MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", ],
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_ARCH.BAICHUAN: [
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.OUTPUT,
}, MODEL_TENSOR.ROPE_FREQS,
MODEL_ARCH.PERSIMMON: { MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", ],
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", MODEL_ARCH.STARCODER: [
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.TOKEN_EMBD,
}, MODEL_TENSOR.POS_EMBD,
MODEL_ARCH.GPT2: { MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.BERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_TYPES,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.MPT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GPTJ: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.PERSIMMON: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.GPT2: [
# TODO # TODO
}, ],
# TODO # TODO
} }
@ -229,33 +291,42 @@ class TensorNameMap:
mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
# Token embeddings # Token embeddings
MODEL_TENSOR.TOKEN_EMBD: ( MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox "gpt_neox.embed_in", # gptneox
"transformer.wte", # gpt2 mpt "transformer.wte", # gpt2 mpt
"transformer.word_embeddings", # falcon "transformer.word_embeddings", # falcon
"model.embed_tokens", # llama-hf "model.embed_tokens", # llama-hf
"tok_embeddings", # llama-pth "tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert
"language_model.embedding.word_embeddings", # persimmon "language_model.embedding.word_embeddings", # persimmon
), ),
# Token type embeddings
MODEL_TENSOR.TOKEN_TYPES: (
"embeddings.token_type_embeddings", # bert
),
# Position embeddings # Position embeddings
MODEL_TENSOR.POS_EMBD: ( MODEL_TENSOR.POS_EMBD: (
"transformer.wpe", # gpt2 "transformer.wpe", # gpt2
"embeddings.position_embeddings", # bert
), ),
# Output # Output
MODEL_TENSOR.OUTPUT: ( MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox "embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan "lm_head", # gpt2 mpt falcon llama-hf baichuan
"output", # llama-pth "output", # llama-pth
"word_embeddings_for_head", # persimmon "word_embeddings_for_head", # persimmon
), ),
# Output norm # Output norm
MODEL_TENSOR.OUTPUT_NORM: ( MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox "gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 falcon "transformer.ln_f", # gpt2 gpt-j falcon
"model.norm", # llama-hf baichuan "model.norm", # llama-hf baichuan
"norm", # llama-pth "norm", # llama-pth
"embeddings.LayerNorm", # bert
"transformer.norm_f", # mpt
"language_model.encoder.final_layernorm", # persimmon "language_model.encoder.final_layernorm", # persimmon
), ),
@ -268,13 +339,14 @@ class TensorNameMap:
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
# Attention norm # Attention norm
MODEL_TENSOR.ATTN_NORM: ( MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox "gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2 "transformer.h.{bid}.ln_1", # gpt2 gpt-j
"transformer.blocks.{bid}.norm_1", # mpt "transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b "transformer.h.{bid}.input_layernorm", # falcon7b
"transformer.h.{bid}.ln_mlp", # falcon40b "transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf "model.layers.{bid}.input_layernorm", # llama-hf
"layers.{bid}.attention_norm", # llama-pth "layers.{bid}.attention_norm", # llama-pth
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon "language_model.encoder.layers.{bid}.input_layernorm", # persimmon
), ),
@ -285,39 +357,47 @@ class TensorNameMap:
# Attention query-key-value # Attention query-key-value
MODEL_TENSOR.ATTN_QKV: ( MODEL_TENSOR.ATTN_QKV: (
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
"transformer.h.{bid}.attn.c_attn", # gpt2 "transformer.h.{bid}.attn.c_attn", # gpt2
"transformer.blocks.{bid}.attn.Wqkv", # mpt "transformer.blocks.{bid}.attn.Wqkv", # mpt
"transformer.h.{bid}.self_attention.query_key_value", # falcon "transformer.h.{bid}.self_attention.query_key_value", # falcon
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
), ),
# Attention query # Attention query
MODEL_TENSOR.ATTN_Q: ( MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf "model.layers.{bid}.self_attn.q_proj", # llama-hf
"layers.{bid}.attention.wq", # llama-pth "layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j
), ),
# Attention key # Attention key
MODEL_TENSOR.ATTN_K: ( MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf "model.layers.{bid}.self_attn.k_proj", # llama-hf
"layers.{bid}.attention.wk", # llama-pth "layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j
), ),
# Attention value # Attention value
MODEL_TENSOR.ATTN_V: ( MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf "model.layers.{bid}.self_attn.v_proj", # llama-hf
"layers.{bid}.attention.wv", # llama-pth "layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j
), ),
# Attention output # Attention output
MODEL_TENSOR.ATTN_OUT: ( MODEL_TENSOR.ATTN_OUT: (
"gpt_neox.layers.{bid}.attention.dense", # gptneox "gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2 "transformer.h.{bid}.attn.c_proj", # gpt2
"transformer.blocks.{bid}.attn.out_proj", # mpt "transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon "transformer.h.{bid}.self_attention.dense", # falcon
"model.layers.{bid}.self_attn.o_proj", # llama-hf "model.layers.{bid}.self_attn.o_proj", # llama-hf
"layers.{bid}.attention.wo", # llama-pth "layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense" # persimmon "language_model.encoder.layers.{bid}.self_attention.dense" # persimmon
), ),
@ -329,22 +409,25 @@ class TensorNameMap:
# Feed-forward norm # Feed-forward norm
MODEL_TENSOR.FFN_NORM: ( MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
"transformer.h.{bid}.ln_2", # gpt2 "transformer.h.{bid}.ln_2", # gpt2
"transformer.blocks.{bid}.norm_2", # mpt "transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf "model.layers.{bid}.post_attention_layernorm", # llama-hf
"layers.{bid}.ffn_norm", # llama-pth "layers.{bid}.ffn_norm", # llama-pth
"encoder.layer.{bid}.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
), ),
# Feed-forward up # Feed-forward up
MODEL_TENSOR.FFN_UP: ( MODEL_TENSOR.FFN_UP: (
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
"transformer.h.{bid}.mlp.c_fc", # gpt2 "transformer.h.{bid}.mlp.c_fc", # gpt2
"transformer.blocks.{bid}.ffn.up_proj", # mpt "transformer.blocks.{bid}.ffn.up_proj", # mpt
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"model.layers.{bid}.mlp.up_proj", # llama-hf "model.layers.{bid}.mlp.up_proj", # llama-hf
"layers.{bid}.feed_forward.w3", # llama-pth "layers.{bid}.feed_forward.w3", # llama-pth
"encoder.layer.{bid}.intermediate.dense", # bert
"transformer.h.{bid}.mlp.fc_in", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
), ),
@ -356,12 +439,14 @@ class TensorNameMap:
# Feed-forward down # Feed-forward down
MODEL_TENSOR.FFN_DOWN: ( MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
"transformer.h.{bid}.mlp.c_proj", # gpt2 "transformer.h.{bid}.mlp.c_proj", # gpt2
"transformer.blocks.{bid}.ffn.down_proj", # mpt "transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"model.layers.{bid}.mlp.down_proj", # llama-hf "model.layers.{bid}.mlp.down_proj", # llama-hf
"layers.{bid}.feed_forward.w2", # llama-pth "layers.{bid}.feed_forward.w2", # llama-pth
"encoder.layer.{bid}.output.dense", # bert
"transformer.h.{bid}.mlp.fc_out", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
), ),
@ -380,28 +465,24 @@ class TensorNameMap:
mapping: dict[str, tuple[MODEL_TENSOR, str]] mapping: dict[str, tuple[MODEL_TENSOR, str]]
tensor_names: dict[MODEL_TENSOR, str]
def __init__(self, arch: MODEL_ARCH, n_blocks: int): def __init__(self, arch: MODEL_ARCH, n_blocks: int):
mapping = self.mapping = {} self.mapping = {}
tensor_names = self.tensor_names = MODEL_TENSOR_NAMES[arch]
for tensor, keys in self.mappings_cfg.items(): for tensor, keys in self.mappings_cfg.items():
tensor_name = tensor_names.get(tensor) if tensor not in MODEL_TENSORS[arch]:
if tensor_name is None:
continue continue
mapping[tensor_name] = (tensor, tensor_name) tensor_name = TENSOR_NAMES[tensor]
self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys: for key in keys:
mapping[key] = (tensor, tensor_name) self.mapping[key] = (tensor, tensor_name)
for bid in range(n_blocks): for bid in range(n_blocks):
for tensor, keys in self.block_mappings_cfg.items(): for tensor, keys in self.block_mappings_cfg.items():
tensor_name = tensor_names.get(tensor) if tensor not in MODEL_TENSORS[arch]:
if tensor_name is None:
continue continue
tensor_name = tensor_name.format(bid = bid) tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
mapping[tensor_name] = (tensor, tensor_name) self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys: for key in keys:
key = key.format(bid = bid) key = key.format(bid = bid)
mapping[key] = (tensor, tensor_name) self.mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key) result = self.mapping.get(key)
@ -842,22 +923,25 @@ class SpecialVocab:
special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad') special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad')
special_token_ids: dict[str, int] = {} special_token_ids: dict[str, int] = {}
def __init__(self, path: Path, load_merges: bool = False, special_token_types: tuple[str, ...] | None = None): def __init__(
self, path: str | os.PathLike[str], load_merges: bool = False,
special_token_types: tuple[str, ...] | None = None,
):
self.special_token_ids = {} self.special_token_ids = {}
self.load_merges = load_merges self.load_merges = load_merges
if special_token_types is not None: if special_token_types is not None:
self.special_token_types = special_token_types self.special_token_types = special_token_types
self.load(path) self._load(Path(path))
def load(self, path: Path): def _load(self, path: Path) -> None:
if not self.try_load_from_tokenizer_json(path): if not self._try_load_from_tokenizer_json(path):
self.try_load_from_config_json(path) self._try_load_from_config_json(path)
def try_load_from_tokenizer_json(self, path: Path) -> bool: def _try_load_from_tokenizer_json(self, path: Path) -> bool:
tokenizer_file = path / 'tokenizer.json' tokenizer_file = path / 'tokenizer.json'
if not tokenizer_file.is_file(): if not tokenizer_file.is_file():
return False return False
with open(tokenizer_file, 'r', encoding = 'utf-8') as f: with open(tokenizer_file, encoding = 'utf-8') as f:
tokenizer = json.load(f) tokenizer = json.load(f)
if self.load_merges: if self.load_merges:
merges = tokenizer.get('model', {}).get('merges') merges = tokenizer.get('model', {}).get('merges')
@ -867,7 +951,7 @@ class SpecialVocab:
added_tokens = tokenizer.get('added_tokens') added_tokens = tokenizer.get('added_tokens')
if added_tokens is None or not tokenizer_config_file.is_file(): if added_tokens is None or not tokenizer_config_file.is_file():
return True return True
with open(tokenizer_config_file, 'r', encoding = 'utf-8') as f: with open(tokenizer_config_file, encoding = 'utf-8') as f:
tokenizer_config = json.load(f) tokenizer_config = json.load(f)
for typ in self.special_token_types: for typ in self.special_token_types:
entry = tokenizer_config.get(f'{typ}_token') entry = tokenizer_config.get(f'{typ}_token')
@ -886,11 +970,11 @@ class SpecialVocab:
break break
return True return True
def try_load_from_config_json(self, path: Path) -> bool: def _try_load_from_config_json(self, path: Path) -> bool:
config_file = path / 'config.json' config_file = path / 'config.json'
if not config_file.is_file(): if not config_file.is_file():
return False return False
with open(config_file, 'r', encoding = 'utf-8') as f: with open(config_file, encoding = 'utf-8') as f:
config = json.load(f) config = json.load(f)
for typ in self.special_token_types: for typ in self.special_token_types:
maybe_token_id = config.get(f'{typ}_token_id') maybe_token_id = config.get(f'{typ}_token_id')
@ -898,7 +982,7 @@ class SpecialVocab:
self.special_token_ids[typ] = maybe_token_id self.special_token_ids[typ] = maybe_token_id
return True return True
def add_to_gguf(self, gw: GGUFWriter): def add_to_gguf(self, gw: GGUFWriter) -> None:
if len(self.merges) > 0: if len(self.merges) > 0:
print(f'gguf: Adding {len(self.merges)} merge(s).') print(f'gguf: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges) gw.add_token_merges(self.merges)
@ -910,8 +994,8 @@ class SpecialVocab:
print(f'gguf: Setting special token type {typ} to {tokid}') print(f'gguf: Setting special token type {typ} to {tokid}')
handler(tokid) handler(tokid)
def __repr__(self): def __repr__(self) -> str:
return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids if self.special_token_ids else "unset"}>' return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids or "unset"}>'
# Example usage: # Example usage:

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "gguf" name = "gguf"
version = "0.3.3" version = "0.4.0"
description = "Write ML models in GGUF for GGML" description = "Write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"] authors = ["GGML <ggml@ggml.ai>"]
packages = [ packages = [

View file

@ -1102,6 +1102,10 @@ struct llama_vocab {
id special_pad_id = -1; id special_pad_id = -1;
id linefeed_id = 13; id linefeed_id = 13;
id special_prefix_id = 32007;
id special_middle_id = 32009;
id special_suffix_id = 32008;
id special_eot_id = 32010;
int find_bpe_rank(std::string token_left, std::string token_right) const { int find_bpe_rank(std::string token_left, std::string token_right) const {
replace_all(token_left, " ", "\u0120"); replace_all(token_left, " ", "\u0120");
@ -7217,13 +7221,14 @@ struct llama_context * llama_new_context_with_model(
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (model->n_gpu_layers > 0) { if (model->n_gpu_layers > 0) {
ggml_metal_log_set_callback(llama_log_callback_default, NULL);
ctx->ctx_metal = ggml_metal_init(1); ctx->ctx_metal = ggml_metal_init(1);
if (!ctx->ctx_metal) { if (!ctx->ctx_metal) {
LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__);
llama_free(ctx); llama_free(ctx);
return NULL; return NULL;
} }
ggml_metal_log_set_callback(llama_log_callback_default, NULL);
//ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); //ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false);
//ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); //ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal));
} }
@ -7950,6 +7955,22 @@ llama_token llama_token_eos(const struct llama_context * ctx) {
llama_token llama_token_nl(const struct llama_context * ctx) { llama_token llama_token_nl(const struct llama_context * ctx) {
return ctx->model.vocab.linefeed_id; return ctx->model.vocab.linefeed_id;
} }
llama_token llama_token_prefix(const struct llama_context * ctx) {
return ctx->model.vocab.special_prefix_id;
}
llama_token llama_token_middle(const struct llama_context * ctx) {
return ctx->model.vocab.special_middle_id;
}
llama_token llama_token_suffix(const struct llama_context * ctx) {
return ctx->model.vocab.special_suffix_id;
}
llama_token llama_token_eot(const struct llama_context * ctx) {
return ctx->model.vocab.special_eot_id;
}
int llama_tokenize( int llama_tokenize(
const struct llama_model * model, const struct llama_model * model,

View file

@ -490,6 +490,11 @@ extern "C" {
LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence
LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line
// codellama infill tokens
LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of infill prefix
LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of infill middle
LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of infill suffix
LLAMA_API llama_token llama_token_eot (const struct llama_context * ctx); // End of infill middle
// //
// Tokenization // Tokenization

View file

@ -56,11 +56,13 @@ find_library(llama_LIBRARY llama
HINTS ${LLAMA_LIB_DIR}) HINTS ${LLAMA_LIB_DIR})
set(_llama_link_deps "Threads::Threads" "@LLAMA_EXTRA_LIBS@") set(_llama_link_deps "Threads::Threads" "@LLAMA_EXTRA_LIBS@")
set(_llama_transient_defines "@LLAMA_TRANSIENT_DEFINES@")
add_library(llama UNKNOWN IMPORTED) add_library(llama UNKNOWN IMPORTED)
set_target_properties(llama set_target_properties(llama
PROPERTIES PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${LLAMA_INCLUDE_DIR}" INTERFACE_INCLUDE_DIRECTORIES "${LLAMA_INCLUDE_DIR}"
INTERFACE_LINK_LIBRARIES "${_llama_link_deps}" INTERFACE_LINK_LIBRARIES "${_llama_link_deps}"
INTERFACE_COMPILE_DEFINITIONS "${_llama_transient_defines}"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
IMPORTED_LOCATION "${llama_LIBRARY}" IMPORTED_LOCATION "${llama_LIBRARY}"
INTERFACE_COMPILE_FEATURES cxx_std_11 INTERFACE_COMPILE_FEATURES cxx_std_11