Merge branch 'master' into xsn/vision_2
This commit is contained in:
commit
32daa38333
65 changed files with 7551 additions and 952 deletions
6
.github/workflows/build.yml
vendored
6
.github/workflows/build.yml
vendored
|
@ -87,6 +87,7 @@ jobs:
|
||||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||||
run: |
|
run: |
|
||||||
cp LICENSE ./build/bin/
|
cp LICENSE ./build/bin/
|
||||||
|
cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp
|
||||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/*
|
zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/*
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
|
@ -149,6 +150,7 @@ jobs:
|
||||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||||
run: |
|
run: |
|
||||||
cp LICENSE ./build/bin/
|
cp LICENSE ./build/bin/
|
||||||
|
cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp
|
||||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/*
|
zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/*
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
|
@ -217,6 +219,7 @@ jobs:
|
||||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||||
run: |
|
run: |
|
||||||
cp LICENSE ./build/bin/
|
cp LICENSE ./build/bin/
|
||||||
|
cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp
|
||||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip ./build/bin/*
|
zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip ./build/bin/*
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
|
@ -234,7 +237,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
sanitizer: [ADDRESS, THREAD, UNDEFINED]
|
sanitizer: [ADDRESS, THREAD, UNDEFINED]
|
||||||
build_type: [Debug, Release]
|
build_type: [Debug]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
|
@ -796,6 +799,7 @@ jobs:
|
||||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||||
run: |
|
run: |
|
||||||
Copy-Item LICENSE .\build\bin\Release\llama.cpp.txt
|
Copy-Item LICENSE .\build\bin\Release\llama.cpp.txt
|
||||||
|
Copy-Item .\examples\run\linenoise.cpp\LICENSE .\build\bin\Release\linenoise.cpp.txt
|
||||||
7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip .\build\bin\Release\*
|
7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip .\build\bin\Release\*
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
|
|
25
.github/workflows/server.yml
vendored
25
.github/workflows/server.yml
vendored
|
@ -112,9 +112,9 @@ jobs:
|
||||||
-DGGML_OPENMP=OFF ;
|
-DGGML_OPENMP=OFF ;
|
||||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
||||||
|
|
||||||
- name: Build
|
- name: Build (sanitizers)
|
||||||
id: cmake_build
|
id: cmake_build_sanitizers
|
||||||
if: ${{ matrix.sanitizer != 'THREAD' }}
|
if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }}
|
||||||
run: |
|
run: |
|
||||||
cmake -B build \
|
cmake -B build \
|
||||||
-DGGML_NATIVE=OFF \
|
-DGGML_NATIVE=OFF \
|
||||||
|
@ -124,12 +124,31 @@ jobs:
|
||||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
|
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
|
||||||
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
||||||
|
|
||||||
|
- name: Build (sanitizers)
|
||||||
|
id: cmake_build
|
||||||
|
if: ${{ matrix.sanitizer == '' }}
|
||||||
|
run: |
|
||||||
|
cmake -B build \
|
||||||
|
-DGGML_NATIVE=OFF \
|
||||||
|
-DLLAMA_BUILD_SERVER=ON \
|
||||||
|
-DLLAMA_CURL=ON \
|
||||||
|
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
|
||||||
|
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
id: server_integration_tests
|
id: server_integration_tests
|
||||||
|
if: ${{ matrix.sanitizer == '' }}
|
||||||
run: |
|
run: |
|
||||||
cd examples/server/tests
|
cd examples/server/tests
|
||||||
./tests.sh
|
./tests.sh
|
||||||
|
|
||||||
|
- name: Tests (sanitizers)
|
||||||
|
id: server_integration_tests_sanitizers
|
||||||
|
if: ${{ matrix.sanitizer != '' }}
|
||||||
|
run: |
|
||||||
|
cd examples/server/tests
|
||||||
|
LLAMA_SANITIZE=1 ./tests.sh
|
||||||
|
|
||||||
- name: Slow tests
|
- name: Slow tests
|
||||||
id: server_integration_tests_slow
|
id: server_integration_tests_slow
|
||||||
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
|
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
|
||||||
|
|
|
@ -83,11 +83,8 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
|
||||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake)
|
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake)
|
||||||
|
|
||||||
# override ggml options
|
# override ggml options
|
||||||
set(GGML_SANITIZE_THREAD ${LLAMA_SANITIZE_THREAD})
|
set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS})
|
||||||
set(GGML_SANITIZE_ADDRESS ${LLAMA_SANITIZE_ADDRESS})
|
set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS})
|
||||||
set(GGML_SANITIZE_UNDEFINED ${LLAMA_SANITIZE_UNDEFINED})
|
|
||||||
set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS})
|
|
||||||
set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS})
|
|
||||||
|
|
||||||
# change the default for these ggml options
|
# change the default for these ggml options
|
||||||
if (NOT DEFINED GGML_LLAMAFILE)
|
if (NOT DEFINED GGML_LLAMAFILE)
|
||||||
|
@ -117,16 +114,62 @@ llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL)
|
||||||
llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16)
|
llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16)
|
||||||
llama_option_depr(WARNING LLAMA_CANN GGML_CANN)
|
llama_option_depr(WARNING LLAMA_CANN GGML_CANN)
|
||||||
|
|
||||||
|
if (NOT MSVC)
|
||||||
|
if (LLAMA_SANITIZE_THREAD)
|
||||||
|
message(STATUS "Using -fsanitize=thread")
|
||||||
|
|
||||||
|
add_compile_options(-fsanitize=thread)
|
||||||
|
link_libraries (-fsanitize=thread)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_SANITIZE_ADDRESS)
|
||||||
|
message(STATUS "Using -fsanitize=address")
|
||||||
|
|
||||||
|
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
|
||||||
|
link_libraries (-fsanitize=address)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_SANITIZE_UNDEFINED)
|
||||||
|
message(STATUS "Using -fsanitize=undefined")
|
||||||
|
|
||||||
|
add_compile_options(-fsanitize=undefined)
|
||||||
|
link_libraries (-fsanitize=undefined)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# build the library
|
# 3rd-party
|
||||||
#
|
#
|
||||||
|
|
||||||
if (NOT TARGET ggml)
|
if (NOT TARGET ggml)
|
||||||
add_subdirectory(ggml)
|
add_subdirectory(ggml)
|
||||||
# ... otherwise assume ggml is added by a parent CMakeLists.txt
|
# ... otherwise assume ggml is added by a parent CMakeLists.txt
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# build the library
|
||||||
|
#
|
||||||
|
|
||||||
add_subdirectory(src)
|
add_subdirectory(src)
|
||||||
|
|
||||||
|
#
|
||||||
|
# utils, programs, examples and tests
|
||||||
|
#
|
||||||
|
|
||||||
|
if (LLAMA_BUILD_COMMON)
|
||||||
|
add_subdirectory(common)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||||
|
include(CTest)
|
||||||
|
add_subdirectory(tests)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_EXAMPLES)
|
||||||
|
add_subdirectory(examples)
|
||||||
|
add_subdirectory(pocs)
|
||||||
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# install
|
# install
|
||||||
#
|
#
|
||||||
|
@ -200,21 +243,3 @@ configure_file(cmake/llama.pc.in
|
||||||
|
|
||||||
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc"
|
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc"
|
||||||
DESTINATION lib/pkgconfig)
|
DESTINATION lib/pkgconfig)
|
||||||
|
|
||||||
#
|
|
||||||
# utils, programs, examples and tests
|
|
||||||
#
|
|
||||||
|
|
||||||
if (LLAMA_BUILD_COMMON)
|
|
||||||
add_subdirectory(common)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
|
||||||
include(CTest)
|
|
||||||
add_subdirectory(tests)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_EXAMPLES)
|
|
||||||
add_subdirectory(examples)
|
|
||||||
add_subdirectory(pocs)
|
|
||||||
endif()
|
|
||||||
|
|
2
Makefile
2
Makefile
|
@ -1361,7 +1361,9 @@ llama-server: \
|
||||||
examples/server/httplib.h \
|
examples/server/httplib.h \
|
||||||
examples/server/index.html.hpp \
|
examples/server/index.html.hpp \
|
||||||
examples/server/loading.html.hpp \
|
examples/server/loading.html.hpp \
|
||||||
|
common/chat-template.hpp \
|
||||||
common/json.hpp \
|
common/json.hpp \
|
||||||
|
common/minja.hpp \
|
||||||
$(OBJ_ALL)
|
$(OBJ_ALL)
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
||||||
|
|
|
@ -44,7 +44,7 @@ if(MSVC)
|
||||||
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
|
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
|
||||||
else()
|
else()
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND sh -c "$@ --version | head -1" _ ${CMAKE_C_COMPILER}
|
COMMAND sh -c "\"$@\" --version | head -1" _ ${CMAKE_C_COMPILER}
|
||||||
OUTPUT_VARIABLE OUT
|
OUTPUT_VARIABLE OUT
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
)
|
)
|
||||||
|
|
|
@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
|
||||||
arg.cpp
|
arg.cpp
|
||||||
arg.h
|
arg.h
|
||||||
base64.hpp
|
base64.hpp
|
||||||
|
chat-template.hpp
|
||||||
common.cpp
|
common.cpp
|
||||||
common.h
|
common.h
|
||||||
console.cpp
|
console.cpp
|
||||||
|
@ -64,6 +65,7 @@ add_library(${TARGET} STATIC
|
||||||
json.hpp
|
json.hpp
|
||||||
log.cpp
|
log.cpp
|
||||||
log.h
|
log.h
|
||||||
|
minja.hpp
|
||||||
ngram-cache.cpp
|
ngram-cache.cpp
|
||||||
ngram-cache.h
|
ngram-cache.h
|
||||||
sampling.cpp
|
sampling.cpp
|
||||||
|
|
|
@ -133,7 +133,8 @@ static void common_params_handle_model_default(
|
||||||
const std::string & model_url,
|
const std::string & model_url,
|
||||||
std::string & hf_repo,
|
std::string & hf_repo,
|
||||||
std::string & hf_file,
|
std::string & hf_file,
|
||||||
const std::string & hf_token) {
|
const std::string & hf_token,
|
||||||
|
const std::string & model_default) {
|
||||||
if (!hf_repo.empty()) {
|
if (!hf_repo.empty()) {
|
||||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||||
if (hf_file.empty()) {
|
if (hf_file.empty()) {
|
||||||
|
@ -163,7 +164,7 @@ static void common_params_handle_model_default(
|
||||||
model = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
model = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||||
}
|
}
|
||||||
} else if (model.empty()) {
|
} else if (model.empty()) {
|
||||||
model = DEFAULT_MODEL_PATH;
|
model = model_default;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,8 +300,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: refactor model params in a common struct
|
// TODO: refactor model params in a common struct
|
||||||
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token);
|
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token, DEFAULT_MODEL_PATH);
|
||||||
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token);
|
common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, "");
|
||||||
|
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token, "");
|
||||||
|
|
||||||
if (params.escape) {
|
if (params.escape) {
|
||||||
string_process_escapes(params.prompt);
|
string_process_escapes(params.prompt);
|
||||||
|
@ -323,6 +325,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||||
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
|
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
|
||||||
|
throw std::runtime_error(string_format(
|
||||||
|
"error: the supplied chat template is not supported: %s%s\n",
|
||||||
|
params.chat_template.c_str(),
|
||||||
|
params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1629,6 +1639,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.hf_repo = value;
|
params.hf_repo = value;
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_HF_REPO"));
|
).set_env("LLAMA_ARG_HF_REPO"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
|
||||||
|
"Same as --hf-repo, but for the draft model (default: unused)",
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
params.speculative.hf_repo = value;
|
||||||
|
}
|
||||||
|
).set_env("LLAMA_ARG_HFD_REPO"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-hff", "--hf-file"}, "FILE",
|
{"-hff", "--hf-file"}, "FILE",
|
||||||
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
|
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
|
||||||
|
@ -1938,24 +1955,44 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--jinja"},
|
||||||
|
"use jinja template for chat (default: disabled)",
|
||||||
|
[](common_params & params) {
|
||||||
|
params.use_jinja = true;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||||
string_format(
|
string_format(
|
||||||
"set custom jinja chat template (default: template taken from model's metadata)\n"
|
"set custom jinja chat template (default: template taken from model's metadata)\n"
|
||||||
"if suffix/prefix are specified, template will be disabled\n"
|
"if suffix/prefix are specified, template will be disabled\n"
|
||||||
|
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
|
||||||
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
|
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
|
||||||
),
|
),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
if (!common_chat_verify_template(value)) {
|
|
||||||
throw std::runtime_error(string_format(
|
|
||||||
"error: the supplied chat template is not supported: %s\n"
|
|
||||||
"note: llama.cpp does not use jinja parser, we only support commonly used templates\n",
|
|
||||||
value.c_str()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
params.chat_template = value;
|
params.chat_template = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
|
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
|
||||||
|
string_format(
|
||||||
|
"set custom jinja chat template file (default: template taken from model's metadata)\n"
|
||||||
|
"if suffix/prefix are specified, template will be disabled\n"
|
||||||
|
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
|
||||||
|
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
|
||||||
|
),
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
std::ifstream file(value);
|
||||||
|
if (!file) {
|
||||||
|
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
|
||||||
|
}
|
||||||
|
std::copy(
|
||||||
|
std::istreambuf_iterator<char>(file),
|
||||||
|
std::istreambuf_iterator<char>(),
|
||||||
|
std::back_inserter(params.chat_template));
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
|
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
|
||||||
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),
|
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),
|
||||||
|
@ -2254,6 +2291,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.vocoder.model = value;
|
params.vocoder.model = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
|
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--tts-use-guide-tokens"},
|
||||||
|
"Use guide tokens to improve TTS word recall",
|
||||||
|
[](common_params & params) {
|
||||||
|
params.vocoder.use_guide_tokens = true;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
|
||||||
|
|
||||||
// model-specific
|
// model-specific
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
249
common/chat-template.hpp
Normal file
249
common/chat-template.hpp
Normal file
|
@ -0,0 +1,249 @@
|
||||||
|
/*
|
||||||
|
Copyright 2024 Google LLC
|
||||||
|
|
||||||
|
Use of this source code is governed by an MIT-style
|
||||||
|
license that can be found in the LICENSE file or at
|
||||||
|
https://opensource.org/licenses/MIT.
|
||||||
|
*/
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "minja.hpp"
|
||||||
|
#include <json.hpp>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
namespace minja {
|
||||||
|
|
||||||
|
class chat_template {
|
||||||
|
public:
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool supports_tools_ = true;
|
||||||
|
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||||
|
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||||
|
bool requires_object_arguments_ = false;
|
||||||
|
bool supports_system_role_ = true;
|
||||||
|
bool supports_parallel_tool_calls_ = false;
|
||||||
|
std::string source_;
|
||||||
|
std::string bos_token_;
|
||||||
|
std::string eos_token_;
|
||||||
|
std::shared_ptr<minja::TemplateNode> template_root_;
|
||||||
|
|
||||||
|
std::string try_render(
|
||||||
|
const nlohmann::ordered_json & messages,
|
||||||
|
const nlohmann::ordered_json & tools,
|
||||||
|
bool add_generation_prompt,
|
||||||
|
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||||
|
{
|
||||||
|
try {
|
||||||
|
auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
|
||||||
|
// fprintf(stderr, "Prompt: %s\n", prompt.c_str());
|
||||||
|
return prompt;
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
// fprintf(stderr, "Error: %s\n", e.what());
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
|
||||||
|
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
|
||||||
|
{
|
||||||
|
template_root_ = minja::Parser::parse(source_, {
|
||||||
|
/* .trim_blocks = */ true,
|
||||||
|
/* .lstrip_blocks = */ true,
|
||||||
|
/* .keep_trailing_newline = */ false,
|
||||||
|
});
|
||||||
|
supports_tools_ = source.find("tools") != std::string::npos;
|
||||||
|
|
||||||
|
auto renders_string_arguments =
|
||||||
|
try_render({
|
||||||
|
{
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", "Hey"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"tool_calls", json::array({
|
||||||
|
{
|
||||||
|
{"id", "call_1___"},
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
|
||||||
|
{"name", "ipython"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
})},
|
||||||
|
}
|
||||||
|
}, {}, false).find("{\"code\": \"print") != std::string::npos;
|
||||||
|
if (!renders_string_arguments) {
|
||||||
|
auto renders_object_arguments =
|
||||||
|
try_render({
|
||||||
|
{
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", "Hey"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"tool_calls", json::array({
|
||||||
|
{
|
||||||
|
{"id", "call_1___"},
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"arguments", {
|
||||||
|
{"code", "print('Hello, World!')"},
|
||||||
|
}},
|
||||||
|
{"name", "ipython"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
})},
|
||||||
|
}
|
||||||
|
}, {}, false).find("{\"code\": \"print") != std::string::npos;
|
||||||
|
requires_object_arguments_ = renders_object_arguments;
|
||||||
|
}
|
||||||
|
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
|
||||||
|
|
||||||
|
supports_system_role_ = try_render({
|
||||||
|
{{"role", "system"}, {"content", "<System Needle>"}},
|
||||||
|
{{"role", "user"}, {"content", "Hey"}}
|
||||||
|
}, {}, false).find("<System Needle>") != std::string::npos;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string & source() const { return source_; }
|
||||||
|
const std::string & bos_token() const { return bos_token_; }
|
||||||
|
const std::string & eos_token() const { return eos_token_; }
|
||||||
|
bool supports_tools() const { return supports_tools_; }
|
||||||
|
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
|
||||||
|
|
||||||
|
std::string apply(
|
||||||
|
const nlohmann::ordered_json & messages,
|
||||||
|
const nlohmann::ordered_json & tools,
|
||||||
|
bool add_generation_prompt,
|
||||||
|
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||||
|
{
|
||||||
|
json actual_messages;
|
||||||
|
|
||||||
|
// First, "fix" messages so they have a chance to be rendered correctly by the template
|
||||||
|
|
||||||
|
if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) {
|
||||||
|
actual_messages = json::array();
|
||||||
|
|
||||||
|
std::string pending_system;
|
||||||
|
auto flush_sys = [&]() {
|
||||||
|
if (!pending_system.empty()) {
|
||||||
|
actual_messages.push_back({
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", pending_system},
|
||||||
|
});
|
||||||
|
pending_system.clear();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (const auto & message_ : messages) {
|
||||||
|
auto message = message_;
|
||||||
|
if (!message.contains("role") || !message.contains("content")) {
|
||||||
|
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||||
|
}
|
||||||
|
std::string role = message.at("role");
|
||||||
|
|
||||||
|
if (message.contains("tool_calls")) {
|
||||||
|
if (requires_object_arguments_ || !supports_tools_) {
|
||||||
|
for (auto & tool_call : message.at("tool_calls")) {
|
||||||
|
if (tool_call["type"] == "function") {
|
||||||
|
auto & function = tool_call.at("function");
|
||||||
|
std::string arguments = function.at("arguments");
|
||||||
|
function["arguments"] = json::parse(arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!supports_tools_) {
|
||||||
|
auto content = message.at("content");
|
||||||
|
auto tool_calls = json::array();
|
||||||
|
for (const auto & tool_call : message.at("tool_calls")) {
|
||||||
|
if (tool_call.at("type") != "function") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto & function = tool_call.at("function");
|
||||||
|
auto tc = json {
|
||||||
|
{"name", function.at("name")},
|
||||||
|
{"arguments", function.at("arguments")},
|
||||||
|
};
|
||||||
|
if (tool_call.contains("id")) {
|
||||||
|
tc["id"] = tool_call["id"];
|
||||||
|
}
|
||||||
|
tool_calls.push_back(tc);
|
||||||
|
}
|
||||||
|
auto obj = json {
|
||||||
|
{"tool_calls", tool_calls},
|
||||||
|
};
|
||||||
|
if (!content.is_null() && content != "") {
|
||||||
|
obj["content"] = content;
|
||||||
|
}
|
||||||
|
message["content"] = obj.dump(2);
|
||||||
|
message.erase("tool_calls");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!supports_tools_ && role == "tool") {
|
||||||
|
message["role"] = "user";
|
||||||
|
auto obj = json {
|
||||||
|
{"tool_response", {
|
||||||
|
{"tool", message.at("name")},
|
||||||
|
{"content", message.at("content")},
|
||||||
|
}},
|
||||||
|
};
|
||||||
|
if (message.contains("tool_call_id")) {
|
||||||
|
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
|
||||||
|
}
|
||||||
|
message["content"] = obj.dump(2);
|
||||||
|
message.erase("name");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!message["content"].is_null() && !supports_system_role_) {
|
||||||
|
std::string content = message.at("content");
|
||||||
|
if (role == "system") {
|
||||||
|
if (!pending_system.empty()) pending_system += "\n";
|
||||||
|
pending_system += content;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
if (role == "user") {
|
||||||
|
if (!pending_system.empty()) {
|
||||||
|
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
|
||||||
|
pending_system.clear();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
flush_sys();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
actual_messages.push_back(message);
|
||||||
|
}
|
||||||
|
flush_sys();
|
||||||
|
} else {
|
||||||
|
actual_messages = messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto context = minja::Context::make(json({
|
||||||
|
{"messages", actual_messages},
|
||||||
|
{"add_generation_prompt", add_generation_prompt},
|
||||||
|
{"bos_token", bos_token_},
|
||||||
|
{"eos_token", eos_token_},
|
||||||
|
}));
|
||||||
|
|
||||||
|
if (!tools.is_null()) {
|
||||||
|
auto tools_val = minja::Value(tools);
|
||||||
|
context->set("tools", tools_val);
|
||||||
|
}
|
||||||
|
if (!extra_context.is_null()) {
|
||||||
|
for (auto & kv : extra_context.items()) {
|
||||||
|
minja::Value val(kv.value());
|
||||||
|
context->set(kv.key(), val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return template_root_->render(context);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace minja
|
|
@ -12,6 +12,7 @@
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "chat-template.hpp"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
|
@ -483,6 +484,48 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||||
s = std::move(builder);
|
s = std::move(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
|
||||||
|
std::ostringstream result;
|
||||||
|
for (size_t i = 0; i < values.size(); ++i) {
|
||||||
|
if (i > 0) {
|
||||||
|
result << separator;
|
||||||
|
}
|
||||||
|
result << values[i];
|
||||||
|
}
|
||||||
|
return result.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
|
||||||
|
std::vector<std::string> parts;
|
||||||
|
size_t start = 0;
|
||||||
|
size_t end = str.find(delimiter);
|
||||||
|
|
||||||
|
while (end != std::string::npos) {
|
||||||
|
parts.push_back(str.substr(start, end - start));
|
||||||
|
start = end + delimiter.length();
|
||||||
|
end = str.find(delimiter, start);
|
||||||
|
}
|
||||||
|
|
||||||
|
parts.push_back(str.substr(start));
|
||||||
|
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string string_repeat(const std::string & str, size_t n) {
|
||||||
|
if (n == 0) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string result;
|
||||||
|
result.reserve(str.length() * n);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n; ++i) {
|
||||||
|
result += str;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
std::string string_from(bool value) {
|
std::string string_from(bool value) {
|
||||||
return value ? "true" : "false";
|
return value ? "true" : "false";
|
||||||
}
|
}
|
||||||
|
@ -1728,67 +1771,75 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
|
||||||
// Chat template utils
|
// Chat template utils
|
||||||
//
|
//
|
||||||
|
|
||||||
std::string common_get_builtin_chat_template(const struct llama_model * model) {
|
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||||
const char * ptr_tmpl = llama_model_chat_template(model);
|
if (use_jinja) {
|
||||||
return ptr_tmpl == nullptr ? "" : ptr_tmpl;
|
try {
|
||||||
}
|
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
|
||||||
|
chat_template.apply({{
|
||||||
bool common_chat_verify_template(const std::string & tmpl) {
|
{"role", "user"},
|
||||||
|
{"content", "test"},
|
||||||
|
}}, json(), true);
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||||
return res >= 0;
|
return res >= 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_chat_apply_template(const struct llama_model * model,
|
std::string common_chat_apply_template(
|
||||||
const std::string & tmpl,
|
const common_chat_template & tmpl,
|
||||||
const std::vector<common_chat_msg> & msgs,
|
const std::vector<common_chat_msg> & msgs,
|
||||||
bool add_ass) {
|
bool add_ass,
|
||||||
|
bool use_jinja) {
|
||||||
|
if (use_jinja) {
|
||||||
|
auto messages = json::array();
|
||||||
|
for (const auto & msg : msgs) {
|
||||||
|
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
||||||
|
}
|
||||||
|
return tmpl.apply(messages, /* tools= */ json(), add_ass);
|
||||||
|
}
|
||||||
|
|
||||||
int alloc_size = 0;
|
int alloc_size = 0;
|
||||||
bool fallback = false; // indicate if we must fallback to default chatml
|
|
||||||
std::vector<llama_chat_message> chat;
|
std::vector<llama_chat_message> chat;
|
||||||
for (const auto & msg : msgs) {
|
for (const auto & msg : msgs) {
|
||||||
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||||
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str();
|
|
||||||
std::vector<char> buf(alloc_size);
|
std::vector<char> buf(alloc_size);
|
||||||
|
|
||||||
// run the first time to get the total output length
|
// run the first time to get the total output length
|
||||||
int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
|
|
||||||
// error: chat template is not supported
|
// error: chat template is not supported
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
if (ptr_tmpl != nullptr) {
|
// if the custom "tmpl" is not supported, we throw an error
|
||||||
// if the custom "tmpl" is not supported, we throw an error
|
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
throw std::runtime_error("this custom template is not supported");
|
||||||
throw std::runtime_error("this custom template is not supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the built-in template is not supported, we default to chatml
|
|
||||||
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
|
||||||
fallback = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// if it turns out that our buffer is too small, we resize it
|
// if it turns out that our buffer is too small, we resize it
|
||||||
if ((size_t) res > buf.size()) {
|
if ((size_t) res > buf.size()) {
|
||||||
buf.resize(res);
|
buf.resize(res);
|
||||||
res = llama_chat_apply_template(
|
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
fallback ? "chatml" : ptr_tmpl,
|
|
||||||
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string formatted_chat(buf.data(), res);
|
std::string formatted_chat(buf.data(), res);
|
||||||
return formatted_chat;
|
return formatted_chat;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_chat_format_single(const struct llama_model * model,
|
std::string common_chat_format_single(
|
||||||
const std::string & tmpl,
|
const common_chat_template & tmpl,
|
||||||
const std::vector<common_chat_msg> & past_msg,
|
const std::vector<common_chat_msg> & past_msg,
|
||||||
const common_chat_msg & new_msg,
|
const common_chat_msg & new_msg,
|
||||||
bool add_ass) {
|
bool add_ass,
|
||||||
|
bool use_jinja) {
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
|
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
|
||||||
std::vector<common_chat_msg> chat_new(past_msg);
|
std::vector<common_chat_msg> chat_new(past_msg);
|
||||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||||
|
@ -1796,21 +1847,74 @@ std::string common_chat_format_single(const struct llama_model * model,
|
||||||
};
|
};
|
||||||
// format chat with new_msg
|
// format chat with new_msg
|
||||||
chat_new.push_back(new_msg);
|
chat_new.push_back(new_msg);
|
||||||
auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
|
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
|
||||||
// get the diff part
|
// get the diff part
|
||||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_chat_format_example(const struct llama_model * model,
|
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
|
||||||
const std::string & tmpl) {
|
|
||||||
std::vector<common_chat_msg> msgs = {
|
std::vector<common_chat_msg> msgs = {
|
||||||
{"system", "You are a helpful assistant"},
|
{"system", "You are a helpful assistant"},
|
||||||
{"user", "Hello"},
|
{"user", "Hello"},
|
||||||
{"assistant", "Hi there"},
|
{"assistant", "Hi there"},
|
||||||
{"user", "How are you?"},
|
{"user", "How are you?"},
|
||||||
};
|
};
|
||||||
return common_chat_apply_template(model, tmpl, msgs, true);
|
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
|
||||||
|
{
|
||||||
|
auto vocab = llama_model_get_vocab(model);
|
||||||
|
std::string default_template_src = chat_template_override;
|
||||||
|
std::string template_tool_use_src = chat_template_override;
|
||||||
|
bool has_explicit_template = !chat_template_override.empty();
|
||||||
|
if (chat_template_override.empty()) {
|
||||||
|
auto str = llama_model_chat_template(model, /* name */ nullptr);
|
||||||
|
if (str) {
|
||||||
|
default_template_src = str;
|
||||||
|
has_explicit_template = true;
|
||||||
|
}
|
||||||
|
str = llama_model_chat_template(model, /* name */ "tool_use");
|
||||||
|
if (str) {
|
||||||
|
template_tool_use_src = str;
|
||||||
|
has_explicit_template = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (default_template_src.empty() || default_template_src == "chatml") {
|
||||||
|
if (!template_tool_use_src.empty()) {
|
||||||
|
default_template_src = template_tool_use_src;
|
||||||
|
} else {
|
||||||
|
default_template_src = R"(
|
||||||
|
{%- for message in messages -%}
|
||||||
|
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
{{- "<|im_start|>assistant\n" -}}
|
||||||
|
{%- endif -%}
|
||||||
|
)";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
||||||
|
if (token == LLAMA_TOKEN_NULL) {
|
||||||
|
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
||||||
|
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
||||||
|
LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
|
||||||
|
}
|
||||||
|
return std::string();
|
||||||
|
} else {
|
||||||
|
return common_token_to_piece(vocab, token, true);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
||||||
|
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
||||||
|
return {
|
||||||
|
has_explicit_template,
|
||||||
|
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
|
||||||
|
template_tool_use_src.empty()
|
||||||
|
? nullptr
|
||||||
|
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos)
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
|
@ -176,7 +176,11 @@ struct common_params_speculative {
|
||||||
struct cpu_params cpuparams;
|
struct cpu_params cpuparams;
|
||||||
struct cpu_params cpuparams_batch;
|
struct cpu_params cpuparams_batch;
|
||||||
|
|
||||||
std::string model = ""; // draft model for speculative decoding // NOLINT
|
std::string hf_repo = ""; // HF repo // NOLINT
|
||||||
|
std::string hf_file = ""; // HF file // NOLINT
|
||||||
|
|
||||||
|
std::string model = ""; // draft model for speculative decoding // NOLINT
|
||||||
|
std::string model_url = ""; // model url to download // NOLINT
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_params_vocoder {
|
struct common_params_vocoder {
|
||||||
|
@ -185,6 +189,8 @@ struct common_params_vocoder {
|
||||||
|
|
||||||
std::string model = ""; // model path // NOLINT
|
std::string model = ""; // model path // NOLINT
|
||||||
std::string model_url = ""; // model url to download // NOLINT
|
std::string model_url = ""; // model url to download // NOLINT
|
||||||
|
|
||||||
|
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_params {
|
struct common_params {
|
||||||
|
@ -329,6 +335,7 @@ struct common_params {
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::string public_path = ""; // NOLINT
|
std::string public_path = ""; // NOLINT
|
||||||
std::string chat_template = ""; // NOLINT
|
std::string chat_template = ""; // NOLINT
|
||||||
|
bool use_jinja = false; // NOLINT
|
||||||
bool enable_chat_template = true;
|
bool enable_chat_template = true;
|
||||||
|
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
|
@ -423,6 +430,10 @@ std::string string_format(const char * fmt, ...);
|
||||||
std::string string_strip(const std::string & str);
|
std::string string_strip(const std::string & str);
|
||||||
std::string string_get_sortable_timestamp();
|
std::string string_get_sortable_timestamp();
|
||||||
|
|
||||||
|
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
|
||||||
|
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
|
||||||
|
std::string string_repeat(const std::string & str, size_t n);
|
||||||
|
|
||||||
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
|
@ -507,12 +518,14 @@ struct llama_model * common_load_model_from_url(
|
||||||
const std::string & local_path,
|
const std::string & local_path,
|
||||||
const std::string & hf_token,
|
const std::string & hf_token,
|
||||||
const struct llama_model_params & params);
|
const struct llama_model_params & params);
|
||||||
|
|
||||||
struct llama_model * common_load_model_from_hf(
|
struct llama_model * common_load_model_from_hf(
|
||||||
const std::string & repo,
|
const std::string & repo,
|
||||||
const std::string & remote_path,
|
const std::string & remote_path,
|
||||||
const std::string & local_path,
|
const std::string & local_path,
|
||||||
const std::string & hf_token,
|
const std::string & hf_token,
|
||||||
const struct llama_model_params & params);
|
const struct llama_model_params & params);
|
||||||
|
|
||||||
std::pair<std::string, std::string> common_get_hf_file(
|
std::pair<std::string, std::string> common_get_hf_file(
|
||||||
const std::string & hf_repo_with_tag,
|
const std::string & hf_repo_with_tag,
|
||||||
const std::string & hf_token);
|
const std::string & hf_token);
|
||||||
|
@ -596,30 +609,43 @@ struct common_chat_msg {
|
||||||
std::string content;
|
std::string content;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the built-in chat template for the model. Return empty string if not present.
|
|
||||||
std::string common_get_builtin_chat_template(const struct llama_model * model);
|
|
||||||
|
|
||||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||||
bool common_chat_verify_template(const std::string & tmpl);
|
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||||
|
|
||||||
|
namespace minja {
|
||||||
|
class chat_template;
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef minja::chat_template common_chat_template;
|
||||||
|
|
||||||
|
struct common_chat_templates {
|
||||||
|
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
||||||
|
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
||||||
|
std::unique_ptr<common_chat_template> template_tool_use;
|
||||||
|
};
|
||||||
|
|
||||||
// CPP wrapper for llama_chat_apply_template
|
// CPP wrapper for llama_chat_apply_template
|
||||||
// If the built-in template is not supported, we default to chatml
|
// If the built-in template is not supported, we default to chatml
|
||||||
// If the custom "tmpl" is not supported, we throw an error
|
// If the custom "tmpl" is not supported, we throw an error
|
||||||
std::string common_chat_apply_template(const struct llama_model * model,
|
std::string common_chat_apply_template(
|
||||||
const std::string & tmpl,
|
const common_chat_template & tmpl,
|
||||||
const std::vector<common_chat_msg> & chat,
|
const std::vector<common_chat_msg> & chat,
|
||||||
bool add_ass);
|
bool add_ass,
|
||||||
|
bool use_jinja);
|
||||||
|
|
||||||
// Format single message, while taking into account the position of that message in chat history
|
// Format single message, while taking into account the position of that message in chat history
|
||||||
std::string common_chat_format_single(const struct llama_model * model,
|
std::string common_chat_format_single(
|
||||||
const std::string & tmpl,
|
const common_chat_template & tmpl,
|
||||||
const std::vector<common_chat_msg> & past_msg,
|
const std::vector<common_chat_msg> & past_msg,
|
||||||
const common_chat_msg & new_msg,
|
const common_chat_msg & new_msg,
|
||||||
bool add_ass);
|
bool add_ass,
|
||||||
|
bool use_jinja);
|
||||||
|
|
||||||
// Returns an example of formatted chat
|
// Returns an example of formatted chat
|
||||||
std::string common_chat_format_example(const struct llama_model * model,
|
std::string common_chat_format_example(
|
||||||
const std::string & tmpl);
|
const common_chat_template & tmpl, bool use_jinja);
|
||||||
|
|
||||||
|
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache utils
|
// KV cache utils
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
@ -11,11 +13,6 @@
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
template <typename Iterator>
|
|
||||||
static std::string join(Iterator begin, Iterator end, const std::string & separator);
|
|
||||||
|
|
||||||
static std::string repeat(const std::string & str, size_t n);
|
|
||||||
|
|
||||||
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
|
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
|
||||||
auto has_max = max_items != std::numeric_limits<int>::max();
|
auto has_max = max_items != std::numeric_limits<int>::max();
|
||||||
|
|
||||||
|
@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||||
if (sub_len > 0) {
|
if (sub_len > 0) {
|
||||||
auto from_sub = from.substr(i + 1);
|
auto from_sub = from.substr(i + 1);
|
||||||
auto to_sub = to.substr(i + 1);
|
auto to_sub = to.substr(i + 1);
|
||||||
auto sub_zeros = repeat("0", sub_len);
|
auto sub_zeros = string_repeat("0", sub_len);
|
||||||
auto sub_nines = repeat("9", sub_len);
|
auto sub_nines = string_repeat("9", sub_len);
|
||||||
|
|
||||||
auto to_reached = false;
|
auto to_reached = false;
|
||||||
out << "(";
|
out << "(";
|
||||||
|
@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||||
auto max_digits = max_s.length();
|
auto max_digits = max_s.length();
|
||||||
|
|
||||||
for (auto digits = min_digits; digits < max_digits; digits++) {
|
for (auto digits = min_digits; digits < max_digits; digits++) {
|
||||||
uniform_range(min_s, repeat("9", digits));
|
uniform_range(min_s, string_repeat("9", digits));
|
||||||
min_s = "1" + repeat("0", digits);
|
min_s = "1" + string_repeat("0", digits);
|
||||||
out << " | ";
|
out << " | ";
|
||||||
}
|
}
|
||||||
uniform_range(min_s, max_s);
|
uniform_range(min_s, max_s);
|
||||||
|
@ -318,49 +315,6 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
||||||
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
||||||
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
|
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
|
||||||
|
|
||||||
template <typename Iterator>
|
|
||||||
std::string join(Iterator begin, Iterator end, const std::string & separator) {
|
|
||||||
std::ostringstream result;
|
|
||||||
if (begin != end) {
|
|
||||||
result << *begin;
|
|
||||||
for (Iterator it = begin + 1; it != end; ++it) {
|
|
||||||
result << separator << *it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
|
|
||||||
std::vector<std::string> tokens;
|
|
||||||
size_t start = 0;
|
|
||||||
size_t end = str.find(delimiter);
|
|
||||||
|
|
||||||
while (end != std::string::npos) {
|
|
||||||
tokens.push_back(str.substr(start, end - start));
|
|
||||||
start = end + delimiter.length();
|
|
||||||
end = str.find(delimiter, start);
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens.push_back(str.substr(start));
|
|
||||||
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string repeat(const std::string & str, size_t n) {
|
|
||||||
if (n == 0) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string result;
|
|
||||||
result.reserve(str.length() * n);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < n; ++i) {
|
|
||||||
result += str;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
|
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
std::string result;
|
std::string result;
|
||||||
|
@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) {
|
||||||
|
|
||||||
class SchemaConverter {
|
class SchemaConverter {
|
||||||
private:
|
private:
|
||||||
|
friend std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
|
||||||
std::function<json(const std::string &)> _fetch_json;
|
std::function<json(const std::string &)> _fetch_json;
|
||||||
bool _dotall;
|
bool _dotall;
|
||||||
std::map<std::string, std::string> _rules;
|
std::map<std::string, std::string> _rules;
|
||||||
|
@ -418,7 +373,7 @@ private:
|
||||||
for (size_t i = 0; i < alt_schemas.size(); i++) {
|
for (size_t i = 0; i < alt_schemas.size(); i++) {
|
||||||
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
|
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
|
||||||
}
|
}
|
||||||
return join(rules.begin(), rules.end(), " | ");
|
return string_join(rules, " | ");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
|
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
|
||||||
|
@ -481,7 +436,7 @@ private:
|
||||||
for (const auto & item : ret) {
|
for (const auto & item : ret) {
|
||||||
results.push_back(to_rule(item));
|
results.push_back(to_rule(item));
|
||||||
}
|
}
|
||||||
return std::make_pair(join(results.begin(), results.end(), " "), false);
|
return std::make_pair(string_join(results, " "), false);
|
||||||
};
|
};
|
||||||
|
|
||||||
while (i < length) {
|
while (i < length) {
|
||||||
|
@ -539,7 +494,7 @@ private:
|
||||||
}
|
}
|
||||||
curly_brackets += '}';
|
curly_brackets += '}';
|
||||||
i++;
|
i++;
|
||||||
auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
|
auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
|
||||||
int min_times = 0;
|
int min_times = 0;
|
||||||
int max_times = std::numeric_limits<int>::max();
|
int max_times = std::numeric_limits<int>::max();
|
||||||
try {
|
try {
|
||||||
|
@ -854,7 +809,7 @@ public:
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::string pointer = ref.substr(ref.find('#') + 1);
|
std::string pointer = ref.substr(ref.find('#') + 1);
|
||||||
std::vector<std::string> tokens = split(pointer, "/");
|
std::vector<std::string> tokens = string_split(pointer, "/");
|
||||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||||
std::string sel = tokens[i];
|
std::string sel = tokens[i];
|
||||||
if (target.is_null() || !target.contains(sel)) {
|
if (target.is_null() || !target.contains(sel)) {
|
||||||
|
@ -905,7 +860,7 @@ public:
|
||||||
for (const auto & v : schema["enum"]) {
|
for (const auto & v : schema["enum"]) {
|
||||||
enum_values.push_back(_generate_constant_rule(v));
|
enum_values.push_back(_generate_constant_rule(v));
|
||||||
}
|
}
|
||||||
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
|
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
|
||||||
} else if ((schema_type.is_null() || schema_type == "object")
|
} else if ((schema_type.is_null() || schema_type == "object")
|
||||||
&& (schema.contains("properties") ||
|
&& (schema.contains("properties") ||
|
||||||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
|
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
|
||||||
|
@ -1019,10 +974,10 @@ public:
|
||||||
|
|
||||||
void check_errors() {
|
void check_errors() {
|
||||||
if (!_errors.empty()) {
|
if (!_errors.empty()) {
|
||||||
throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
|
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
|
||||||
}
|
}
|
||||||
if (!_warnings.empty()) {
|
if (!_warnings.empty()) {
|
||||||
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
|
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1036,10 +991,27 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const json & schema) {
|
std::string json_schema_to_grammar(const json & schema) {
|
||||||
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
|
return build_grammar([&](const llama_grammar_builder & callbacks) {
|
||||||
auto copy = schema;
|
auto copy = schema;
|
||||||
converter.resolve_refs(copy, "input");
|
callbacks.resolve_refs(copy);
|
||||||
converter.visit(copy, "");
|
callbacks.add_schema("", copy);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
|
||||||
|
SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false);
|
||||||
|
llama_grammar_builder builder {
|
||||||
|
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||||
|
return converter._add_rule(name, rule);
|
||||||
|
},
|
||||||
|
/* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
|
||||||
|
return converter.visit(schema, name == "root" ? "" : name);
|
||||||
|
},
|
||||||
|
/* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
|
||||||
|
converter.resolve_refs(schema, "");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
cb(builder);
|
||||||
converter.check_errors();
|
converter.check_errors();
|
||||||
return converter.format_grammar();
|
return converter.format_grammar();
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,4 +5,12 @@
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
|
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
|
||||||
|
|
||||||
|
struct llama_grammar_builder {
|
||||||
|
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
||||||
|
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
||||||
|
std::function<void(nlohmann::ordered_json &)> resolve_refs;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
|
||||||
|
|
2788
common/minja.hpp
Normal file
2788
common/minja.hpp
Normal file
File diff suppressed because it is too large
Load diff
|
@ -751,6 +751,9 @@ class Model:
|
||||||
if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5":
|
if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5":
|
||||||
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3
|
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3
|
||||||
res = "deepseek-v3"
|
res = "deepseek-v3"
|
||||||
|
if chkhsh == "b3f499bb4255f8ca19fccd664443283318f2fd2414d5e0b040fbdd0cc195d6c5":
|
||||||
|
# ref: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
||||||
|
res = "deepseek-r1-qwen"
|
||||||
|
|
||||||
if res is None:
|
if res is None:
|
||||||
logger.warning("\n")
|
logger.warning("\n")
|
||||||
|
|
|
@ -65,49 +65,50 @@ else:
|
||||||
|
|
||||||
# TODO: add models here, base models preferred
|
# TODO: add models here, base models preferred
|
||||||
models = [
|
models = [
|
||||||
{"name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", },
|
{"name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", },
|
||||||
{"name": "llama-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", },
|
{"name": "llama-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", },
|
||||||
{"name": "phi-3", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", },
|
{"name": "phi-3", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", },
|
||||||
{"name": "deepseek-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", },
|
{"name": "deepseek-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", },
|
||||||
{"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", },
|
{"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", },
|
||||||
{"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
|
{"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
|
||||||
{"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
|
{"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
|
||||||
{"name": "falcon3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon3-7B-Base", },
|
{"name": "falcon3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon3-7B-Base", },
|
||||||
{"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", },
|
{"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", },
|
||||||
{"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
|
{"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
|
||||||
{"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },
|
{"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },
|
||||||
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
|
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
|
||||||
{"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", },
|
{"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", },
|
||||||
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
|
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
|
||||||
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
|
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
|
||||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
|
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
|
||||||
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
|
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
|
||||||
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
|
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
|
||||||
{"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", },
|
{"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", },
|
||||||
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
|
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
|
||||||
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
|
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
|
||||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
|
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
|
||||||
{"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", },
|
{"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", },
|
||||||
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
|
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
|
||||||
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
|
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
|
||||||
{"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B
|
{"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B
|
||||||
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
|
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
|
||||||
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
|
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
|
||||||
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
|
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
|
||||||
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
|
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
|
||||||
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
|
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
|
||||||
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
|
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
|
||||||
{"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", },
|
{"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", },
|
||||||
{'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", },
|
{'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", },
|
||||||
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
|
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
|
||||||
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
|
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
|
||||||
{"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", },
|
{"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", },
|
||||||
{"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", },
|
{"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", },
|
||||||
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", },
|
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", },
|
||||||
{"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
|
{"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
|
||||||
{"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"},
|
{"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"},
|
||||||
{"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"},
|
{"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"},
|
||||||
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
|
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
|
||||||
|
{"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -345,8 +345,18 @@ struct lora_merge_ctx {
|
||||||
gf = ggml_new_graph(ctx0);
|
gf = ggml_new_graph(ctx0);
|
||||||
struct ggml_tensor * cur = inp_base;
|
struct ggml_tensor * cur = inp_base;
|
||||||
for (size_t i = 0; i < adapters.size(); ++i) {
|
for (size_t i = 0; i < adapters.size(); ++i) {
|
||||||
struct ggml_tensor * a_T = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32)));
|
struct ggml_tensor * delta;
|
||||||
struct ggml_tensor * delta = ggml_mul_mat(ctx0, a_T, ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32));
|
bool is_tok_embd = string_starts_with(name_base, "token_embd");
|
||||||
|
if (is_tok_embd) {
|
||||||
|
printf("%s : detected token embeddings tensor\n", __func__);
|
||||||
|
delta = ggml_mul_mat(ctx0,
|
||||||
|
ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32),
|
||||||
|
ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32));
|
||||||
|
} else {
|
||||||
|
delta = ggml_mul_mat(ctx0,
|
||||||
|
ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32))),
|
||||||
|
ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32));
|
||||||
|
}
|
||||||
// scale
|
// scale
|
||||||
const float alpha = adapters[i]->alpha;
|
const float alpha = adapters[i]->alpha;
|
||||||
const float rank = (float) inp_b[i]->ne[0];
|
const float rank = (float) inp_b[i]->ne[0];
|
||||||
|
|
46
examples/llava/README-minicpmo2.6.md
Normal file
46
examples/llava/README-minicpmo2.6.md
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
## MiniCPM-o 2.6
|
||||||
|
Currently, this readme only supports minicpm-omni's image capabilities, and we will update the full-mode support as soon as possible.
|
||||||
|
|
||||||
|
### Prepare models and code
|
||||||
|
|
||||||
|
Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder.
|
||||||
|
|
||||||
|
Clone llama.cpp:
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:OpenBMB/llama.cpp.git
|
||||||
|
cd llama.cpp
|
||||||
|
git checkout minicpm-omni
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage of MiniCPM-o 2.6
|
||||||
|
|
||||||
|
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-o-2_6
|
||||||
|
python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4
|
||||||
|
python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model
|
||||||
|
|
||||||
|
# quantize int4 version
|
||||||
|
./llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||||
|
```
|
||||||
|
|
||||||
|
Build llama.cpp using `CMake`:
|
||||||
|
https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -B build
|
||||||
|
cmake --build build --config Release
|
||||||
|
```
|
||||||
|
|
||||||
|
Inference on Linux or Mac
|
||||||
|
```
|
||||||
|
# run f16 version
|
||||||
|
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||||
|
|
||||||
|
# run quantized int4 version
|
||||||
|
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||||
|
|
||||||
|
# or run in interactive mode
|
||||||
|
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
|
||||||
|
```
|
|
@ -718,6 +718,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
else if (ctx->minicpmv_version == 3) {
|
else if (ctx->minicpmv_version == 3) {
|
||||||
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
|
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
|
||||||
}
|
}
|
||||||
|
else if (ctx->minicpmv_version == 4) {
|
||||||
|
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
|
||||||
|
}
|
||||||
ggml_set_name(pos_embed, "pos_embed");
|
ggml_set_name(pos_embed, "pos_embed");
|
||||||
ggml_set_input(pos_embed);
|
ggml_set_input(pos_embed);
|
||||||
}
|
}
|
||||||
|
@ -1053,6 +1056,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
n_head = hidden_size/d_head;
|
n_head = hidden_size/d_head;
|
||||||
num_query = 64;
|
num_query = 64;
|
||||||
}
|
}
|
||||||
|
else if (ctx->minicpmv_version == 4) {
|
||||||
|
hidden_size = 3584;
|
||||||
|
n_head = hidden_size/d_head;
|
||||||
|
num_query = 64;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
|
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
|
||||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||||
|
@ -2041,6 +2049,7 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
|
||||||
images[images.size()-1].push_back(patch);
|
images[images.size()-1].push_back(patch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
clip_image_u8_free(refine_image);
|
||||||
}
|
}
|
||||||
return images;
|
return images;
|
||||||
}
|
}
|
||||||
|
@ -2079,6 +2088,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
||||||
clip_image_f32_free(res);
|
clip_image_f32_free(res);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (size_t i = 0; i < imgs.size(); ++i) {
|
||||||
|
for (size_t j = 0; j < imgs[i].size(); ++j) {
|
||||||
|
if (imgs[i][j] != nullptr) {
|
||||||
|
clip_image_u8_free(imgs[i][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
else if (ctx->has_qwen2vl_merger) {
|
else if (ctx->has_qwen2vl_merger) {
|
||||||
|
@ -2335,6 +2351,9 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
|
||||||
else if (ctx->minicpmv_version == 3) {
|
else if (ctx->minicpmv_version == 3) {
|
||||||
n_patches = 64;
|
n_patches = 64;
|
||||||
}
|
}
|
||||||
|
else if (ctx->minicpmv_version == 4) {
|
||||||
|
n_patches = 64;
|
||||||
|
}
|
||||||
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||||
int patch_size = params.patch_size * 2;
|
int patch_size = params.patch_size * 2;
|
||||||
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
||||||
|
@ -2514,8 +2533,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
|
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
|
||||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||||
int bucket_coords_h[70];
|
int bucket_coords_h[1024];
|
||||||
int bucket_coords_w[70];
|
int bucket_coords_w[1024];
|
||||||
for (int i = 0; i < pos_h; i++){
|
for (int i = 0; i < pos_h; i++){
|
||||||
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
|
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
|
||||||
}
|
}
|
||||||
|
@ -2543,6 +2562,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
else if (ctx->minicpmv_version == 3) {
|
else if (ctx->minicpmv_version == 3) {
|
||||||
embed_dim = 3584;
|
embed_dim = 3584;
|
||||||
}
|
}
|
||||||
|
else if (ctx->minicpmv_version == 4) {
|
||||||
|
embed_dim = 3584;
|
||||||
|
}
|
||||||
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
|
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
|
||||||
|
|
||||||
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
|
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
|
||||||
|
@ -2786,6 +2808,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||||
else if (ctx->minicpmv_version == 3) {
|
else if (ctx->minicpmv_version == 3) {
|
||||||
return 3584;
|
return 3584;
|
||||||
}
|
}
|
||||||
|
else if (ctx->minicpmv_version == 4) {
|
||||||
|
return 3584;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||||
return ctx->vision_model.mm_1_b->ne[0];
|
return ctx->vision_model.mm_1_b->ne[0];
|
||||||
|
|
|
@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) {
|
static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) {
|
||||||
int width = image->nx;
|
int width = image->nx;
|
||||||
int height = image->ny;
|
int height = image->ny;
|
||||||
int num_patches = (height / patch_size) * (width / patch_size);
|
int num_patches = (height / patch_size) * (width / patch_size);
|
||||||
|
@ -277,13 +277,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||||
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
|
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
|
encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
|
||||||
if (has_minicpmv_projector == 2) {
|
|
||||||
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
|
|
||||||
}
|
|
||||||
else if (has_minicpmv_projector == 3) {
|
|
||||||
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!encoded) {
|
if (!encoded) {
|
||||||
|
@ -313,6 +307,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||||
load_image_size->height = img->ny;
|
load_image_size->height = img->ny;
|
||||||
clip_add_load_image_size(ctx_clip, load_image_size);
|
clip_add_load_image_size(ctx_clip, load_image_size);
|
||||||
LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height);
|
LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height);
|
||||||
|
delete[] img_res_v.data;
|
||||||
|
img_res_v.size = 0;
|
||||||
|
img_res_v.data = nullptr;
|
||||||
}
|
}
|
||||||
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
|
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
|
||||||
// flat / default llava-1.5 type embedding
|
// flat / default llava-1.5 type embedding
|
||||||
|
|
|
@ -140,6 +140,9 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
|
||||||
else if (has_minicpmv_projector == 3) {
|
else if (has_minicpmv_projector == 3) {
|
||||||
system_prompt = "<|im_start|>user\n";
|
system_prompt = "<|im_start|>user\n";
|
||||||
}
|
}
|
||||||
|
else if (has_minicpmv_projector == 4) {
|
||||||
|
system_prompt = "<|im_start|>user\n";
|
||||||
|
}
|
||||||
LOG_INF("%s: image token past: %d\n", __func__, n_past);
|
LOG_INF("%s: image token past: %d\n", __func__, n_past);
|
||||||
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
|
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
|
||||||
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
|
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
|
||||||
|
@ -227,6 +230,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
|
||||||
else if (has_minicpmv_projector == 3) {
|
else if (has_minicpmv_projector == 3) {
|
||||||
user_prompt = "<|im_start|>user\n" + prompt;
|
user_prompt = "<|im_start|>user\n" + prompt;
|
||||||
}
|
}
|
||||||
|
else if (has_minicpmv_projector == 4) {
|
||||||
|
user_prompt = "<|im_start|>user\n" + prompt;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
|
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
|
||||||
|
@ -236,6 +242,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
|
||||||
else if (has_minicpmv_projector == 3) {
|
else if (has_minicpmv_projector == 3) {
|
||||||
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
|
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
|
||||||
}
|
}
|
||||||
|
else if (has_minicpmv_projector == 4) {
|
||||||
|
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
|
||||||
|
}
|
||||||
|
|
||||||
// generate the response
|
// generate the response
|
||||||
|
|
||||||
|
@ -308,7 +317,6 @@ int main(int argc, char ** argv) {
|
||||||
const auto * tmp = llama_loop(ctx_llava, smpl, n_past);
|
const auto * tmp = llama_loop(ctx_llava, smpl, n_past);
|
||||||
response += tmp;
|
response += tmp;
|
||||||
if (strcmp(tmp, "</s>") == 0) break;
|
if (strcmp(tmp, "</s>") == 0) break;
|
||||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
|
||||||
printf("%s", tmp);// mistral llava-1.6
|
printf("%s", tmp);// mistral llava-1.6
|
||||||
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
|
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
|
@ -501,7 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||||
default_image_std = [0.26862954, 0.26130258, 0.27577711]
|
default_image_std = [0.26862954, 0.26130258, 0.27577711]
|
||||||
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
|
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
|
||||||
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
|
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
|
||||||
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2)
|
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2)
|
||||||
|
|
||||||
# with proper
|
# with proper
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
@ -545,12 +545,19 @@ if args.use_f32:
|
||||||
|
|
||||||
minicpmv_version = args.minicpmv_version
|
minicpmv_version = args.minicpmv_version
|
||||||
emb_dim = 4096
|
emb_dim = 4096
|
||||||
|
block_count = 26
|
||||||
if minicpmv_version == 1:
|
if minicpmv_version == 1:
|
||||||
emb_dim = 2304
|
emb_dim = 2304
|
||||||
|
block_count = 26
|
||||||
elif minicpmv_version == 2:
|
elif minicpmv_version == 2:
|
||||||
emb_dim = 4096
|
emb_dim = 4096
|
||||||
|
block_count = 27
|
||||||
elif minicpmv_version == 3:
|
elif minicpmv_version == 3:
|
||||||
emb_dim = 3584
|
emb_dim = 3584
|
||||||
|
block_count = 27
|
||||||
|
elif minicpmv_version == 4:
|
||||||
|
emb_dim = 3584
|
||||||
|
block_count = 27
|
||||||
|
|
||||||
default_vision_config = {
|
default_vision_config = {
|
||||||
"hidden_size": 1152,
|
"hidden_size": 1152,
|
||||||
|
@ -567,6 +574,9 @@ model = Idefics2VisionTransformer(vision_config)
|
||||||
if minicpmv_version == 3:
|
if minicpmv_version == 3:
|
||||||
vision_config = SiglipVisionConfig(**default_vision_config)
|
vision_config = SiglipVisionConfig(**default_vision_config)
|
||||||
model = SiglipVisionTransformer(vision_config)
|
model = SiglipVisionTransformer(vision_config)
|
||||||
|
elif minicpmv_version == 4:
|
||||||
|
vision_config = SiglipVisionConfig(**default_vision_config)
|
||||||
|
model = SiglipVisionTransformer(vision_config)
|
||||||
|
|
||||||
processor = None
|
processor = None
|
||||||
# if model.attn_pool is not None:
|
# if model.attn_pool is not None:
|
||||||
|
@ -587,7 +597,7 @@ elif args.minicpmv_projector is not None:
|
||||||
fname_middle = "mmproj-"
|
fname_middle = "mmproj-"
|
||||||
has_text_encoder = False
|
has_text_encoder = False
|
||||||
has_minicpmv_projector = True
|
has_minicpmv_projector = True
|
||||||
minicpmv_version = 3
|
minicpmv_version = 4
|
||||||
elif args.vision_only:
|
elif args.vision_only:
|
||||||
fname_middle = "vision-"
|
fname_middle = "vision-"
|
||||||
has_text_encoder = False
|
has_text_encoder = False
|
||||||
|
@ -625,7 +635,6 @@ if has_vision_encoder:
|
||||||
fout.add_uint32("clip.vision.projection_dim", 0)
|
fout.add_uint32("clip.vision.projection_dim", 0)
|
||||||
fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16)
|
fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16)
|
||||||
fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
|
fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
|
||||||
block_count = 26
|
|
||||||
fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count)
|
fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count)
|
||||||
|
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
|
|
|
@ -8,7 +8,7 @@ ap.add_argument("-m", "--model", help="Path to MiniCPM-V model")
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
# find the model part that includes the the multimodal projector weights
|
# find the model part that includes the the multimodal projector weights
|
||||||
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True)
|
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
checkpoint = model.state_dict()
|
checkpoint = model.state_dict()
|
||||||
|
|
||||||
# get a list of mm tensor names
|
# get a list of mm tensor names
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "chat-template.hpp"
|
||||||
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
@ -84,14 +85,6 @@ static void sigint_handler(int signo) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
|
|
||||||
common_chat_msg new_msg{role, content};
|
|
||||||
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
|
||||||
chat_msgs.push_back({role, content});
|
|
||||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
|
||||||
return formatted;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
common_params params;
|
common_params params;
|
||||||
g_params = ¶ms;
|
g_params = ¶ms;
|
||||||
|
@ -165,6 +158,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
auto chat_templates = common_chat_templates_from_model(model, params.chat_template);
|
||||||
|
|
||||||
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
|
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
|
||||||
|
|
||||||
|
@ -207,7 +201,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// auto enable conversation mode if chat template is available
|
// auto enable conversation mode if chat template is available
|
||||||
const bool has_chat_template = !common_get_builtin_chat_template(model).empty() || !params.chat_template.empty();
|
const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default;
|
||||||
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
|
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
|
||||||
if (has_chat_template) {
|
if (has_chat_template) {
|
||||||
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
|
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
|
||||||
|
@ -225,7 +219,7 @@ int main(int argc, char ** argv) {
|
||||||
// print chat template example in conversation mode
|
// print chat template example in conversation mode
|
||||||
if (params.conversation_mode) {
|
if (params.conversation_mode) {
|
||||||
if (params.enable_chat_template) {
|
if (params.enable_chat_template) {
|
||||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
|
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
|
||||||
} else {
|
} else {
|
||||||
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
||||||
}
|
}
|
||||||
|
@ -269,10 +263,18 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
std::vector<llama_token> embd_inp;
|
std::vector<llama_token> embd_inp;
|
||||||
|
|
||||||
|
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
|
||||||
|
common_chat_msg new_msg{role, content};
|
||||||
|
auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
|
||||||
|
chat_msgs.push_back({role, content});
|
||||||
|
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
||||||
|
return formatted;
|
||||||
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
auto prompt = (params.conversation_mode && params.enable_chat_template)
|
auto prompt = (params.conversation_mode && params.enable_chat_template)
|
||||||
// format the system prompt in conversation mode (fallback to default if empty)
|
// format the system prompt in conversation mode (fallback to default if empty)
|
||||||
? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
|
? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
|
||||||
// otherwise use the prompt as is
|
// otherwise use the prompt as is
|
||||||
: params.prompt;
|
: params.prompt;
|
||||||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
||||||
|
@ -779,7 +781,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.enable_chat_template) {
|
if (params.enable_chat_template) {
|
||||||
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
|
chat_add_and_format("assistant", assistant_ss.str());
|
||||||
}
|
}
|
||||||
is_interacting = true;
|
is_interacting = true;
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
|
@ -844,7 +846,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
bool format_chat = params.conversation_mode && params.enable_chat_template;
|
bool format_chat = params.conversation_mode && params.enable_chat_template;
|
||||||
std::string user_inp = format_chat
|
std::string user_inp = format_chat
|
||||||
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
|
? chat_add_and_format("user", std::move(buffer))
|
||||||
: std::move(buffer);
|
: std::move(buffer);
|
||||||
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
||||||
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
|
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
set(TARGET llama-run)
|
set(TARGET llama-run)
|
||||||
add_executable(${TARGET} run.cpp)
|
add_executable(${TARGET} run.cpp linenoise.cpp/linenoise.cpp)
|
||||||
install(TARGETS ${TARGET} RUNTIME)
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||||
|
|
26
examples/run/linenoise.cpp/LICENSE
Normal file
26
examples/run/linenoise.cpp/LICENSE
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
Copyright (c) 2010-2014, Salvatore Sanfilippo <antirez at gmail dot com>
|
||||||
|
Copyright (c) 2010-2013, Pieter Noordhuis <pcnoordhuis at gmail dot com>
|
||||||
|
Copyright (c) 2025, Eric Curtin <ericcurtin17 at gmail dot com>
|
||||||
|
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
* Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||||
|
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
1350
examples/run/linenoise.cpp/linenoise.cpp
Normal file
1350
examples/run/linenoise.cpp/linenoise.cpp
Normal file
File diff suppressed because it is too large
Load diff
128
examples/run/linenoise.cpp/linenoise.h
Normal file
128
examples/run/linenoise.cpp/linenoise.h
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
/* linenoise.h -- VERSION 1.0
|
||||||
|
*
|
||||||
|
* Guerrilla line editing library against the idea that a line editing lib
|
||||||
|
* needs to be 20,000 lines of C++ code.
|
||||||
|
*
|
||||||
|
* See linenoise.cpp for more information.
|
||||||
|
*
|
||||||
|
* ------------------------------------------------------------------------
|
||||||
|
*
|
||||||
|
* Copyright (c) 2010-2023, Salvatore Sanfilippo <antirez at gmail dot com>
|
||||||
|
* Copyright (c) 2010-2013, Pieter Noordhuis <pcnoordhuis at gmail dot com>
|
||||||
|
* Copyright (c) 2025, Eric Curtin <ericcurtin17 at gmail dot com>
|
||||||
|
*
|
||||||
|
* All rights reserved.
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are
|
||||||
|
* met:
|
||||||
|
*
|
||||||
|
* * Redistributions of source code must retain the above copyright
|
||||||
|
* notice, this list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* * Redistributions in binary form must reproduce the above copyright
|
||||||
|
* notice, this list of conditions and the following disclaimer in the
|
||||||
|
* documentation and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
* HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef __LINENOISE_H
|
||||||
|
#define __LINENOISE_H
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <stddef.h> /* For size_t. */
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
extern const char *linenoiseEditMore;
|
||||||
|
|
||||||
|
/* The linenoiseState structure represents the state during line editing.
|
||||||
|
* We pass this state to functions implementing specific editing
|
||||||
|
* functionalities. */
|
||||||
|
struct linenoiseState {
|
||||||
|
int in_completion; /* The user pressed TAB and we are now in completion
|
||||||
|
* mode, so input is handled by completeLine(). */
|
||||||
|
size_t completion_idx; /* Index of next completion to propose. */
|
||||||
|
int ifd; /* Terminal stdin file descriptor. */
|
||||||
|
int ofd; /* Terminal stdout file descriptor. */
|
||||||
|
char *buf; /* Edited line buffer. */
|
||||||
|
size_t buflen; /* Edited line buffer size. */
|
||||||
|
const char *prompt; /* Prompt to display. */
|
||||||
|
size_t plen; /* Prompt length. */
|
||||||
|
size_t pos; /* Current cursor position. */
|
||||||
|
size_t oldpos; /* Previous refresh cursor position. */
|
||||||
|
size_t len; /* Current edited line length. */
|
||||||
|
size_t cols; /* Number of columns in terminal. */
|
||||||
|
size_t oldrows; /* Rows used by last refrehsed line (multiline mode) */
|
||||||
|
int history_index; /* The history index we are currently editing. */
|
||||||
|
};
|
||||||
|
|
||||||
|
struct linenoiseCompletions {
|
||||||
|
size_t len = 0;
|
||||||
|
char ** cvec = nullptr;
|
||||||
|
bool to_free = true;
|
||||||
|
|
||||||
|
~linenoiseCompletions() {
|
||||||
|
if (!to_free) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
|
free(cvec[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
free(cvec);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Non blocking API. */
|
||||||
|
int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt);
|
||||||
|
const char *linenoiseEditFeed(struct linenoiseState *l);
|
||||||
|
void linenoiseEditStop(struct linenoiseState *l);
|
||||||
|
void linenoiseHide(struct linenoiseState *l);
|
||||||
|
void linenoiseShow(struct linenoiseState *l);
|
||||||
|
|
||||||
|
/* Blocking API. */
|
||||||
|
const char *linenoise(const char *prompt);
|
||||||
|
void linenoiseFree(void *ptr);
|
||||||
|
|
||||||
|
/* Completion API. */
|
||||||
|
typedef void(linenoiseCompletionCallback)(const char *, linenoiseCompletions *);
|
||||||
|
typedef const char*(linenoiseHintsCallback)(const char *, int *color, int *bold);
|
||||||
|
typedef void(linenoiseFreeHintsCallback)(const char *);
|
||||||
|
void linenoiseSetCompletionCallback(linenoiseCompletionCallback *);
|
||||||
|
void linenoiseSetHintsCallback(linenoiseHintsCallback *);
|
||||||
|
void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *);
|
||||||
|
void linenoiseAddCompletion(linenoiseCompletions *, const char *);
|
||||||
|
|
||||||
|
/* History API. */
|
||||||
|
int linenoiseHistoryAdd(const char *line);
|
||||||
|
int linenoiseHistorySetMaxLen(int len);
|
||||||
|
int linenoiseHistorySave(const char *filename);
|
||||||
|
int linenoiseHistoryLoad(const char *filename);
|
||||||
|
|
||||||
|
/* Other utilities. */
|
||||||
|
void linenoiseClearScreen(void);
|
||||||
|
void linenoiseSetMultiLine(int ml);
|
||||||
|
void linenoisePrintKeyCodes(void);
|
||||||
|
void linenoiseMaskModeEnable(void);
|
||||||
|
void linenoiseMaskModeDisable(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif /* __LINENOISE_H */
|
|
@ -19,13 +19,16 @@
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <list>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
|
#include "linenoise.cpp/linenoise.h"
|
||||||
#include "llama-cpp.h"
|
#include "llama-cpp.h"
|
||||||
|
#include "chat-template.hpp"
|
||||||
|
|
||||||
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
|
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
|
||||||
[[noreturn]] static void sigint_handler(int) {
|
[[noreturn]] static void sigint_handler(int) {
|
||||||
|
@ -103,6 +106,7 @@ class Opt {
|
||||||
llama_model_params model_params;
|
llama_model_params model_params;
|
||||||
std::string model_;
|
std::string model_;
|
||||||
std::string user;
|
std::string user;
|
||||||
|
bool use_jinja = false;
|
||||||
int context_size = -1, ngl = -1;
|
int context_size = -1, ngl = -1;
|
||||||
float temperature = -1;
|
float temperature = -1;
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
|
@ -154,6 +158,8 @@ class Opt {
|
||||||
} else if (options_parsing &&
|
} else if (options_parsing &&
|
||||||
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
|
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
|
||||||
verbose = true;
|
verbose = true;
|
||||||
|
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
|
||||||
|
use_jinja = true;
|
||||||
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
|
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
|
||||||
help = true;
|
help = true;
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -536,7 +542,7 @@ class LlamaData {
|
||||||
llama_sampler_ptr sampler;
|
llama_sampler_ptr sampler;
|
||||||
llama_context_ptr context;
|
llama_context_ptr context;
|
||||||
std::vector<llama_chat_message> messages;
|
std::vector<llama_chat_message> messages;
|
||||||
std::vector<std::string> msg_strs;
|
std::list<std::string> msg_strs;
|
||||||
std::vector<char> fmtted;
|
std::vector<char> fmtted;
|
||||||
|
|
||||||
int init(Opt & opt) {
|
int init(Opt & opt) {
|
||||||
|
@ -711,13 +717,31 @@ static void add_message(const char * role, const std::string & text, LlamaData &
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to apply the chat template and resize `formatted` if needed
|
// Function to apply the chat template and resize `formatted` if needed
|
||||||
static int apply_chat_template(LlamaData & llama_data, const bool append) {
|
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
|
||||||
|
if (use_jinja) {
|
||||||
|
json messages = json::array();
|
||||||
|
for (const auto & msg : llama_data.messages) {
|
||||||
|
messages.push_back({
|
||||||
|
{"role", msg.role},
|
||||||
|
{"content", msg.content},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
auto result = tmpl.apply(messages, /* tools= */ json(), append);
|
||||||
|
llama_data.fmtted.resize(result.size() + 1);
|
||||||
|
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
||||||
|
return result.size();
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
printe("failed to render the chat template: %s\n", e.what());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
int result = llama_chat_apply_template(
|
int result = llama_chat_apply_template(
|
||||||
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
|
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
|
||||||
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
|
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
|
||||||
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
|
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
|
||||||
llama_data.fmtted.resize(result);
|
llama_data.fmtted.resize(result);
|
||||||
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
|
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
|
||||||
llama_data.messages.size(), append, llama_data.fmtted.data(),
|
llama_data.messages.size(), append, llama_data.fmtted.data(),
|
||||||
llama_data.fmtted.size());
|
llama_data.fmtted.size());
|
||||||
}
|
}
|
||||||
|
@ -727,10 +751,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
|
||||||
|
|
||||||
// Function to tokenize the prompt
|
// Function to tokenize the prompt
|
||||||
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
|
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
|
||||||
std::vector<llama_token> & prompt_tokens) {
|
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
|
||||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
|
const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
|
||||||
|
|
||||||
|
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
|
||||||
prompt_tokens.resize(n_prompt_tokens);
|
prompt_tokens.resize(n_prompt_tokens);
|
||||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
|
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
|
||||||
true) < 0) {
|
true) < 0) {
|
||||||
printe("failed to tokenize the prompt\n");
|
printe("failed to tokenize the prompt\n");
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -776,7 +802,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
|
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
|
||||||
|
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
if (tokenize_prompt(vocab, prompt, tokens) < 0) {
|
if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -807,24 +833,44 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
||||||
batch = llama_batch_get_one(&new_token_id, 1);
|
batch = llama_batch_get_one(&new_token_id, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
printf("\033[0m");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int read_user_input(std::string & user) {
|
static int read_user_input(std::string & user_input) {
|
||||||
std::getline(std::cin, user);
|
static const char * prompt_prefix = "> ";
|
||||||
|
#ifdef WIN32
|
||||||
|
printf(
|
||||||
|
"\r%*s"
|
||||||
|
"\r\033[0m%s",
|
||||||
|
get_terminal_width(), " ", prompt_prefix);
|
||||||
|
|
||||||
|
std::getline(std::cin, user_input);
|
||||||
if (std::cin.eof()) {
|
if (std::cin.eof()) {
|
||||||
printf("\n");
|
printf("\n");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
if (user == "/bye") {
|
std::unique_ptr<char, decltype(&std::free)> line(const_cast<char *>(linenoise(prompt_prefix)), free);
|
||||||
|
if (!line) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (user.empty()) {
|
user_input = line.get();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (user_input == "/bye") {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (user_input.empty()) {
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef WIN32
|
||||||
|
linenoiseHistoryAdd(line.get());
|
||||||
|
#endif
|
||||||
|
|
||||||
return 0; // Should have data in happy path
|
return 0; // Should have data in happy path
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -847,8 +893,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to apply the chat template and handle errors
|
// Helper function to apply the chat template and handle errors
|
||||||
static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) {
|
static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
|
||||||
const int new_len = apply_chat_template(llama_data, append);
|
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
|
||||||
if (new_len < 0) {
|
if (new_len < 0) {
|
||||||
printe("failed to apply the chat template\n");
|
printe("failed to apply the chat template\n");
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -865,10 +911,6 @@ static int handle_user_input(std::string & user_input, const std::string & user)
|
||||||
return 0; // No need for interactive input
|
return 0; // No need for interactive input
|
||||||
}
|
}
|
||||||
|
|
||||||
printf(
|
|
||||||
"\r%*s"
|
|
||||||
"\r\033[32m> \033[0m",
|
|
||||||
get_terminal_width(), " ");
|
|
||||||
return read_user_input(user_input); // Returns true if input ends the loop
|
return read_user_input(user_input); // Returns true if input ends the loop
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -911,9 +953,11 @@ static int get_user_input(std::string & user_input, const std::string & user) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Main chat loop function
|
// Main chat loop function
|
||||||
static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
|
||||||
int prev_len = 0;
|
int prev_len = 0;
|
||||||
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
||||||
|
auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
|
||||||
|
GGML_ASSERT(chat_templates.template_default);
|
||||||
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
||||||
while (true) {
|
while (true) {
|
||||||
// Get user input
|
// Get user input
|
||||||
|
@ -924,7 +968,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
||||||
|
|
||||||
add_message("user", user.empty() ? user_input : user, llama_data);
|
add_message("user", user.empty() ? user_input : user, llama_data);
|
||||||
int new_len;
|
int new_len;
|
||||||
if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
|
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -939,7 +983,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
||||||
}
|
}
|
||||||
|
|
||||||
add_message("assistant", response, llama_data);
|
add_message("assistant", response, llama_data);
|
||||||
if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) {
|
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -999,7 +1043,7 @@ int main(int argc, const char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (chat_loop(llama_data, opt.user)) {
|
if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,7 @@ The project is under active development, and we are [looking for feedback and co
|
||||||
| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') |
|
| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') |
|
||||||
| `--grammar-file FNAME` | file to read grammar from |
|
| `--grammar-file FNAME` | file to read grammar from |
|
||||||
| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
|
| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
|
||||||
|
| `--jinja` | Enable experimental Jinja templating engine (needed for tool use) |
|
||||||
|
|
||||||
**Example-specific params**
|
**Example-specific params**
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -19,6 +19,7 @@
|
||||||
#include "loading.html.hpp"
|
#include "loading.html.hpp"
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <chrono>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
|
@ -32,6 +33,8 @@
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
constexpr int HTTP_POLLING_SECONDS = 1;
|
||||||
|
|
||||||
enum stop_type {
|
enum stop_type {
|
||||||
STOP_TYPE_NONE,
|
STOP_TYPE_NONE,
|
||||||
STOP_TYPE_EOS,
|
STOP_TYPE_EOS,
|
||||||
|
@ -264,6 +267,11 @@ struct server_task {
|
||||||
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
||||||
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
||||||
|
|
||||||
|
// Use OpenAI API logprobs only if n_probs wasn't provided
|
||||||
|
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
|
||||||
|
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
|
||||||
|
}
|
||||||
|
|
||||||
if (data.contains("lora")) {
|
if (data.contains("lora")) {
|
||||||
if (data.at("lora").is_array()) {
|
if (data.at("lora").is_array()) {
|
||||||
params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
|
params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
|
||||||
|
@ -1602,6 +1610,30 @@ struct server_response {
|
||||||
// should never reach here
|
// should never reach here
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// same as recv(), but have timeout in seconds
|
||||||
|
// if timeout is reached, nullptr is returned
|
||||||
|
server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
|
||||||
|
while (true) {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_results);
|
||||||
|
bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{
|
||||||
|
return !queue_results.empty();
|
||||||
|
});
|
||||||
|
if (!cr_res) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < (int) queue_results.size(); i++) {
|
||||||
|
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
||||||
|
server_task_result_ptr res = std::move(queue_results[i]);
|
||||||
|
queue_results.erase(queue_results.begin() + i);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// should never reach here
|
||||||
|
}
|
||||||
|
|
||||||
// single-task version of recv()
|
// single-task version of recv()
|
||||||
server_task_result_ptr recv(int id_task) {
|
server_task_result_ptr recv(int id_task) {
|
||||||
std::unordered_set<int> id_tasks = {id_task};
|
std::unordered_set<int> id_tasks = {id_task};
|
||||||
|
@ -1661,6 +1693,8 @@ struct server_context {
|
||||||
// Necessary similarity of prompt for slot selection
|
// Necessary similarity of prompt for slot selection
|
||||||
float slot_prompt_similarity = 0.0f;
|
float slot_prompt_similarity = 0.0f;
|
||||||
|
|
||||||
|
common_chat_templates chat_templates;
|
||||||
|
|
||||||
~server_context() {
|
~server_context() {
|
||||||
// Clear any sampling context
|
// Clear any sampling context
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
|
@ -1701,13 +1735,16 @@ struct server_context {
|
||||||
add_bos_token = llama_vocab_get_add_bos(vocab);
|
add_bos_token = llama_vocab_get_add_bos(vocab);
|
||||||
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
|
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
if (!params_base.speculative.model.empty()) {
|
if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) {
|
||||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
|
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
|
||||||
|
|
||||||
auto params_dft = params_base;
|
auto params_dft = params_base;
|
||||||
|
|
||||||
params_dft.devices = params_base.speculative.devices;
|
params_dft.devices = params_base.speculative.devices;
|
||||||
|
params_dft.hf_file = params_base.speculative.hf_file;
|
||||||
|
params_dft.hf_repo = params_base.speculative.hf_repo;
|
||||||
params_dft.model = params_base.speculative.model;
|
params_dft.model = params_base.speculative.model;
|
||||||
|
params_dft.model_url = params_base.speculative.model_url;
|
||||||
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
|
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
|
||||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||||
params_dft.n_parallel = 1;
|
params_dft.n_parallel = 1;
|
||||||
|
@ -1737,14 +1774,39 @@ struct server_context {
|
||||||
cparams_dft.type_v = GGML_TYPE_F16;
|
cparams_dft.type_v = GGML_TYPE_F16;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
|
||||||
|
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool validate_builtin_chat_template() const {
|
bool validate_builtin_chat_template(bool use_jinja) const {
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
const char * tmpl = llama_model_chat_template(model);
|
|
||||||
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
|
if (use_jinja) {
|
||||||
return chat_res > 0;
|
auto templates = common_chat_templates_from_model(model, "");
|
||||||
|
GGML_ASSERT(templates.template_default);
|
||||||
|
try {
|
||||||
|
templates.template_default->apply({{
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", "test"},
|
||||||
|
}}, json(), true);
|
||||||
|
if (templates.template_tool_use) {
|
||||||
|
templates.template_tool_use->apply({{
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", "test"},
|
||||||
|
}}, json(), true);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
SRV_ERR("failed to apply template: %s\n", e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
|
||||||
|
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
|
||||||
|
return chat_res > 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void init() {
|
void init() {
|
||||||
|
@ -2322,10 +2384,21 @@ struct server_context {
|
||||||
void receive_multi_results(
|
void receive_multi_results(
|
||||||
const std::unordered_set<int> & id_tasks,
|
const std::unordered_set<int> & id_tasks,
|
||||||
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
|
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
|
||||||
const std::function<void(json)> & error_handler) {
|
const std::function<void(json)> & error_handler,
|
||||||
|
const std::function<bool()> & is_connection_closed) {
|
||||||
std::vector<server_task_result_ptr> results(id_tasks.size());
|
std::vector<server_task_result_ptr> results(id_tasks.size());
|
||||||
for (size_t i = 0; i < id_tasks.size(); i++) {
|
for (int i = 0; i < (int)id_tasks.size(); i++) {
|
||||||
server_task_result_ptr result = queue_results.recv(id_tasks);
|
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
|
||||||
|
|
||||||
|
if (is_connection_closed()) {
|
||||||
|
cancel_tasks(id_tasks);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result == nullptr) {
|
||||||
|
i--; // retry
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
error_handler(result->to_json());
|
error_handler(result->to_json());
|
||||||
|
@ -2349,10 +2422,20 @@ struct server_context {
|
||||||
void receive_cmpl_results_stream(
|
void receive_cmpl_results_stream(
|
||||||
const std::unordered_set<int> & id_tasks,
|
const std::unordered_set<int> & id_tasks,
|
||||||
const std::function<bool(server_task_result_ptr&)> & result_handler,
|
const std::function<bool(server_task_result_ptr&)> & result_handler,
|
||||||
const std::function<void(json)> & error_handler) {
|
const std::function<void(json)> & error_handler,
|
||||||
|
const std::function<bool()> & is_connection_closed) {
|
||||||
size_t n_finished = 0;
|
size_t n_finished = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
server_task_result_ptr result = queue_results.recv(id_tasks);
|
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
|
||||||
|
|
||||||
|
if (is_connection_closed()) {
|
||||||
|
cancel_tasks(id_tasks);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result == nullptr) {
|
||||||
|
continue; // retry
|
||||||
|
}
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
error_handler(result->to_json());
|
error_handler(result->to_json());
|
||||||
|
@ -3609,9 +3692,12 @@ int main(int argc, char ** argv) {
|
||||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||||
{ "model_path", ctx_server.params_base.model },
|
{ "model_path", ctx_server.params_base.model },
|
||||||
{ "chat_template", common_get_builtin_chat_template(ctx_server.model) },
|
{ "chat_template", ctx_server.chat_templates.template_default->source() },
|
||||||
{ "build_info", build_info },
|
{ "build_info", build_info },
|
||||||
};
|
};
|
||||||
|
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
|
||||||
|
data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
|
||||||
|
}
|
||||||
|
|
||||||
res_ok(res, data);
|
res_ok(res, data);
|
||||||
};
|
};
|
||||||
|
@ -3634,6 +3720,7 @@ int main(int argc, char ** argv) {
|
||||||
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
|
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
|
||||||
server_task_type type,
|
server_task_type type,
|
||||||
json & data,
|
json & data,
|
||||||
|
std::function<bool()> is_connection_closed,
|
||||||
httplib::Response & res,
|
httplib::Response & res,
|
||||||
oaicompat_type oaicompat) {
|
oaicompat_type oaicompat) {
|
||||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||||
|
@ -3695,7 +3782,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
res_error(res, error_data);
|
res_error(res, error_data);
|
||||||
});
|
}, is_connection_closed);
|
||||||
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||||
} else {
|
} else {
|
||||||
|
@ -3705,6 +3792,7 @@ int main(int argc, char ** argv) {
|
||||||
if (res_json.is_array()) {
|
if (res_json.is_array()) {
|
||||||
for (const auto & res : res_json) {
|
for (const auto & res : res_json) {
|
||||||
if (!server_sent_event(sink, "data", res)) {
|
if (!server_sent_event(sink, "data", res)) {
|
||||||
|
// sending failed (HTTP connection closed), cancel the generation
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3714,6 +3802,9 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
server_sent_event(sink, "error", error_data);
|
server_sent_event(sink, "error", error_data);
|
||||||
|
}, [&sink]() {
|
||||||
|
// note: do not use req.is_connection_closed here because req is already destroyed
|
||||||
|
return !sink.is_writable();
|
||||||
});
|
});
|
||||||
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
||||||
static const std::string ev_done = "data: [DONE]\n\n";
|
static const std::string ev_done = "data: [DONE]\n\n";
|
||||||
|
@ -3736,6 +3827,7 @@ int main(int argc, char ** argv) {
|
||||||
return handle_completions_impl(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
data,
|
data,
|
||||||
|
req.is_connection_closed,
|
||||||
res,
|
res,
|
||||||
OAICOMPAT_TYPE_NONE);
|
OAICOMPAT_TYPE_NONE);
|
||||||
};
|
};
|
||||||
|
@ -3745,6 +3837,7 @@ int main(int argc, char ** argv) {
|
||||||
return handle_completions_impl(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
data,
|
data,
|
||||||
|
req.is_connection_closed,
|
||||||
res,
|
res,
|
||||||
OAICOMPAT_TYPE_COMPLETION);
|
OAICOMPAT_TYPE_COMPLETION);
|
||||||
};
|
};
|
||||||
|
@ -3821,6 +3914,7 @@ int main(int argc, char ** argv) {
|
||||||
return handle_completions_impl(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_INFILL,
|
SERVER_TASK_TYPE_INFILL,
|
||||||
data,
|
data,
|
||||||
|
req.is_connection_closed,
|
||||||
res,
|
res,
|
||||||
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
||||||
};
|
};
|
||||||
|
@ -3831,10 +3925,14 @@ int main(int argc, char ** argv) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
auto body = json::parse(req.body);
|
||||||
|
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
||||||
|
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
|
||||||
|
|
||||||
return handle_completions_impl(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
data,
|
data,
|
||||||
|
req.is_connection_closed,
|
||||||
res,
|
res,
|
||||||
OAICOMPAT_TYPE_CHAT);
|
OAICOMPAT_TYPE_CHAT);
|
||||||
};
|
};
|
||||||
|
@ -3981,7 +4079,7 @@ int main(int argc, char ** argv) {
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
res_error(res, error_data);
|
res_error(res, error_data);
|
||||||
error = true;
|
error = true;
|
||||||
});
|
}, req.is_connection_closed);
|
||||||
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||||
}
|
}
|
||||||
|
@ -4071,7 +4169,7 @@ int main(int argc, char ** argv) {
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
res_error(res, error_data);
|
res_error(res, error_data);
|
||||||
error = true;
|
error = true;
|
||||||
});
|
}, req.is_connection_closed);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (error) {
|
if (error) {
|
||||||
|
@ -4240,7 +4338,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
||||||
if (params.chat_template.empty()) {
|
if (params.chat_template.empty()) {
|
||||||
if (!ctx_server.validate_builtin_chat_template()) {
|
if (!ctx_server.validate_builtin_chat_template(params.use_jinja)) {
|
||||||
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||||
params.chat_template = "chatml";
|
params.chat_template = "chatml";
|
||||||
}
|
}
|
||||||
|
@ -4248,8 +4346,8 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// print sample chat example to make it clear which template is used
|
// print sample chat example to make it clear which template is used
|
||||||
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
||||||
params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(),
|
ctx_server.chat_templates.template_default->source().c_str(),
|
||||||
common_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
|
||||||
|
|
||||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||||
|
|
|
@ -4,22 +4,26 @@ from utils import *
|
||||||
|
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
|
||||||
[
|
[
|
||||||
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False, None),
|
||||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
|
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True, None),
|
||||||
|
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||||
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
|
||||||
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
|
||||||
global server
|
global server
|
||||||
|
server.jinja = jinja
|
||||||
|
server.chat_template = chat_template
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
import time
|
import time
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from utils import *
|
from utils import *
|
||||||
|
@ -405,3 +406,23 @@ def test_n_probs_post_sampling():
|
||||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||||
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
|
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
|
||||||
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_request():
|
||||||
|
global server
|
||||||
|
server.n_ctx = 4096
|
||||||
|
server.n_predict = -1
|
||||||
|
server.n_slots = 1
|
||||||
|
server.server_slots = True
|
||||||
|
server.start()
|
||||||
|
# send a request that will take a long time, but cancel it before it finishes
|
||||||
|
try:
|
||||||
|
server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
}, timeout=0.1)
|
||||||
|
except requests.exceptions.ReadTimeout:
|
||||||
|
pass # expected
|
||||||
|
# make sure the slot is free
|
||||||
|
time.sleep(1) # wait for HTTP_POLLING_SECONDS
|
||||||
|
res = server.make_request("GET", "/slots")
|
||||||
|
assert res.body[0]["is_processing"] == False
|
||||||
|
|
|
@ -26,6 +26,9 @@ from re import RegexFlag
|
||||||
import wget
|
import wget
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_HTTP_TIMEOUT = 10 if "LLAMA_SANITIZE" not in os.environ else 30
|
||||||
|
|
||||||
|
|
||||||
class ServerResponse:
|
class ServerResponse:
|
||||||
headers: dict
|
headers: dict
|
||||||
status_code: int
|
status_code: int
|
||||||
|
@ -69,13 +72,14 @@ class ServerProcess:
|
||||||
pooling: str | None = None
|
pooling: str | None = None
|
||||||
draft: int | None = None
|
draft: int | None = None
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
response_format: str | None = None
|
|
||||||
lora_files: List[str] | None = None
|
lora_files: List[str] | None = None
|
||||||
disable_ctx_shift: int | None = False
|
disable_ctx_shift: int | None = False
|
||||||
draft_min: int | None = None
|
draft_min: int | None = None
|
||||||
draft_max: int | None = None
|
draft_max: int | None = None
|
||||||
no_webui: bool | None = None
|
no_webui: bool | None = None
|
||||||
|
jinja: bool | None = None
|
||||||
chat_template: str | None = None
|
chat_template: str | None = None
|
||||||
|
chat_template_file: str | None = None
|
||||||
|
|
||||||
# session variables
|
# session variables
|
||||||
process: subprocess.Popen | None = None
|
process: subprocess.Popen | None = None
|
||||||
|
@ -88,7 +92,7 @@ class ServerProcess:
|
||||||
if "PORT" in os.environ:
|
if "PORT" in os.environ:
|
||||||
self.server_port = int(os.environ["PORT"])
|
self.server_port = int(os.environ["PORT"])
|
||||||
|
|
||||||
def start(self, timeout_seconds: int = 10) -> None:
|
def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
|
||||||
if "LLAMA_SERVER_BIN_PATH" in os.environ:
|
if "LLAMA_SERVER_BIN_PATH" in os.environ:
|
||||||
server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
|
server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
|
||||||
elif os.name == "nt":
|
elif os.name == "nt":
|
||||||
|
@ -166,8 +170,12 @@ class ServerProcess:
|
||||||
server_args.extend(["--draft-min", self.draft_min])
|
server_args.extend(["--draft-min", self.draft_min])
|
||||||
if self.no_webui:
|
if self.no_webui:
|
||||||
server_args.append("--no-webui")
|
server_args.append("--no-webui")
|
||||||
|
if self.jinja:
|
||||||
|
server_args.append("--jinja")
|
||||||
if self.chat_template:
|
if self.chat_template:
|
||||||
server_args.extend(["--chat-template", self.chat_template])
|
server_args.extend(["--chat-template", self.chat_template])
|
||||||
|
if self.chat_template_file:
|
||||||
|
server_args.extend(["--chat-template-file", self.chat_template_file])
|
||||||
|
|
||||||
args = [str(arg) for arg in [server_path, *server_args]]
|
args = [str(arg) for arg in [server_path, *server_args]]
|
||||||
print(f"bench: starting server with: {' '.join(args)}")
|
print(f"bench: starting server with: {' '.join(args)}")
|
||||||
|
@ -219,17 +227,18 @@ class ServerProcess:
|
||||||
path: str,
|
path: str,
|
||||||
data: dict | Any | None = None,
|
data: dict | Any | None = None,
|
||||||
headers: dict | None = None,
|
headers: dict | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
) -> ServerResponse:
|
) -> ServerResponse:
|
||||||
url = f"http://{self.server_host}:{self.server_port}{path}"
|
url = f"http://{self.server_host}:{self.server_port}{path}"
|
||||||
parse_body = False
|
parse_body = False
|
||||||
if method == "GET":
|
if method == "GET":
|
||||||
response = requests.get(url, headers=headers)
|
response = requests.get(url, headers=headers, timeout=timeout)
|
||||||
parse_body = True
|
parse_body = True
|
||||||
elif method == "POST":
|
elif method == "POST":
|
||||||
response = requests.post(url, headers=headers, json=data)
|
response = requests.post(url, headers=headers, json=data, timeout=timeout)
|
||||||
parse_body = True
|
parse_body = True
|
||||||
elif method == "OPTIONS":
|
elif method == "OPTIONS":
|
||||||
response = requests.options(url, headers=headers)
|
response = requests.options(url, headers=headers, timeout=timeout)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unimplemented method: {method}")
|
raise ValueError(f"Unimplemented method: {method}")
|
||||||
result = ServerResponse()
|
result = ServerResponse()
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
|
#include "minja.hpp"
|
||||||
|
#include "chat-template.hpp"
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
@ -349,7 +351,7 @@ static llama_tokens format_infill(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
|
||||||
std::vector<common_chat_msg> chat;
|
std::vector<common_chat_msg> chat;
|
||||||
|
|
||||||
for (size_t i = 0; i < messages.size(); ++i) {
|
for (size_t i = 0; i < messages.size(); ++i) {
|
||||||
|
@ -377,7 +379,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
||||||
chat.push_back({role, content});
|
chat.push_back({role, content});
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true);
|
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
|
||||||
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
||||||
|
|
||||||
return formatted_chat;
|
return formatted_chat;
|
||||||
|
@ -576,14 +578,23 @@ static json oaicompat_completion_params_parse(const json & body) {
|
||||||
return llama_params;
|
return llama_params;
|
||||||
}
|
}
|
||||||
|
|
||||||
static json oaicompat_chat_completion_params_parse(
|
static json oaicompat_completion_params_parse(
|
||||||
const struct llama_model * model,
|
const json & body, /* openai api json semantics */
|
||||||
const json & body, /* openai api json semantics */
|
const common_chat_template & tmpl,
|
||||||
const std::string & chat_template) {
|
bool use_jinja)
|
||||||
|
{
|
||||||
json llama_params;
|
json llama_params;
|
||||||
|
|
||||||
// Apply chat template to the list of messages
|
auto tools = json_value(body, "tools", json());
|
||||||
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
|
auto has_tools = tools.is_array() && !tools.empty();
|
||||||
|
|
||||||
|
if (has_tools) {
|
||||||
|
if (use_jinja) {
|
||||||
|
LOG_WRN("tools param is not fully supported yet\n");
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("tools param requires --jinja flag");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Handle "stop" field
|
// Handle "stop" field
|
||||||
if (body.contains("stop") && body.at("stop").is_string()) {
|
if (body.contains("stop") && body.at("stop").is_string()) {
|
||||||
|
@ -606,6 +617,13 @@ static json oaicompat_chat_completion_params_parse(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply chat template to the list of messages
|
||||||
|
if (use_jinja) {
|
||||||
|
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
||||||
|
} else {
|
||||||
|
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
||||||
|
}
|
||||||
|
|
||||||
// Handle "n" field
|
// Handle "n" field
|
||||||
int n_choices = json_value(body, "n", 1);
|
int n_choices = json_value(body, "n", 1);
|
||||||
if (n_choices != 1) {
|
if (n_choices != 1) {
|
||||||
|
@ -621,7 +639,7 @@ static json oaicompat_chat_completion_params_parse(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Params supported by OAI but unsupported by llama.cpp
|
// Params supported by OAI but unsupported by llama.cpp
|
||||||
static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
|
static const std::vector<std::string> unsupported_params { "tool_choice" };
|
||||||
for (const auto & param : unsupported_params) {
|
for (const auto & param : unsupported_params) {
|
||||||
if (body.contains(param)) {
|
if (body.contains(param)) {
|
||||||
throw std::runtime_error("Unsupported param: " + param);
|
throw std::runtime_error("Unsupported param: " + param);
|
||||||
|
|
|
@ -98,10 +98,12 @@ int main(int argc, char ** argv) {
|
||||||
auto generate = [&](const std::string & prompt) {
|
auto generate = [&](const std::string & prompt) {
|
||||||
std::string response;
|
std::string response;
|
||||||
|
|
||||||
|
const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
|
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
|
||||||
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
|
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
|
||||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
|
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) {
|
||||||
GGML_ABORT("failed to tokenize the prompt\n");
|
GGML_ABORT("failed to tokenize the prompt\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,7 +163,7 @@ int main(int argc, char ** argv) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * tmpl = llama_model_chat_template(model);
|
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
|
||||||
|
|
||||||
// add the user input to the message list and format it
|
// add the user input to the message list and format it
|
||||||
messages.push_back({"user", strdup(user.c_str())});
|
messages.push_back({"user", strdup(user.c_str())});
|
||||||
|
|
|
@ -425,6 +425,33 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
|
||||||
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
|
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
|
||||||
|
const std::string& delimiter = "<|text_sep|>";
|
||||||
|
|
||||||
|
std::vector<llama_token> result;
|
||||||
|
size_t start = 0;
|
||||||
|
size_t end = str.find(delimiter);
|
||||||
|
|
||||||
|
//first token is always a newline, as it was not previously added
|
||||||
|
result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
|
||||||
|
|
||||||
|
while (end != std::string::npos) {
|
||||||
|
std::string current_word = str.substr(start, end - start);
|
||||||
|
auto tmp = common_tokenize(vocab, current_word, false, true);
|
||||||
|
result.push_back(tmp[0]);
|
||||||
|
start = end + delimiter.length();
|
||||||
|
end = str.find(delimiter, start);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the last part
|
||||||
|
std::string current_word = str.substr(start);
|
||||||
|
auto tmp = common_tokenize(vocab, current_word, false, true);
|
||||||
|
if (tmp.size() > 0) {
|
||||||
|
result.push_back(tmp[0]);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
common_params params;
|
common_params params;
|
||||||
|
|
||||||
|
@ -494,6 +521,7 @@ int main(int argc, char ** argv) {
|
||||||
const auto t_main_start = ggml_time_us();
|
const auto t_main_start = ggml_time_us();
|
||||||
|
|
||||||
std::vector<llama_token> codes;
|
std::vector<llama_token> codes;
|
||||||
|
std::vector<llama_token> guide_tokens;
|
||||||
|
|
||||||
// process prompt and generate voice codes
|
// process prompt and generate voice codes
|
||||||
{
|
{
|
||||||
|
@ -508,6 +536,9 @@ int main(int argc, char ** argv) {
|
||||||
// convert the input text into the necessary format expected by OuteTTS
|
// convert the input text into the necessary format expected by OuteTTS
|
||||||
{
|
{
|
||||||
std::string prompt_clean = process_text(params.prompt);
|
std::string prompt_clean = process_text(params.prompt);
|
||||||
|
if (params.vocoder.use_guide_tokens) {
|
||||||
|
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
|
||||||
|
}
|
||||||
|
|
||||||
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
|
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
|
||||||
|
|
||||||
|
@ -717,6 +748,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||||
int n_past = batch.n_tokens;
|
int n_past = batch.n_tokens;
|
||||||
int n_decode = 0;
|
int n_decode = 0;
|
||||||
|
|
||||||
|
bool next_token_uses_guide_token = true;
|
||||||
|
|
||||||
while (n_decode <= n_predict) {
|
while (n_decode <= n_predict) {
|
||||||
// prepare the next batch
|
// prepare the next batch
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch);
|
||||||
|
@ -728,7 +761,17 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
|
llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
|
||||||
|
|
||||||
|
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
|
||||||
|
if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) {
|
||||||
|
llama_token guide_token = guide_tokens[0];
|
||||||
|
guide_tokens.erase(guide_tokens.begin());
|
||||||
|
new_token_id = guide_token; //ensure correct word fragment is used
|
||||||
|
}
|
||||||
|
|
||||||
|
//this is the token id that always precedes a new word
|
||||||
|
next_token_uses_guide_token = (new_token_id == 198);
|
||||||
|
|
||||||
common_sampler_accept(smpl[i], new_token_id, true);
|
common_sampler_accept(smpl[i], new_token_id, true);
|
||||||
|
|
||||||
|
|
|
@ -4416,7 +4416,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
||||||
device const half * dh = &x[ib].d;
|
device const half * dh = &x[ib].d;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < N_DST; row++) {
|
||||||
|
|
||||||
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
||||||
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
||||||
for (int i = 0; i < 8; i += 2) {
|
for (int i = 0; i < 8; i += 2) {
|
||||||
|
@ -4447,7 +4446,7 @@ void kernel_mul_mv_q2_K_f32_impl(
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[first_row + row] = all_sum;
|
dst_f32[first_row + row] = all_sum;
|
||||||
|
@ -4613,7 +4612,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
|
||||||
dst_f32[first_row + row] = sumf1[row];
|
dst_f32[first_row + row] = sumf1[row];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4729,7 +4728,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[first_row + row] = all_sum;
|
dst_f32[first_row + row] = all_sum;
|
||||||
|
@ -4861,7 +4860,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
|
||||||
const float tot = simd_sum(sumf[row]);
|
const float tot = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[first_row + row] = tot;
|
dst_f32[first_row + row] = tot;
|
||||||
|
@ -4906,6 +4905,10 @@ void kernel_mul_mv_q6_K_f32_impl(
|
||||||
|
|
||||||
const int row = 2*r0 + sgitg;
|
const int row = 2*r0 + sgitg;
|
||||||
|
|
||||||
|
if (row >= args.ne0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
|
@ -5061,7 +5064,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[first_row + row] = all_sum * 0.25f;
|
dst_f32[first_row + row] = all_sum * 0.25f;
|
||||||
|
@ -5179,7 +5182,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[first_row + row] = all_sum * 0.25f;
|
dst_f32[first_row + row] = all_sum * 0.25f;
|
||||||
|
@ -5289,7 +5292,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[first_row + row] = all_sum * 0.5f;
|
dst_f32[first_row + row] = all_sum * 0.5f;
|
||||||
|
@ -5401,7 +5404,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[first_row + row] = all_sum;
|
dst_f32[first_row + row] = all_sum;
|
||||||
|
@ -5514,7 +5517,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[first_row + row] = all_sum * 0.25f;
|
dst_f32[first_row + row] = all_sum * 0.25f;
|
||||||
|
@ -5614,7 +5617,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[first_row + row] = all_sum;
|
dst_f32[first_row + row] = all_sum;
|
||||||
|
@ -5709,7 +5712,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[first_row + row] = all_sum;
|
dst_f32[first_row + row] = all_sum;
|
||||||
|
@ -5799,7 +5802,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
|
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[first_row + row] = all_sum;
|
dst_f32[first_row + row] = all_sum;
|
||||||
|
@ -5888,7 +5891,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[first_row + row] = all_sum;
|
dst_f32[first_row + row] = all_sum;
|
||||||
|
|
|
@ -181,7 +181,7 @@ struct ggml_backend_rpc_context {
|
||||||
|
|
||||||
struct ggml_backend_rpc_buffer_context {
|
struct ggml_backend_rpc_buffer_context {
|
||||||
std::shared_ptr<socket_t> sock;
|
std::shared_ptr<socket_t> sock;
|
||||||
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
|
void * base_ptr;
|
||||||
uint64_t remote_ptr;
|
uint64_t remote_ptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -423,16 +423,15 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
|
|
||||||
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||||
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
|
if (ctx->base_ptr != nullptr) {
|
||||||
return ctx->base_cache[buffer];
|
return ctx->base_ptr;
|
||||||
}
|
}
|
||||||
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
||||||
rpc_msg_buffer_get_base_rsp response;
|
rpc_msg_buffer_get_base_rsp response;
|
||||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
||||||
ctx->base_cache[buffer] = base_ptr;
|
return ctx->base_ptr;
|
||||||
return base_ptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
||||||
|
@ -557,7 +556,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
||||||
if (response.remote_ptr != 0) {
|
if (response.remote_ptr != 0) {
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
||||||
ggml_backend_rpc_buffer_interface,
|
ggml_backend_rpc_buffer_interface,
|
||||||
new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
|
new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
|
||||||
response.remote_size);
|
response.remote_size);
|
||||||
return buffer;
|
return buffer;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -333,8 +333,12 @@ struct ggml_backend_sycl_context {
|
||||||
// pool
|
// pool
|
||||||
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
||||||
|
|
||||||
|
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
|
||||||
|
|
||||||
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
|
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
|
||||||
|
|
||||||
|
static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);
|
||||||
|
|
||||||
ggml_sycl_pool & pool(int device) {
|
ggml_sycl_pool & pool(int device) {
|
||||||
if (pools[device] == nullptr) {
|
if (pools[device] == nullptr) {
|
||||||
pools[device] = new_pool_for_device(stream(device,0), device);
|
pools[device] = new_pool_for_device(stream(device,0), device);
|
||||||
|
@ -345,6 +349,15 @@ struct ggml_backend_sycl_context {
|
||||||
ggml_sycl_pool & pool() {
|
ggml_sycl_pool & pool() {
|
||||||
return pool(device);
|
return pool(device);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_sycl_pool & host_pool(int device) {
|
||||||
|
if (host_pools[device] == nullptr) {
|
||||||
|
host_pools[device] = new_pool_for_host(stream(device, 0), device);
|
||||||
|
}
|
||||||
|
return *host_pools[device];
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_sycl_pool & host_pool() { return host_pool(device); }
|
||||||
};
|
};
|
||||||
|
|
||||||
// common device functions
|
// common device functions
|
||||||
|
|
|
@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
|
||||||
return device_type.str();
|
return device_type.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Ts> struct matrix_info_t {
|
||||||
|
oneapi::mkl::transpose transpose_info[2];
|
||||||
|
Ts value_info[2];
|
||||||
|
std::int64_t size_info[3];
|
||||||
|
std::int64_t ld_info[3];
|
||||||
|
std::int64_t groupsize_info;
|
||||||
|
};
|
||||||
|
|
||||||
namespace dpct
|
namespace dpct
|
||||||
{
|
{
|
||||||
typedef sycl::queue *queue_ptr;
|
typedef sycl::queue *queue_ptr;
|
||||||
|
@ -1727,26 +1735,13 @@ namespace dpct
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class Ta, class Tb, class Tc, class Ts>
|
template <class Ta, class Tb, class Tc, class Ts>
|
||||||
inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
|
||||||
oneapi::mkl::transpose b_trans, int m, int n, int k,
|
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
|
||||||
const void *alpha, const void **a, int lda,
|
int ldb, const void * beta, void ** c, int ldc, int batch_size,
|
||||||
const void **b, int ldb, const void *beta, void **c,
|
matrix_info_t<float> * matrix_info) {
|
||||||
int ldc, int batch_size)
|
|
||||||
{
|
|
||||||
struct matrix_info_t
|
|
||||||
{
|
|
||||||
oneapi::mkl::transpose transpose_info[2];
|
|
||||||
Ts value_info[2];
|
|
||||||
std::int64_t size_info[3];
|
|
||||||
std::int64_t ld_info[3];
|
|
||||||
std::int64_t groupsize_info;
|
|
||||||
};
|
|
||||||
|
|
||||||
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
||||||
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
||||||
|
|
||||||
matrix_info_t *matrix_info =
|
|
||||||
(matrix_info_t *)std::malloc(sizeof(matrix_info_t));
|
|
||||||
matrix_info->transpose_info[0] = a_trans;
|
matrix_info->transpose_info[0] = a_trans;
|
||||||
matrix_info->transpose_info[1] = b_trans;
|
matrix_info->transpose_info[1] = b_trans;
|
||||||
matrix_info->value_info[0] = alpha_value;
|
matrix_info->value_info[0] = alpha_value;
|
||||||
|
@ -1763,23 +1758,18 @@ namespace dpct
|
||||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
||||||
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
||||||
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
|
matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
||||||
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
||||||
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
|
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
||||||
&(matrix_info->groupsize_info));
|
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
||||||
#else
|
#else
|
||||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
||||||
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
|
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
||||||
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
||||||
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
||||||
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
q.submit([&](sycl::handler &cgh)
|
|
||||||
{
|
|
||||||
cgh.depends_on(e);
|
|
||||||
cgh.host_task([=] { std::free(matrix_info); }); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Ta, class Tb, class Tc, class Ts>
|
template <class Ta, class Tb, class Tc, class Ts>
|
||||||
|
@ -2422,25 +2412,11 @@ namespace dpct
|
||||||
/// \param [in] ldc Leading dimension of C.
|
/// \param [in] ldc Leading dimension of C.
|
||||||
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
||||||
/// \param [in] scaling_type Data type of the scaling factors.
|
/// \param [in] scaling_type Data type of the scaling factors.
|
||||||
inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
|
||||||
oneapi::mkl::transpose b_trans, int m, int n, int k,
|
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
|
||||||
const void *alpha, const void *a[],
|
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
|
||||||
library_data_t a_type, int lda, const void *b[],
|
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
|
||||||
library_data_t b_type, int ldb, const void *beta,
|
matrix_info_t<float> * matrix_info) {
|
||||||
void *c[], library_data_t c_type, int ldc,
|
|
||||||
int batch_size, library_data_t scaling_type)
|
|
||||||
{
|
|
||||||
if (scaling_type == library_data_t::real_float &&
|
|
||||||
c_type == library_data_t::complex_float)
|
|
||||||
{
|
|
||||||
scaling_type = library_data_t::complex_float;
|
|
||||||
}
|
|
||||||
else if (scaling_type == library_data_t::real_double &&
|
|
||||||
c_type == library_data_t::complex_double)
|
|
||||||
{
|
|
||||||
scaling_type = library_data_t::complex_double;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::uint64_t key =
|
std::uint64_t key =
|
||||||
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
||||||
switch (key)
|
switch (key)
|
||||||
|
@ -2449,48 +2425,24 @@ namespace dpct
|
||||||
library_data_t::real_float, library_data_t::real_float,
|
library_data_t::real_float, library_data_t::real_float,
|
||||||
library_data_t::real_float, library_data_t::real_float):
|
library_data_t::real_float, library_data_t::real_float):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<float, float, float, float>(
|
detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
beta, c, ldc, batch_size, matrix_info);
|
||||||
batch_size);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
library_data_t::real_double, library_data_t::real_double,
|
library_data_t::real_double, library_data_t::real_double,
|
||||||
library_data_t::real_double, library_data_t::real_double):
|
library_data_t::real_double, library_data_t::real_double):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<double, double, double, double>(
|
detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
beta, c, ldc, batch_size, matrix_info);
|
||||||
batch_size);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case detail::get_type_combination_id(
|
|
||||||
library_data_t::complex_float, library_data_t::complex_float,
|
|
||||||
library_data_t::complex_float, library_data_t::complex_float):
|
|
||||||
{
|
|
||||||
detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
|
|
||||||
std::complex<float>, std::complex<float>>(
|
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
||||||
batch_size);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case detail::get_type_combination_id(
|
|
||||||
library_data_t::complex_double, library_data_t::complex_double,
|
|
||||||
library_data_t::complex_double, library_data_t::complex_double):
|
|
||||||
{
|
|
||||||
detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
|
|
||||||
std::complex<double>, std::complex<double>>(
|
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
||||||
batch_size);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
library_data_t::real_half, library_data_t::real_half,
|
library_data_t::real_half, library_data_t::real_half,
|
||||||
library_data_t::real_half, library_data_t::real_half):
|
library_data_t::real_half, library_data_t::real_half):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
|
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
||||||
sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
||||||
a, lda, b, ldb, beta, c, ldc,
|
|
||||||
batch_size);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#ifdef __INTEL_MKL__
|
#ifdef __INTEL_MKL__
|
||||||
|
@ -2498,19 +2450,16 @@ namespace dpct
|
||||||
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
||||||
library_data_t::real_bfloat16, library_data_t::real_float):
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
|
||||||
oneapi::mkl::bfloat16, float>(
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
||||||
batch_size);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
||||||
library_data_t::real_float, library_data_t::real_float):
|
library_data_t::real_float, library_data_t::real_float):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
|
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
|
||||||
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
||||||
b, ldb, beta, c, ldc, batch_size);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -2522,10 +2471,9 @@ namespace dpct
|
||||||
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
|
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
|
||||||
float beta_float =
|
float beta_float =
|
||||||
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
|
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
|
||||||
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
|
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
|
||||||
float>(q, a_trans, b_trans, m, n, k, &alpha_float,
|
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
|
||||||
a, lda, b, ldb, &beta_float, c, ldc,
|
matrix_info);
|
||||||
batch_size);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
|
@ -2533,8 +2481,7 @@ namespace dpct
|
||||||
library_data_t::real_float, library_data_t::real_float):
|
library_data_t::real_float, library_data_t::real_float):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
||||||
batch_size);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
|
@ -2542,8 +2489,7 @@ namespace dpct
|
||||||
library_data_t::real_float, library_data_t::real_float):
|
library_data_t::real_float, library_data_t::real_float):
|
||||||
{
|
{
|
||||||
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
||||||
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
||||||
batch_size);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case detail::get_type_combination_id(
|
case detail::get_type_combination_id(
|
||||||
|
@ -2557,8 +2503,7 @@ namespace dpct
|
||||||
sycl::half alpha_half(alpha_value);
|
sycl::half alpha_half(alpha_value);
|
||||||
sycl::half beta_half(beta_value);
|
sycl::half beta_half(beta_value);
|
||||||
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
||||||
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
|
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
|
||||||
batch_size);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -1173,6 +1173,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ggml_sycl_pool_host : public ggml_sycl_pool {
|
||||||
|
queue_ptr qptr;
|
||||||
|
int device;
|
||||||
|
|
||||||
|
inline static int counter{ 0 };
|
||||||
|
|
||||||
|
struct ggml_sycl_buffer {
|
||||||
|
void * ptr = nullptr;
|
||||||
|
size_t size = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set arbitrarly to 64
|
||||||
|
static constexpr int MAX_POOL_SIZE{ 64 };
|
||||||
|
std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
|
||||||
|
size_t pool_size = 0;
|
||||||
|
|
||||||
|
explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
|
||||||
|
|
||||||
|
~ggml_sycl_pool_host() {
|
||||||
|
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
|
||||||
|
ggml_sycl_buffer & b = buffer_pool[i];
|
||||||
|
if (b.ptr != nullptr) {
|
||||||
|
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
|
||||||
|
b.ptr = nullptr;
|
||||||
|
pool_size -= b.size;
|
||||||
|
b.size = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
counter = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void * alloc(size_t size, size_t * actual_size) override {
|
||||||
|
if (counter == MAX_POOL_SIZE) {
|
||||||
|
ggml_sycl_buffer b = buffer_pool[0];
|
||||||
|
void * ptr = b.ptr;
|
||||||
|
*actual_size = b.size;
|
||||||
|
counter = 1;
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
ggml_sycl_buffer & b = buffer_pool[counter];
|
||||||
|
|
||||||
|
if (b.ptr == nullptr) {
|
||||||
|
void * ptr;
|
||||||
|
|
||||||
|
SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
|
||||||
|
if (!ptr) {
|
||||||
|
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
pool_size += size;
|
||||||
|
*actual_size = size;
|
||||||
|
counter = counter + 1;
|
||||||
|
return ptr;
|
||||||
|
} else {
|
||||||
|
++counter;
|
||||||
|
b.size = size;
|
||||||
|
return b.ptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void free(void * ptr, size_t size) override {
|
||||||
|
// if the pool is not completed add the pointer to it in place of the first nullptr found.
|
||||||
|
// Otherwise do nothing, pointers will be freed once the pool is deallocated.
|
||||||
|
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
|
||||||
|
ggml_sycl_buffer & b = buffer_pool[i];
|
||||||
|
if (b.ptr == nullptr) {
|
||||||
|
b.ptr = ptr;
|
||||||
|
b.size = size;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
|
||||||
|
// return pool for the host to speed up memory management
|
||||||
|
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
|
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
|
||||||
// TBD: NO VMM support
|
// TBD: NO VMM support
|
||||||
// if (ggml_sycl_info().devices[device].vmm) {
|
// if (ggml_sycl_info().devices[device].vmm) {
|
||||||
|
@ -3363,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
||||||
|
|
||||||
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
|
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
|
||||||
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
|
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
|
||||||
|
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
||||||
|
|
||||||
sycl::range<3> block_dims(1, ne12, ne13);
|
sycl::range<3> block_dims(1, ne12, ne13);
|
||||||
/*
|
/*
|
||||||
|
@ -3391,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||||
*main_stream, oneapi::mkl::transpose::trans,
|
*main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||||
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
|
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
||||||
(const void **)(ptrs_src.get() + 0 * ne23),
|
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
|
||||||
dpct::library_data_t::real_half, nb01 / nb00,
|
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
|
||||||
(const void **)(ptrs_src.get() + 1 * ne23),
|
|
||||||
dpct::library_data_t::real_half, nb11 / nb10, beta,
|
|
||||||
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
|
|
||||||
cu_compute_type)));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch (sycl::exception const &exc) {
|
catch (sycl::exception const &exc) {
|
||||||
|
|
|
@ -29,8 +29,6 @@
|
||||||
|
|
||||||
#include "ggml-vulkan-shaders.hpp"
|
#include "ggml-vulkan-shaders.hpp"
|
||||||
|
|
||||||
#define VK_API_VERSION VK_API_VERSION_1_2
|
|
||||||
|
|
||||||
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
||||||
|
|
||||||
#define VK_VENDOR_ID_AMD 0x1002
|
#define VK_VENDOR_ID_AMD 0x1002
|
||||||
|
@ -386,10 +384,13 @@ struct vk_flash_attn_push_constants {
|
||||||
uint32_t nev3;
|
uint32_t nev3;
|
||||||
uint32_t nem1;
|
uint32_t nem1;
|
||||||
|
|
||||||
|
uint32_t nb01;
|
||||||
uint32_t nb02;
|
uint32_t nb02;
|
||||||
uint32_t nb03;
|
uint32_t nb03;
|
||||||
|
uint32_t nb11;
|
||||||
uint32_t nb12;
|
uint32_t nb12;
|
||||||
uint32_t nb13;
|
uint32_t nb13;
|
||||||
|
uint32_t nb21;
|
||||||
uint32_t nb22;
|
uint32_t nb22;
|
||||||
uint32_t nb23;
|
uint32_t nb23;
|
||||||
uint32_t nb31;
|
uint32_t nb31;
|
||||||
|
@ -1611,11 +1612,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||||
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||||
|
|
||||||
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
||||||
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
||||||
|
|
||||||
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
||||||
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
|
@ -1628,21 +1625,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
|
|
||||||
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
|
||||||
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
||||||
#undef CREATE_MM
|
#undef CREATE_MM
|
||||||
#undef CREATE_MM2
|
#undef CREATE_MM2
|
||||||
} else
|
} else
|
||||||
|
@ -2284,6 +2278,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
VkPhysicalDeviceMaintenance4Features maint4_features {};
|
||||||
|
maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
|
||||||
|
if (maintenance4_support) {
|
||||||
|
last_struct->pNext = (VkBaseOutStructure *)&maint4_features;
|
||||||
|
last_struct = (VkBaseOutStructure *)&maint4_features;
|
||||||
|
device_extensions.push_back("VK_KHR_maintenance4");
|
||||||
|
}
|
||||||
|
|
||||||
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
||||||
|
|
||||||
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
||||||
|
@ -2659,7 +2661,14 @@ void ggml_vk_instance_init() {
|
||||||
|
|
||||||
vk_instance_initialized = true;
|
vk_instance_initialized = true;
|
||||||
|
|
||||||
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
|
uint32_t api_version = vk::enumerateInstanceVersion();
|
||||||
|
|
||||||
|
if (api_version < VK_API_VERSION_1_2) {
|
||||||
|
std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl;
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
|
||||||
|
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
|
||||||
|
|
||||||
const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
|
const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
|
||||||
const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
|
const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
|
||||||
|
@ -2969,7 +2978,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(src1_type == GGML_TYPE_F32);
|
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
|
||||||
|
|
||||||
switch (src0_type) {
|
switch (src0_type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
|
@ -3809,8 +3818,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
src1_uma = d_Qy != nullptr;
|
src1_uma = d_Qy != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
|
// Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
|
||||||
// Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
|
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
||||||
|
!ggml_vk_dim01_contiguous(src0);
|
||||||
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
||||||
!ggml_vk_dim01_contiguous(src1);
|
!ggml_vk_dim01_contiguous(src1);
|
||||||
|
|
||||||
|
@ -4390,8 +4400,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
ids_uma = d_ids != nullptr;
|
ids_uma = d_ids != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
|
// Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
|
||||||
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
|
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
||||||
|
!ggml_vk_dim01_contiguous(src0);
|
||||||
|
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
||||||
|
!ggml_vk_dim01_contiguous(src1);
|
||||||
|
|
||||||
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
||||||
|
|
||||||
|
@ -4401,7 +4414,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
||||||
|
|
||||||
if (qx_needs_dequant) {
|
if (qx_needs_dequant) {
|
||||||
GGML_ABORT("fatal error");
|
// Fall back to dequant + f16 mulmat
|
||||||
|
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not implemented
|
// Not implemented
|
||||||
|
@ -4809,7 +4823,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
}
|
}
|
||||||
assert(pipelines);
|
assert(pipelines);
|
||||||
|
|
||||||
bool aligned = (KV % pipelines[1]->align) == 0;
|
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||||
|
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||||
|
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
||||||
|
|
||||||
|
bool aligned = (KV % pipelines[1]->align) == 0 &&
|
||||||
|
// the "aligned" shader variant will forcibly align strides, for performance
|
||||||
|
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
||||||
|
|
||||||
vk_pipeline pipeline = pipelines[aligned];
|
vk_pipeline pipeline = pipelines[aligned];
|
||||||
assert(pipeline);
|
assert(pipeline);
|
||||||
|
|
||||||
|
@ -4845,15 +4866,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
|
|
||||||
if (ctx->device->uma) {
|
if (ctx->device->uma) {
|
||||||
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
|
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
|
||||||
ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
|
ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset);
|
||||||
ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
|
ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset);
|
||||||
ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
|
ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset);
|
||||||
Q_uma = d_Q != nullptr;
|
Q_uma = d_Q != nullptr;
|
||||||
K_uma = d_K != nullptr;
|
K_uma = d_K != nullptr;
|
||||||
V_uma = d_V != nullptr;
|
V_uma = d_V != nullptr;
|
||||||
D_uma = d_D != nullptr;
|
D_uma = d_D != nullptr;
|
||||||
if (mask) {
|
if (mask) {
|
||||||
ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
|
ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
|
||||||
M_uma = d_M != nullptr;
|
M_uma = d_M != nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4891,7 +4912,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
|
const vk_flash_attn_push_constants pc = { N, KV,
|
||||||
|
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
||||||
|
(uint32_t)neq2, (uint32_t)neq3,
|
||||||
|
(uint32_t)nek2, (uint32_t)nek3,
|
||||||
|
(uint32_t)nev2, (uint32_t)nev3,
|
||||||
|
nem1,
|
||||||
|
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
||||||
|
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
||||||
|
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
||||||
|
nbm1,
|
||||||
|
scale, max_bias, logit_softcap,
|
||||||
|
mask != nullptr, n_head_log2, m0, m1 };
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||||
{
|
{
|
||||||
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
||||||
|
@ -8668,6 +8700,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
||||||
ggml_tensor * src0 = tensor->src[0];
|
ggml_tensor * src0 = tensor->src[0];
|
||||||
ggml_tensor * src1 = tensor->src[1];
|
ggml_tensor * src1 = tensor->src[1];
|
||||||
ggml_tensor * src2 = tensor->src[2];
|
ggml_tensor * src2 = tensor->src[2];
|
||||||
|
ggml_tensor * src3 = tensor->src[3];
|
||||||
|
|
||||||
void * tensor_data = tensor->data;
|
void * tensor_data = tensor->data;
|
||||||
|
|
||||||
|
@ -8730,6 +8763,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
||||||
if (src2 != nullptr) {
|
if (src2 != nullptr) {
|
||||||
std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
||||||
}
|
}
|
||||||
|
if (src3 != nullptr) {
|
||||||
|
std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
|
||||||
|
}
|
||||||
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
||||||
std::cerr << std::endl << "Result:" << std::endl;
|
std::cerr << std::endl << "Result:" << std::endl;
|
||||||
ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
|
ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
|
||||||
|
@ -8774,6 +8810,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
||||||
if (src2 != nullptr) {
|
if (src2 != nullptr) {
|
||||||
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
||||||
}
|
}
|
||||||
|
if (src3 != nullptr) {
|
||||||
|
std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
|
||||||
|
}
|
||||||
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
||||||
std::cerr << std::endl << "Result:" << std::endl;
|
std::cerr << std::endl << "Result:" << std::endl;
|
||||||
ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
|
ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
|
||||||
|
@ -8796,6 +8835,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
||||||
if (src2 != nullptr) {
|
if (src2 != nullptr) {
|
||||||
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
|
||||||
}
|
}
|
||||||
|
if (src3 != nullptr) {
|
||||||
|
std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
|
||||||
|
}
|
||||||
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
|
||||||
std::cerr << std::endl << "Result:" << std::endl;
|
std::cerr << std::endl << "Result:" << std::endl;
|
||||||
ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
|
ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
|
||||||
|
|
|
@ -42,10 +42,13 @@ layout (push_constant) uniform parameter {
|
||||||
uint32_t nev3;
|
uint32_t nev3;
|
||||||
uint32_t nem1;
|
uint32_t nem1;
|
||||||
|
|
||||||
|
uint32_t nb01;
|
||||||
uint32_t nb02;
|
uint32_t nb02;
|
||||||
uint32_t nb03;
|
uint32_t nb03;
|
||||||
|
uint32_t nb11;
|
||||||
uint32_t nb12;
|
uint32_t nb12;
|
||||||
uint32_t nb13;
|
uint32_t nb13;
|
||||||
|
uint32_t nb21;
|
||||||
uint32_t nb22;
|
uint32_t nb22;
|
||||||
uint32_t nb23;
|
uint32_t nb23;
|
||||||
uint32_t nb31;
|
uint32_t nb31;
|
||||||
|
@ -146,7 +149,24 @@ void main() {
|
||||||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
||||||
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
||||||
|
|
||||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
|
// nb?1 are already divided by the type size and are in units of elements
|
||||||
|
uint32_t q_stride = p.nb01;
|
||||||
|
uint32_t k_stride = p.nb11;
|
||||||
|
uint32_t v_stride = p.nb21;
|
||||||
|
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
||||||
|
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||||
|
{
|
||||||
|
q_stride &= ~7;
|
||||||
|
#if !defined(BLOCK_SIZE)
|
||||||
|
k_stride &= ~7;
|
||||||
|
v_stride &= ~7;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
||||||
|
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||||
|
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||||
|
|
||||||
|
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
||||||
|
|
||||||
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
||||||
|
|
|
@ -57,17 +57,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
#if QUANT_K > 1
|
#if QUANT_K > 1
|
||||||
#define DECODEFUNCA , dequantFuncA
|
#define DECODEFUNCA , dequantFuncA
|
||||||
#define MAT_A_TYPE float16_t
|
|
||||||
|
|
||||||
#include "dequant_funcs_cm2.comp"
|
#include "dequant_funcs_cm2.comp"
|
||||||
|
|
||||||
#else
|
#else
|
||||||
#define DECODEFUNCA
|
#define DECODEFUNCA
|
||||||
#define MAT_A_TYPE A_TYPE
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define MAT_B_TYPE B_TYPE
|
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||||
|
|
||||||
|
@ -236,16 +232,13 @@ void main() {
|
||||||
|
|
||||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||||
|
|
||||||
coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
|
||||||
|
|
||||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
#endif // !defined(MUL_MAT_ID)
|
#endif // !defined(MUL_MAT_ID)
|
||||||
|
@ -261,10 +254,8 @@ void main() {
|
||||||
[[dont_unroll]]
|
[[dont_unroll]]
|
||||||
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
|
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
|
||||||
|
|
||||||
coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft;
|
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft;
|
|
||||||
|
|
||||||
// Clamping is expensive, so detect different code paths for each combination
|
// Clamping is expensive, so detect different code paths for each combination
|
||||||
// of A and B needing clamping.
|
// of A and B needing clamping.
|
||||||
|
@ -281,16 +272,12 @@ void main() {
|
||||||
#else
|
#else
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
||||||
#endif
|
#endif
|
||||||
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||||
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
|
||||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
|
||||||
} else if (unclampedA && !unclampedB) {
|
} else if (unclampedA && !unclampedB) {
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||||
|
|
||||||
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||||
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
|
||||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
|
||||||
} else if (!unclampedA && unclampedB) {
|
} else if (!unclampedA && unclampedB) {
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
|
@ -298,16 +285,12 @@ void main() {
|
||||||
#else
|
#else
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
||||||
#endif
|
#endif
|
||||||
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||||
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
|
||||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
|
||||||
} else if (!unclampedA && !unclampedB) {
|
} else if (!unclampedA && !unclampedB) {
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||||
|
|
||||||
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||||
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
|
||||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -316,8 +316,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||||
// For aligned matmul loads
|
// For aligned matmul loads
|
||||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
|
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
|
||||||
|
|
||||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
// don't generate f32 variants for coopmat2
|
||||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
if (!coopmat2) {
|
||||||
|
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
}
|
||||||
|
|
||||||
if (tname != "f16" && tname != "f32") {
|
if (tname != "f16" && tname != "f32") {
|
||||||
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
|
|
@ -648,6 +648,10 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||||
|
|
||||||
ok = ok && data != nullptr;
|
ok = ok && data != nullptr;
|
||||||
|
|
||||||
|
if (ok) {
|
||||||
|
ggml_set_name(data, "GGUF tensor data binary blob");
|
||||||
|
}
|
||||||
|
|
||||||
// read the binary blob with the tensor data
|
// read the binary blob with the tensor data
|
||||||
ok = ok && gr.read(data->data, ctx->size);
|
ok = ok && gr.read(data->data, ctx->size);
|
||||||
|
|
||||||
|
|
|
@ -524,7 +524,8 @@ extern "C" {
|
||||||
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
|
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
|
||||||
|
|
||||||
// Get the default chat template. Returns nullptr if not available
|
// Get the default chat template. Returns nullptr if not available
|
||||||
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
|
// If name is NULL, returns the default chat template
|
||||||
|
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name);
|
||||||
|
|
||||||
// Returns the total number of parameters in the model
|
// Returns the total number of parameters in the model
|
||||||
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
|
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
|
||||||
|
|
112
models/ggml-vocab-deepseek-r1-qwen.gguf.inp
Normal file
112
models/ggml-vocab-deepseek-r1-qwen.gguf.inp
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
ied 4 ½ months
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Führer
|
||||||
|
__ggml_vocab_test__
|
||||||
|
|
||||||
|
__ggml_vocab_test__
|
||||||
|
|
||||||
|
__ggml_vocab_test__
|
||||||
|
|
||||||
|
__ggml_vocab_test__
|
||||||
|
|
||||||
|
__ggml_vocab_test__
|
||||||
|
|
||||||
|
__ggml_vocab_test__
|
||||||
|
|
||||||
|
|
||||||
|
__ggml_vocab_test__
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
__ggml_vocab_test__
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
__ggml_vocab_test__
|
||||||
|
|
||||||
|
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello world
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello world
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello World
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello World
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello World!
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello, world!
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello, world!
|
||||||
|
__ggml_vocab_test__
|
||||||
|
this is 🦙.cpp
|
||||||
|
__ggml_vocab_test__
|
||||||
|
w048 7tuijk dsdfhu
|
||||||
|
__ggml_vocab_test__
|
||||||
|
нещо на Български
|
||||||
|
__ggml_vocab_test__
|
||||||
|
កាន់តែពិសេសអាចខលចេញ
|
||||||
|
__ggml_vocab_test__
|
||||||
|
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello
|
||||||
|
Hello
|
||||||
|
__ggml_vocab_test__
|
||||||
|
(
|
||||||
|
__ggml_vocab_test__
|
||||||
|
|
||||||
|
=
|
||||||
|
__ggml_vocab_test__
|
||||||
|
' era
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||||
|
__ggml_vocab_test__
|
||||||
|
!!!!!!
|
||||||
|
__ggml_vocab_test__
|
||||||
|
3
|
||||||
|
__ggml_vocab_test__
|
||||||
|
33
|
||||||
|
__ggml_vocab_test__
|
||||||
|
333
|
||||||
|
__ggml_vocab_test__
|
||||||
|
3333
|
||||||
|
__ggml_vocab_test__
|
||||||
|
33333
|
||||||
|
__ggml_vocab_test__
|
||||||
|
333333
|
||||||
|
__ggml_vocab_test__
|
||||||
|
3333333
|
||||||
|
__ggml_vocab_test__
|
||||||
|
33333333
|
||||||
|
__ggml_vocab_test__
|
||||||
|
333333333
|
||||||
|
__ggml_vocab_test__
|
||||||
|
Cửa Việt
|
||||||
|
__ggml_vocab_test__
|
||||||
|
discards
|
||||||
|
__ggml_vocab_test__
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
|
||||||
|
__ggml_vocab_test__
|
46
models/ggml-vocab-deepseek-r1-qwen.gguf.out
Normal file
46
models/ggml-vocab-deepseek-r1-qwen.gguf.out
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
1122 220 19 220 26062 3951
|
||||||
|
37 50753 261
|
||||||
|
|
||||||
|
220
|
||||||
|
256
|
||||||
|
262
|
||||||
|
197
|
||||||
|
198
|
||||||
|
271
|
||||||
|
1406
|
||||||
|
1572
|
||||||
|
9707 1879
|
||||||
|
21927 1879
|
||||||
|
9707 4337
|
||||||
|
21927 4337
|
||||||
|
21927 4337 0
|
||||||
|
9707 11 1879 0
|
||||||
|
21927 11 1879 0
|
||||||
|
419 374 11162 99 247 13 10821
|
||||||
|
86 15 19 23 220 22 83 1963 41808 11472 2940 16739
|
||||||
|
78762 14144 1456 13073 63471 33594 3038 133178 79012
|
||||||
|
146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 147805 148301 147270 44258 223 146848
|
||||||
|
145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 320 3243 42365 429 702 1181 1828 3950 8
|
||||||
|
9707
|
||||||
|
21927
|
||||||
|
220 21927
|
||||||
|
256 21927
|
||||||
|
262 21927
|
||||||
|
262 21927 198 262 21927
|
||||||
|
320
|
||||||
|
198 284
|
||||||
|
6 11385
|
||||||
|
9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216
|
||||||
|
17085 2928
|
||||||
|
18
|
||||||
|
18 18
|
||||||
|
18 18 18
|
||||||
|
18 18 18 18
|
||||||
|
18 18 18 18 18
|
||||||
|
18 18 18 18 18 18
|
||||||
|
18 18 18 18 18 18 18
|
||||||
|
18 18 18 18 18 18 18 18
|
||||||
|
18 18 18 18 18 18 18 18 18
|
||||||
|
34 90063 128324
|
||||||
|
2560 2347
|
||||||
|
198 4710 14731 65497 7847 1572 2303 78672 10947 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 11162 99 247 149955 220 18 220 18 18 220 18 18 18 220 18 18 18 18 220 18 18 18 18 18 220 18 18 18 18 18 18 220 18 18 18 18 18 18 18 220 18 18 18 18 18 18 18 18 220 18 13 18 220 18 496 18 220 18 1112 18 220 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 144534 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 55460 53237 18658 14144 1456 13073 63471 33594 3038 133178 79012 3355 4605 4605 13874 13874 73594 3014 3014 28149 17085 2928 26610 7646 358 3003 1012 364 83 813 566 594 1052 11 364 787 498 2704 30 364 44 537 2704 358 3278 1281 432 11 364 35 498 1075 1045 15243 30 1205 6 42612 264 63866 43
|
77
scripts/get_hf_chat_template.py
Executable file
77
scripts/get_hf_chat_template.py
Executable file
|
@ -0,0 +1,77 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
'''
|
||||||
|
Fetches the Jinja chat template of a HuggingFace model.
|
||||||
|
If a model has multiple chat templates, you can specify the variant name.
|
||||||
|
|
||||||
|
Syntax:
|
||||||
|
./scripts/get_hf_chat_template.py model_id [variant]
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
|
||||||
|
./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use
|
||||||
|
./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
'''
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_chat_template(model_id, variant=None):
|
||||||
|
try:
|
||||||
|
# Use huggingface_hub library if available.
|
||||||
|
# Allows access to gated models if the user has access and ran `huggingface-cli login`.
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
|
||||||
|
config_str = f.read()
|
||||||
|
except ImportError:
|
||||||
|
import requests
|
||||||
|
assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}"
|
||||||
|
response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json")
|
||||||
|
if response.status_code == 401:
|
||||||
|
raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`')
|
||||||
|
response.raise_for_status()
|
||||||
|
config_str = response.text
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = json.loads(config_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
|
||||||
|
# (Remove extra '}' near the end of the file)
|
||||||
|
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
|
||||||
|
|
||||||
|
chat_template = config['chat_template']
|
||||||
|
if isinstance(chat_template, str):
|
||||||
|
return chat_template
|
||||||
|
else:
|
||||||
|
variants = {
|
||||||
|
ct['name']: ct['template']
|
||||||
|
for ct in chat_template
|
||||||
|
}
|
||||||
|
|
||||||
|
def format_variants():
|
||||||
|
return ', '.join(f'"{v}"' for v in variants.keys())
|
||||||
|
|
||||||
|
if variant is None:
|
||||||
|
if 'default' not in variants:
|
||||||
|
raise Exception(f'Please specify a chat template variant (one of {format_variants()})')
|
||||||
|
variant = 'default'
|
||||||
|
sys.stderr.write(f'Note: picked "default" chat template variant (out of {format_variants()})\n')
|
||||||
|
elif variant not in variants:
|
||||||
|
raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})")
|
||||||
|
|
||||||
|
return variants[variant]
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
if len(args) < 1:
|
||||||
|
raise ValueError("Please provide a model ID and an optional variant name")
|
||||||
|
model_id = args[0]
|
||||||
|
variant = None if len(args) < 2 else args[1]
|
||||||
|
|
||||||
|
template = get_hf_chat_template(model_id, variant)
|
||||||
|
sys.stdout.write(template)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main(sys.argv[1:])
|
112
scripts/hf.sh
Executable file
112
scripts/hf.sh
Executable file
|
@ -0,0 +1,112 @@
|
||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Shortcut for downloading HF models
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./llama-cli -m $(./scripts/hf.sh https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf)
|
||||||
|
# ./llama-cli -m $(./scripts/hf.sh --url https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/blob/main/mixtral-8x7b-v0.1.Q4_K_M.gguf)
|
||||||
|
# ./llama-cli -m $(./scripts/hf.sh --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf)
|
||||||
|
#
|
||||||
|
|
||||||
|
# all logs go to stderr
|
||||||
|
function log {
|
||||||
|
echo "$@" 1>&2
|
||||||
|
}
|
||||||
|
|
||||||
|
function usage {
|
||||||
|
log "Usage: $0 [[--url] <url>] [--repo <repo>] [--file <file>] [--outdir <dir> [-h|--help]"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# check for curl or wget
|
||||||
|
function has_cmd {
|
||||||
|
if ! [ -x "$(command -v $1)" ]; then
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
if has_cmd wget; then
|
||||||
|
cmd="wget -q -c -O %s/%s %s"
|
||||||
|
elif has_cmd curl; then
|
||||||
|
cmd="curl -C - -f --output-dir %s -o %s -L %s"
|
||||||
|
else
|
||||||
|
log "[E] curl or wget not found"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
url=""
|
||||||
|
repo=""
|
||||||
|
file=""
|
||||||
|
outdir="."
|
||||||
|
|
||||||
|
# parse args
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case "$1" in
|
||||||
|
--url)
|
||||||
|
url="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--repo)
|
||||||
|
repo="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--file)
|
||||||
|
file="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--outdir)
|
||||||
|
outdir="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
usage
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
url="$1"
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ -n "$repo" ] && [ -n "$file" ]; then
|
||||||
|
url="https://huggingface.co/$repo/resolve/main/$file"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$url" ]; then
|
||||||
|
log "[E] missing --url"
|
||||||
|
usage
|
||||||
|
fi
|
||||||
|
|
||||||
|
# check if the URL is a HuggingFace model, and if so, try to download it
|
||||||
|
is_url=false
|
||||||
|
|
||||||
|
if [[ ${#url} -gt 22 ]]; then
|
||||||
|
if [[ ${url:0:22} == "https://huggingface.co" ]]; then
|
||||||
|
is_url=true
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$is_url" = false ]; then
|
||||||
|
log "[E] invalid URL, must start with https://huggingface.co"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# replace "blob/main" with "resolve/main"
|
||||||
|
url=${url/blob\/main/resolve\/main}
|
||||||
|
|
||||||
|
basename=$(basename $url)
|
||||||
|
|
||||||
|
log "[+] attempting to download $basename"
|
||||||
|
|
||||||
|
if [ -n "$cmd" ]; then
|
||||||
|
cmd=$(printf "$cmd" "$outdir" "$basename" "$url")
|
||||||
|
log "[+] $cmd"
|
||||||
|
if $cmd; then
|
||||||
|
echo $outdir/$basename
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
log "[-] failed to download"
|
||||||
|
|
||||||
|
exit 1
|
|
@ -30,7 +30,7 @@ add_library(llama
|
||||||
unicode-data.cpp
|
unicode-data.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
target_include_directories(llama PUBLIC . ../include)
|
target_include_directories(llama PUBLIC . ../include ../common)
|
||||||
target_compile_features (llama PUBLIC cxx_std_17) # don't bump
|
target_compile_features (llama PUBLIC cxx_std_17) # don't bump
|
||||||
|
|
||||||
target_link_libraries(llama PUBLIC ggml)
|
target_link_libraries(llama PUBLIC ggml)
|
||||||
|
|
|
@ -183,6 +183,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
||||||
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
||||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
|
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
|
||||||
|
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
|
||||||
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
|
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
|
||||||
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
|
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
|
||||||
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
|
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
|
||||||
|
@ -1532,10 +1533,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
};
|
};
|
||||||
|
|
||||||
LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {}
|
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
||||||
|
|
||||||
std::string LLM_KV::operator()(llm_kv kv) const {
|
std::string LLM_KV::operator()(llm_kv kv) const {
|
||||||
return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
|
||||||
|
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string LLM_TN_IMPL::str() const {
|
std::string LLM_TN_IMPL::str() const {
|
||||||
|
|
|
@ -181,6 +181,7 @@ enum llm_kv {
|
||||||
LLM_KV_TOKENIZER_HF_JSON,
|
LLM_KV_TOKENIZER_HF_JSON,
|
||||||
LLM_KV_TOKENIZER_RWKV,
|
LLM_KV_TOKENIZER_RWKV,
|
||||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
|
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
|
||||||
|
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
|
||||||
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
||||||
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
||||||
LLM_KV_TOKENIZER_FIM_MID_ID,
|
LLM_KV_TOKENIZER_FIM_MID_ID,
|
||||||
|
@ -387,9 +388,10 @@ enum llm_tensor_layer {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LLM_KV {
|
struct LLM_KV {
|
||||||
LLM_KV(llm_arch arch);
|
LLM_KV(llm_arch arch, const char * suffix = nullptr);
|
||||||
|
|
||||||
llm_arch arch;
|
llm_arch arch;
|
||||||
|
const char * suffix;
|
||||||
|
|
||||||
std::string operator()(llm_kv kv) const;
|
std::string operator()(llm_kv kv) const;
|
||||||
};
|
};
|
||||||
|
|
|
@ -152,7 +152,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||||
return LLM_CHAT_TEMPLATE_MINICPM;
|
return LLM_CHAT_TEMPLATE_MINICPM;
|
||||||
} else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
|
} else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
|
||||||
return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
|
return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
|
||||||
} else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) {
|
} else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
|
||||||
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
|
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
|
||||||
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
|
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
|
||||||
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <climits>
|
#include <climits>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
#include <cerrno>
|
||||||
|
|
||||||
#ifdef __has_include
|
#ifdef __has_include
|
||||||
#if __has_include(<unistd.h>)
|
#if __has_include(<unistd.h>)
|
||||||
|
|
|
@ -2253,6 +2253,50 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_PHIMOE:
|
||||||
|
{
|
||||||
|
const int64_t n_embd_head = n_embd / n_head;
|
||||||
|
|
||||||
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||||
|
|
||||||
|
// output
|
||||||
|
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||||
|
output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
||||||
|
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
|
||||||
|
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
auto & layer = layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||||
|
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0);
|
||||||
|
|
||||||
|
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
if (layer.wqkv == nullptr) {
|
||||||
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
||||||
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||||
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
|
||||||
|
|
||||||
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||||
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
||||||
|
}
|
||||||
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
|
||||||
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0);
|
||||||
|
|
||||||
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
|
||||||
|
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0);
|
||||||
|
|
||||||
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||||
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||||
|
|
||||||
|
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||||
|
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_PLAMO:
|
case LLM_ARCH_PLAMO:
|
||||||
{
|
{
|
||||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
@ -4071,8 +4115,10 @@ uint64_t llama_model_size(const struct llama_model * model) {
|
||||||
return model->size();
|
return model->size();
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * llama_model_chat_template(const struct llama_model * model) {
|
const char * llama_model_chat_template(const struct llama_model * model, const char * name) {
|
||||||
const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE));
|
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
|
||||||
|
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
|
||||||
|
const auto & it = model->gguf_kv.find(key);
|
||||||
if (it == model->gguf_kv.end()) {
|
if (it == model->gguf_kv.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1523,7 +1523,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
|
||||||
clean_spaces = false;
|
clean_spaces = false;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "qwen2") {
|
tokenizer_pre == "qwen2" ||
|
||||||
|
tokenizer_pre == "deepseek-r1-qwen") {
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
||||||
clean_spaces = false;
|
clean_spaces = false;
|
||||||
} else if (
|
} else if (
|
||||||
|
|
|
@ -7,18 +7,17 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <codecvt>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <locale>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <locale>
|
|
||||||
#include <codecvt>
|
|
||||||
|
|
||||||
size_t unicode_len_utf8(char src) {
|
size_t unicode_len_utf8(char src) {
|
||||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
llama_add_compile_flags()
|
||||||
|
|
||||||
function(llama_test target)
|
function(llama_test target)
|
||||||
include(CMakeParseArguments)
|
include(CMakeParseArguments)
|
||||||
set(options)
|
set(options)
|
||||||
|
|
|
@ -3046,9 +3046,10 @@ struct test_flash_attn_ext : public test_case {
|
||||||
const float logit_softcap; // Gemma 2
|
const float logit_softcap; // Gemma 2
|
||||||
|
|
||||||
const ggml_type type_KV;
|
const ggml_type type_KV;
|
||||||
|
std::array<int32_t, 4> permute;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV);
|
return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
|
||||||
}
|
}
|
||||||
|
|
||||||
double max_nmse_err() override {
|
double max_nmse_err() override {
|
||||||
|
@ -3063,19 +3064,33 @@ struct test_flash_attn_ext : public test_case {
|
||||||
}
|
}
|
||||||
|
|
||||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
|
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
|
||||||
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
|
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
|
||||||
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
|
std::array<int32_t, 4> permute = {0, 1, 2, 3})
|
||||||
|
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
||||||
|
|
||||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1);
|
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
|
||||||
|
int64_t ne[4] = {ne0, ne1, ne2, ne3};
|
||||||
|
int64_t ne_perm[4];
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
ne_perm[permute[i]] = ne[i];
|
||||||
|
}
|
||||||
|
ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
|
||||||
|
if (permute != std::array<int32_t, 4>{0, 1, 2, 3}) {
|
||||||
|
t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);
|
||||||
|
}
|
||||||
|
return t;
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
|
||||||
ggml_set_name(q, "q");
|
ggml_set_name(q, "q");
|
||||||
|
|
||||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
||||||
ggml_set_name(k, "k");
|
ggml_set_name(k, "k");
|
||||||
|
|
||||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
|
||||||
ggml_set_name(v, "v");
|
ggml_set_name(v, "v");
|
||||||
|
|
||||||
ggml_tensor * m = nullptr;
|
ggml_tensor * m = nullptr;
|
||||||
|
@ -4167,6 +4182,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
for (int nb : { 1, 3, 32, 35, }) {
|
for (int nb : { 1, 3, 32, 35, }) {
|
||||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
|
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
|
||||||
|
// run fewer test cases permuted
|
||||||
|
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,16 @@
|
||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include "chat-template.hpp"
|
||||||
|
|
||||||
|
static std::string normalize_newlines(const std::string & s) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
static const std::regex nl_regex("\r\n");
|
||||||
|
return std::regex_replace(s, nl_regex, "\n");
|
||||||
|
#else
|
||||||
|
return s;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
int main(void) {
|
int main(void) {
|
||||||
std::vector<llama_chat_message> conversation {
|
std::vector<llama_chat_message> conversation {
|
||||||
|
@ -21,156 +31,228 @@ int main(void) {
|
||||||
std::string name;
|
std::string name;
|
||||||
std::string template_str;
|
std::string template_str;
|
||||||
std::string expected_output;
|
std::string expected_output;
|
||||||
|
std::string expected_output_jinja;
|
||||||
|
std::string bos_token = "";
|
||||||
|
std::string eos_token = "";
|
||||||
|
bool supported_with_jinja = true;
|
||||||
};
|
};
|
||||||
std::vector<TestCase> test_cases {
|
std::vector<TestCase> test_cases {
|
||||||
{
|
{
|
||||||
/* .name= */ "teknium/OpenHermes-2.5-Mistral-7B",
|
/* .name= */ "teknium/OpenHermes-2.5-Mistral-7B",
|
||||||
/* .template_str= */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
/* .template_str= */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
||||||
/* .expected_output= */ "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
|
/* .expected_output= */ "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)",
|
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)",
|
||||||
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
||||||
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "</s>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "TheBloke/FusionNet_34Bx2_MoE-AWQ",
|
/* .name= */ "TheBloke/FusionNet_34Bx2_MoE-AWQ",
|
||||||
/* .template_str= */ "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
|
/* .template_str= */ "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <<SYS>>\\n' + messages[idx]['content'] + '\\n<</SYS>>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}",
|
||||||
/* .expected_output= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
|
/* .expected_output= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
|
||||||
|
/* .expected_output_jinja= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
|
||||||
|
/* .bos_token= */ "<s>",
|
||||||
|
/* .eos_token= */ "</s>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "bofenghuang/vigogne-2-70b-chat",
|
/* .name= */ "bofenghuang/vigogne-2-70b-chat",
|
||||||
/* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
|
/* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
|
||||||
/* .expected_output= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s>[INST] Who are you [/INST]I am an assistant</s>[INST] Another question [/INST]",
|
/* .expected_output= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s>[INST] Who are you [/INST]I am an assistant</s>[INST] Another question [/INST]",
|
||||||
|
/* .expected_output_jinja= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "</s>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "mlabonne/AlphaMonarch-7B",
|
/* .name= */ "mlabonne/AlphaMonarch-7B",
|
||||||
/* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
|
/* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
|
||||||
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
|
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
|
||||||
|
/* .expected_output_jinja= */ "<s>system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
|
||||||
|
/* .bos_token= */ "<s>",
|
||||||
|
/* .eos_token= */ "</s>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "google/gemma-7b-it",
|
/* .name= */ "google/gemma-7b-it",
|
||||||
/* .template_str= */ "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}",
|
/* .template_str= */ "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}",
|
||||||
/* .expected_output= */ "<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
|
/* .expected_output= */ "<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
|
||||||
|
/* .expected_output_jinja= */ "<start_of_turn>user\nYou are a helpful assistant\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "OrionStarAI/Orion-14B-Chat",
|
/* .name= */ "OrionStarAI/Orion-14B-Chat",
|
||||||
/* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
|
/* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
|
||||||
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
|
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
|
||||||
|
/* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "</s>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "openchat/openchat-3.5-0106",
|
/* .name= */ "openchat/openchat-3.5-0106",
|
||||||
// The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d
|
// The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d
|
||||||
// So we match against the included template but implement the suggested version.
|
// So we match against the included template but implement the suggested version.
|
||||||
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
|
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
|
||||||
/* .expected_output= */ "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
|
/* .expected_output= */ "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
|
||||||
|
/* .expected_output_jinja= */ "GPT4 Correct System: You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "deepseek-ai/deepseek-coder-33b-instruct",
|
/* .name= */ "deepseek-ai/deepseek-coder-33b-instruct",
|
||||||
/* .template_str= */ "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
|
/* .template_str= */ "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
|
||||||
/* .expected_output= */ "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n",
|
/* .expected_output= */ "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "eachadea/vicuna-13b-1.1",
|
/* .name= */ "eachadea/vicuna-13b-1.1",
|
||||||
// No template included in tokenizer_config.json, so this template likely needs to be manually set.
|
// No template included in tokenizer_config.json, so this template likely needs to be manually set.
|
||||||
/* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
|
/* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
|
||||||
/* .expected_output= */ "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
|
/* .expected_output= */ "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "Orca-Vicuna",
|
/* .name= */ "Orca-Vicuna",
|
||||||
// No template included in tokenizer_config.json, so this template likely needs to be manually set.
|
// No template included in tokenizer_config.json, so this template likely needs to be manually set.
|
||||||
/* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
|
/* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
|
||||||
/* .expected_output= */ "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
|
/* .expected_output= */ "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "CohereForAI/c4ai-command-r-plus",
|
/* .name= */ "CohereForAI/c4ai-command-r-plus",
|
||||||
/* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
/* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||||
/* .expected_output= */ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
/* .expected_output= */ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "Llama-3",
|
/* .name= */ "Llama-3",
|
||||||
/* .template_str= */ "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
|
/* .template_str= */ "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
|
||||||
/* .expected_output= */ "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
/* .expected_output= */ "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "Phi-3-mini",
|
/* .name= */ "Phi-3-mini",
|
||||||
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||||
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
||||||
|
/* .expected_output_jinja= */ "<|user|>\nYou are a helpful assistant\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "Phi-3-small",
|
/* .name= */ "Phi-3-small",
|
||||||
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
|
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
|
||||||
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "Phi-3-medium",
|
/* .name= */ "Phi-3-medium",
|
||||||
/* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
/* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||||
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
||||||
|
/* .expected_output_jinja= */ "<|user|>\nYou are a helpful assistant\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "Phi-3-vision",
|
/* .name= */ "Phi-3-vision",
|
||||||
/* .template_str= */ "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
|
/* .template_str= */ "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
|
||||||
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "ChatGLM3",
|
/* .name= */ "ChatGLM3",
|
||||||
/* .template_str= */ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
|
/* .template_str= */ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
|
||||||
/* .expected_output= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
|
/* .expected_output= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
|
||||||
|
/* .expected_output_jinja= */ "[gMASK]sop<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "ChatGLM4",
|
/* .name= */ "ChatGLM4",
|
||||||
/* .template_str= */ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
|
/* .template_str= */ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
|
||||||
/* .expected_output= */ "[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
/* .expected_output= */ "[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF",
|
/* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF",
|
||||||
/* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
|
/* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
|
||||||
/* .expected_output= */ u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
|
/* .expected_output= */ u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "DeepSeek-V2",
|
/* .name= */ "DeepSeek-V2",
|
||||||
/* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
/* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
||||||
/* .expected_output= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
|
/* .expected_output= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "<|end▁of▁sentence|>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "ibm-granite/granite-3.0-8b-instruct",
|
/* .name= */ "ibm-granite/granite-3.0-8b-instruct",
|
||||||
/* .template_str= */ "{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}",
|
/* .template_str= */ "{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}",
|
||||||
/* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n",
|
/* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n",
|
||||||
|
/* .expected_output_jinja= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)",
|
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)",
|
||||||
/* .template_str= */ "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n",
|
/* .template_str= */ "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n",
|
||||||
/* .expected_output= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
|
/* .expected_output= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "</s>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)",
|
/* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)",
|
||||||
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
|
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
|
||||||
/* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] Another question[/INST]",
|
/* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] Another question[/INST]",
|
||||||
|
/* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] You are a helpful assistant\n\nAnother question[/INST]",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "</s>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)",
|
/* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)",
|
||||||
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
|
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
|
||||||
/* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]Another question[/INST]",
|
/* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]Another question[/INST]",
|
||||||
|
/* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]You are a helpful assistant\n\nAnother question[/INST]",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "</s>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)",
|
/* .name= */ "mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)",
|
||||||
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
||||||
/* .expected_output= */ "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant </s>[INST] Another question[/INST]",
|
/* .expected_output= */ "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant </s>[INST] Another question[/INST]",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "</s>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "ai-sage/GigaChat-20B-A3B-instruct",
|
/* .name= */ "ai-sage/GigaChat-20B-A3B-instruct",
|
||||||
/* .template_str= */ "{% if messages[0]['role'] == 'system' -%}\n {%- set loop_messages = messages[1:] -%}\n {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n {%- set loop_messages = messages -%}\n {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {% endif %}\n \n {%- if loop.index0 == 0 -%}\n {{ system_message -}}\n {%- endif -%}\n {%- if message['role'] == 'user' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if message['role'] == 'assistant' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if loop.last and add_generation_prompt -%}\n {{ 'assistant' + additional_special_tokens[0] -}}\n {%- endif -%}\n{%- endfor %}",
|
/* .template_str= */ "{% if messages[0]['role'] == 'system' -%}\n {%- set loop_messages = messages[1:] -%}\n {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n {%- set loop_messages = messages -%}\n {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {% endif %}\n \n {%- if loop.index0 == 0 -%}\n {{ system_message -}}\n {%- endif -%}\n {%- if message['role'] == 'user' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if message['role'] == 'assistant' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if loop.last and add_generation_prompt -%}\n {{ 'assistant' + additional_special_tokens[0] -}}\n {%- endif -%}\n{%- endfor %}",
|
||||||
/* .expected_output= */ "<s>You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>",
|
/* .expected_output= */ "<s>You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "",
|
||||||
|
/* .supported_with_jinja= */ false, // Requires additional_special_tokens as extra context
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "Infinigence/Megrez-3B-Instruct",
|
/* .name= */ "Infinigence/Megrez-3B-Instruct",
|
||||||
/* .template_str= */ u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}",
|
/* .template_str= */ u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}",
|
||||||
/* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
|
/* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
/* .name= */ "phi-4",
|
/* .name= */ "phi-4",
|
||||||
/* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}",
|
/* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}",
|
||||||
/* .expected_output= */ "<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|> I am an assistant <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>",
|
/* .expected_output= */ "<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|> I am an assistant <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>",
|
||||||
|
/* .expected_output_jinja= */ "",
|
||||||
|
/* .bos_token= */ "",
|
||||||
|
/* .eos_token= */ "",
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
std::vector<char> formatted_chat(1024);
|
std::vector<char> formatted_chat(1024);
|
||||||
|
@ -190,6 +272,7 @@ int main(void) {
|
||||||
// test invalid chat template
|
// test invalid chat template
|
||||||
res = llama_chat_apply_template("INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size());
|
res = llama_chat_apply_template("INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size());
|
||||||
assert(res < 0);
|
assert(res < 0);
|
||||||
|
const auto add_generation_prompt = true;
|
||||||
|
|
||||||
for (const auto & test_case : test_cases) {
|
for (const auto & test_case : test_cases) {
|
||||||
printf("\n\n=== %s ===\n\n", test_case.name.c_str());
|
printf("\n\n=== %s ===\n\n", test_case.name.c_str());
|
||||||
|
@ -198,26 +281,59 @@ int main(void) {
|
||||||
test_case.template_str.c_str(),
|
test_case.template_str.c_str(),
|
||||||
conversation.data(),
|
conversation.data(),
|
||||||
conversation.size(),
|
conversation.size(),
|
||||||
true,
|
add_generation_prompt,
|
||||||
formatted_chat.data(),
|
formatted_chat.data(),
|
||||||
formatted_chat.size()
|
formatted_chat.size()
|
||||||
);
|
);
|
||||||
formatted_chat.resize(res);
|
formatted_chat.resize(res);
|
||||||
std::string output(formatted_chat.data(), formatted_chat.size());
|
std::string output(formatted_chat.data(), formatted_chat.size());
|
||||||
printf("%s\n", output.c_str());
|
if (output != test_case.expected_output) {
|
||||||
printf("-------------------------\n");
|
printf("Expected:\n%s\n", test_case.expected_output.c_str());
|
||||||
assert(output == test_case.expected_output);
|
printf("-------------------------\n");
|
||||||
|
printf("Actual:\n%s\n", output.c_str());
|
||||||
|
fflush(stdout);
|
||||||
|
assert(output == test_case.expected_output);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json messages = json::array();
|
||||||
|
for (const auto & msg : conversation) {
|
||||||
|
messages.push_back({
|
||||||
|
{"role", msg.role},
|
||||||
|
{"content", msg.content},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
for (const auto & test_case : test_cases) {
|
||||||
|
if (!test_case.supported_with_jinja) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
|
||||||
|
try {
|
||||||
|
minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token);
|
||||||
|
auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt));
|
||||||
|
auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
|
||||||
|
if (output != expected_output) {
|
||||||
|
printf("Expected:\n%s\n", expected_output.c_str());
|
||||||
|
printf("-------------------------\n");
|
||||||
|
printf("Actual:\n%s\n", output.c_str());
|
||||||
|
fflush(stdout);
|
||||||
|
assert(output == expected_output);
|
||||||
|
}
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
printf("ERROR: %s\n", e.what());
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// test llama_chat_format_single for system message
|
// test llama_chat_format_single for system message
|
||||||
printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
|
printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
|
||||||
std::vector<common_chat_msg> chat2;
|
std::vector<common_chat_msg> chat2;
|
||||||
common_chat_msg sys_msg{"system", "You are a helpful assistant"};
|
common_chat_msg sys_msg{"system", "You are a helpful assistant"};
|
||||||
|
|
||||||
auto fmt_sys = [&](std::string tmpl) {
|
auto fmt_sys = [&](std::string tmpl_str) {
|
||||||
auto output = common_chat_format_single(nullptr, tmpl, chat2, sys_msg, false);
|
minja::chat_template tmpl(tmpl_str, "", "");
|
||||||
printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
|
auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
|
||||||
|
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
||||||
printf("-------------------------\n");
|
printf("-------------------------\n");
|
||||||
return output;
|
return output;
|
||||||
};
|
};
|
||||||
|
@ -241,9 +357,10 @@ int main(void) {
|
||||||
chat2.push_back({"assistant", "I am assistant"});
|
chat2.push_back({"assistant", "I am assistant"});
|
||||||
common_chat_msg new_msg{"user", "How are you"};
|
common_chat_msg new_msg{"user", "How are you"};
|
||||||
|
|
||||||
auto fmt_single = [&](std::string tmpl) {
|
auto fmt_single = [&](std::string tmpl_str) {
|
||||||
auto output = common_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
|
minja::chat_template tmpl(tmpl_str, "", "");
|
||||||
printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
|
auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
|
||||||
|
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
||||||
printf("-------------------------\n");
|
printf("-------------------------\n");
|
||||||
return output;
|
return output;
|
||||||
};
|
};
|
||||||
|
@ -258,7 +375,5 @@ int main(void) {
|
||||||
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
|
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
|
||||||
assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
|
assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
|
||||||
|
|
||||||
printf("Test chat templates: OK\n");
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ enum handcrafted_file_type {
|
||||||
HANDCRAFTED_DATA_CUSTOM_ALIGN = 810 + offset_has_data,
|
HANDCRAFTED_DATA_CUSTOM_ALIGN = 810 + offset_has_data,
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string handcrafted_file_type_name(const enum handcrafted_file_type hft) {
|
static std::string handcrafted_file_type_name(const enum handcrafted_file_type hft) {
|
||||||
switch (hft) {
|
switch (hft) {
|
||||||
case HANDCRAFTED_HEADER_BAD_MAGIC: return "HEADER_BAD_MAGIC";
|
case HANDCRAFTED_HEADER_BAD_MAGIC: return "HEADER_BAD_MAGIC";
|
||||||
case HANDCRAFTED_HEADER_BAD_VERSION_1: return "HEADER_BAD_VERSION_1";
|
case HANDCRAFTED_HEADER_BAD_VERSION_1: return "HEADER_BAD_VERSION_1";
|
||||||
|
@ -99,7 +99,7 @@ static bool expect_context_not_null(const enum handcrafted_file_type hft) {
|
||||||
|
|
||||||
typedef std::pair<enum ggml_type, std::array<int64_t, GGML_MAX_DIMS>> tensor_config_t;
|
typedef std::pair<enum ggml_type, std::array<int64_t, GGML_MAX_DIMS>> tensor_config_t;
|
||||||
|
|
||||||
std::vector<tensor_config_t> get_tensor_configs(std::mt19937 & rng) {
|
static std::vector<tensor_config_t> get_tensor_configs(std::mt19937 & rng) {
|
||||||
std::vector<tensor_config_t> tensor_configs;
|
std::vector<tensor_config_t> tensor_configs;
|
||||||
tensor_configs.reserve(100);
|
tensor_configs.reserve(100);
|
||||||
|
|
||||||
|
@ -122,7 +122,7 @@ std::vector<tensor_config_t> get_tensor_configs(std::mt19937 & rng) {
|
||||||
return tensor_configs;
|
return tensor_configs;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<enum gguf_type, enum gguf_type>> get_kv_types(std::mt19937 rng) {
|
static std::vector<std::pair<enum gguf_type, enum gguf_type>> get_kv_types(std::mt19937 rng) {
|
||||||
std::vector<std::pair<enum gguf_type, enum gguf_type>> kv_types;
|
std::vector<std::pair<enum gguf_type, enum gguf_type>> kv_types;
|
||||||
kv_types.reserve(100);
|
kv_types.reserve(100);
|
||||||
|
|
||||||
|
@ -626,8 +626,6 @@ static bool handcrafted_check_tensor_data(const gguf_context * gguf_ctx, const u
|
||||||
|
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
|
|
||||||
const uint32_t alignment = GGUF_DEFAULT_ALIGNMENT;
|
|
||||||
|
|
||||||
for (int i = 0; i < int(tensor_configs.size()); ++i) {
|
for (int i = 0; i < int(tensor_configs.size()); ++i) {
|
||||||
const ggml_type type = tensor_configs[i].first;
|
const ggml_type type = tensor_configs[i].first;
|
||||||
const std::array<int64_t, GGML_MAX_DIMS> shape = tensor_configs[i].second;
|
const std::array<int64_t, GGML_MAX_DIMS> shape = tensor_configs[i].second;
|
||||||
|
@ -866,13 +864,13 @@ static struct random_gguf_context_result get_random_gguf_context(ggml_backend_t
|
||||||
case GGUF_TYPE_COUNT:
|
case GGUF_TYPE_COUNT:
|
||||||
default: {
|
default: {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
} break;
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGUF_TYPE_COUNT:
|
case GGUF_TYPE_COUNT:
|
||||||
default: {
|
default: {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
} break;
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -938,7 +936,7 @@ static bool all_kv_in_other(const gguf_context * ctx, const gguf_context * other
|
||||||
}
|
}
|
||||||
|
|
||||||
if (type == GGUF_TYPE_ARRAY) {
|
if (type == GGUF_TYPE_ARRAY) {
|
||||||
const int arr_n = gguf_get_arr_n(ctx, id);
|
const size_t arr_n = gguf_get_arr_n(ctx, id);
|
||||||
if (arr_n != gguf_get_arr_n(other, idx_other)) {
|
if (arr_n != gguf_get_arr_n(other, idx_other)) {
|
||||||
ok = false;
|
ok = false;
|
||||||
continue;
|
continue;
|
||||||
|
@ -953,7 +951,7 @@ static bool all_kv_in_other(const gguf_context * ctx, const gguf_context * other
|
||||||
if (type_arr == GGUF_TYPE_BOOL) {
|
if (type_arr == GGUF_TYPE_BOOL) {
|
||||||
const int8_t * data = reinterpret_cast<const int8_t *>(gguf_get_arr_data(ctx, id));
|
const int8_t * data = reinterpret_cast<const int8_t *>(gguf_get_arr_data(ctx, id));
|
||||||
const int8_t * data_other = reinterpret_cast<const int8_t *>(gguf_get_arr_data(other, idx_other));
|
const int8_t * data_other = reinterpret_cast<const int8_t *>(gguf_get_arr_data(other, idx_other));
|
||||||
for (int arr_i = 0; arr_i < arr_n; ++arr_i) {
|
for (size_t arr_i = 0; arr_i < arr_n; ++arr_i) {
|
||||||
if (bool(data[arr_i]) != bool(data_other[arr_i])) {
|
if (bool(data[arr_i]) != bool(data_other[arr_i])) {
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
|
@ -962,7 +960,7 @@ static bool all_kv_in_other(const gguf_context * ctx, const gguf_context * other
|
||||||
}
|
}
|
||||||
|
|
||||||
if (type_arr == GGUF_TYPE_STRING) {
|
if (type_arr == GGUF_TYPE_STRING) {
|
||||||
for (int arr_i = 0; arr_i < arr_n; ++arr_i) {
|
for (size_t arr_i = 0; arr_i < arr_n; ++arr_i) {
|
||||||
const std::string str = gguf_get_arr_str(ctx, id, arr_i);
|
const std::string str = gguf_get_arr_str(ctx, id, arr_i);
|
||||||
const std::string str_other = gguf_get_arr_str(other, idx_other, arr_i);
|
const std::string str_other = gguf_get_arr_str(other, idx_other, arr_i);
|
||||||
if (str != str_other) {
|
if (str != str_other) {
|
||||||
|
@ -1033,6 +1031,12 @@ static bool same_tensor_data(const struct ggml_context * orig, const struct ggml
|
||||||
|
|
||||||
struct ggml_tensor * t_orig = ggml_get_first_tensor(orig);
|
struct ggml_tensor * t_orig = ggml_get_first_tensor(orig);
|
||||||
struct ggml_tensor * t_read = ggml_get_first_tensor(read);
|
struct ggml_tensor * t_read = ggml_get_first_tensor(read);
|
||||||
|
|
||||||
|
if (std::string(t_read->name) != "GGUF tensor data binary blob") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
t_read = ggml_get_next_tensor(read, t_read);
|
||||||
|
|
||||||
while (t_orig) {
|
while (t_orig) {
|
||||||
if (!t_read) {
|
if (!t_read) {
|
||||||
ok = false;
|
ok = false;
|
||||||
|
@ -1051,13 +1055,13 @@ static bool same_tensor_data(const struct ggml_context * orig, const struct ggml
|
||||||
}
|
}
|
||||||
|
|
||||||
t_orig = ggml_get_next_tensor(orig, t_orig);
|
t_orig = ggml_get_next_tensor(orig, t_orig);
|
||||||
t_read = ggml_get_next_tensor(orig, t_read);
|
t_read = ggml_get_next_tensor(read, t_read);
|
||||||
}
|
}
|
||||||
if (t_read) {
|
if (t_read) {
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return ok;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::pair<int, int> test_roundtrip(ggml_backend_dev_t dev, const unsigned int seed, const bool only_meta) {
|
static std::pair<int, int> test_roundtrip(ggml_backend_dev_t dev, const unsigned int seed, const bool only_meta) {
|
||||||
|
|
|
@ -144,7 +144,6 @@ static void test_penalties(
|
||||||
|
|
||||||
sampler_tester tester(probs, probs_expected);
|
sampler_tester tester(probs, probs_expected);
|
||||||
|
|
||||||
const size_t n_vocab = probs.size();
|
|
||||||
auto * sampler = llama_sampler_init_penalties(last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
auto * sampler = llama_sampler_init_penalties(last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
||||||
|
|
||||||
for (size_t i = 0; i < last_tokens.size(); i++) {
|
for (size_t i = 0; i < last_tokens.size(); i++) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue