Merge branch 'master' into bugfix-292
This commit is contained in:
commit
c8e940ede7
16 changed files with 595 additions and 282 deletions
3
.github/workflows/build.yml
vendored
3
.github/workflows/build.yml
vendored
|
@ -54,6 +54,7 @@ jobs:
|
|||
cd build
|
||||
cmake ..
|
||||
cmake --build . --config Release
|
||||
ctest --output-on-failure
|
||||
|
||||
macOS-latest-make:
|
||||
runs-on: macos-latest
|
||||
|
@ -90,6 +91,7 @@ jobs:
|
|||
cd build
|
||||
cmake ..
|
||||
cmake --build . --config Release
|
||||
ctest --output-on-failure
|
||||
|
||||
windows-latest-cmake:
|
||||
runs-on: windows-latest
|
||||
|
@ -106,6 +108,7 @@ jobs:
|
|||
cd build
|
||||
cmake ..
|
||||
cmake --build . --config Release
|
||||
ctest --output-on-failure
|
||||
|
||||
- name: Get commit hash
|
||||
id: commit
|
||||
|
|
2
.github/workflows/docker.yml
vendored
2
.github/workflows/docker.yml
vendored
|
@ -40,7 +40,7 @@ jobs:
|
|||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and push Docker image (versioned)
|
||||
|
|
287
CMakeLists.txt
287
CMakeLists.txt
|
@ -1,131 +1,252 @@
|
|||
cmake_minimum_required(VERSION 3.8)
|
||||
project("llama.cpp")
|
||||
cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason
|
||||
project("llama.cpp" C CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED true)
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||
find_package(Threads REQUIRED)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
|
||||
endif()
|
||||
|
||||
option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON)
|
||||
option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF)
|
||||
option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF)
|
||||
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
|
||||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
set(LLAMA_STANDALONE ON)
|
||||
|
||||
if (APPLE)
|
||||
option(LLAMA_NO_ACCELERATE "llama: disable Accelerate framework" OFF)
|
||||
option(LLAMA_NO_AVX "llama: disable AVX" OFF)
|
||||
option(LLAMA_NO_AVX2 "llama: disable AVX2" OFF)
|
||||
option(LLAMA_NO_FMA "llama: disable FMA" OFF)
|
||||
# configure project version
|
||||
# TODO
|
||||
else()
|
||||
set(LLAMA_STANDALONE OFF)
|
||||
endif()
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||
|
||||
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON)
|
||||
else()
|
||||
if (MINGW)
|
||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||
else()
|
||||
set(BUILD_SHARED_LIBS_DEFAULT ON)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Option list
|
||||
#
|
||||
|
||||
# general
|
||||
option(LLAMA_STATIC "llama: static link libraries" OFF)
|
||||
option(LLAMA_NATIVE "llama: enable -march=native flag" OFF)
|
||||
option(LLAMA_LTO "llama: enable link time optimization" OFF)
|
||||
|
||||
# debug
|
||||
option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON)
|
||||
option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF)
|
||||
option(LLAMA_GPROF "llama: enable gprof" OFF)
|
||||
|
||||
# sanitizers
|
||||
option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF)
|
||||
option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF)
|
||||
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
|
||||
|
||||
# instruction set specific
|
||||
option(LLAMA_AVX "llama: enable AVX" ON)
|
||||
option(LLAMA_AVX2 "llama: enable AVX2" ON)
|
||||
option(LLAMA_FMA "llama: enable FMA" ON)
|
||||
|
||||
# 3rd party libs
|
||||
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
|
||||
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
|
||||
|
||||
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||
|
||||
#
|
||||
# Compile flags
|
||||
#
|
||||
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED true)
|
||||
set(CMAKE_C_STANDARD_REQUIRED true)
|
||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
if (NOT MSVC)
|
||||
if (LLAMA_SANITIZE_THREAD)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=thread")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread")
|
||||
add_compile_options(-fsanitize=thread)
|
||||
endif()
|
||||
|
||||
if (LLAMA_SANITIZE_ADDRESS)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address -fno-omit-frame-pointer")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer")
|
||||
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
|
||||
endif()
|
||||
|
||||
if (LLAMA_SANITIZE_UNDEFINED)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=undefined")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined")
|
||||
add_compile_options(-fsanitize=undefined)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (APPLE AND NOT LLAMA_NO_ACCELERATE)
|
||||
if (APPLE AND LLAMA_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
||||
if (ACCELERATE_FRAMEWORK)
|
||||
message(STATUS "Accelerate framework found")
|
||||
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
||||
set(LLAMA_EXTRA_FLAGS ${LLAMA_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
|
||||
add_compile_definitions(GGML_USE_ACCELERATE)
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
||||
else()
|
||||
message(WARNING "Accelerate framework not found")
|
||||
endif()
|
||||
endif()
|
||||
if (LLAMA_OPENBLAS)
|
||||
if (LLAMA_STATIC)
|
||||
set(BLA_STATIC ON)
|
||||
endif()
|
||||
|
||||
set(BLA_VENDOR OpenBLAS)
|
||||
find_package(BLAS)
|
||||
if (BLAS_FOUND)
|
||||
message(STATUS "OpenBLAS found")
|
||||
|
||||
add_compile_definitions(GGML_USE_OPENBLAS)
|
||||
add_link_options(${BLAS_LIBRARIES})
|
||||
else()
|
||||
message(WARNING "OpenBLAS not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_ALL_WARNINGS)
|
||||
if (NOT MSVC)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} \
|
||||
-Wall \
|
||||
-Wextra \
|
||||
-Wpedantic \
|
||||
-Wshadow \
|
||||
-Wcast-qual \
|
||||
-Wstrict-prototypes \
|
||||
-Wpointer-arith \
|
||||
-Wno-unused-function \
|
||||
")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
|
||||
-Wall \
|
||||
-Wextra \
|
||||
-Wpedantic \
|
||||
-Wcast-qual \
|
||||
")
|
||||
set(c_flags
|
||||
-Wall
|
||||
-Wextra
|
||||
-Wpedantic
|
||||
-Wshadow
|
||||
-Wcast-qual
|
||||
-Wstrict-prototypes
|
||||
-Wpointer-arith
|
||||
-Wno-unused-function
|
||||
)
|
||||
set(cxx_flags
|
||||
-Wall
|
||||
-Wextra
|
||||
-Wpedantic
|
||||
-Wcast-qual
|
||||
)
|
||||
else()
|
||||
# todo : msvc
|
||||
endif()
|
||||
|
||||
add_compile_options(
|
||||
"$<$<COMPILE_LANGUAGE:C>:${c_flags}>"
|
||||
"$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags}>"
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
||||
message(STATUS "ARM detected")
|
||||
else()
|
||||
message(STATUS "x86 detected")
|
||||
if (MSVC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
|
||||
if (LLAMA_LTO)
|
||||
include(CheckIPOSupported)
|
||||
check_ipo_supported(RESULT result OUTPUT output)
|
||||
if (result)
|
||||
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
|
||||
else()
|
||||
if(NOT LLAMA_NO_AVX)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
|
||||
endif()
|
||||
if(NOT LLAMA_NO_AVX2)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2")
|
||||
endif()
|
||||
if(NOT LLAMA_NO_FMA)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mfma")
|
||||
endif()
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mf16c")
|
||||
message(WARNING "IPO is not supported: ${output}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# if (LLAMA_PERF)
|
||||
# set(LLAMA_EXTRA_FLAGS ${LLAMA_EXTRA_FLAGS} -DGGML_PERF)
|
||||
# endif()
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
if (NOT MSVC)
|
||||
if (LLAMA_STATIC)
|
||||
add_link_options(-static)
|
||||
if (MINGW)
|
||||
add_link_options(-static-libgcc -static-libstdc++)
|
||||
endif()
|
||||
endif()
|
||||
if (LLAMA_GPROF)
|
||||
add_compile_options(-pg)
|
||||
endif()
|
||||
if (LLAMA_NATIVE)
|
||||
add_compile_options(-march=native)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_executable(llama
|
||||
main.cpp
|
||||
utils.cpp
|
||||
utils.h)
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
||||
message(STATUS "ARM detected")
|
||||
if (MSVC)
|
||||
# TODO: arm msvc?
|
||||
else()
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
||||
add_compile_options(-mcpu=native)
|
||||
endif()
|
||||
# TODO: armv6,7,8 version specific flags
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
|
||||
message(STATUS "x86 detected")
|
||||
if (MSVC)
|
||||
if (LLAMA_AVX2)
|
||||
add_compile_options(/arch:AVX2)
|
||||
elseif (LLAMA_AVX)
|
||||
add_compile_options(/arch:AVX)
|
||||
endif()
|
||||
else()
|
||||
add_compile_options(-mf16c)
|
||||
if (LLAMA_FMA)
|
||||
add_compile_options(-mfma)
|
||||
endif()
|
||||
if (LLAMA_AVX)
|
||||
add_compile_options(-mavx)
|
||||
endif()
|
||||
if (LLAMA_AVX2)
|
||||
add_compile_options(-mavx2)
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
# TODO: support PowerPC
|
||||
message(STATUS "Unknown architecture")
|
||||
endif()
|
||||
|
||||
add_executable(quantize
|
||||
quantize.cpp
|
||||
utils.cpp
|
||||
utils.h)
|
||||
|
||||
add_library(ggml
|
||||
ggml.c
|
||||
ggml.h)
|
||||
#
|
||||
# Build library
|
||||
#
|
||||
|
||||
target_compile_definitions(ggml PUBLIC ${LLAMA_EXTRA_FLAGS})
|
||||
target_compile_definitions(llama PUBLIC ${LLAMA_EXTRA_FLAGS})
|
||||
target_compile_definitions(quantize PUBLIC ${LLAMA_EXTRA_FLAGS})
|
||||
add_executable(llama main.cpp)
|
||||
|
||||
add_executable(quantize quantize.cpp)
|
||||
|
||||
add_library(utils OBJECT
|
||||
utils.cpp
|
||||
utils.h)
|
||||
|
||||
target_include_directories(utils PUBLIC .)
|
||||
target_compile_features(utils PUBLIC cxx_std_11) # don't bump
|
||||
|
||||
add_library(ggml OBJECT
|
||||
ggml.c
|
||||
ggml.h)
|
||||
|
||||
target_link_libraries(ggml PRIVATE ${LLAMA_EXTRA_LIBS})
|
||||
target_include_directories(ggml PUBLIC .)
|
||||
target_link_libraries(quantize PRIVATE ggml)
|
||||
target_link_libraries(llama PRIVATE ggml)
|
||||
target_link_libraries(ggml PRIVATE Threads::Threads)
|
||||
target_compile_features(ggml PUBLIC c_std_11) # don't bump
|
||||
|
||||
#
|
||||
# Linking
|
||||
#
|
||||
|
||||
target_link_libraries(ggml PRIVATE Threads::Threads ${LLAMA_EXTRA_LIBS})
|
||||
target_link_libraries(llama PRIVATE ggml utils)
|
||||
target_link_libraries(quantize PRIVATE ggml utils)
|
||||
|
||||
#
|
||||
# programs, examples and tests
|
||||
#
|
||||
|
||||
if (LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||
enable_testing()
|
||||
add_subdirectory(tests)
|
||||
endif ()
|
||||
|
||||
#if (LLAMA_BUILD_EXAMPLES)
|
||||
# add_subdirectory(examples)
|
||||
#endif()
|
||||
|
|
49
Makefile
49
Makefile
|
@ -17,7 +17,7 @@ CXXV := $(shell $(CXX) --version | head -n 1)
|
|||
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
ifneq ($(UNAME_P),arm)
|
||||
SYSCTL_M := $(shell sysctl -n hw.optional.arm64)
|
||||
SYSCTL_M := $(shell sysctl -n hw.optional.arm64 2>/dev/null)
|
||||
ifeq ($(SYSCTL_M),1)
|
||||
# UNAME_P := arm
|
||||
# UNAME_M := arm64
|
||||
|
@ -30,8 +30,9 @@ endif
|
|||
# Compile flags
|
||||
#
|
||||
|
||||
# keep standard at C11 and C++11
|
||||
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++17 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
||||
LDFLAGS =
|
||||
|
||||
# OS specific
|
||||
|
@ -52,6 +53,10 @@ ifeq ($(UNAME_S),NetBSD)
|
|||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
ifeq ($(UNAME_S),OpenBSD)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
endif
|
||||
ifeq ($(UNAME_S),Haiku)
|
||||
CFLAGS += -pthread
|
||||
CXXFLAGS += -pthread
|
||||
|
@ -95,6 +100,38 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
|||
ifneq (,$(findstring sse3,$(SSE3_M)))
|
||||
CFLAGS += -msse3
|
||||
endif
|
||||
AVX512F_M := $(shell grep "avx512f " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx512f,$(AVX512F_M)))
|
||||
CFLAGS += -mavx512f
|
||||
endif
|
||||
AVX512BW_M := $(shell grep "avx512bw " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx512bw,$(AVX512BW_M)))
|
||||
CFLAGS += -mavx512bw
|
||||
endif
|
||||
AVX512DQ_M := $(shell grep "avx512dq " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx512dq,$(AVX512DQ_M)))
|
||||
CFLAGS += -mavx512dq
|
||||
endif
|
||||
AVX512VL_M := $(shell grep "avx512vl " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx512vl,$(AVX512VL_M)))
|
||||
CFLAGS += -mavx512vl
|
||||
endif
|
||||
AVX512CD_M := $(shell grep "avx512cd " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx512cd,$(AVX512CD_M)))
|
||||
CFLAGS += -mavx512cd
|
||||
endif
|
||||
AVX512ER_M := $(shell grep "avx512er " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx512er,$(AVX512ER_M)))
|
||||
CFLAGS += -mavx512er
|
||||
endif
|
||||
AVX512IFMA_M := $(shell grep "avx512ifma " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx512ifma,$(AVX512IFMA_M)))
|
||||
CFLAGS += -mavx512ifma
|
||||
endif
|
||||
AVX512PF_M := $(shell grep "avx512pf " /proc/cpuinfo)
|
||||
ifneq (,$(findstring avx512pf,$(AVX512PF_M)))
|
||||
CFLAGS += -mavx512pf
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Haiku)
|
||||
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||
|
@ -116,9 +153,6 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
|||
CFLAGS += -mfma -mf16c -mavx -mavx2
|
||||
endif
|
||||
endif
|
||||
ifeq ($(UNAME_M),amd64)
|
||||
CFLAGS += -mavx -mavx2 -mfma -mf16c
|
||||
endif
|
||||
ifneq ($(filter ppc64%,$(UNAME_M)),)
|
||||
POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
|
||||
ifneq (,$(findstring POWER9,$(POWER9_M)))
|
||||
|
@ -130,7 +164,8 @@ ifneq ($(filter ppc64%,$(UNAME_M)),)
|
|||
endif
|
||||
endif
|
||||
ifndef LLAMA_NO_ACCELERATE
|
||||
# Mac M1 - include Accelerate framework
|
||||
# Mac M1 - include Accelerate framework.
|
||||
# `-framework Accelerate` works on Mac Intel as well, with negliable performance boost (as of the predict time).
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CFLAGS += -DGGML_USE_ACCELERATE
|
||||
LDFLAGS += -framework Accelerate
|
||||
|
@ -193,7 +228,7 @@ clean:
|
|||
|
||||
main: main.cpp ggml.o utils.o
|
||||
$(CXX) $(CXXFLAGS) main.cpp ggml.o utils.o -o main $(LDFLAGS)
|
||||
./main -h
|
||||
@echo "\x1b[36mrun ./main -h for help\x1b[0m"
|
||||
|
||||
quantize: quantize.cpp ggml.o utils.o
|
||||
$(CXX) $(CXXFLAGS) quantize.cpp ggml.o utils.o -o quantize $(LDFLAGS)
|
||||
|
|
|
@ -192,11 +192,10 @@ First, download the `ggml` Alpaca model into the `./models` folder:
|
|||
|
||||
```
|
||||
# use one of these
|
||||
# NOTE: these are copied from the alpaca.cpp repo - not sure how long these will work
|
||||
# TODO: add a script to simplify the download
|
||||
curl -o ggml-alpaca-7b-q4.bin -C - https://gateway.estuary.tech/gw/ipfs/QmQ1bf2BTnYxq73MFJWu1B7bQ2UD6qG7D7YDCxhTndVkPC
|
||||
curl -o ggml-alpaca-7b-q4.bin -C - https://ipfs.io/ipfs/QmQ1bf2BTnYxq73MFJWu1B7bQ2UD6qG7D7YDCxhTndVkPC
|
||||
curl -o ggml-alpaca-7b-q4.bin -C - https://cloudflare-ipfs.com/ipfs/QmQ1bf2BTnYxq73MFJWu1B7bQ2UD6qG7D7YDCxhTndVkPC
|
||||
curl -o ./models/ggml-alpaca-7b-q4.bin -C - https://gateway.estuary.tech/gw/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1
|
||||
curl -o ./models/ggml-alpaca-7b-q4.bin -C - https://ipfs.io/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1
|
||||
curl -o ./models/ggml-alpaca-7b-q4.bin -C - https://cloudflare-ipfs.com/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1
|
||||
```
|
||||
|
||||
Now run the `main` tool like this:
|
||||
|
|
|
@ -3,4 +3,4 @@
|
|||
# Temporary script - will be removed in the future
|
||||
#
|
||||
|
||||
./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins --top_k 10000 --temp 0.96 --repeat_penalty 1 -t 7
|
||||
./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins --top_k 10000 --temp 0.2 --repeat_penalty 1 -t 7
|
||||
|
|
|
@ -10,25 +10,26 @@
|
|||
# - Name (char[name_length])
|
||||
# - Data (float[n_dims])
|
||||
#
|
||||
# By default, the bigger matrices are converted to 16-bit floats.
|
||||
# This can be disabled by adding the "use-f32" CLI argument.
|
||||
#
|
||||
# At the start of the ggml file we write the model parameters
|
||||
# and vocabulary.
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import struct
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
def parse_args():
|
||||
|
||||
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
|
||||
parser.add_argument('dir_model', help='directory containing the model checkpoint')
|
||||
parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)')
|
||||
parser.add_argument('dir_model', help='directory containing the model checkpoint')
|
||||
parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1)
|
||||
parser.add_argument('vocab_only', help='only write vocab to file', type=int, default=0, nargs='?')
|
||||
return parser.parse_args()
|
||||
|
||||
def get_n_parts(dim):
|
||||
|
@ -44,8 +45,14 @@ def get_n_parts(dim):
|
|||
|
||||
def load_hparams_and_tokenizer(dir_model):
|
||||
|
||||
# `dir_model` is something like `models/7B` or `models/7B/`.
|
||||
# "tokenizer.model" is expected under model's parent dir.
|
||||
# When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found.
|
||||
# Let's use the model's parent dir directly.
|
||||
model_parent_dir = os.path.dirname(os.path.normpath(dir_model))
|
||||
|
||||
fname_hparams = f"{dir_model}/params.json"
|
||||
fname_tokenizer = f"{dir_model}/../tokenizer.model"
|
||||
fname_tokenizer = f"{model_parent_dir}/tokenizer.model"
|
||||
|
||||
with open(fname_hparams, "r") as f:
|
||||
hparams = json.load(f)
|
||||
|
@ -60,7 +67,7 @@ def write_header(fout, hparams, ftype):
|
|||
|
||||
keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
|
||||
values = [
|
||||
0x67676d66, # magic: ggml in hex
|
||||
0x67676d66, # magic: ggmf in hex
|
||||
1, # file version
|
||||
*[hparams[key] for key in keys],
|
||||
hparams["dim"] // hparams["n_heads"], # rot (obsolete)
|
||||
|
@ -127,6 +134,29 @@ def main():
|
|||
ftype_str = ["f32", "f16"]
|
||||
|
||||
hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
|
||||
|
||||
print(args)
|
||||
|
||||
# if only writing vocab to file
|
||||
if args.vocab_only:
|
||||
|
||||
fname_model = f"{dir_model}/consolidated.00.pth"
|
||||
fname_out = f"{dir_model}/ggml-vocab.bin"
|
||||
|
||||
print(f"Extracting only the vocab from '{fname_model}'\n")
|
||||
|
||||
model = torch.load(fname_model, map_location="cpu")
|
||||
|
||||
with open(fname_out, "wb") as fout:
|
||||
fout.write(struct.pack("i", hparams["vocab_size"]))
|
||||
write_tokens(fout, tokenizer)
|
||||
|
||||
del model
|
||||
|
||||
print(f"Done. Output file: {fname_out}\n")
|
||||
|
||||
return
|
||||
|
||||
n_parts = get_n_parts(hparams["dim"])
|
||||
|
||||
for p in range(n_parts):
|
||||
|
@ -144,6 +174,7 @@ def main():
|
|||
process_and_write_variables(fout, model, ftype)
|
||||
|
||||
del model
|
||||
|
||||
print(f"Done. Output file: {fname_out}, (part {p})\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
cat ${./convert-pth-to-ggml.py} >> $out/bin/convert-pth-to-ggml
|
||||
chmod +x $out/bin/convert-pth-to-ggml
|
||||
'';
|
||||
meta.mainProgram = "llama";
|
||||
};
|
||||
devShells.default = pkgs.mkShell {
|
||||
packages = with pkgs; [
|
||||
|
|
80
ggml.c
80
ggml.c
|
@ -2,7 +2,7 @@
|
|||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||
#elif !defined(__FreeBSD__) && !defined(__NetBSD__)
|
||||
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
||||
#include <alloca.h>
|
||||
#endif
|
||||
|
||||
|
@ -361,7 +361,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|||
|
||||
// AVX routines provided by GH user Const-me
|
||||
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
||||
#if __AVX2__
|
||||
#if __AVX2__ || __AVX512F__
|
||||
// Unpack 32 4-bit fields into 32 bytes
|
||||
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
||||
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
|
||||
|
@ -397,7 +397,6 @@ static inline __m128i packNibbles( __m256i bytes )
|
|||
}
|
||||
#endif
|
||||
|
||||
|
||||
// method 5
|
||||
// blocks of QK elements
|
||||
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
|
||||
|
@ -1262,6 +1261,47 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
|||
*s = sumf;
|
||||
}
|
||||
|
||||
#if __AVX512F__ && QK == 32
|
||||
static inline __m512 dot_q4_0_oneblock_avx512(
|
||||
__m512 acc,
|
||||
const uint8_t * pd0,
|
||||
const uint8_t * pd1,
|
||||
const uint8_t * pb0,
|
||||
const uint8_t * pb1,
|
||||
size_t bs,
|
||||
int i
|
||||
) {
|
||||
const float * d0_0 = (const float *) (pd0 + i*bs);
|
||||
const float * d1_0 = (const float *) (pd1 + i*bs);
|
||||
|
||||
const uint8_t * restrict p0 = pb0 + (i+0)*bs;
|
||||
const uint8_t * restrict p1 = pb1 + (i+0)*bs;
|
||||
|
||||
// Compute combined scale for the block
|
||||
float scaleScalar = d0_0[0] * d1_0[0];
|
||||
__m512 scale = _mm512_set1_ps( scaleScalar );
|
||||
|
||||
__m256i bx = bytesFromNibbles( p0 );
|
||||
__m256i by = bytesFromNibbles( p1 );
|
||||
|
||||
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
||||
const __m256i off = _mm256_set1_epi8( 8 );
|
||||
bx = _mm256_sub_epi8( bx, off );
|
||||
by = _mm256_sub_epi8( by, off );
|
||||
|
||||
// Sign-extend 16 signed bytes into int16_t
|
||||
__m512i x32 = _mm512_cvtepi8_epi16( bx );
|
||||
__m512i y32 = _mm512_cvtepi8_epi16( by );
|
||||
// Compute products of int16_t integers, add pairwise
|
||||
__m512i i64 = _mm512_madd_epi16( x32, y32 );
|
||||
|
||||
// Convert int32_t to float
|
||||
__m512 p = _mm512_cvtepi32_ps( i64 );
|
||||
// Apply the scale, and accumulate
|
||||
return _mm512_fmadd_ps( scale, p, acc );
|
||||
}
|
||||
#endif
|
||||
|
||||
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
||||
ggml_float sumf = 0.0;
|
||||
|
||||
|
@ -1417,6 +1457,40 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|||
#else
|
||||
#error "not implemented for QK"
|
||||
#endif
|
||||
#elif defined(__AVX512F__)
|
||||
|
||||
#if QK == 32
|
||||
// Initialize accumulator with zeros
|
||||
__m512 acc0 = _mm512_setzero_ps();
|
||||
__m512 acc1 = _mm512_setzero_ps();
|
||||
|
||||
const int superblock_size = 8;
|
||||
const int superblock_count = nb / superblock_size;
|
||||
const int remainder = nb % superblock_size;
|
||||
|
||||
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
|
||||
int i = superblock_ix * superblock_size;
|
||||
|
||||
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+0 );
|
||||
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+1 );
|
||||
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+2 );
|
||||
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+3 );
|
||||
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+4 );
|
||||
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+5 );
|
||||
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+6 );
|
||||
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+7 );
|
||||
}
|
||||
|
||||
// Remainders
|
||||
for (int i = superblock_count * superblock_size; i < nb; ++i) {
|
||||
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i );
|
||||
}
|
||||
|
||||
// Horizontal sum of all lanes of the accumulator
|
||||
sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
|
||||
#else
|
||||
#error "not implemented for QK"
|
||||
#endif
|
||||
#elif defined(__AVX2__)
|
||||
#if QK == 32
|
||||
const size_t countBlocks = nb;
|
||||
|
|
63
main.cpp
63
main.cpp
|
@ -90,7 +90,8 @@ struct llama_model {
|
|||
};
|
||||
|
||||
// load the model's weights from a file
|
||||
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
|
||||
|
||||
bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, int n_parts, ggml_type memory_type = GGML_TYPE_F32) {
|
||||
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
|
||||
std::vector<char> f_buf(1024*1024);
|
||||
|
@ -106,12 +107,12 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
|
|||
{
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
if (magic == 0x67676d6c) {
|
||||
if (magic == FILE_MAGIC_UNVERSIONED) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n",
|
||||
__func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
if (magic != 0x67676d66) {
|
||||
if (magic != FILE_MAGIC) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
@ -119,15 +120,14 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
|
|||
uint32_t format_version;
|
||||
fin.read((char *) &format_version, sizeof(format_version));
|
||||
|
||||
if (format_version != 1) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ")\n",
|
||||
__func__, fname.c_str(), format_version);
|
||||
if (format_version != FILE_VERSION) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ", expected %d)\n",
|
||||
__func__, fname.c_str(), format_version, FILE_VERSION);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int n_ff = 0;
|
||||
int n_parts = 0;
|
||||
|
||||
// load hparams
|
||||
{
|
||||
|
@ -145,7 +145,10 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
|
|||
hparams.n_ctx = n_ctx;
|
||||
|
||||
n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
|
||||
n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
|
||||
|
||||
if (n_parts < 1) {
|
||||
n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
fprintf(stderr, "%s: n_ctx = %d\n", __func__, hparams.n_ctx);
|
||||
|
@ -162,12 +165,20 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
|
|||
// load vocab
|
||||
{
|
||||
std::string word;
|
||||
std::vector<char> tmp(64);
|
||||
|
||||
for (int i = 0; i < model.hparams.n_vocab; i++) {
|
||||
uint32_t len;
|
||||
fin.read((char *) &len, sizeof(len));
|
||||
|
||||
word.resize(len);
|
||||
fin.read((char *) word.data(), len);
|
||||
if (len > 0) {
|
||||
tmp.resize(len);
|
||||
fin.read(tmp.data(), len);
|
||||
word.assign(tmp.data(), len);
|
||||
} else {
|
||||
word.clear();
|
||||
}
|
||||
|
||||
float score;
|
||||
fin.read((char *) &score, sizeof(score));
|
||||
|
@ -175,10 +186,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
|
|||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
vocab.score[i] = score;
|
||||
|
||||
//if (i < 30000) {
|
||||
// fprintf(stderr, "%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
|
||||
//}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -544,9 +551,9 @@ bool llama_eval(
|
|||
const llama_model & model,
|
||||
const int n_threads,
|
||||
const int n_past,
|
||||
const std::vector<gpt_vocab::id> & embd_inp,
|
||||
std::vector<float> & embd_w,
|
||||
size_t & mem_per_token) {
|
||||
const std::vector<llama_vocab::id> & embd_inp,
|
||||
std::vector<float> & embd_w,
|
||||
size_t & mem_per_token) {
|
||||
const int N = embd_inp.size();
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
@ -832,14 +839,14 @@ int main(int argc, char ** argv) {
|
|||
|
||||
int64_t t_load_us = 0;
|
||||
|
||||
gpt_vocab vocab;
|
||||
llama_vocab vocab;
|
||||
llama_model model;
|
||||
|
||||
// load the model
|
||||
{
|
||||
const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) {
|
||||
if (!llama_model_load(params.model, model, vocab, params.n_ctx, params.n_parts, memory_type)) {
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
@ -864,13 +871,13 @@ int main(int argc, char ** argv) {
|
|||
// Add a space in front of the first character to match OG llama tokenizer behavior
|
||||
params.prompt.insert(0, 1, ' ');
|
||||
// tokenize the prompt
|
||||
std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
|
||||
std::vector<llama_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
|
||||
|
||||
params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
|
||||
|
||||
// prefix & suffix for instruct mode
|
||||
const std::vector<gpt_vocab::id> inp_pfx = ::llama_tokenize(vocab, "\n\n### Instruction:\n\n", true);
|
||||
const std::vector<gpt_vocab::id> inp_sfx = ::llama_tokenize(vocab, "\n\n### Response:\n\n", false);
|
||||
const std::vector<llama_vocab::id> inp_pfx = ::llama_tokenize(vocab, "\n\n### Instruction:\n\n", true);
|
||||
const std::vector<llama_vocab::id> inp_sfx = ::llama_tokenize(vocab, "\n\n### Response:\n\n", false);
|
||||
|
||||
// in instruct mode, we inject a prefix and a suffix to each input by the user
|
||||
if (params.instruct) {
|
||||
|
@ -912,14 +919,14 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
|
||||
fprintf(stderr, "\n\n");
|
||||
|
||||
std::vector<gpt_vocab::id> embd;
|
||||
std::vector<llama_vocab::id> embd;
|
||||
|
||||
// determine the required inference memory per token:
|
||||
size_t mem_per_token = 0;
|
||||
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
||||
|
||||
int last_n_size = params.repeat_last_n;
|
||||
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
|
||||
std::vector<llama_vocab::id> last_n_tokens(last_n_size);
|
||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||
|
||||
if (params.interactive) {
|
||||
|
@ -958,7 +965,7 @@ int main(int argc, char ** argv) {
|
|||
n_past += embd.size();
|
||||
embd.clear();
|
||||
|
||||
if (embd_inp.size() <= input_consumed) {
|
||||
if ((int) embd_inp.size() <= input_consumed) {
|
||||
// out of user input, sample next token
|
||||
const float top_k = params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
|
@ -967,7 +974,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const int n_vocab = model.hparams.n_vocab;
|
||||
|
||||
gpt_vocab::id id = 0;
|
||||
llama_vocab::id id = 0;
|
||||
|
||||
{
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
@ -995,7 +1002,7 @@ int main(int argc, char ** argv) {
|
|||
--remaining_tokens;
|
||||
} else {
|
||||
// some user input remains from prompt or interaction, forward it to processing
|
||||
while (embd_inp.size() > input_consumed) {
|
||||
while ((int) embd_inp.size() > input_consumed) {
|
||||
embd.push_back(embd_inp[input_consumed]);
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(embd_inp[input_consumed]);
|
||||
|
@ -1020,7 +1027,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// in interactive mode, and not currently processing queued inputs;
|
||||
// check if we should prompt the user for more
|
||||
if (params.interactive && embd_inp.size() <= input_consumed) {
|
||||
if (params.interactive && (int) embd_inp.size() <= input_consumed) {
|
||||
// check for reverse prompt
|
||||
std::string last_output;
|
||||
for (auto id : last_n_tokens) {
|
||||
|
@ -1058,7 +1065,7 @@ int main(int argc, char ** argv) {
|
|||
} while (another_line);
|
||||
if (params.use_color) printf(ANSI_COLOR_RESET);
|
||||
|
||||
std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
|
||||
std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
|
||||
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||
|
||||
if (params.instruct) {
|
||||
|
|
BIN
models/ggml-vocab.bin
Normal file
BIN
models/ggml-vocab.bin
Normal file
Binary file not shown.
12
quantize.cpp
12
quantize.cpp
|
@ -44,7 +44,7 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna
|
|||
return false;
|
||||
}
|
||||
|
||||
gpt_vocab vocab;
|
||||
llama_vocab vocab;
|
||||
|
||||
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
|
||||
|
||||
|
@ -64,12 +64,12 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna
|
|||
{
|
||||
uint32_t magic;
|
||||
finp.read((char *) &magic, sizeof(magic));
|
||||
if (magic == 0x67676d6c) {
|
||||
if (magic == FILE_MAGIC_UNVERSIONED) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n",
|
||||
__func__, fname_inp.c_str());
|
||||
return false;
|
||||
}
|
||||
if (magic != 0x67676d66) {
|
||||
if (magic != FILE_MAGIC) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
|
||||
return false;
|
||||
}
|
||||
|
@ -79,9 +79,9 @@ bool llama_model_quantize(const std::string & fname_inp, const std::string & fna
|
|||
uint32_t format_version;
|
||||
finp.read((char *) &format_version, sizeof(format_version));
|
||||
|
||||
if (format_version != 1) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ")\n",
|
||||
__func__, fname_inp.c_str(), format_version);
|
||||
if (format_version != FILE_VERSION) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ", expected %d)\n",
|
||||
__func__, fname_inp.c_str(), format_version, FILE_VERSION);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
4
tests/CMakeLists.txt
Normal file
4
tests/CMakeLists.txt
Normal file
|
@ -0,0 +1,4 @@
|
|||
set(TEST_TARGET test-tokenizer-0)
|
||||
add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE utils)
|
||||
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
|
69
tests/test-tokenizer-0.cpp
Normal file
69
tests/test-tokenizer-0.cpp
Normal file
|
@ -0,0 +1,69 @@
|
|||
#include "utils.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
static const std::map<std::string, std::vector<llama_vocab::id>> k_tests = {
|
||||
{ "Hello World", { 1, 10994, 2787, }, },
|
||||
{ " Hello World", { 1, 15043, 2787, }, },
|
||||
{ " Hello World!", { 1, 15043, 2787, 29991, }, },
|
||||
{ " this is 🦙.cpp", { 1, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, },
|
||||
{ "w048 7tuijk dsdfhu", { 1, 29893, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, },
|
||||
{ "нещо на Български", { 1, 821, 4851, 665, 1386, 29713, 1305, }, },
|
||||
};
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc < 2) {
|
||||
fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string fname = argv[1];
|
||||
|
||||
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
|
||||
|
||||
llama_vocab vocab;
|
||||
|
||||
if (!llama_vocab_load(fname, vocab)) {
|
||||
fprintf(stderr, "%s : failed to load vocab from: '%s'\n", __func__, fname.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int n_vocab = vocab.id_to_token.size();
|
||||
|
||||
if (n_vocab != 32000) {
|
||||
fprintf(stderr, "%s : expected 32000 tokens, got %d\n", __func__, n_vocab);
|
||||
return 2;
|
||||
}
|
||||
|
||||
for (const auto & test_kv : k_tests) {
|
||||
const auto res = llama_tokenize(vocab, test_kv.first, true);
|
||||
|
||||
bool correct = res.size() == test_kv.second.size();
|
||||
|
||||
for (int i = 0; i < (int) res.size() && correct; ++i) {
|
||||
if (res[i] != test_kv.second[i]) {
|
||||
correct = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!correct) {
|
||||
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
|
||||
fprintf(stderr, "%s : expected tokens: ", __func__);
|
||||
for (const auto & t : test_kv.second) {
|
||||
fprintf(stderr, "%6d, ", t);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s : got tokens: ", __func__);
|
||||
for (const auto & t : res) {
|
||||
fprintf(stderr, "%6d, ", t);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
179
utils.cpp
179
utils.cpp
|
@ -12,7 +12,7 @@
|
|||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||
#elif !defined(__FreeBSD__) && !defined(__NetBSD__)
|
||||
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
||||
#include <alloca.h>
|
||||
#endif
|
||||
|
||||
|
@ -74,6 +74,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
params.antiprompt.push_back(argv[++i]);
|
||||
} else if (arg == "--ignore-eos") {
|
||||
params.ignore_eos = true;
|
||||
} else if (arg == "--n_parts") {
|
||||
params.n_parts = std::stoi(argv[++i]);
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
gpt_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
|
@ -116,6 +118,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
|
||||
fprintf(stderr, " --memory_f16 use f16 instead of f32 for memory key+value\n");
|
||||
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
|
||||
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
|
||||
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||
|
@ -240,61 +243,6 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
|
|||
return result;
|
||||
}
|
||||
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
static size_t utf8_len(char src) {
|
||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||
|
@ -305,7 +253,8 @@ struct llama_sp_symbol {
|
|||
using index = int;
|
||||
index prev;
|
||||
index next;
|
||||
std::string_view text;
|
||||
const char * text;
|
||||
size_t n;
|
||||
};
|
||||
|
||||
struct llama_sp_bigram {
|
||||
|
@ -322,19 +271,23 @@ struct llama_sp_bigram {
|
|||
size_t size;
|
||||
};
|
||||
|
||||
// original implementation:
|
||||
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
||||
struct llama_tokenizer {
|
||||
llama_tokenizer(const gpt_vocab & vocab): vocab_(vocab) {}
|
||||
llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
|
||||
|
||||
void tokenize(std::string_view text, std::vector<gpt_vocab::id> & output) {
|
||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||
// split string into utf8 chars
|
||||
int index = 0;
|
||||
while (!text.empty()) {
|
||||
size_t offs = 0;
|
||||
while (offs < text.size()) {
|
||||
llama_sp_symbol sym;
|
||||
size_t char_len = std::min(text.size(), utf8_len(text.data()[0]));
|
||||
sym.text = std::string_view(text.data(), char_len);
|
||||
size_t char_len = std::min(text.size() - offs, utf8_len(text[offs]));
|
||||
sym.text = text.c_str() + offs;
|
||||
sym.n = char_len;
|
||||
offs += char_len;
|
||||
sym.prev = index - 1;
|
||||
text.remove_prefix(char_len);
|
||||
sym.next = text.empty() ? -1 : index + 1;
|
||||
sym.next = offs == text.size() ? -1 : index + 1;
|
||||
index++;
|
||||
symbols_.emplace_back(std::move(sym));
|
||||
}
|
||||
|
@ -353,14 +306,16 @@ struct llama_tokenizer {
|
|||
auto & right_sym = symbols_[bigram.right];
|
||||
|
||||
// if one of the symbols already got merged, skip it.
|
||||
if (left_sym.text.empty() || right_sym.text.empty() ||
|
||||
left_sym.text.size() + right_sym.text.size() != bigram.size) {
|
||||
if (left_sym.n == 0 || right_sym.n == 0 ||
|
||||
left_sym.n + right_sym.n != bigram.size) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// merge the right sym into the left one
|
||||
left_sym.text = std::string_view(left_sym.text.data(), left_sym.text.size() + right_sym.text.size());
|
||||
right_sym.text = std::string_view("");
|
||||
left_sym.n += right_sym.n;
|
||||
right_sym.n = 0;
|
||||
|
||||
//printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
|
||||
|
||||
// remove the right sym from the chain
|
||||
left_sym.next = right_sym.next;
|
||||
|
@ -374,13 +329,13 @@ struct llama_tokenizer {
|
|||
}
|
||||
|
||||
for (int i = 0; i != -1; i = symbols_[i].next) {
|
||||
auto& symbol = symbols_[i];
|
||||
auto token = vocab_.token_to_id.find(std::string(symbol.text));
|
||||
auto & symbol = symbols_[i];
|
||||
auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n));
|
||||
|
||||
if (token == vocab_.token_to_id.end()) {
|
||||
// output any symbols that did not form tokens as bytes.
|
||||
for (int j = 0; j < symbol.text.size(); ++j) {
|
||||
gpt_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
|
||||
for (int j = 0; j < (int) symbol.n; ++j) {
|
||||
llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
|
||||
output.push_back(token_id);
|
||||
}
|
||||
} else {
|
||||
|
@ -395,8 +350,8 @@ private:
|
|||
return;
|
||||
}
|
||||
|
||||
std::string_view text(symbols_[left].text.data(), symbols_[left].text.size() + symbols_[right].text.size());
|
||||
auto token = vocab_.token_to_id.find(std::string(text));
|
||||
const std::string text = std::string(symbols_[left].text, symbols_[left].n + symbols_[right].n);
|
||||
auto token = vocab_.token_to_id.find(text);
|
||||
|
||||
if (token == vocab_.token_to_id.end()) {
|
||||
return;
|
||||
|
@ -416,14 +371,52 @@ private:
|
|||
work_queue_.push(bigram);
|
||||
}
|
||||
|
||||
const gpt_vocab & vocab_;
|
||||
const llama_vocab & vocab_;
|
||||
std::vector<llama_sp_symbol> symbols_;
|
||||
llama_sp_bigram::queue work_queue_;
|
||||
};
|
||||
|
||||
std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos) {
|
||||
// TODO: temporary code duplication with llama.cpp
|
||||
// will resolve after #77 is merged
|
||||
bool llama_vocab_load(const std::string & fname, llama_vocab & vocab) {
|
||||
std::ifstream fin(fname, std::ios::binary);
|
||||
if (!fin.is_open()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int n_vocab = 0;
|
||||
fin.read((char *) &n_vocab, sizeof(n_vocab));
|
||||
|
||||
std::string word;
|
||||
std::vector<char> tmp(64);
|
||||
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
fin.read((char *) &len, sizeof(len));
|
||||
|
||||
word.resize(len);
|
||||
if (len > 0) {
|
||||
tmp.resize(len);
|
||||
fin.read(tmp.data(), len);
|
||||
word.assign(tmp.data(), len);
|
||||
} else {
|
||||
word.clear();
|
||||
}
|
||||
|
||||
float score;
|
||||
fin.read((char *) &score, sizeof(score));
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
vocab.score[i] = score;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
|
||||
llama_tokenizer tokenizer(vocab);
|
||||
std::vector<gpt_vocab::id> output;
|
||||
std::vector<llama_vocab::id> output;
|
||||
|
||||
if (text.size() == 0) {
|
||||
return output;
|
||||
|
@ -437,42 +430,22 @@ std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_v
|
|||
return output;
|
||||
}
|
||||
|
||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
||||
|
||||
vocab.token_to_id = ::json_parse(fname);
|
||||
|
||||
for (const auto & kv : vocab.token_to_id) {
|
||||
vocab.id_to_token[kv.second] = kv.first;
|
||||
}
|
||||
|
||||
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
|
||||
|
||||
// print the vocabulary
|
||||
//for (auto kv : vocab.token_to_id) {
|
||||
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
|
||||
//}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k) {
|
||||
void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k) {
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
||||
[](const std::pair<double, llama_vocab::id> & a, const std::pair<double, llama_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
logits_id.resize(top_k);
|
||||
}
|
||||
|
||||
gpt_vocab::id llama_sample_top_p_top_k(
|
||||
const gpt_vocab & vocab,
|
||||
llama_vocab::id llama_sample_top_p_top_k(
|
||||
const llama_vocab & vocab,
|
||||
const float * logits,
|
||||
std::vector<gpt_vocab::id> & last_n_tokens,
|
||||
std::vector<llama_vocab::id> & last_n_tokens,
|
||||
double repeat_penalty,
|
||||
int top_k,
|
||||
double top_p,
|
||||
|
@ -480,7 +453,7 @@ gpt_vocab::id llama_sample_top_p_top_k(
|
|||
std::mt19937 & rng) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
||||
std::vector<std::pair<double, llama_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
{
|
||||
|
|
68
utils.h
68
utils.h
|
@ -13,33 +13,33 @@
|
|||
//
|
||||
|
||||
struct gpt_params {
|
||||
int32_t seed = -1; // RNG seed
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_predict = 128; // new tokens to predict
|
||||
int32_t seed = -1; // RNG seed
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_predict = 128; // new tokens to predict
|
||||
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||
int32_t n_ctx = 512; //context size
|
||||
bool memory_f16 = false; // use f16 instead of f32 for memory kv
|
||||
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
|
||||
int32_t n_ctx = 512; //context size
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.95f;
|
||||
float temp = 0.80f;
|
||||
float repeat_penalty = 1.30f;
|
||||
float repeat_penalty = 1.10f;
|
||||
|
||||
int32_t n_batch = 8; // batch size for prompt processing
|
||||
|
||||
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
|
||||
std::string prompt = "";
|
||||
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
|
||||
std::string prompt = "";
|
||||
|
||||
bool random_prompt = false;
|
||||
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
|
||||
bool interactive = false; // interactive mode
|
||||
bool interactive_start = false; // reverse prompt immediately
|
||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||
bool instruct = false; // instruction mode (used for Alpaca models)
|
||||
bool ignore_eos = false; // do not stop generating after eos
|
||||
|
||||
bool memory_f16 = false; // use f16 instead of f32 for memory kv
|
||||
bool random_prompt = false; // do not randomize prompt if none provided
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
bool interactive = false; // interactive mode
|
||||
bool interactive_start = false; // reverse prompt immediately
|
||||
bool instruct = false; // instruction mode (used for Alpaca models)
|
||||
bool ignore_eos = false; // do not stop generating after eos
|
||||
};
|
||||
|
||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||
|
@ -48,11 +48,19 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
|
|||
|
||||
std::string gpt_random_prompt(std::mt19937 & rng);
|
||||
|
||||
//
|
||||
// Model file parsing
|
||||
//
|
||||
|
||||
#define FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
|
||||
#define FILE_MAGIC 0x67676d66 // 'ggmf' in hex
|
||||
#define FILE_VERSION 1
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
struct gpt_vocab {
|
||||
struct llama_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
|
||||
|
@ -66,34 +74,22 @@ void replace(std::string & str, const std::string & needle, const std::string &
|
|||
// poor-man's JSON parsing
|
||||
std::map<std::string, int32_t> json_parse(const std::string & fname);
|
||||
|
||||
// split text into tokens
|
||||
//
|
||||
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
||||
//
|
||||
// Regex (Python):
|
||||
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
//
|
||||
// Regex (C++):
|
||||
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
||||
//
|
||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
|
||||
// TODO: temporary until #77 is merged, need this now for some tokenizer tests
|
||||
bool llama_vocab_load(const std::string & fname, llama_vocab & vocab);
|
||||
|
||||
// TODO: this is probably wrong, but I cannot figure out how this tokenizer works ..
|
||||
// ref: https://github.com/google/sentencepiece
|
||||
std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos);
|
||||
|
||||
// load the tokens from encoder.json
|
||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
|
||||
std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos);
|
||||
|
||||
// sample next token given probabilities for each embedding
|
||||
//
|
||||
// - consider only the top K tokens
|
||||
// - from them, consider only the top tokens with cumulative probability > P
|
||||
//
|
||||
gpt_vocab::id llama_sample_top_p_top_k(
|
||||
const gpt_vocab & vocab,
|
||||
llama_vocab::id llama_sample_top_p_top_k(
|
||||
const llama_vocab & vocab,
|
||||
const float * logits,
|
||||
std::vector<gpt_vocab::id> & last_n_tokens,
|
||||
std::vector<llama_vocab::id> & last_n_tokens,
|
||||
double repeat_penalty,
|
||||
int top_k,
|
||||
double top_p,
|
||||
|
@ -101,7 +97,7 @@ gpt_vocab::id llama_sample_top_p_top_k(
|
|||
std::mt19937 & rng);
|
||||
|
||||
// filer to top K tokens from list of logits
|
||||
void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k);
|
||||
void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k);
|
||||
|
||||
//
|
||||
// Quantization
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue