Merge branch 'master' into add-retrieval-example

This commit is contained in:
Minsoo Cheong 2024-03-24 21:27:49 +09:00
commit 3cb7567607
78 changed files with 10594 additions and 4661 deletions

View file

@ -15,14 +15,133 @@ on:
types: [opened, synchronize, reopened] types: [opened, synchronize, reopened]
paths: ['**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m'] paths: ['**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m']
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env: env:
BRANCH_NAME: ${{ github.head_ref || github.ref_name }} BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
GGML_NLOOP: 3 GGML_NLOOP: 3
GGML_N_THREADS: 1 GGML_N_THREADS: 1
jobs: jobs:
macOS-latest-cmake-arm64:
runs-on: macos-14
steps:
- name: Clone
id: checkout
uses: actions/checkout@v3
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
- name: Build
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_METAL_EMBED_LIBRARY=ON -DLLAMA_CURL=ON ..
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
- name: Test
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900
- name: Determine tag name
id: tag
shell: bash
run: |
BUILD_NUMBER="$(git rev-list --count HEAD)"
SHORT_HASH="$(git rev-parse --short=7 HEAD)"
if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT
else
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
fi
- name: Pack artifacts
id: pack_artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: |
cp LICENSE ./build/bin/
zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/*
- name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: actions/upload-artifact@v3
with:
path: |
llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip
macOS-latest-cmake-x64:
runs-on: macos-latest
steps:
- name: Clone
id: checkout
uses: actions/checkout@v3
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
- name: Build
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_METAL_EMBED_LIBRARY=ON -DLLAMA_CURL=ON ..
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
- name: Test
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900
- name: Determine tag name
id: tag
shell: bash
run: |
BUILD_NUMBER="$(git rev-list --count HEAD)"
SHORT_HASH="$(git rev-parse --short=7 HEAD)"
if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT
else
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
fi
- name: Pack artifacts
id: pack_artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: |
cp LICENSE ./build/bin/
zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/*
- name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: actions/upload-artifact@v3
with:
path: |
llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip
ubuntu-focal-make: ubuntu-focal-make:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
env:
LLAMA_NODE_AVAILABLE: true
LLAMA_PYTHON_AVAILABLE: true
steps: steps:
- name: Clone - name: Clone
@ -35,6 +154,14 @@ jobs:
sudo apt-get update sudo apt-get update
sudo apt-get install build-essential gcc-8 sudo apt-get install build-essential gcc-8
- uses: actions/setup-node@v4
with:
node-version: "20"
- uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Build - name: Build
id: make_build id: make_build
env: env:
@ -98,6 +225,17 @@ jobs:
cd build cd build
ctest -L main --verbose --timeout 900 ctest -L main --verbose --timeout 900
- name: Test llama2c conversion
id: llama2c_test
run: |
cd build
echo "Fetch tokenizer"
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/tok512.bin
echo "Fetch llama2c model"
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/stories260K.bin
./bin/convert-llama2c-to-ggml --copy-vocab-from-model ./tok512.bin --llama2c-model stories260K.bin --llama2c-output-model stories260K.gguf
./bin/main -m stories260K.gguf -p "One day, Lily met a Shoggoth" -n 500 -c 256
# ubuntu-latest-cmake-sanitizer: # ubuntu-latest-cmake-sanitizer:
# runs-on: ubuntu-latest # runs-on: ubuntu-latest
# #
@ -662,6 +800,7 @@ jobs:
windows-latest-cmake-sycl: windows-latest-cmake-sycl:
runs-on: windows-latest runs-on: windows-latest
defaults: defaults:
run: run:
shell: bash shell: bash
@ -670,7 +809,6 @@ jobs:
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/62641e01-1e8d-4ace-91d6-ae03f7f8a71f/w_BaseKit_p_2024.0.0.49563_offline.exe WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/62641e01-1e8d-4ace-91d6-ae03f7f8a71f/w_BaseKit_p_2024.0.0.49563_offline.exe
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel
steps: steps:
- name: Clone - name: Clone
id: checkout id: checkout
@ -685,6 +823,32 @@ jobs:
id: cmake_build id: cmake_build
run: examples/sycl/win-build-sycl.bat run: examples/sycl/win-build-sycl.bat
- name: Determine tag name
id: tag
shell: bash
run: |
BUILD_NUMBER="$(git rev-list --count HEAD)"
SHORT_HASH="$(git rev-parse --short=7 HEAD)"
if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT
else
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
fi
- name: Pack artifacts
id: pack_artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: |
7z a llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip .\build\bin\*
- name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: actions/upload-artifact@v3
with:
path: |
llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip
ios-xcode-build: ios-xcode-build:
runs-on: macos-latest runs-on: macos-latest
@ -748,6 +912,8 @@ jobs:
- macOS-latest-cmake - macOS-latest-cmake
- windows-latest-cmake - windows-latest-cmake
- windows-latest-cmake-cublas - windows-latest-cmake-cublas
- macOS-latest-cmake-arm64
- macOS-latest-cmake-x64
steps: steps:
- name: Clone - name: Clone

View file

@ -19,5 +19,5 @@ jobs:
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
days-before-pr-stale: -1 days-before-pr-stale: -1
days-before-pr-close: -1 days-before-pr-close: -1
operations-per-run: 1000 operations-per-run: 10000
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}

View file

@ -5,6 +5,10 @@ env:
GGML_NLOOP: 3 GGML_NLOOP: 3
GGML_N_THREADS: 1 GGML_N_THREADS: 1
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
run: run:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04

View file

@ -15,6 +15,10 @@ on:
branches: branches:
- master - master
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
push_to_registry: push_to_registry:
name: Push Docker image to Docker Hub name: Push Docker image to Docker Hub

View file

@ -14,6 +14,10 @@ on:
branches: branches:
- master - master
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
editorconfig: editorconfig:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View file

@ -17,6 +17,10 @@ on:
types: [opened, synchronize, reopened] types: [opened, synchronize, reopened]
paths: ['**/*.nix', 'flake.lock'] paths: ['**/*.nix', 'flake.lock']
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
nix-build-aarch64: nix-build-aarch64:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View file

@ -8,6 +8,10 @@ on:
pull_request: pull_request:
types: [opened, synchronize, reopened] types: [opened, synchronize, reopened]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
nix-eval: nix-eval:
strategy: strategy:

View file

@ -16,6 +16,10 @@ on:
- 'requirements.txt' - 'requirements.txt'
- 'requirements/*.txt' - 'requirements/*.txt'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
python-check-requirements: python-check-requirements:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View file

@ -2,6 +2,10 @@ name: flake8 Lint
on: [push, pull_request] on: [push, pull_request]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
flake8-lint: flake8-lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View file

@ -18,6 +18,10 @@ on:
schedule: schedule:
- cron: '0 0 * * *' - cron: '0 0 * * *'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
server: server:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -31,7 +35,6 @@ jobs:
include: include:
- build_type: Release - build_type: Release
sanitizer: "" sanitizer: ""
disabled_on_pr: true
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
container: container:

View file

@ -6,6 +6,10 @@ on:
branches: branches:
- master - master
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs: jobs:
build: build:
strategy: strategy:

5
.gitignore vendored
View file

@ -11,6 +11,7 @@
*.gcda *.gcda
*.dot *.dot
*.bat *.bat
*.tmp
*.metallib *.metallib
*.etag *.etag
*.lastModified *.lastModified
@ -49,6 +50,7 @@ models-mnt
/embedding /embedding
/gguf /gguf
/gguf-llama-simple /gguf-llama-simple
/gguf-split
/gritlm /gritlm
/imatrix /imatrix
/infill /infill
@ -57,6 +59,9 @@ models-mnt
/llava-cli /llava-cli
/lookahead /lookahead
/lookup /lookup
/lookup-create
/lookup-merge
/lookup-stats
/main /main
/metal /metal
/passkey /passkey

View file

@ -99,6 +99,7 @@ option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"llama: max. batch size for using peer access") "llama: max. batch size for using peer access")
option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF)
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF) option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF)
@ -387,6 +388,9 @@ if (LLAMA_CUBLAS)
endif() endif()
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE}) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE})
if (LLAMA_CUDA_NO_PEER_COPY)
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
endif()
if (LLAMA_STATIC) if (LLAMA_STATIC)
if (WIN32) if (WIN32)
@ -531,6 +535,10 @@ if (LLAMA_HIPBLAS)
add_compile_definitions(GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif() endif()
if (LLAMA_CUDA_NO_PEER_COPY)
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
endif()
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})

View file

@ -1,7 +1,7 @@
# Define the default target now so that it is always the first target # Define the default target now so that it is always the first target
BUILD_TARGETS = \ BUILD_TARGETS = \
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \ simple batched batched-bench save-load-state server gguf gguf-split llama-bench libllava.a llava-cli baby-llama beam-search \
retrieval speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o retrieval speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o
# Binaries only useful for tests # Binaries only useful for tests
@ -9,7 +9,8 @@ TEST_TARGETS = \
tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \ tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \
tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \ tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \
tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope \ tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope \
tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease \
tests/test-json-schema-to-grammar
# Code coverage output files # Code coverage output files
COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report
@ -451,9 +452,9 @@ ifdef LLAMA_CUDA_PEER_MAX_BATCH_SIZE
else else
MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128
endif # LLAMA_CUDA_PEER_MAX_BATCH_SIZE endif # LLAMA_CUDA_PEER_MAX_BATCH_SIZE
#ifdef LLAMA_CUDA_CUBLAS ifdef LLAMA_CUDA_NO_PEER_COPY
# MK_NVCCFLAGS += -DGGML_CUDA_CUBLAS MK_NVCCFLAGS += -DGGML_CUDA_NO_PEER_COPY
#endif # LLAMA_CUDA_CUBLAS endif # LLAMA_CUDA_NO_PEER_COPY
ifdef LLAMA_CUDA_CCBIN ifdef LLAMA_CUDA_CCBIN
MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN) MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
endif endif
@ -534,6 +535,9 @@ endif # LLAMA_HIP_UMA
ifdef LLAMA_CUDA_FORCE_DMMV ifdef LLAMA_CUDA_FORCE_DMMV
HIPFLAGS += -DGGML_CUDA_FORCE_DMMV HIPFLAGS += -DGGML_CUDA_FORCE_DMMV
endif # LLAMA_CUDA_FORCE_DMMV endif # LLAMA_CUDA_FORCE_DMMV
ifdef LLAMA_CUDA_NO_PEER_COPY
HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY
endif # LLAMA_CUDA_NO_PEER_COPY
OBJS += ggml-cuda.o OBJS += ggml-cuda.o
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
@ -666,9 +670,15 @@ console.o: common/console.cpp common/console.h
grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
json-schema-to-grammar.o: common/json-schema-to-grammar.cpp common/json-schema-to-grammar.h
$(CXX) $(CXXFLAGS) -c $< -o $@
train.o: common/train.cpp common/train.h train.o: common/train.cpp common/train.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
ngram-cache.o: common/ngram-cache.cpp common/ngram-cache.h
$(CXX) $(CXXFLAGS) -c $< -o $@
libllama.so: llama.o ggml.o $(OBJS) libllama.so: llama.o ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
@ -676,7 +686,7 @@ libllama.a: llama.o ggml.o $(OBJS) $(COMMON_DEPS)
ar rcs libllama.a llama.o ggml.o $(OBJS) $(COMMON_DEPS) ar rcs libllama.a llama.o ggml.o $(OBJS) $(COMMON_DEPS)
clean: clean:
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS) rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
find examples pocs -type f -name "*.o" -delete find examples pocs -type f -name "*.o" -delete
# #
@ -745,7 +755,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp json-schema-to-grammar.o common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
@ -810,9 +820,15 @@ lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) lookup: examples/lookup/lookup.cpp ggml.o llama.o ngram-cache.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-create.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp) -o lookup-create $(LDFLAGS)
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-merge.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-merge.cpp)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-merge.cpp) -o lookup-merge $(LDFLAGS)
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-stats.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp) -o lookup-stats $(LDFLAGS)
passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
@ -869,6 +885,10 @@ tests/test-double-float: tests/test-double-float.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp json-schema-to-grammar.o ggml.o llama.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-grad0: tests/test-grad0.cpp ggml.o $(OBJS) tests/test-grad0: tests/test-grad0.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View file

@ -29,6 +29,7 @@ For Intel CPU, recommend to use llama.cpp for X86 (Intel MKL building).
## News ## News
- 2024.3 - 2024.3
- A blog is published: **Run LLM on all Intel GPUs Using llama.cpp**: [intel.com](https://www.intel.com/content/www/us/en/developer/articles/technical/run-llm-on-all-gpus-using-llama-cpp-artical.html) or [medium.com](https://medium.com/@jianyu_neo/run-llm-on-all-intel-gpus-using-llama-cpp-fd2e2dcbd9bd).
- New base line is ready: [tag b2437](https://github.com/ggerganov/llama.cpp/tree/b2437). - New base line is ready: [tag b2437](https://github.com/ggerganov/llama.cpp/tree/b2437).
- Support multiple cards: **--split-mode**: [none|layer]; not support [row], it's on developing. - Support multiple cards: **--split-mode**: [none|layer]; not support [row], it's on developing.
- Support to assign main GPU by **--main-gpu**, replace $GGML_SYCL_DEVICE. - Support to assign main GPU by **--main-gpu**, replace $GGML_SYCL_DEVICE.
@ -115,7 +116,7 @@ You can choose between **F16** and **F32** build. F16 is faster for long-prompt
# Or, for F32: # Or, for F32:
docker build -t llama-cpp-sycl -f .devops/main-intel.Dockerfile . docker build -t llama-cpp-sycl -f .devops/main-intel.Dockerfile .
# Note: you can also use the ".devops/main-server.Dockerfile", which compiles the "server" example # Note: you can also use the ".devops/server-intel.Dockerfile", which compiles the "server" example
``` ```
### Run ### Run

View file

@ -17,10 +17,12 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
### Hot topics ### Hot topics
- Fix major bug in Metal batched inference https://github.com/ggerganov/llama.cpp/pull/6225
- Multi-GPU pipeline parallelizm support https://github.com/ggerganov/llama.cpp/pull/6017 - Multi-GPU pipeline parallelizm support https://github.com/ggerganov/llama.cpp/pull/6017
- Looking for contributions to add Deepseek support: https://github.com/ggerganov/llama.cpp/issues/5981 - Looking for contributions to add Deepseek support: https://github.com/ggerganov/llama.cpp/issues/5981
- Quantization blind testing: https://github.com/ggerganov/llama.cpp/discussions/5962 - Quantization blind testing: https://github.com/ggerganov/llama.cpp/discussions/5962
- Initial Mamba support has been added: https://github.com/ggerganov/llama.cpp/pull/5328 - Initial Mamba support has been added: https://github.com/ggerganov/llama.cpp/pull/5328
- Support loading sharded model, using `gguf-split` CLI https://github.com/ggerganov/llama.cpp/pull/6187
---- ----
@ -165,6 +167,7 @@ Unless otherwise noted these projects are open-source with permissive licensing:
- [cztomsik/ava](https://github.com/cztomsik/ava) (MIT) - [cztomsik/ava](https://github.com/cztomsik/ava) (MIT)
- [ptsochantaris/emeltal](https://github.com/ptsochantaris/emeltal) - [ptsochantaris/emeltal](https://github.com/ptsochantaris/emeltal)
- [pythops/tenere](https://github.com/pythops/tenere) (AGPL) - [pythops/tenere](https://github.com/pythops/tenere) (AGPL)
- [RecurseChat](https://recurse.chat/) (proprietary)
- [semperai/amica](https://github.com/semperai/amica) - [semperai/amica](https://github.com/semperai/amica)
- [withcatai/catai](https://github.com/withcatai/catai) - [withcatai/catai](https://github.com/withcatai/catai)
- [Mobile-Artificial-Intelligence/maid](https://github.com/Mobile-Artificial-Intelligence/maid) (MIT) - [Mobile-Artificial-Intelligence/maid](https://github.com/Mobile-Artificial-Intelligence/maid) (MIT)

View file

@ -122,6 +122,7 @@ pub fn build(b: *std.build.Builder) !void {
const console = make.obj("console", "common/console.cpp"); const console = make.obj("console", "common/console.cpp");
const sampling = make.obj("sampling", "common/sampling.cpp"); const sampling = make.obj("sampling", "common/sampling.cpp");
const grammar_parser = make.obj("grammar-parser", "common/grammar-parser.cpp"); const grammar_parser = make.obj("grammar-parser", "common/grammar-parser.cpp");
const json_schema_to_grammar = make.obj("json-schema-to-grammar", "common/json-schema-to-grammar.cpp");
const train = make.obj("train", "common/train.cpp"); const train = make.obj("train", "common/train.cpp");
const clip = make.obj("clip", "examples/llava/clip.cpp"); const clip = make.obj("clip", "examples/llava/clip.cpp");
const llava = make.obj("llava", "examples/llava/llava.cpp"); const llava = make.obj("llava", "examples/llava/llava.cpp");
@ -133,7 +134,7 @@ pub fn build(b: *std.build.Builder) !void {
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, train }); _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, train });
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, train }); _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, train });
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, sampling, grammar_parser, clip, llava }); const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, sampling, grammar_parser, json_schema_to_grammar, clip, llava });
if (server.target.isWindows()) { if (server.target.isWindows()) {
server.linkSystemLibrary("ws2_32"); server.linkSystemLibrary("ws2_32");
} }

View file

@ -47,6 +47,8 @@ if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif() endif()
set(TARGET json-schema-to-grammar)
add_library(${TARGET} OBJECT json-schema-to-grammar.cpp json-schema-to-grammar.h)
set(TARGET common) set(TARGET common)
@ -60,8 +62,11 @@ add_library(${TARGET} STATIC
console.cpp console.cpp
grammar-parser.h grammar-parser.h
grammar-parser.cpp grammar-parser.cpp
json.hpp
train.h train.h
train.cpp train.cpp
ngram-cache.h
ngram-cache.cpp
) )
if (BUILD_SHARED_LIBS) if (BUILD_SHARED_LIBS)

View file

@ -39,6 +39,9 @@
#endif #endif
#if defined(LLAMA_USE_CURL) #if defined(LLAMA_USE_CURL)
#include <curl/curl.h> #include <curl/curl.h>
#include <curl/easy.h>
#include <thread>
#include <future>
#endif #endif
#if defined(_MSC_VER) #if defined(_MSC_VER)
@ -61,7 +64,7 @@
#else #else
#include <sys/syslimits.h> #include <sys/syslimits.h>
#endif #endif
#define LLAMA_CURL_MAX_PATH_LENGTH PATH_MAX #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
#define LLAMA_CURL_MAX_HEADER_LENGTH 256 #define LLAMA_CURL_MAX_HEADER_LENGTH 256
#endif // LLAMA_USE_CURL #endif // LLAMA_USE_CURL
@ -101,7 +104,7 @@ int32_t get_num_physical_cores() {
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
} }
void process_escapes(std::string& input) { void process_escapes(std::string & input) {
std::size_t input_len = input.length(); std::size_t input_len = input.length();
std::size_t output_idx = 0; std::size_t output_idx = 0;
@ -154,8 +157,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
return result; return result;
} }
bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, bool & invalid_param) { bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
std::string arg = argv[i];
llama_sampling_params& sparams = params.sparams; llama_sampling_params& sparams = params.sparams;
if (arg == "-s" || arg == "--seed") { if (arg == "-s" || arg == "--seed") {
@ -648,14 +650,6 @@ bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, b
params.model = argv[i]; params.model = argv[i];
return true; return true;
} }
if (arg == "-mu" || arg == "--model-url") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.model_url = argv[i];
return true;
}
if (arg == "-md" || arg == "--model-draft") { if (arg == "-md" || arg == "--model-draft") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -672,6 +666,30 @@ bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, b
params.model_alias = argv[i]; params.model_alias = argv[i];
return true; return true;
} }
if (arg == "-mu" || arg == "--model-url") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.model_url = argv[i];
return true;
}
if (arg == "-hfr" || arg == "--hf-repo") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.hf_repo = argv[i];
return true;
}
if (arg == "-hff" || arg == "--hf-file") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.hf_file = argv[i];
return true;
}
if (arg == "--lora") { if (arg == "--lora") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -948,6 +966,22 @@ bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int & i, b
} }
return true; return true;
} }
if (arg == "-lcs" || arg == "--lookup-cache-static") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.lookup_cache_static = argv[i];
return true;
}
if (arg == "-lcd" || arg == "--lookup-cache-dynamic") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.lookup_cache_dynamic = argv[i];
return true;
}
if (arg == "--save-all-logits" || arg == "--kl-divergence-base") { if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -1201,13 +1235,15 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
std::replace(arg.begin(), arg.end(), '_', '-'); std::replace(arg.begin(), arg.end(), '_', '-');
} }
if (!gpt_params_find_arg(argc, argv, params, i, invalid_param)) { if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
throw std::invalid_argument("error: unknown argument: " + arg); throw std::invalid_argument("error: unknown argument: " + arg);
} }
} }
if (invalid_param) { if (invalid_param) {
throw std::invalid_argument("error: invalid parameter for argument: " + arg); throw std::invalid_argument("error: invalid parameter for argument: " + arg);
} }
if (params.prompt_cache_all && if (params.prompt_cache_all &&
(params.interactive || params.interactive_first || (params.interactive || params.interactive_first ||
params.instruct)) { params.instruct)) {
@ -1215,6 +1251,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
} }
// short-hand to avoid specifying --hf-file -> default it to --model
if (!params.hf_repo.empty() && params.hf_file.empty()) {
params.hf_file = params.model;
}
if (params.escape) { if (params.escape) {
process_escapes(params.prompt); process_escapes(params.prompt);
process_escapes(params.input_prefix); process_escapes(params.input_prefix);
@ -1404,12 +1445,20 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" layer range to apply the control vector(s) to, start and end inclusive\n"); printf(" layer range to apply the control vector(s) to, start and end inclusive\n");
printf(" -m FNAME, --model FNAME\n"); printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str()); printf(" model path (default: %s)\n", params.model.c_str());
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
printf(" model download url (default: %s)\n", params.model_url.c_str());
printf(" -md FNAME, --model-draft FNAME\n"); printf(" -md FNAME, --model-draft FNAME\n");
printf(" draft model for speculative decoding\n"); printf(" draft model for speculative decoding (default: unused)\n");
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
printf(" model download url (default: unused)\n");
printf(" -hfr REPO, --hf-repo REPO\n");
printf(" Hugging Face model repository (default: unused)\n");
printf(" -hff FILE, --hf-file FILE\n");
printf(" Hugging Face model file (default: unused)\n");
printf(" -ld LOGDIR, --logdir LOGDIR\n"); printf(" -ld LOGDIR, --logdir LOGDIR\n");
printf(" path under which to save YAML logs (no logging if unset)\n"); printf(" path under which to save YAML logs (no logging if unset)\n");
printf(" -lcs FNAME, --lookup-cache-static FNAME\n");
printf(" path to static lookup cache to use for lookup decoding (not updated by generation)\n");
printf(" -lcd FNAME, --lookup-cache-dynamic FNAME\n");
printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n");
printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
@ -1590,6 +1639,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "q4_1") { if (s == "q4_1") {
return GGML_TYPE_Q4_1; return GGML_TYPE_Q4_1;
} }
if (s == "iq4_nl") {
return GGML_TYPE_IQ4_NL;
}
if (s == "q5_0") { if (s == "q5_0") {
return GGML_TYPE_Q5_0; return GGML_TYPE_Q5_0;
} }
@ -1653,25 +1705,13 @@ void llama_batch_add(
#ifdef LLAMA_USE_CURL #ifdef LLAMA_USE_CURL
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, static bool llama_download_file(CURL * curl, const char * url, const char * path) {
struct llama_model_params params) { bool force_download = false;
// Basic validation of the model_url
if (!model_url || strlen(model_url) == 0) {
fprintf(stderr, "%s: invalid model_url\n", __func__);
return NULL;
}
// Initialize libcurl globally
auto curl = curl_easy_init();
if (!curl) {
fprintf(stderr, "%s: error initializing libcurl\n", __func__);
return NULL;
}
// Set the URL, allow to follow http redirection // Set the URL, allow to follow http redirection
curl_easy_setopt(curl, CURLOPT_URL, model_url); curl_easy_setopt(curl, CURLOPT_URL, url);
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
#if defined(_WIN32) #if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows. // operating system. Currently implemented under MS-Windows.
@ -1680,16 +1720,16 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
// Check if the file already exists locally // Check if the file already exists locally
struct stat model_file_info; struct stat model_file_info;
auto file_exists = (stat(path_model, &model_file_info) == 0); auto file_exists = (stat(path, &model_file_info) == 0);
// If the file exists, check for ${path_model}.etag or ${path_model}.lastModified files // If the file exists, check for ${path_model}.etag or ${path_model}.lastModified files
char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
char etag_path[LLAMA_CURL_MAX_PATH_LENGTH] = {0}; char etag_path[PATH_MAX] = {0};
snprintf(etag_path, sizeof(etag_path), "%s.etag", path_model); snprintf(etag_path, sizeof(etag_path), "%s.etag", path);
char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
char last_modified_path[LLAMA_CURL_MAX_PATH_LENGTH] = {0}; char last_modified_path[PATH_MAX] = {0};
snprintf(last_modified_path, sizeof(last_modified_path), "%s.lastModified", path_model); snprintf(last_modified_path, sizeof(last_modified_path), "%s.lastModified", path);
if (file_exists) { if (file_exists) {
auto * f_etag = fopen(etag_path, "r"); auto * f_etag = fopen(etag_path, "r");
@ -1697,7 +1737,7 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
if (!fgets(etag, sizeof(etag), f_etag)) { if (!fgets(etag, sizeof(etag), f_etag)) {
fprintf(stderr, "%s: unable to read file %s\n", __func__, etag_path); fprintf(stderr, "%s: unable to read file %s\n", __func__, etag_path);
} else { } else {
fprintf(stderr, "%s: previous model file found %s: %s\n", __func__, etag_path, etag); fprintf(stderr, "%s: previous file found %s: %s\n", __func__, etag_path, etag);
} }
fclose(f_etag); fclose(f_etag);
} }
@ -1707,7 +1747,7 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
if (!fgets(last_modified, sizeof(last_modified), f_last_modified)) { if (!fgets(last_modified, sizeof(last_modified), f_last_modified)) {
fprintf(stderr, "%s: unable to read file %s\n", __func__, last_modified_path); fprintf(stderr, "%s: unable to read file %s\n", __func__, last_modified_path);
} else { } else {
fprintf(stderr, "%s: previous model file found %s: %s\n", __func__, last_modified_path, fprintf(stderr, "%s: previous file found %s: %s\n", __func__, last_modified_path,
last_modified); last_modified);
} }
fclose(f_last_modified); fclose(f_last_modified);
@ -1725,6 +1765,11 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata; llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata;
// Convert header field name to lowercase
for (size_t i = 0; i < n_items && buffer[i] != ':'; ++i) {
buffer[i] = tolower(buffer[i]);
}
const char * etag_prefix = "etag: "; const char * etag_prefix = "etag: ";
if (strncmp(buffer, etag_prefix, strlen(etag_prefix)) == 0) { if (strncmp(buffer, etag_prefix, strlen(etag_prefix)) == 0) {
strncpy(headers->etag, buffer + strlen(etag_prefix), n_items - strlen(etag_prefix) - 2); // Remove CRLF strncpy(headers->etag, buffer + strlen(etag_prefix), n_items - strlen(etag_prefix) - 2); // Remove CRLF
@ -1747,7 +1792,7 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
if (res != CURLE_OK) { if (res != CURLE_OK) {
curl_easy_cleanup(curl); curl_easy_cleanup(curl);
fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
return NULL; return false;
} }
long http_code = 0; long http_code = 0;
@ -1755,30 +1800,34 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
if (http_code != 200) { if (http_code != 200) {
// HEAD not supported, we don't know if the file has changed // HEAD not supported, we don't know if the file has changed
// force trigger downloading // force trigger downloading
file_exists = false; force_download = true;
fprintf(stderr, "%s: HEAD invalid http status code received: %ld\n", __func__, http_code); fprintf(stderr, "%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
} }
} }
// If the ETag or the Last-Modified headers are different: trigger a new download // If the ETag or the Last-Modified headers are different: trigger a new download
if (!file_exists || strcmp(etag, headers.etag) != 0 || strcmp(last_modified, headers.last_modified) != 0) { bool should_download = !file_exists
char path_model_temporary[LLAMA_CURL_MAX_PATH_LENGTH] = {0}; || force_download
snprintf(path_model_temporary, sizeof(path_model_temporary), "%s.downloadInProgress", path_model); || (strlen(headers.etag) > 0 && strcmp(etag, headers.etag) != 0)
|| (strlen(headers.last_modified) > 0 && strcmp(last_modified, headers.last_modified) != 0);
if (should_download) {
char path_temporary[PATH_MAX] = {0};
snprintf(path_temporary, sizeof(path_temporary), "%s.downloadInProgress", path);
if (file_exists) { if (file_exists) {
fprintf(stderr, "%s: deleting previous downloaded model file: %s\n", __func__, path_model); fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path);
if (remove(path_model) != 0) { if (remove(path) != 0) {
curl_easy_cleanup(curl); curl_easy_cleanup(curl);
fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path_model); fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path);
return NULL; return false;
} }
} }
// Set the output file // Set the output file
auto * outfile = fopen(path_model_temporary, "wb"); auto * outfile = fopen(path_temporary, "wb");
if (!outfile) { if (!outfile) {
curl_easy_cleanup(curl); curl_easy_cleanup(curl);
fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path_model); fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path);
return NULL; return false;
} }
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd); typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
@ -1792,15 +1841,30 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
// display download progress // display download progress
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
// helper function to hide password in URL
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
std::size_t protocol_pos = url.find("://");
if (protocol_pos == std::string::npos) {
return url; // Malformed URL
}
std::size_t at_pos = url.find('@', protocol_pos + 3);
if (at_pos == std::string::npos) {
return url; // No password in URL
}
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
};
// start the download // start the download
fprintf(stderr, "%s: downloading model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, fprintf(stderr, "%s: downloading from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
model_url, path_model, headers.etag, headers.last_modified); llama_download_hide_password_in_url(url).c_str(), path, headers.etag, headers.last_modified);
auto res = curl_easy_perform(curl); auto res = curl_easy_perform(curl);
if (res != CURLE_OK) { if (res != CURLE_OK) {
fclose(outfile); fclose(outfile);
curl_easy_cleanup(curl); curl_easy_cleanup(curl);
fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
return NULL; return false;
} }
long http_code = 0; long http_code = 0;
@ -1809,7 +1873,7 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
fclose(outfile); fclose(outfile);
curl_easy_cleanup(curl); curl_easy_cleanup(curl);
fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code); fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code);
return NULL; return false;
} }
// Clean up // Clean up
@ -1821,7 +1885,7 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
if (etag_file) { if (etag_file) {
fputs(headers.etag, etag_file); fputs(headers.etag, etag_file);
fclose(etag_file); fclose(etag_file);
fprintf(stderr, "%s: model etag saved %s: %s\n", __func__, etag_path, headers.etag); fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag);
} }
} }
@ -1831,42 +1895,177 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
if (last_modified_file) { if (last_modified_file) {
fputs(headers.last_modified, last_modified_file); fputs(headers.last_modified, last_modified_file);
fclose(last_modified_file); fclose(last_modified_file);
fprintf(stderr, "%s: model last modified saved %s: %s\n", __func__, last_modified_path, fprintf(stderr, "%s: file last modified saved %s: %s\n", __func__, last_modified_path,
headers.last_modified); headers.last_modified);
} }
} }
if (rename(path_model_temporary, path_model) != 0) { if (rename(path_temporary, path) != 0) {
curl_easy_cleanup(curl);
fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary, path);
return false;
}
}
return true;
}
struct llama_model * llama_load_model_from_url(
const char * model_url,
const char * path_model,
const struct llama_model_params & params) {
// Basic validation of the model_url
if (!model_url || strlen(model_url) == 0) {
fprintf(stderr, "%s: invalid model_url\n", __func__);
return NULL;
}
// Initialize libcurl
auto * curl = curl_easy_init();
if (!curl) {
fprintf(stderr, "%s: error initializing libcurl\n", __func__);
return NULL;
}
if (!curl) {
fprintf(stderr, "%s: error initializing libcurl\n", __func__);
return NULL;
}
if (!llama_download_file(curl, model_url, path_model)) {
return NULL;
}
// check for additional GGUFs split to download
int n_split = 0;
{
struct gguf_init_params gguf_params = {
/*.no_alloc = */ true,
/*.ctx = */ NULL,
};
auto * ctx_gguf = gguf_init_from_file(path_model, gguf_params);
if (!ctx_gguf) {
fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, path_model);
curl_easy_cleanup(curl); curl_easy_cleanup(curl);
fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_model_temporary, path_model);
return NULL; return NULL;
} }
auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
if (key_n_split >= 0) {
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
}
gguf_free(ctx_gguf);
} }
curl_easy_cleanup(curl); curl_easy_cleanup(curl);
if (n_split > 1) {
char split_prefix[PATH_MAX] = {0};
char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
// Verify the first split file format
// and extract split URL and PATH prefixes
{
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), path_model, 0, n_split)) {
fprintf(stderr, "\n%s: unexpected model file name: %s"
" n_split=%d\n", __func__, path_model, n_split);
return NULL;
}
if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url, 0, n_split)) {
fprintf(stderr, "\n%s: unexpected model url: %s"
" n_split=%d\n", __func__, model_url, n_split);
return NULL;
}
}
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
for (int idx = 1; idx < n_split; idx++) {
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool {
char split_path[PATH_MAX] = {0};
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
auto * curl = curl_easy_init();
bool res = llama_download_file(curl, split_url, split_path);
curl_easy_cleanup(curl);
return res;
}, idx));
}
// Wait for all downloads to complete
for (auto & f : futures_download) {
if (!f.get()) {
return NULL;
}
}
}
return llama_load_model_from_file(path_model, params); return llama_load_model_from_file(path_model, params);
} }
struct llama_model * llama_load_model_from_hf(
const char * repo,
const char * model,
const char * path_model,
const struct llama_model_params & params) {
// construct hugging face model url:
//
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
//
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
//
std::string model_url = "https://huggingface.co/";
model_url += repo;
model_url += "/resolve/main/";
model_url += model;
return llama_load_model_from_url(model_url.c_str(), path_model, params);
}
#else #else
struct llama_model * llama_load_model_from_url(const char * /*model_url*/, const char * /*path_model*/, struct llama_model * llama_load_model_from_url(
struct llama_model_params /*params*/) { const char * /*model_url*/,
const char * /*path_model*/,
const struct llama_model_params & /*params*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return nullptr; return nullptr;
} }
struct llama_model * llama_load_model_from_hf(
const char * /*repo*/,
const char * /*model*/,
const char * /*path_model*/,
const struct llama_model_params & /*params*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return nullptr;
}
#endif // LLAMA_USE_CURL #endif // LLAMA_USE_CURL
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) { std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params); auto mparams = llama_model_params_from_gpt_params(params);
llama_model * model = nullptr; llama_model * model = nullptr;
if (!params.model_url.empty()) {
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
} else if (!params.model_url.empty()) {
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
} else { } else {
model = llama_load_model_from_file(params.model.c_str(), mparams); model = llama_load_model_from_file(params.model.c_str(), mparams);
} }
if (model == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr); return std::make_tuple(nullptr, nullptr);
@ -1906,7 +2105,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
} }
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]); const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]); float lora_scale = std::get<1>(params.lora_adapter[i]);
int err = llama_model_apply_lora_from_file(model, int err = llama_model_apply_lora_from_file(model,
lora_adapter.c_str(), lora_adapter.c_str(),

View file

@ -88,18 +88,22 @@ struct gpt_params {
// // sampling parameters // // sampling parameters
struct llama_sampling_params sparams; struct llama_sampling_params sparams;
std::string model = "models/7B/ggml-model-f16.gguf"; // model path std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_url = ""; // model url to download std::string model_draft = ""; // draft model for speculative decoding
std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias
std::string model_alias = "unknown"; // model alias std::string model_url = ""; // model url to download
std::string prompt = ""; std::string hf_repo = ""; // HF repo
std::string prompt_file = ""; // store the external prompt file name std::string hf_file = ""; // HF file
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string prompt = "";
std::string input_prefix = ""; // string to prefix user inputs with std::string prompt_file = ""; // store the external prompt file name
std::string input_suffix = ""; // string to suffix user inputs with std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files std::string logdir = ""; // directory in which to save YAML log files
std::string logits_file = ""; // file for saving *all* logits std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding
std::string logits_file = ""; // file for saving *all* logits
std::vector<llama_model_kv_override> kv_overrides; std::vector<llama_model_kv_override> kv_overrides;
@ -139,7 +143,7 @@ struct gpt_params {
bool interactive_first = false; // wait for user input immediately bool interactive_first = false; // wait for user input immediately
bool multiline_input = false; // reverse the usage of `\` bool multiline_input = false; // reverse the usage of `\`
bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = false; // insert new sequences for decoding on-the-fly bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens bool ignore_eos = false; // ignore generated EOS tokens
@ -194,8 +198,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
struct llama_model_params params); struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
// Batch utils // Batch utils
@ -304,3 +308,10 @@ struct llama_control_vector_load_info {
// Load control vectors, scale each by strength, and add them together. // Load control vectors, scale each by strength, and add them together.
// On error, returns {-1, empty} // On error, returns {-1, empty}
llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos); llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);
//
// Split utils
//
static const char * const LLM_KV_SPLIT_NO = "split.no";
static const char * const LLM_KV_SPLIT_COUNT = "split.count";
static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";

View file

@ -0,0 +1,721 @@
#include "json-schema-to-grammar.h"
#include <algorithm>
#include <fstream>
#include <map>
#include <regex>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
using json = nlohmann::ordered_json;
const std::string SPACE_RULE = "\" \"?";
std::unordered_map<std::string, std::string> PRIMITIVE_RULES = {
{"boolean", "(\"true\" | \"false\") space"},
{"number", "(\"-\"? ([0-9] | [1-9] [0-9]*)) (\".\" [0-9]+)? ([eE] [-+]? [0-9]+)? space"},
{"integer", "(\"-\"? ([0-9] | [1-9] [0-9]*)) space"},
{"value", "object | array | string | number | boolean"},
{"object", "\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space"},
{"array", "\"[\" space ( value (\",\" space value)* )? \"]\" space"},
{"uuid", "\"\\\"\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] \"\\\"\" space"},
{"string", " \"\\\"\" (\n"
" [^\"\\\\] |\n"
" \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])\n"
" )* \"\\\"\" space"},
{"null", "\"null\" space"}
};
std::vector<std::string> OBJECT_RULE_NAMES = {"object", "array", "string", "number", "boolean", "null", "value"};
std::unordered_map<std::string, std::string> DATE_RULES = {
{"date", "[0-9] [0-9] [0-9] [0-9] \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )"},
{"time", "([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9] [0-9] [0-9] )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )"},
{"date-time", "date \"T\" time"},
{"date-string", "\"\\\"\" date \"\\\"\" space"},
{"time-string", "\"\\\"\" time \"\\\"\" space"},
{"date-time-string", "\"\\\"\" date-time \"\\\"\" space"}
};
static bool is_reserved_name(const std::string & name) {
static std::unordered_set<std::string> RESERVED_NAMES;
if (RESERVED_NAMES.empty()) {
RESERVED_NAMES.insert("root");
for (const auto &p : PRIMITIVE_RULES) RESERVED_NAMES.insert(p.first);
for (const auto &p : DATE_RULES) RESERVED_NAMES.insert(p.first);
}
return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
}
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
};
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
template <typename Iterator>
std::string join(Iterator begin, Iterator end, const std::string & separator) {
std::ostringstream result;
if (begin != end) {
result << *begin;
for (Iterator it = begin + 1; it != end; ++it) {
result << separator << *it;
}
}
return result.str();
}
static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
std::vector<std::string> tokens;
size_t start = 0;
size_t end = str.find(delimiter);
while (end != std::string::npos) {
tokens.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}
tokens.push_back(str.substr(start));
return tokens;
}
static std::string repeat(const std::string & str, size_t n) {
if (n == 0) {
return "";
}
std::string result;
result.reserve(str.length() * n);
for (size_t i = 0; i < n; ++i) {
result += str;
}
return result;
}
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
std::smatch match;
std::string result;
std::string::const_iterator searchStart(input.cbegin());
std::string::const_iterator searchEnd(input.cend());
while (std::regex_search(searchStart, searchEnd, match, regex)) {
result.append(searchStart, searchStart + match.position());
result.append(replacement(match));
searchStart = match.suffix().first;
}
result.append(searchStart, searchEnd);
return result;
}
static std::string format_literal(const std::string & literal) {
std::string escaped = replacePattern(literal, GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
char c = match.str()[0];
return GRAMMAR_LITERAL_ESCAPES.at(c);
});
return "\"" + escaped + "\"";
}
class SchemaConverter {
private:
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
std::map<std::string, std::string> _rules;
std::unordered_map<std::string, json> _refs;
std::unordered_set<std::string> _refs_being_resolved;
std::vector<std::string> _errors;
std::vector<std::string> _warnings;
std::string _add_rule(const std::string & name, const std::string & rule) {
std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
_rules[esc_name] = rule;
return esc_name;
} else {
int i = 0;
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
i++;
}
std::string key = esc_name + std::to_string(i);
_rules[key] = rule;
return key;
}
}
std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) {
std::vector<std::string> rules;
for (size_t i = 0; i < alt_schemas.size(); i++) {
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
}
return join(rules.begin(), rules.end(), " | ");
}
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
if (!(pattern.front() == '^' && pattern.back() == '$')) {
_errors.push_back("Pattern must start with '^' and end with '$'");
return "";
}
std::string sub_pattern = pattern.substr(1, pattern.length() - 2);
std::unordered_map<std::string, std::string> sub_rule_ids;
size_t i = 0;
size_t length = sub_pattern.length();
using literal_or_rule = std::pair<std::string, bool>;
auto to_rule = [&](const literal_or_rule & ls) {
auto is_literal = ls.second;
auto s = ls.first;
return is_literal ? "\"" + s + "\"" : s;
};
std::function<literal_or_rule()> transform = [&]() -> literal_or_rule {
size_t start = i;
std::vector<literal_or_rule> seq;
auto get_dot = [&]() {
std::string rule;
if (_dotall) {
rule = "[\\U00000000-\\U0010FFFF]";
} else {
rule = "[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]";
}
return _add_rule("dot", rule);
};
// Joins the sequence, merging consecutive literals together.
auto join_seq = [&]() {
std::vector<literal_or_rule> ret;
std::string literal;
auto flush_literal = [&]() {
if (literal.empty()) {
return false;
}
ret.push_back(std::make_pair(literal, true));
literal.clear();
return true;
};
for (const auto & item : seq) {
auto is_literal = item.second;
if (is_literal) {
literal += item.first;
} else {
flush_literal();
ret.push_back(item);
}
}
flush_literal();
std::vector<std::string> results;
for (const auto & item : ret) {
results.push_back(to_rule(item));
}
return std::make_pair(join(results.begin(), results.end(), " "), false);
};
while (i < length) {
char c = sub_pattern[i];
if (c == '.') {
seq.push_back(std::make_pair(get_dot(), false));
i++;
} else if (c == '(') {
i++;
if (i < length) {
if (sub_pattern[i] == '?') {
_warnings.push_back("Unsupported pattern syntax");
}
}
seq.push_back(std::make_pair("(" + to_rule(transform()) + ")", false));
} else if (c == ')') {
i++;
if (start > 0 && sub_pattern[start - 1] != '(') {
_errors.push_back("Unbalanced parentheses");
}
return join_seq();
} else if (c == '[') {
std::string square_brackets = std::string(1, c);
i++;
while (i < length && sub_pattern[i] != ']') {
if (sub_pattern[i] == '\\') {
square_brackets += sub_pattern.substr(i, 2);
i += 2;
} else {
square_brackets += sub_pattern[i];
i++;
}
}
if (i >= length) {
_errors.push_back("Unbalanced square brackets");
}
square_brackets += ']';
i++;
seq.push_back(std::make_pair(square_brackets, false));
} else if (c == '|') {
seq.push_back(std::make_pair("|", false));
i++;
} else if (c == '*' || c == '+' || c == '?') {
seq.back() = std::make_pair(to_rule(seq.back()) + c, false);
i++;
} else if (c == '{') {
std::string curly_brackets = std::string(1, c);
i++;
while (i < length && sub_pattern[i] != '}') {
curly_brackets += sub_pattern[i];
i++;
}
if (i >= length) {
_errors.push_back("Unbalanced curly brackets");
}
curly_brackets += '}';
i++;
auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
int min_times = 0;
int max_times = std::numeric_limits<int>::max();
try {
if (nums.size() == 1) {
min_times = max_times = std::stoi(nums[0]);
} else if (nums.size() != 2) {
_errors.push_back("Wrong number of values in curly brackets");
} else {
if (!nums[0].empty()) {
min_times = std::stoi(nums[0]);
}
if (!nums[1].empty()) {
max_times = std::stoi(nums[1]);
}
}
} catch (const std::invalid_argument & e) {
_errors.push_back("Invalid number in curly brackets");
return std::make_pair("", false);
}
auto &last = seq.back();
auto &sub = last.first;
auto sub_is_literal = last.second;
if (min_times == 0 && max_times == std::numeric_limits<int>::max()) {
sub += "*";
} else if (min_times == 0 && max_times == 1) {
sub += "?";
} else if (min_times == 1 && max_times == std::numeric_limits<int>::max()) {
sub += "+";
} else {
if (!sub_is_literal) {
std::string & sub_id = sub_rule_ids[sub];
if (sub_id.empty()) {
sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
}
sub = sub_id;
}
std::string result;
if (sub_is_literal && min_times > 0) {
result = "\"" + repeat(sub.substr(1, sub.length() - 2), min_times) + "\"";
} else {
for (int j = 0; j < min_times; j++) {
if (j > 0) {
result += " ";
}
result += sub;
}
}
if (min_times > 0 && min_times < max_times) {
result += " ";
}
if (max_times == std::numeric_limits<int>::max()) {
result += sub + "*";
} else {
for (int j = min_times; j < max_times; j++) {
if (j > min_times) {
result += " ";
}
result += sub + "?";
}
}
seq.back().first = result;
seq.back().second = false;
}
} else {
std::string literal;
auto is_non_literal = [&](char c) {
return NON_LITERAL_SET.find(c) != NON_LITERAL_SET.end();
};
while (i < length) {
if (sub_pattern[i] == '\\' && i < length - 1) {
char next = sub_pattern[i + 1];
if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.find(next) != ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.end()) {
i++;
literal += sub_pattern[i];
i++;
} else {
literal += sub_pattern.substr(i, 2);
i += 2;
}
} else if (sub_pattern[i] == '"') {
literal += "\\\"";
i++;
} else if (!is_non_literal(sub_pattern[i]) &&
(i == length - 1 || literal.empty() || sub_pattern[i + 1] == '.' || !is_non_literal(sub_pattern[i + 1]))) {
literal += sub_pattern[i];
i++;
} else {
break;
}
}
if (!literal.empty()) {
seq.push_back(std::make_pair(literal, true));
}
}
}
return join_seq();
};
return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
}
std::string _resolve_ref(const std::string & ref) {
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
_refs_being_resolved.insert(ref);
json resolved = _refs[ref];
ref_name = visit(resolved, ref_name);
_refs_being_resolved.erase(ref);
}
return ref_name;
}
std::string _build_object_rule(
const std::vector<std::pair<std::string, json>> & properties,
const std::unordered_set<std::string> & required,
const std::string & name,
const json & additional_properties)
{
std::vector<std::string> required_props;
std::vector<std::string> optional_props;
std::unordered_map<std::string, std::string> prop_kv_rule_names;
for (const auto & kv : properties) {
const auto &prop_name = kv.first;
const auto &prop_schema = kv.second;
std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
prop_kv_rule_names[prop_name] = _add_rule(
name + (name.empty() ? "" : "-") + prop_name + "-kv",
format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name
);
if (required.find(prop_name) != required.end()) {
required_props.push_back(prop_name);
} else {
optional_props.push_back(prop_name);
}
}
if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get<bool>())) {
std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value");
std::string kv_rule = _add_rule(sub_name + "-kv", _add_rule("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule);
prop_kv_rule_names["*"] = kv_rule;
optional_props.push_back("*");
}
std::string rule = "\"{\" space ";
for (size_t i = 0; i < required_props.size(); i++) {
if (i > 0) {
rule += " \",\" space ";
}
rule += prop_kv_rule_names[required_props[i]];
}
if (!optional_props.empty()) {
rule += " (";
if (!required_props.empty()) {
rule += " \",\" space ( ";
}
std::function<std::string(const std::vector<std::string> &, bool)> get_recursive_refs = [&](const std::vector<std::string> & ks, bool first_is_optional) {
std::string res;
if (ks.empty()) {
return res;
}
std::string k = ks[0];
std::string kv_rule_name = prop_kv_rule_names[k];
if (k == "*") {
res = _add_rule(
name + (name.empty() ? "" : "-") + "additional-kvs",
kv_rule_name + " ( \",\" space " + kv_rule_name + " )*"
);
} else if (first_is_optional) {
res = "( \",\" space " + kv_rule_name + " )?";
} else {
res = kv_rule_name;
}
if (ks.size() > 1) {
res += " " + _add_rule(
name + (name.empty() ? "" : "-") + k + "-rest",
get_recursive_refs(std::vector<std::string>(ks.begin() + 1, ks.end()), true)
);
}
return res;
};
for (size_t i = 0; i < optional_props.size(); i++) {
if (i > 0) {
rule += " | ";
}
rule += get_recursive_refs(std::vector<std::string>(optional_props.begin() + i, optional_props.end()), false);
}
if (!required_props.empty()) {
rule += " )";
}
rule += " )?";
}
rule += " \"}\" space";
return rule;
}
public:
SchemaConverter(
const std::function<json(const std::string &)> & fetch_json,
bool dotall)
: _fetch_json(fetch_json), _dotall(dotall)
{
_rules["space"] = SPACE_RULE;
}
void resolve_refs(json & schema, const std::string & url) {
/*
* Resolves all $ref fields in the given schema, fetching any remote schemas,
* replacing each $ref with absolute reference URL and populates _refs with the
* respective referenced (sub)schema dictionaries.
*/
std::function<void(json &)> visit_refs = [&](json & n) {
if (n.is_array()) {
for (auto & x : n) {
visit_refs(x);
}
} else if (n.is_object()) {
if (n.contains("$ref")) {
std::string ref = n["$ref"];
if (_refs.find(ref) == _refs.end()) {
json target;
if (ref.find("https://") == 0) {
std::string base_url = ref.substr(0, ref.find('#'));
auto it = _refs.find(base_url);
if (it != _refs.end()) {
target = it->second;
} else {
// Fetch the referenced schema and resolve its refs
auto referenced = _fetch_json(ref);
resolve_refs(referenced, base_url);
_refs[base_url] = referenced;
}
if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) {
return;
}
} else if (ref.find("#/") == 0) {
target = schema;
n["$ref"] = url + ref;
ref = url + ref;
} else {
_errors.push_back("Unsupported ref: " + ref);
return;
}
std::string pointer = ref.substr(ref.find('#') + 1);
std::vector<std::string> tokens = split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return;
}
target = target[sel];
}
_refs[ref] = target;
}
} else {
for (auto & kv : n.items()) {
visit_refs(kv.value());
}
}
}
};
visit_refs(schema);
}
std::string _generate_constant_rule(const json & value) {
return format_literal(value.dump());
}
std::string visit(const json & schema, const std::string & name) {
json schema_type = schema.contains("type") ? schema["type"] : json();
std::string schema_format = schema.contains("format") ? schema["format"].get<std::string>() : "";
std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
if (schema.contains("$ref")) {
return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
} else if (schema.contains("oneOf") || schema.contains("anyOf")) {
std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
} else if (schema_type.is_array()) {
std::vector<json> schema_types;
for (const auto & t : schema_type) {
schema_types.push_back({{"type", t}});
}
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
} else if (schema.contains("const")) {
return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
} else if (schema.contains("enum")) {
std::vector<std::string> enum_values;
for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v));
}
return _add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | "));
} else if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") ||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
std::unordered_set<std::string> required;
if (schema.contains("required") && schema["required"].is_array()) {
for (const auto & item : schema["required"]) {
if (item.is_string()) {
required.insert(item.get<std::string>());
}
}
}
std::vector<std::pair<std::string, json>> properties;
if (schema.contains("properties")) {
for (const auto & prop : schema["properties"].items()) {
properties.emplace_back(prop.key(), prop.value());
}
}
return _add_rule(rule_name,
_build_object_rule(
properties, required, name,
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
std::unordered_set<std::string> required;
std::vector<std::pair<std::string, json>> properties;
std::string hybrid_name = name;
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
if (comp_schema.contains("$ref")) {
add_component(_refs[comp_schema["$ref"]], is_required);
} else if (comp_schema.contains("properties")) {
for (const auto & prop : comp_schema["properties"].items()) {
properties.emplace_back(prop.key(), prop.value());
if (is_required) {
required.insert(prop.key());
}
}
} else {
// todo warning
}
};
for (auto & t : schema["allOf"]) {
if (t.contains("anyOf")) {
for (auto & tt : t["anyOf"]) {
add_component(tt, false);
}
} else {
add_component(t, true);
}
}
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
if (items.is_array()) {
std::string rule = "\"[\" space ";
for (size_t i = 0; i < items.size(); i++) {
if (i > 0) {
rule += " \",\" space ";
}
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
}
rule += " \"]\" space";
return _add_rule(rule_name, rule);
} else {
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
std::string list_item_operator = "( \",\" space " + item_rule_name + " )";
std::string successive_items;
int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : -1;
if (min_items > 0) {
successive_items += repeat(list_item_operator, min_items - 1);
min_items--;
}
if (max_items >= 0 && max_items > min_items) {
successive_items += repeat(list_item_operator + "?", max_items - min_items - 1);
} else {
successive_items += list_item_operator + "*";
}
std::string rule;
if (min_items == 0) {
rule = "\"[\" space ( " + item_rule_name + " " + successive_items + " )? \"]\" space";
} else {
rule = "\"[\" space " + item_rule_name + " " + successive_items + " \"]\" space";
}
return _add_rule(rule_name, rule);
}
} else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
return _visit_pattern(schema["pattern"], rule_name);
} else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
return _add_rule(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
} else if ((schema_type.is_null() || schema_type == "string") && DATE_RULES.find(schema_format) != DATE_RULES.end()) {
for (const auto & kv : DATE_RULES) {
_add_rule(kv.first, kv.second);
}
return schema_format + "-string";
} else if (schema.empty() || schema_type == "object") {
for (const auto & n : OBJECT_RULE_NAMES) {
_add_rule(n, PRIMITIVE_RULES.at(n));
}
return _add_rule(rule_name, "object");
} else {
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
_errors.push_back("Unrecognized schema: " + schema.dump());
return "";
}
// TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return _add_rule(rule_name == "root" ? "root" : schema_type.get<std::string>(), PRIMITIVE_RULES.at(schema_type.get<std::string>()));
}
}
void check_errors() {
if (!_errors.empty()) {
throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
}
if (!_warnings.empty()) {
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
}
}
std::string format_grammar() {
std::stringstream ss;
for (const auto & kv : _rules) {
ss << kv.first << " ::= " << kv.second << std::endl;
}
return ss.str();
}
};
std::string json_schema_to_grammar(const json & schema) {
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
auto copy = schema;
converter.resolve_refs(copy, "input");
converter.visit(copy, "");
converter.check_errors();
return converter.format_grammar();
}

View file

@ -0,0 +1,4 @@
#pragma once
#include "json.hpp"
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);

280
common/ngram-cache.cpp Normal file
View file

@ -0,0 +1,280 @@
#include "ngram-cache.h"
#include "log.h"
#include <fstream>
void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
std::vector<llama_token> & inp, int nnew, bool print_progress) {
const int64_t t_start_ms = ggml_time_ms();
const int64_t inp_size = inp.size();
const int64_t n_todo = inp_size * (ngram_max - ngram_min + 1);
int64_t n_done = 0;
for (int64_t ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
const int64_t i_start = std::max(inp_size - nnew, ngram_size);
for (int64_t i = i_start; i < inp_size; ++i) {
const int64_t ngram_start = i - ngram_size;
llama_ngram ngram(&inp[ngram_start], ngram_size);
const llama_token token = inp[i];
llama_ngram_cache::iterator part_it = ngram_cache.find(ngram);
if (part_it == ngram_cache.end()) {
llama_ngram_cache_part part;
part.emplace(token, 1);
ngram_cache.emplace(ngram, part);
} else {
llama_ngram_cache_part::iterator token_count_it = part_it->second.find(token);
if (token_count_it == part_it->second.end()) {
part_it->second.emplace(token, 1);
} else {
token_count_it->second++;
}
}
++n_done;
if (print_progress && n_done % 10000000 == 0) {
const int64_t t_now_ms = ggml_time_ms();
const int64_t eta_ms = (inp_size*(ngram_max-ngram_min+1) - n_done) * (t_now_ms - t_start_ms) / n_done;
const int64_t eta_min = eta_ms / (60*1000);
const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000;
fprintf(stderr, "%s: %" PRId64 "/%" PRId64 " done, ETA: %02" PRId64 ":%02" PRId64 "\n", __func__, n_done, n_todo, eta_min, eta_s);
}
}
}
}
// Helper function to get a token from the combined, speculative sequence of inp and draft.
static llama_token get_token(const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
}
// If sample size or percentage are below these thresholds the draft is aborted early:
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};
// Helper function that tries to draft a token from only the static ngram cache:
static llama_token try_draft(llama_ngram_cache & nc_static, const llama_ngram ngram_static) {
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
if (part_static_it == nc_static.end()) {
return -1;
}
const llama_ngram_cache_part part_static = part_static_it->second;
int max_count_static = 0;
int sum_count_static = 0;
llama_token max_token = -1;
for (std::pair<llama_token, int> token_count_static : part_static) {
const llama_token token = token_count_static.first;
const int32_t count_static = token_count_static.second;
if (count_static > max_count_static) {
max_token = token;
max_count_static = count_static;
}
sum_count_static += count_static;
}
if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) {
return -1;
}
if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
return -1;
}
return max_token;
}
// Try to draft a token from primary cache (context/dynamic), validate with static cache:
static llama_token try_draft(
llama_ngram_cache & nc_primary, const std::vector<llama_ngram> & ngrams_primary, llama_ngram_cache_part & part_static,
const int * min_sample_size, const int * min_percent) {
llama_token drafted_token = -1;
for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) {
const llama_ngram ngram_primary = ngrams_primary[i];
llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
if (part_primary_it == nc_primary.end()) {
continue;
}
const llama_ngram_cache_part part_primary = part_primary_it->second;
int max_count_primary = 0;
int max_count_static = 0;
int sum_count_primary = 0;
llama_token max_token = -1;
for (std::pair<llama_token, int> token_count_primary : part_primary) {
const llama_token token = token_count_primary.first;
llama_ngram_cache_part::iterator token_count_static_it = part_static.find(token);
const int32_t count_primary = token_count_primary.second;
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
if (count_primary*count_static > max_count_primary*max_count_static) {
max_token = token;
max_count_primary = count_primary;
max_count_static = count_static;
}
sum_count_primary += count_primary;
}
if (sum_count_primary < min_sample_size[i]) {
continue;
}
if (100*max_count_primary < min_percent[i]*sum_count_primary) {
continue;;
}
drafted_token = max_token;
}
return drafted_token;
}
void llama_ngram_cache_draft(
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
) {
GGML_ASSERT(draft.size() == 1);
const int inp_size = inp.size();
if (inp_size < LLAMA_NGRAM_STATIC) {
return;
}
while ((int) draft.size()-1 < n_draft) {
llama_token drafted_token = -1;
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
llama_ngram ngram_static;
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
}
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
llama_ngram_cache_part part_static;
if (part_static_it != nc_static.end()) {
part_static = part_static_it->second;
}
// cd = context + dynamic
std::vector<llama_ngram> ngrams_cd;
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
llama_ngram ngram_cd;
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
}
ngrams_cd.push_back(ngram_cd);
}
if (drafted_token == -1) {
drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
}
if (drafted_token == -1) {
drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
}
if (drafted_token == -1) {
drafted_token = try_draft(nc_static, ngram_static);
}
if (drafted_token == -1) {
break;
}
LOG(" - draft candidate: token=%d\n", drafted_token);
draft.push_back(drafted_token);
}
}
void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filename) {
std::ofstream file_out(filename, std::ios::binary);
for (std::pair<llama_ngram, llama_ngram_cache_part> item : ngram_cache) {
const llama_ngram ngram = item.first;
llama_ngram_cache_part token_counts = item.second;
GGML_ASSERT(!token_counts.empty());
const int32_t ntokens = token_counts.size();
GGML_ASSERT(ntokens > 0);
file_out.write(reinterpret_cast<const char *>(&ngram), sizeof(llama_ngram));
file_out.write(reinterpret_cast<const char *>(&ntokens), sizeof(int32_t));
for (std::pair<llama_token, int32_t> item2 : token_counts) {
const llama_token token = item2.first;
const int32_t count = item2.second;
GGML_ASSERT(count > 0);
file_out.write(reinterpret_cast<const char *>(&token), sizeof(llama_token));
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
}
}
}
llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
std::ifstream hashmap_file(filename, std::ios::binary);
if (!hashmap_file) {
throw std::ifstream::failure("Unable to open file " + filename);
}
llama_ngram_cache ngram_cache;
llama_ngram ngram;
int32_t ntokens;
llama_token token;
int32_t count;
char * ngramc = reinterpret_cast<char*>(&ngram);
char * ntokensc = reinterpret_cast<char*>(&ntokens);
char * tokenc = reinterpret_cast<char*>(&token);
char * countc = reinterpret_cast<char*>(&count);
while(hashmap_file.read(ngramc, sizeof(llama_ngram))) {
GGML_ASSERT(!hashmap_file.eof());
GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t)));
GGML_ASSERT(ntokens > 0);
llama_ngram_cache_part token_counts;
for (int i = 0; i < ntokens; ++i) {
GGML_ASSERT(!hashmap_file.eof());
GGML_ASSERT(hashmap_file.read(tokenc, sizeof(llama_token)));
GGML_ASSERT(!hashmap_file.eof());
GGML_ASSERT(hashmap_file.read(countc, sizeof(int32_t)));
GGML_ASSERT(count > 0);
token_counts.emplace(token, count);
}
ngram_cache.emplace(ngram, token_counts);
}
GGML_ASSERT(hashmap_file.eof());
return ngram_cache;
}
void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add) {
for (std::pair<llama_ngram, llama_ngram_cache_part> ngram_part : ngram_cache_add) {
const llama_ngram ngram = ngram_part.first;
llama_ngram_cache_part part = ngram_part.second;
llama_ngram_cache::iterator part_merged_it = ngram_cache_target.find(ngram);
if (part_merged_it == ngram_cache_target.end()) {
ngram_cache_target.emplace(ngram, part);
continue;
}
for (std::pair<llama_token, int32_t> token_count : part) {
const llama_token token = token_count.first;
const int32_t count = token_count.second;
GGML_ASSERT(count > 0);
llama_ngram_cache_part::iterator token_count_merged_it = part_merged_it->second.find(token);
if (token_count_merged_it == part_merged_it->second.end()) {
part_merged_it->second.emplace(token, count);
continue;
}
token_count_merged_it->second += count;
}
}
}

94
common/ngram-cache.h Normal file
View file

@ -0,0 +1,94 @@
#pragma once
#include "llama.h"
#include <unordered_map>
#include <string>
#include <vector>
#define LLAMA_NGRAM_MIN 1
#define LLAMA_NGRAM_MAX 4
#define LLAMA_NGRAM_STATIC 2
// Data structures to map n-grams to empirical token probabilities:
struct llama_ngram {
llama_token tokens[LLAMA_NGRAM_MAX];
llama_ngram() {
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
tokens[i] = -1;
}
}
llama_ngram(const llama_token * input, const int ngram_size) {
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
tokens[i] = i < ngram_size ? input[i] : -1;
}
}
bool operator==(const llama_ngram & other) const {
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
if (tokens[i] != other.tokens[i]) {
return false;
}
}
return true;
}
};
struct llama_ngram_hash_function {
size_t operator()(const llama_ngram & ngram) const {
size_t hash = 0;
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
hash ^= std::hash<llama_token>{}(ngram.tokens[i]);
}
return hash;
}
};
// token -> number of times token has been seen
typedef std::unordered_map<llama_token, int32_t> llama_ngram_cache_part;
// n-gram -> empirical distribution of following tokens
typedef std::unordered_map<llama_ngram, llama_ngram_cache_part, llama_ngram_hash_function> llama_ngram_cache;
// Update an ngram cache with tokens.
// ngram_cache: the cache to modify.
// ngram_min/ngram_max: the min/max size of the ngrams to extract from inp_data.
// inp_data: the token sequence with which to update ngram_cache.
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
// print_progress: whether to print progress to stderr.
//
// In order to get correct results inp_data can ONLY BE APPENDED TO.
// Changes in the middle need a complete rebuild.
void llama_ngram_cache_update(
llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max, std::vector<llama_token> & inp_data, int nnew, bool print_progress);
// Try to draft tokens from ngram caches.
// inp: the tokens generated so far.
// draft: the token sequence to draft. Expected to initially contain the previously sampled token.
// n_draft: maximum number of tokens to add to draft.
// ngram_min/gram_max: the min/max size of the ngrams in nc_context and nc_dynamic.
// nc_context: ngram cache based on current context.
// nc_dynamic: ngram cache based on previous user generations.
// nc_static: ngram cache generated from a large text corpus, used for validation.
void llama_ngram_cache_draft(
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static);
// Save an ngram cache to a file.
// ngram_cache: the ngram cache to save.
// filename: the path under which to save the ngram cache.
void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filename);
// Load an ngram cache saved with llama_ngram_cache_save.
// filename: the path from which to load the ngram cache.
// returns: an ngram cache containing the information saved to filename.
llama_ngram_cache llama_ngram_cache_load(std::string & filename);
// Merge two ngram caches.
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
// ngram_cache_add: the ngram cache to add to ngram_cache_target.
void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add);

View file

@ -168,77 +168,20 @@ static llama_token llama_sampling_sample_impl(
bool is_resampling) { // Add a parameter to indicate if we are resampling bool is_resampling) { // Add a parameter to indicate if we are resampling
const llama_sampling_params & params = ctx_sampling->params; const llama_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const float temp = params.temp; const float temp = params.temp;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat; const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau; const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta; const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
std::vector<float> original_logits;
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
if (!is_resampling) {
GGML_ASSERT(!original_logits.empty());
}
llama_token id = 0; llama_token id = 0;
// Get a pointer to the logits // Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx); float * logits = llama_get_logits_ith(ctx_main, idx);
// Declare original_logits at the beginning of the function scope
std::vector<float> original_logits;
if (!is_resampling) {
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
}
// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
}
cur.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
llama_sample_repetition_penalties(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
cur_p.data[idx].logit = nl_logit;
break;
}
}
}
}
// If we are in the resampling phase, apply grammar checks before sampling logic
if (is_resampling && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}
if (temp < 0.0) { if (temp < 0.0) {
// greedy sampling, with probs // greedy sampling, with probs
llama_sample_softmax(ctx_main, &cur_p); llama_sample_softmax(ctx_main, &cur_p);
@ -302,11 +245,13 @@ static llama_token llama_sampling_sample_impl(
return id; return id;
} }
static llama_token_data_array llama_sample_probability_distribution_impl( static llama_token_data_array llama_sampling_prepare_impl(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
const int idx) { const int idx,
bool apply_grammar,
std::vector<float> * original_logits) {
const llama_sampling_params & params = ctx_sampling->params; const llama_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@ -315,6 +260,7 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
const float penalty_repeat = params.penalty_repeat; const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq; const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present; const float penalty_present = params.penalty_present;
const bool penalize_nl = params.penalize_nl; const bool penalize_nl = params.penalize_nl;
auto & prev = ctx_sampling->prev; auto & prev = ctx_sampling->prev;
@ -323,8 +269,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
// Get a pointer to the logits // Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx); float * logits = llama_get_logits_ith(ctx_main, idx);
// Declare original_logits at the beginning of the function scope if (apply_grammar && original_logits != NULL) {
std::vector<float> original_logits; // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
*original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
}
// apply params.logit_bias map // apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
@ -364,12 +312,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
} }
} }
// apply grammar checks // apply grammar checks before sampling logic
if (ctx_sampling->grammar != NULL) { if (apply_grammar && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
} }
llama_sample_softmax(ctx_main, &cur_p);
return cur_p; return cur_p;
} }
@ -382,12 +329,14 @@ llama_token llama_sampling_sample(
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
} }
llama_token_data_array llama_sampling_probability_distribution( llama_token_data_array llama_sampling_prepare(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
const int idx) { const int idx,
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx); bool apply_grammar,
std::vector<float> * original_logits) {
return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
} }
void llama_sampling_accept( void llama_sampling_accept(

View file

@ -131,12 +131,14 @@ llama_token llama_sampling_sample(
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
int idx = 0); int idx = 0);
// returns the probability that token of given id will be sampled // Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
llama_token_data_array llama_sampling_probability_distribution( llama_token_data_array llama_sampling_prepare(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
int idx = 0); int idx = 0,
bool apply_grammar = true,
std::vector<float> * original_logits = nullptr);
void llama_sampling_accept( void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,

View file

@ -93,31 +93,42 @@ class Model(ABC):
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None: if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_context_length(n_ctx)
print(f"gguf: context length = {n_ctx}")
n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_embd = self.find_hparam(["hidden_size", "n_embd"])
self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_embedding_length(n_embd)
print(f"gguf: embedding length = {n_embd}")
if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None: if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
self.gguf_writer.add_feed_forward_length(n_ff) self.gguf_writer.add_feed_forward_length(n_ff)
print(f"gguf: feed forward length = {n_ff}")
n_head = self.find_hparam(["num_attention_heads", "n_head"]) n_head = self.find_hparam(["num_attention_heads", "n_head"])
self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count(n_head)
print(f"gguf: head count = {n_head}")
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_head_count_kv(n_head_kv)
print(f"gguf: key-value head count = {n_head_kv}")
if (rope_theta := self.hparams.get("rope_theta")) is not None: if (rope_theta := self.hparams.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta) self.gguf_writer.add_rope_freq_base(rope_theta)
print(f"gguf: rope theta = {rope_theta}")
if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
print(f"gguf: rms norm epsilon = {f_rms_eps}")
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
self.gguf_writer.add_layer_norm_eps(f_norm_eps) self.gguf_writer.add_layer_norm_eps(f_norm_eps)
print(f"gguf: layer norm epsilon = {f_norm_eps}")
if (n_experts := self.hparams.get("num_local_experts")) is not None: if (n_experts := self.hparams.get("num_local_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts) self.gguf_writer.add_expert_count(n_experts)
print(f"gguf: expert count = {n_experts}")
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used) self.gguf_writer.add_expert_used_count(n_experts_used)
print(f"gguf: experts used count = {n_experts_used}")
self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_file_type(self.ftype)
print(f"gguf: file type = {self.ftype}")
def write_tensors(self): def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
@ -1051,6 +1062,21 @@ class MixtralModel(Model):
self._set_vocab_sentencepiece() self._set_vocab_sentencepiece()
@Model.register("GrokForCausalLM")
class GrokModel(Model):
model_arch = gguf.MODEL_ARCH.GROK
def set_vocab(self):
self._set_vocab_sentencepiece()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_name("Grok")
@Model.register("MiniCPMForCausalLM") @Model.register("MiniCPMForCausalLM")
class MiniCPMModel(Model): class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM model_arch = gguf.MODEL_ARCH.MINICPM

View file

@ -48,6 +48,8 @@ int main(int argc, char ** argv) {
params.prompt = "Hello my name is"; params.prompt = "Hello my name is";
} }
process_escapes(params.prompt);
// init LLM // init LLM
llama_backend_init(); llama_backend_init();
@ -78,7 +80,7 @@ int main(int argc, char ** argv) {
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = 1234; ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_req; ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_len, n_parallel); ctx_params.n_batch = std::max(n_len, n_parallel);
ctx_params.n_seq_max = n_parallel; ctx_params.n_seq_max = n_parallel;
ctx_params.n_threads = params.n_threads; ctx_params.n_threads = params.n_threads;

View file

@ -21,6 +21,8 @@ An example command using a model from [karpathy/tinyllamas](https://huggingface.
`$ ./convert-llama2c-to-ggml --copy-vocab-from-model llama-2-7b-chat.gguf.q2_K.bin --llama2c-model stories42M.bin --llama2c-output-model stories42M.gguf.bin` `$ ./convert-llama2c-to-ggml --copy-vocab-from-model llama-2-7b-chat.gguf.q2_K.bin --llama2c-model stories42M.bin --llama2c-output-model stories42M.gguf.bin`
Note: The vocabulary for `stories260K.bin` should be its own tokenizer `tok512.bin` found in [karpathy/tinyllamas/stories260K](https://huggingface.co/karpathy/tinyllamas/tree/main/stories260K).
Now you can use the model with a command like: Now you can use the model with a command like:
`$ ./main -m stories42M.gguf.bin -p "One day, Lily met a Shoggoth" -n 500 -c 256` `$ ./main -m stories42M.gguf.bin -p "One day, Lily met a Shoggoth" -n 500 -c 256`

View file

@ -1,6 +1,7 @@
#include "ggml.h" #include "ggml.h"
#include "llama.h" #include "llama.h"
#include "common.h" #include "common.h"
#include "log.h"
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
@ -78,111 +79,101 @@ typedef struct {
struct TransformerWeights { struct TransformerWeights {
// token embedding table // token embedding table
float* token_embedding_table; // (vocab_size, dim) std::vector<float> token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms // weights for rmsnorms
float* rms_att_weight; // (layer, dim) rmsnorm weights std::vector<float> rms_att_weight; // (layer, dim) rmsnorm weights
float* rms_ffn_weight; // (layer, dim) std::vector<float> rms_ffn_weight; // (layer, dim)
// weights for matmuls // weights for matmuls
float* wq; // (layer, dim, dim) std::vector<float> wq; // (layer, dim, dim)
float* wk; // (layer, dim, dim) std::vector<float> wk; // (layer, dim, dim)
float* wv; // (layer, dim, dim) std::vector<float> wv; // (layer, dim, dim)
float* wo; // (layer, dim, dim) std::vector<float> wo; // (layer, dim, dim)
// weights for ffn // weights for ffn
float* w1; // (layer, hidden_dim, dim) std::vector<float> w1; // (layer, hidden_dim, dim)
float* w2; // (layer, dim, hidden_dim) std::vector<float> w2; // (layer, dim, hidden_dim)
float* w3; // (layer, hidden_dim, dim) std::vector<float> w3; // (layer, hidden_dim, dim)
// final rmsnorm // final rmsnorm
float* rms_final_weight; // (dim,) std::vector<float> rms_final_weight; // (dim,)
// freq_cis for RoPE relatively positional embeddings // freq_cis for RoPE relatively positional embeddings
// float* freq_cis_real; // (seq_len, dim/2) // std::vector<float> freq_cis_real; // (seq_len, dim/2)
// float* freq_cis_imag; // (seq_len, dim/2) // std::vector<float> freq_cis_imag; // (seq_len, dim/2)
// (optional) classifier weights for the logits, on the last layer // (optional) classifier weights for the logits, on the last layer
float* wcls; std::vector<float> wcls;
~TransformerWeights() {
delete[] token_embedding_table;
delete[] rms_att_weight;
delete[] rms_ffn_weight;
delete[] wq;
delete[] wk;
delete[] wv;
delete[] wo;
delete[] w1;
delete[] w2;
delete[] w3;
delete[] rms_final_weight;
delete[] wcls;
}
}; };
static void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) { static void alloc_weights(TransformerWeights * w, const Config * p, bool shared_weights) {
// we calloc instead of malloc to keep valgrind happy const int n_multiqueries = p->n_kv_heads <= 0 || p->n_kv_heads >= p->n_heads ? 1 : p->n_heads / p->n_kv_heads;
w->token_embedding_table = new float[p->vocab_size * p->dim](); try {
printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); w->token_embedding_table.resize(p->vocab_size * p->dim);
LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
w->rms_att_weight = new float[p->n_layers * p->dim](); w->rms_att_weight.resize(p->n_layers * p->dim);
printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->rms_att_weight\n",__func__,p->n_layers, p->dim, p->n_layers * p->dim); LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_att_weight\n",__func__,p->n_layers, p->dim, p->n_layers * p->dim);
w->rms_ffn_weight = new float[p->n_layers * p->dim](); w->rms_ffn_weight.resize(p->n_layers * p->dim);
printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->rms_ffn_weight\n",__func__,p->n_layers , p->dim, p->n_layers * p->dim); LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_ffn_weight\n",__func__,p->n_layers , p->dim, p->n_layers * p->dim);
w->wq = new float[p->n_layers * p->dim * p->dim](); w->wq.resize(p->n_layers * p->dim * p->dim);
printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->wq\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim); LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wq\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim);
w->wk = new float[p->n_layers * p->dim * p->dim](); w->wk.resize(p->n_layers * p->dim * p->dim / n_multiqueries);
printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->wk\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim); LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wk\n",__func__,p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries);
w->wv = new float[p->n_layers * p->dim * p->dim](); w->wv.resize(p->n_layers * p->dim * p->dim / n_multiqueries);
printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->wv\n",__func__, p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim); LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wv\n",__func__, p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries);
w->wo = new float[p->n_layers * p->dim * p->dim](); w->wo.resize(p->n_layers * p->dim * p->dim);
printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->wo\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim); LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wo\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim);
w->w1 = new float[p->n_layers * p->hidden_dim * p->dim](); w->w1.resize(p->n_layers * p->hidden_dim * p->dim);
printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->w1\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim); LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w1\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim);
w->w2 = new float[p->n_layers * p->hidden_dim * p->dim](); w->w2.resize(p->n_layers * p->hidden_dim * p->dim);
printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->w2\n",__func__,p->n_layers, p->dim, p->hidden_dim, p->n_layers * p->hidden_dim * p->dim); LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w2\n",__func__,p->n_layers, p->dim, p->hidden_dim, p->n_layers * p->hidden_dim * p->dim);
w->w3 = new float[p->n_layers * p->hidden_dim * p->dim](); w->w3.resize(p->n_layers * p->hidden_dim * p->dim);
printf("[%s:AK] Allocating [%d] x [%d] x [%d] = [%d] float space for w->w3\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim); LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w3\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim);
w->rms_final_weight = new float[p->dim](); w->rms_final_weight.resize(p->dim);
printf("[%s:AK] Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim); LOG("%s: Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim);
if (shared_weights) { if (shared_weights) {
w->wcls = NULL; w->wcls = {};
} else { } else {
w->wcls = new float[p->vocab_size * p->dim](); w->wcls.resize(p->vocab_size * p->dim);
printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
}
}
catch (std::length_error &) {
die("Invalid configuration. Failed to allocate memory for weights");
} }
} }
static int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shared_weights) { static int checkpoint_init_weights(TransformerWeights * w, const Config * p, FILE * f, bool shared_weights) {
if (fread(w->token_embedding_table, sizeof(float), p->vocab_size * p->dim, f) != static_cast<size_t>(p->vocab_size * p->dim)) return 1; if (fread(w->token_embedding_table.data(), sizeof(float), w->token_embedding_table.size(), f) != w->token_embedding_table.size()) return 1;
if (fread(w->rms_att_weight, sizeof(float), p->n_layers * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim)) return 1; if (fread(w->rms_att_weight.data(), sizeof(float), w->rms_att_weight.size(), f) != w->rms_att_weight.size()) return 1;
if (fread(w->wq, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1; if (fread(w->wq.data(), sizeof(float), w->wq.size(), f) != w->wq.size()) return 1;
if (fread(w->wk, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1; if (fread(w->wk.data(), sizeof(float), w->wk.size(), f) != w->wk.size()) return 1;
if (fread(w->wv, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1; if (fread(w->wv.data(), sizeof(float), w->wv.size(), f) != w->wv.size()) return 1;
if (fread(w->wo, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1; if (fread(w->wo.data(), sizeof(float), w->wo.size(), f) != w->wo.size()) return 1;
if (fread(w->rms_ffn_weight, sizeof(float), p->n_layers * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim)) return 1; if (fread(w->rms_ffn_weight.data(), sizeof(float), w->rms_ffn_weight.size(), f) != w->rms_ffn_weight.size()) return 1;
if (fread(w->w1, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->hidden_dim)) return 1; if (fread(w->w1.data(), sizeof(float), w->w1.size(), f) != w->w1.size()) return 1;
if (fread(w->w2, sizeof(float), p->n_layers * p->hidden_dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->hidden_dim * p->dim)) return 1; if (fread(w->w2.data(), sizeof(float), w->w2.size(), f) != w->w2.size()) return 1;
if (fread(w->w3, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->hidden_dim)) return 1; if (fread(w->w3.data(), sizeof(float), w->w3.size(), f) != w->w3.size()) return 1;
if (fread(w->rms_final_weight, sizeof(float), p->dim, f) != static_cast<size_t>(p->dim)) return 1; if (fread(w->rms_final_weight.data(), sizeof(float), w->rms_final_weight.size(), f) != w->rms_final_weight.size()) return 1;
// Skip freq_cis_real & freq_cis_imag // Skip freq_cis_real & freq_cis_imag
int head_size = p->dim / p->n_heads; int head_size = p->dim / p->n_heads;
fseek(f, p->seq_len * head_size * sizeof(float), SEEK_CUR); fseek(f, p->seq_len * head_size * sizeof(float), SEEK_CUR);
if (!shared_weights && fread(w->wcls, sizeof(float), p->vocab_size * p->dim, f) != static_cast<size_t>(p->vocab_size * p->dim)) return 1; if (!shared_weights && fread(w->wcls.data(), sizeof(float), w->wcls.size(), f) != w->wcls.size()) return 1;
// Check we didn't forget to read anything // Check we didn't forget to read anything
auto curr = ftell(f); auto curr = ftell(f);
fseek(f, 0, SEEK_END); fseek(f, 0, SEEK_END);
auto end = ftell(f); auto end = ftell(f);
if (curr != end) { if (curr != end) {
printf("Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", curr, end); LOG("%s: Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", __func__, curr, end);
return 1; return 1;
} }
@ -190,20 +181,20 @@ static int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bo
} }
static void print_sample_weights(TransformerWeights *w){ static void print_sample_weights(TransformerWeights *w){
printf("----- Quick print of first of the weight vales of all the variables\n"); LOG("----- Quick print of first of the weight vales of all the variables\n");
printf("%f\n", w->token_embedding_table[0]); LOG("%f\n", w->token_embedding_table[0]);
printf("%f\n", w->rms_att_weight[0]); LOG("%f\n", w->rms_att_weight[0]);
printf("%f\n", w->rms_ffn_weight[0]); LOG("%f\n", w->rms_ffn_weight[0]);
printf("%f\n", w->wq[0]); LOG("%f\n", w->wq[0]);
printf("%f\n", w->wk[0]); LOG("%f\n", w->wk[0]);
printf("%f\n", w->wv[0]); LOG("%f\n", w->wv[0]);
printf("%f\n", w->wo[0]); LOG("%f\n", w->wo[0]);
printf("%f\n", w->w1[0]); LOG("%f\n", w->w1[0]);
printf("%f\n", w->w2[0]); LOG("%f\n", w->w2[0]);
printf("%f\n", w->w3[0]); LOG("%f\n", w->w3[0]);
printf("%f\n", w->rms_att_weight[0]); LOG("%f\n", w->rms_att_weight[0]);
if (w->wcls) printf("%f\n", w->wcls[0]); if (!w->wcls.empty()) LOG("%f\n", w->wcls[0]);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -225,14 +216,16 @@ struct llama_vocab {
}; };
struct my_llama_hparams { struct my_llama_hparams {
uint32_t n_vocab = 32000; uint32_t n_vocab = 32000;
uint32_t n_ctx = 512; // this is provided as user input? uint32_t n_ctx = 512; // this is provided as user input?
uint32_t n_embd = 4096; uint32_t n_embd = 4096;
uint32_t n_ff = 11008; uint32_t n_ff = 11008;
uint32_t n_mult = 4; uint32_t n_mult = 4;
uint32_t n_head = 32; uint32_t n_head = 32;
uint32_t n_layer = 32; uint32_t n_head_kv = 32;
uint32_t n_rot = 64; uint32_t n_layer = 32;
uint32_t n_rot = 64;
bool operator!=(const my_llama_hparams& other) const { bool operator!=(const my_llama_hparams& other) const {
return memcmp(this, &other, sizeof(my_llama_hparams)); return memcmp(this, &other, sizeof(my_llama_hparams));
} }
@ -325,14 +318,30 @@ struct train_params {
}; };
static void print_params(struct my_llama_hparams * params) { static void print_params(struct my_llama_hparams * params) {
printf("%s: n_vocab: %u\n", __func__, params->n_vocab); LOG("%s: n_vocab: %u\n", __func__, params->n_vocab);
printf("%s: n_ctx: %u\n", __func__, params->n_ctx); LOG("%s: n_ctx: %u\n", __func__, params->n_ctx);
printf("%s: n_embd: %u\n", __func__, params->n_embd); LOG("%s: n_embd: %u\n", __func__, params->n_embd);
printf("%s: n_mult: %u\n", __func__, params->n_mult); LOG("%s: n_mult: %u\n", __func__, params->n_mult);
printf("%s: n_head: %u\n", __func__, params->n_head); LOG("%s: n_head: %u\n", __func__, params->n_head);
printf("%s: n_ff: %u\n", __func__, params->n_ff); LOG("%s: n_head_kv: %u\n", __func__, params->n_head_kv);
printf("%s: n_layer: %u\n", __func__, params->n_layer); LOG("%s: n_ff: %u\n", __func__, params->n_ff);
printf("%s: n_rot: %u\n", __func__, params->n_rot); LOG("%s: n_layer: %u\n", __func__, params->n_layer);
LOG("%s: n_rot: %u\n", __func__, params->n_rot);
}
static void print_tensor_info(const struct ggml_context * ctx) {
for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
LOG("%s: Allocating ", __func__);
int64_t total = 1;
int i = 0;
for (; i < ggml_n_dims(t); ++i) {
if (i > 0) LOG("x ");
LOG("[%" PRId64 "] ", t->ne[i]);
total *= t->ne[i];
}
if (i > 1) LOG("= [%" PRId64 "] ", total);
LOG("float space for %s\n", ggml_get_name(t));
}
} }
static void init_model(struct my_llama_model * model) { static void init_model(struct my_llama_model * model) {
@ -342,6 +351,8 @@ static void init_model(struct my_llama_model * model) {
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
const uint32_t n_vocab = hparams.n_vocab; const uint32_t n_vocab = hparams.n_vocab;
const uint32_t n_multiqueries = hparams.n_head_kv <= 0 || hparams.n_head_kv >= hparams.n_head ? 1 : hparams.n_head / hparams.n_head_kv;
const uint32_t n_ff = hparams.n_ff; const uint32_t n_ff = hparams.n_ff;
struct ggml_context * ctx = model->ctx; struct ggml_context * ctx = model->ctx;
@ -350,25 +361,8 @@ static void init_model(struct my_llama_model * model) {
model->train_tokens = 0; model->train_tokens = 0;
model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
printf("[%s:GG] Allocating [%u] x [%u] = [%u] float space for model->tok_embeddings\n",__func__,n_embd , n_vocab, n_embd * n_vocab);
model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
printf("[%s:GG] Allocating [%u] float space for model->norm\n",__func__,n_embd);
model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for model->output\n",__func__,n_embd, n_vocab, n_embd * n_vocab);
// printing the per-layer allocations here so we dont print in the for loop.
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wq for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wk for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wv for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wo for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
printf("[%s:GG] Allocating [%u] float space for layer.ffn_norm for [%u] layers\n",__func__,n_embd, n_layer);
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.w1 for [%u] layers\n",__func__, n_ff, n_embd, n_embd * n_ff, n_layer);
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.w2 for [%u] layers\n",__func__, n_embd, n_ff, n_ff * n_embd, n_layer);
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.w3 for [%u] layers\n",__func__, n_ff, n_embd, n_embd * n_ff, n_layer);
ggml_set_name(model->tok_embeddings, "tok_embeddings.weight"); ggml_set_name(model->tok_embeddings, "tok_embeddings.weight");
ggml_set_name(model->norm, "norm.weight"); ggml_set_name(model->norm, "norm.weight");
@ -383,8 +377,8 @@ static void init_model(struct my_llama_model * model) {
layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd / n_multiqueries);
layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd / n_multiqueries);
layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
@ -406,6 +400,8 @@ static void init_model(struct my_llama_model * model) {
ggml_format_name(layer.w2, "%s.feed_forward.w2.weight", layers_i.c_str()); ggml_format_name(layer.w2, "%s.feed_forward.w2.weight", layers_i.c_str());
ggml_format_name(layer.w3, "%s.feed_forward.w3.weight", layers_i.c_str()); ggml_format_name(layer.w3, "%s.feed_forward.w3.weight", layers_i.c_str());
} }
print_tensor_info(ctx);
} }
static float get_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) { static float get_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
@ -421,9 +417,9 @@ static int32_t get_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
static void print_row(struct ggml_tensor * probs, int i) { static void print_row(struct ggml_tensor * probs, int i) {
for (int k = 0; k < probs->ne[0]; ++k) { for (int k = 0; k < probs->ne[0]; ++k) {
float p = get_f32_2d(probs, k, i); float p = get_f32_2d(probs, k, i);
printf(" %f", p); LOG(" %f", p);
} }
printf("\n"); LOG("\n");
} }
static void print_matrix(struct ggml_tensor * probs) { static void print_matrix(struct ggml_tensor * probs) {
@ -431,33 +427,12 @@ static void print_matrix(struct ggml_tensor * probs) {
for (int i = 0; i < probs->ne[1]; ++i) { for (int i = 0; i < probs->ne[1]; ++i) {
for (int k = 0; k < probs->ne[0]; ++k) { for (int k = 0; k < probs->ne[0]; ++k) {
float p = get_f32_2d(probs, k, i); float p = get_f32_2d(probs, k, i);
printf(" %.2f", p); LOG(" %.2f", p);
} }
printf("\n"); LOG("\n");
} }
} }
#ifdef __GNUC__
#ifdef __MINGW32__
__attribute__((format(gnu_printf, 1, 2)))
#else
__attribute__((format(printf, 1, 2)))
#endif
#endif
static std::string format(const char * fmt, ...) {
va_list ap, ap2;
va_start(ap, fmt);
va_copy(ap2, ap);
int size = vsnprintf(NULL, 0, fmt, ap);
GGML_ASSERT(size >= 0 && size < INT_MAX);
std::vector<char> buf(size + 1);
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
GGML_ASSERT(size2 == size);
va_end(ap2);
va_end(ap);
return std::string(buf.data(), size);
}
struct llama_file { struct llama_file {
// use FILE * so we don't have to re-open the file to mmap // use FILE * so we don't have to re-open the file to mmap
FILE * fp; FILE * fp;
@ -549,8 +524,9 @@ static std::string llama_escape_whitespaces(const std::string & text) {
return out.str(); return out.str();
} }
static void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) { static void load_vocab(const char * filename, const Config * config, struct llama_vocab * vocab) {
if (is_ggml_file(filename)) { if (is_ggml_file(filename)) {
LOG("%s: Loading vocabulary from gguf file %s\n", __func__, filename);
struct ggml_context * ctx_data = NULL; struct ggml_context * ctx_data = NULL;
struct gguf_init_params params = { struct gguf_init_params params = {
@ -578,6 +554,9 @@ static void load_vocab(const char *filename, Config *config, struct llama_vocab
const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
if (n_vocab != static_cast<uint32_t>(config->vocab_size)) {
die_fmt("vocab size mismatch: (gguf) %u != (llama2c) %d", n_vocab, config->vocab_size);
}
vocab->id_to_token.resize(n_vocab); vocab->id_to_token.resize(n_vocab);
@ -595,7 +574,7 @@ static void load_vocab(const char *filename, Config *config, struct llama_vocab
gguf_free(ctx); gguf_free(ctx);
} else { } else {
// assume llama2.c vocabulary // assume llama2.c vocabulary
printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename); LOG("%s: Assuming llama2.c vocabulary since %s is not a gguf file\n", __func__, filename);
llama_file file(filename, "rb"); llama_file file(filename, "rb");
if (!file.fp) { if (!file.fp) {
die_fmt("%s: %s", strerror(errno), filename); die_fmt("%s: %s", strerror(errno), filename);
@ -638,38 +617,15 @@ static void load_vocab(const char *filename, Config *config, struct llama_vocab
} }
static void convert_weights_ak_to_gg(struct ggml_tensor * gg_weights, const float * karpathy_weights) { static void convert_weights_ak_to_gg(struct ggml_tensor * gg_weights, const float * karpathy_weights) {
int ct; int size = 1;
switch (ggml_n_dims(gg_weights)) { for (int dim = 0; dim < ggml_n_dims(gg_weights); ++dim) {
case 1: size *= gg_weights->ne[dim];
ct = 0; }
for (int i0 = 0; i0 < gg_weights->ne[0]; i0++){ for (int ct = 0; ct < size; ++ct) {
float * ptr = (float *) ((char *) gg_weights->data + i0*gg_weights->nb[0]); int64_t i0 = 0; int64_t i1 = 0;
*ptr = karpathy_weights[ct]; int64_t i2 = 0; int64_t i3 = 0;
ct++; ggml_unravel_index(gg_weights, ct, &i0, &i1, &i2, &i3);
} ggml_set_f32_nd(gg_weights, i0, i1, i2, i3, karpathy_weights[ct]);
break;
case 2:
ct = 0;
for (int i1 = 0; i1 < gg_weights->ne[1]; i1++) {
for (int i0 = 0; i0 < gg_weights->ne[0]; i0++) {
float * ptr = (float *) ((char *) gg_weights->data + i0*gg_weights->nb[0] + i1*gg_weights->nb[1]);
*ptr = karpathy_weights[ct];
ct++;
}
}
break;
case 3:
ct = 0;
for (int i2 = 0; i2 < gg_weights->ne[2]; i2++) {
for (int i1 = 0; i1 < gg_weights->ne[1]; i1++) {
for (int i0 = 0; i0 < gg_weights->ne[0]; i0++) {
float * ptr = (float *) ((char *) gg_weights->data + i0*gg_weights->nb[0] + i1*gg_weights->nb[1] + i2*gg_weights->nb[2]);
*ptr = karpathy_weights[ct];
ct++;
}
}
}
break;
} }
} }
@ -679,16 +635,18 @@ static void save_as_llama_model(
// convert AK weights into GG weights one by one. // convert AK weights into GG weights one by one.
// w->token_embedding_table -> model->tok_embeddings // w->token_embedding_table -> model->tok_embeddings
// float* -> struct ggml_tensor // float* -> struct ggml_tensor
convert_weights_ak_to_gg(model->tok_embeddings, w->token_embedding_table); convert_weights_ak_to_gg(model->tok_embeddings, w->token_embedding_table.data());
convert_weights_ak_to_gg(model->output, w->wcls ? w->wcls : w->token_embedding_table); convert_weights_ak_to_gg(model->output, !w->wcls.empty() ? w->wcls.data() : w->token_embedding_table.data());
convert_weights_ak_to_gg(model->norm, w->rms_final_weight); convert_weights_ak_to_gg(model->norm, w->rms_final_weight.data());
//print_row(model->norm, 0); //print_row(model->norm, 0);
// for rms-att-weight // for rms-att-weight
int row_length = model->hparams.n_embd; int row_length = model->hparams.n_embd;
int n_ff = model->hparams.n_ff; int n_ff = model->hparams.n_ff;
const uint32_t n_multiqueries = model->hparams.n_head_kv <= 0 || model->hparams.n_head_kv >= model->hparams.n_head ? 1 : model->hparams.n_head / model->hparams.n_head_kv;
for (uint32_t i = 0; i < model->hparams.n_layer; ++i){ for (uint32_t i = 0; i < model->hparams.n_layer; ++i){
auto & layer = model->layers[i]; auto & layer = model->layers[i];
// 1d // 1d
@ -697,9 +655,10 @@ static void save_as_llama_model(
// from 3d matrix layer x dim x dim to 2d matrix dim x dim // from 3d matrix layer x dim x dim to 2d matrix dim x dim
convert_weights_ak_to_gg(layer.wq , &w->wq[i*row_length*row_length]); convert_weights_ak_to_gg(layer.wq , &w->wq[i*row_length*row_length]);
convert_weights_ak_to_gg(layer.wk , &w->wk[i*row_length*row_length]);
convert_weights_ak_to_gg(layer.wv , &w->wv[i*row_length*row_length]);
convert_weights_ak_to_gg(layer.wo , &w->wo[i*row_length*row_length]); convert_weights_ak_to_gg(layer.wo , &w->wo[i*row_length*row_length]);
// from 3d matrix layer x dim x dim to 2d matrix dim x dim / n_multiqueries
convert_weights_ak_to_gg(layer.wk , &w->wk[i*row_length*row_length/n_multiqueries]);
convert_weights_ak_to_gg(layer.wv , &w->wv[i*row_length*row_length/n_multiqueries]);
convert_weights_ak_to_gg(layer.w1 , &w->w1[i*row_length*n_ff]); convert_weights_ak_to_gg(layer.w1 , &w->w1[i*row_length*n_ff]);
convert_weights_ak_to_gg(layer.w2 , &w->w2[i*n_ff*row_length]); convert_weights_ak_to_gg(layer.w2 , &w->w2[i*n_ff*row_length]);
@ -736,8 +695,8 @@ static void save_as_llama_model(
gguf_set_val_u32(ctx, KV_EMBEDDING_LENGTH, model->hparams.n_embd); gguf_set_val_u32(ctx, KV_EMBEDDING_LENGTH, model->hparams.n_embd);
gguf_set_val_u32(ctx, KV_FEED_FORWARD_LENGTH, model->hparams.n_ff); gguf_set_val_u32(ctx, KV_FEED_FORWARD_LENGTH, model->hparams.n_ff);
gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head); gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head);
// n_head_kv is optional, default to n_head gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head);
// gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT_KV, ...); gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT_KV, model->hparams.n_head_kv);
gguf_set_val_u32(ctx, KV_BLOCK_COUNT, model->hparams.n_layer); gguf_set_val_u32(ctx, KV_BLOCK_COUNT, model->hparams.n_layer);
gguf_set_val_u32(ctx, KV_ROPE_DIMENSION_COUNT, model->hparams.n_rot); gguf_set_val_u32(ctx, KV_ROPE_DIMENSION_COUNT, model->hparams.n_rot);
gguf_set_val_f32(ctx, KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f); gguf_set_val_f32(ctx, KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f);
@ -789,12 +748,12 @@ static void save_as_llama_model(
static struct train_params get_default_train_params() { static struct train_params get_default_train_params() {
struct train_params params; struct train_params params;
params.fn_vocab_model = "models/7B/ggml-model-f16.gguf"; params.fn_vocab_model = "models/7B/ggml-model-f16.gguf";
params.fn_llama2c_output_model = "ak_llama_model.bin"; params.fn_llama2c_output_model = "ak_llama_model.bin";
params.fn_train_data = "shakespeare.txt"; params.fn_train_data = "shakespeare.txt";
params.fn_checkpoint_in = "checkpoint.bin"; params.fn_checkpoint_in = "checkpoint.bin";
params.fn_checkpoint_out = "checkpoint.bin"; params.fn_checkpoint_out = "checkpoint.bin";
params.fn_model_out = "ggml-checkpoint-f32.bin"; params.fn_model_out = "ggml-checkpoint-f32.bin";
params.seed = -1; params.seed = -1;
@ -829,8 +788,8 @@ static struct train_params get_default_train_params() {
params.adam_alpha = 1e-3f; params.adam_alpha = 1e-3f;
params.adam_decay = 1e-3f; params.adam_decay = 1e-3f;
params.mem_model_gb = 2; params.mem_model_gb = 2;
params.mem_compute_gb = 24; params.mem_compute_gb = 24;
params.mem_compute0_gb = 8; params.mem_compute0_gb = 8;
params.mem_compute1_gb = 2; params.mem_compute1_gb = 2;
@ -916,19 +875,30 @@ int main(int argc, char ** argv) {
if (!params_parse(argc, argv, &params)) { if (!params_parse(argc, argv, &params)) {
return 1; return 1;
} }
log_set_target(stdout);
Config config; Config config;
TransformerWeights weights = {}; TransformerWeights weights = {};
{ {
FILE *file = fopen(params.fn_llama2c_model, "rb"); LOG("%s: Loading llama2c model from %s\n", __func__, params.fn_llama2c_model);
if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; } FILE *file = fopen(params.fn_llama2c_model, "r");
if (!file) {
LOG("%s: Unable to open the checkpoint file %s!\n", __func__, params.fn_llama2c_model);
return 1;
}
// read in the config header // read in the config header
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; } if (fread(&config, sizeof(Config), 1, file) != 1) {
LOG("%s: Unable to read llama2c config from %s!\n",__func__,params.fn_llama2c_model);
return 1;
}
auto shared_weights = config.vocab_size > 0; auto shared_weights = config.vocab_size > 0;
config.vocab_size = abs(config.vocab_size); config.vocab_size = abs(config.vocab_size);
// read in the Transformer weights // read in the Transformer weights
malloc_weights(&weights, &config, shared_weights); alloc_weights(&weights, &config, shared_weights);
if(checkpoint_init_weights(&weights, &config, file, shared_weights)) { return 1; } if (checkpoint_init_weights(&weights, &config, file, shared_weights)) {
LOG("%s: Unable to initialize transformer weights from %s!",__func__,params.fn_llama2c_model);
return 1;
}
fclose(file); fclose(file);
} }
@ -936,15 +906,18 @@ int main(int argc, char ** argv) {
load_vocab(params.fn_vocab_model, &config, &vocab); load_vocab(params.fn_vocab_model, &config, &vocab);
struct my_llama_model model; struct my_llama_model model;
model.hparams.n_vocab = config.vocab_size; //llama_n_vocab(lctx); model.hparams.n_vocab = config.vocab_size; //llama_n_vocab(lctx);
model.hparams.n_ctx = params.n_ctx; model.hparams.n_ctx = params.n_ctx;
model.hparams.n_embd = config.dim; //params.n_embd; model.hparams.n_embd = config.dim; //params.n_embd;
model.hparams.n_ff = config.hidden_dim; model.hparams.n_ff = config.hidden_dim;
model.hparams.n_mult = 32;//params.n_mult; model.hparams.n_mult = 32;//params.n_mult;
model.hparams.n_head = config.n_heads; //params.n_head; model.hparams.n_head = config.n_heads; //params.n_head;
model.hparams.n_layer = config.n_layers; //params.n_layer; model.hparams.n_head_kv = config.n_kv_heads;
model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head); model.hparams.n_layer = config.n_layers; //params.n_layer;
model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head);
print_params(&model.hparams); print_params(&model.hparams);
struct ggml_init_params lcparams; struct ggml_init_params lcparams;
lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb); lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb);
lcparams.mem_buffer = NULL; lcparams.mem_buffer = NULL;
@ -956,7 +929,7 @@ int main(int argc, char ** argv) {
model.name = basename(params.fn_llama2c_model); model.name = basename(params.fn_llama2c_model);
save_as_llama_model(&vocab, &model, &weights, params.fn_llama2c_output_model); save_as_llama_model(&vocab, &model, &weights, params.fn_llama2c_output_model);
printf("Saving llama.c model file %s in ggml format at %s\n", params.fn_llama2c_model, params.fn_llama2c_output_model); LOG("%s: Saving llama.c model file %s in ggml format at %s\n", __func__, params.fn_llama2c_model, params.fn_llama2c_output_model);
ggml_free(model.ctx); ggml_free(model.ctx);
return 0; return 0;

View file

@ -1,32 +1,31 @@
#include "llama.h" #include "llama.h"
#include "ggml.h"
#include "common.h" #include "common.h"
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <fstream> #include <fstream>
#include <ios>
#include <string> #include <string>
#include <vector> #include <vector>
#include <stdio.h> #include <stdio.h>
#include <fcntl.h>
#include <string.h> #include <string.h>
#include <climits>
#include <stdexcept>
#if defined(_WIN32)
#include <windows.h>
#ifndef PATH_MAX
#define PATH_MAX MAX_PATH
#endif
#include <io.h>
#endif
enum split_operation : uint8_t { enum split_operation : uint8_t {
SPLIT_OP_SPLIT, SPLIT_OP_SPLIT,
SPLIT_OP_MERGE, SPLIT_OP_MERGE,
}; };
static const char * const LLM_KV_GENERAL_SPLIT_I_SPLIT = "general.split";
static const char * const LLM_KV_GENERAL_SPLIT_N_SPLIT = "general.split_count";
static const int SPLIT_FILENAME_MAX = 256;
static const char * const SPLIT_FILENAME_FORMAT = "%s-%05d-of-%05d.gguf";
struct split_params { struct split_params {
split_operation operation = SPLIT_OP_SPLIT; split_operation operation = SPLIT_OP_SPLIT;
int n_split_tensors = 128; int n_split_tensors = 128;
@ -116,13 +115,13 @@ static bool split_params_parse(int argc, const char ** argv, split_params & para
try { try {
if (!split_params_parse_ex(argc, argv, params)) { if (!split_params_parse_ex(argc, argv, params)) {
split_print_usage(argv[0]); split_print_usage(argv[0]);
exit(1); exit(EXIT_FAILURE);
} }
} }
catch (const std::invalid_argument & ex) { catch (const std::invalid_argument & ex) {
fprintf(stderr, "%s\n", ex.what()); fprintf(stderr, "%s\n", ex.what());
split_print_usage(argv[0]); split_print_usage(argv[0]);
exit(1); exit(EXIT_FAILURE);
} }
return result; return result;
} }
@ -134,12 +133,6 @@ static void zeros(std::ofstream & file, size_t n) {
} }
} }
static std::string split_file_name(const std::string & path, int i_split, int n_split) {
char f_split[SPLIT_FILENAME_MAX] = {0};
snprintf(f_split, sizeof(f_split), SPLIT_FILENAME_FORMAT, path.c_str(), i_split + 1, n_split);
return std::string(f_split);
}
struct split_strategy { struct split_strategy {
const split_params params; const split_params params;
std::ifstream & f_input; std::ifstream & f_input;
@ -180,8 +173,9 @@ struct split_strategy {
if (i_split == 0) { if (i_split == 0) {
gguf_set_kv(ctx_out, ctx_gguf); gguf_set_kv(ctx_out, ctx_gguf);
} }
gguf_set_val_u8(ctx_out, LLM_KV_GENERAL_SPLIT_I_SPLIT, i_split); gguf_set_val_u16(ctx_out, LLM_KV_SPLIT_NO, i_split);
gguf_set_val_u8(ctx_out, LLM_KV_GENERAL_SPLIT_N_SPLIT, n_split); gguf_set_val_u16(ctx_out, LLM_KV_SPLIT_COUNT, n_split);
gguf_set_val_i32(ctx_out, LLM_KV_SPLIT_TENSORS_COUNT, n_tensors);
// populate the original tensors, so we get an initial metadata // populate the original tensors, so we get an initial metadata
for (int i = i_split * params.n_split_tensors; i < n_tensors && i < (i_split + 1) * params.n_split_tensors; ++i) { for (int i = i_split * params.n_split_tensors; i < n_tensors && i < (i_split + 1) * params.n_split_tensors; ++i) {
@ -189,10 +183,11 @@ struct split_strategy {
gguf_add_tensor(ctx_out, meta); gguf_add_tensor(ctx_out, meta);
} }
auto split_name = split_file_name(params.output, i_split, n_split); char split_path[PATH_MAX] = {0};
llama_split_path(split_path, sizeof(split_path), params.output.c_str(), i_split, n_split);
fprintf(stderr, "%s: %s ...", __func__, split_name.c_str()); fprintf(stderr, "%s: %s ...", __func__, split_path);
fout = std::ofstream(split_name, std::ios::binary); fout = std::ofstream(split_path, std::ios::binary);
fout.exceptions(std::ofstream::failbit); // fail fast on write errors fout.exceptions(std::ofstream::failbit); // fail fast on write errors
auto meta_size = gguf_get_meta_size(ctx_out); auto meta_size = gguf_get_meta_size(ctx_out);
@ -250,19 +245,23 @@ static void gguf_split(const split_params & split_params) {
std::ifstream f_input(split_params.input.c_str(), std::ios::binary); std::ifstream f_input(split_params.input.c_str(), std::ios::binary);
if (!f_input.is_open()) { if (!f_input.is_open()) {
fprintf(stderr, "%s: failed to open input GGUF from %s\n", __func__, split_params.input.c_str()); fprintf(stderr, "%s: failed to open input GGUF from %s\n", __func__, split_params.input.c_str());
exit(1); exit(EXIT_FAILURE);
} }
auto * ctx_gguf = gguf_init_from_file(split_params.input.c_str(), params); auto * ctx_gguf = gguf_init_from_file(split_params.input.c_str(), params);
if (!ctx_gguf) { if (!ctx_gguf) {
fprintf(stderr, "%s: failed to load input GGUF from %s\n", __func__, split_params.input.c_str()); fprintf(stderr, "%s: failed to load input GGUF from %s\n", __func__, split_params.input.c_str());
exit(1); exit(EXIT_FAILURE);
} }
split_strategy strategy(split_params, f_input, ctx_gguf, ctx_meta); split_strategy strategy(split_params, f_input, ctx_gguf, ctx_meta);
char first_split_path[PATH_MAX] = {0};
llama_split_path(first_split_path, sizeof(first_split_path),
split_params.output.c_str(), strategy.i_split, strategy.n_split);
fprintf(stderr, "%s: %s -> %s (%d tensors per file)\n", fprintf(stderr, "%s: %s -> %s (%d tensors per file)\n",
__func__, split_params.input.c_str(), __func__, split_params.input.c_str(),
split_file_name(split_params.output, strategy.i_split, strategy.n_split).c_str(), first_split_path,
split_params.n_split_tensors); split_params.n_split_tensors);
strategy.split_start(); strategy.split_start();
@ -298,7 +297,9 @@ static void gguf_merge(const split_params & split_params) {
std::vector<ggml_context *> ctx_metas; std::vector<ggml_context *> ctx_metas;
std::vector<gguf_context *> ctx_ggufs; std::vector<gguf_context *> ctx_ggufs;
std::string split_prefix; char split_path[PATH_MAX] = {0};
strncpy(split_path, split_params.input.c_str(), sizeof(split_path) - 1);
char split_prefix[PATH_MAX] = {0};
// First pass to find KV and tensors metadata // First pass to find KV and tensors metadata
for (int i_split = 0; i_split < n_split; i_split++) { for (int i_split = 0; i_split < n_split; i_split++) {
@ -309,89 +310,66 @@ static void gguf_merge(const split_params & split_params) {
/*.ctx = */ &ctx_meta, /*.ctx = */ &ctx_meta,
}; };
auto split_name = split_params.input;
if (i_split > 0) { if (i_split > 0) {
split_name = split_file_name(split_prefix, i_split, n_split); llama_split_path(split_path, sizeof(split_path), split_prefix, i_split, n_split);
} }
fprintf(stderr, "%s: reading metadata %s ...", __func__, split_name.c_str()); fprintf(stderr, "%s: reading metadata %s ...", __func__, split_path);
auto * ctx_gguf = gguf_init_from_file(split_name.c_str(), params); auto * ctx_gguf = gguf_init_from_file(split_path, params);
if (!ctx_gguf) { if (!ctx_gguf) {
fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, split_params.input.c_str()); fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, split_params.input.c_str());
exit(1); exit(EXIT_FAILURE);
} }
ctx_ggufs.push_back(ctx_gguf); ctx_ggufs.push_back(ctx_gguf);
ctx_metas.push_back(ctx_meta); ctx_metas.push_back(ctx_meta);
if (i_split == 0) { if (i_split == 0) {
auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_GENERAL_SPLIT_N_SPLIT); auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
if (key_n_split < 0) { if (key_n_split < 0) {
fprintf(stderr, fprintf(stderr,
"\n%s: input file does not contain %s metadata\n", "\n%s: input file does not contain %s metadata\n",
__func__, __func__,
LLM_KV_GENERAL_SPLIT_N_SPLIT); LLM_KV_SPLIT_COUNT);
gguf_free(ctx_gguf); gguf_free(ctx_gguf);
ggml_free(ctx_meta);
gguf_free(ctx_out); gguf_free(ctx_out);
fout.close(); fout.close();
exit(1); exit(EXIT_FAILURE);
} }
n_split = gguf_get_val_u8(ctx_gguf, key_n_split); n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
if (n_split < 1) { if (n_split < 1) {
fprintf(stderr, fprintf(stderr,
"\n%s: input file does not contain a valid split count %d\n", "\n%s: input file does not contain a valid split count %d\n",
__func__, __func__,
n_split); n_split);
gguf_free(ctx_gguf); gguf_free(ctx_gguf);
ggml_free(ctx_meta);
gguf_free(ctx_out); gguf_free(ctx_out);
fout.close(); fout.close();
exit(1); exit(EXIT_FAILURE);
}
// Verify the file naming and extract split_prefix
if (!llama_split_prefix(split_prefix, sizeof (split_prefix), split_path, i_split, n_split)) {
fprintf(stderr, "\n%s: unexpected input file name: %s"
" i_split=%d"
" n_split=%d\n", __func__,
split_path, i_split, n_split);
gguf_free(ctx_gguf);
ggml_free(ctx_meta);
gguf_free(ctx_out);
fout.close();
exit(EXIT_FAILURE);
} }
// Do not trigger merge if we try to merge again the output // Do not trigger merge if we try to merge again the output
gguf_set_val_u8(ctx_out, LLM_KV_GENERAL_SPLIT_N_SPLIT, 0); gguf_set_val_u16(ctx_gguf, LLM_KV_SPLIT_COUNT, 0);
// Set metadata from the first split // Set metadata from the first split
gguf_set_kv(ctx_out, ctx_gguf); gguf_set_kv(ctx_out, ctx_gguf);
} }
// Verify the file naming
{
int i_split_file = 0;
int n_split_file = 0;
const char * i_split_format = "-00000-of-00000.gguf";
if (split_name.size() < strlen(i_split_format)) {
fprintf(stderr, "\n%s: unexpected input file name: %s\n", __func__, split_params.input.c_str());
for (auto * _ctx_gguf : ctx_ggufs) {
gguf_free(_ctx_gguf);
}
gguf_free(ctx_out);
fout.close();
exit(1);
}
split_prefix = split_name.substr(0, split_name.size() - strlen(i_split_format));
const char * split_name_c_str = split_name.c_str();
int n_part = sscanf(&split_name_c_str[0] + split_prefix.size(), "-%d-of-%d", &i_split_file, &n_split_file);
if (n_part != 2 || i_split_file - 1 != i_split || n_split_file != n_split) {
fprintf(stderr, "\n%s: unexpected input file name: %s"
" i_split=%d i_split_file=%d"
" n_split=%d n_split_file=%d\n", __func__,
split_params.input.c_str(),
i_split, i_split_file,
n_split, n_split_file);
for (auto * _ctx_gguf : ctx_ggufs) {
gguf_free(_ctx_gguf);
}
gguf_free(ctx_out);
fout.close();
exit(1);
}
}
auto n_tensors = gguf_get_n_tensors(ctx_gguf); auto n_tensors = gguf_get_n_tensors(ctx_gguf);
for (int i_tensor = 0; i_tensor < n_tensors; i_tensor++) { for (int i_tensor = 0; i_tensor < n_tensors; i_tensor++) {
const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor); const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor);
@ -411,18 +389,19 @@ static void gguf_merge(const split_params & split_params) {
// Write tensors data // Write tensors data
for (int i_split = 0; i_split < n_split; i_split++) { for (int i_split = 0; i_split < n_split; i_split++) {
auto split_name = split_file_name(split_prefix, i_split, n_split); llama_split_path(split_path, sizeof(split_path), split_prefix, i_split, n_split);
std::ifstream f_input(split_name.c_str(), std::ios::binary); std::ifstream f_input(split_path, std::ios::binary);
if (!f_input.is_open()) { if (!f_input.is_open()) {
fprintf(stderr, "%s: failed to open input GGUF from %s\n", __func__, split_name.c_str()); fprintf(stderr, "%s: failed to open input GGUF from %s\n", __func__, split_path);
for (auto * _ctx_gguf : ctx_ggufs) { for (uint32_t i = 0; i < ctx_ggufs.size(); i++) {
gguf_free(_ctx_gguf); gguf_free(ctx_ggufs[i]);
ggml_free(ctx_metas[i]);
} }
gguf_free(ctx_out); gguf_free(ctx_out);
fout.close(); fout.close();
exit(1); exit(EXIT_FAILURE);
} }
fprintf(stderr, "%s: writing tensors %s ...", __func__, split_name.c_str()); fprintf(stderr, "%s: writing tensors %s ...", __func__, split_path);
auto * ctx_gguf = ctx_ggufs[i_split]; auto * ctx_gguf = ctx_ggufs[i_split];
auto * ctx_meta = ctx_metas[i_split]; auto * ctx_meta = ctx_metas[i_split];
@ -481,8 +460,8 @@ int main(int argc, const char ** argv) {
break; break;
case SPLIT_OP_MERGE: gguf_merge(params); case SPLIT_OP_MERGE: gguf_merge(params);
break; break;
default:split_print_usage(argv[0]); default: split_print_usage(argv[0]);
exit(1); exit(EXIT_FAILURE);
} }
return 0; return 0;

View file

@ -0,0 +1,74 @@
# Usage:
#! ./server -m some-model.gguf &
#! pip install pydantic
#! python json-schema-pydantic-example.py
from pydantic import BaseModel, TypeAdapter
from annotated_types import MinLen
from typing import Annotated, List, Optional
import json, requests
if True:
def create_completion(*, response_model=None, endpoint="http://localhost:8080/v1/chat/completions", messages, **kwargs):
'''
Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support
(llama.cpp server, llama-cpp-python, Anyscale / Together...)
The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below)
'''
if response_model:
type_adapter = TypeAdapter(response_model)
schema = type_adapter.json_schema()
messages = [{
"role": "system",
"content": f"You respond in JSON format with the following schema: {json.dumps(schema, indent=2)}"
}] + messages
response_format={"type": "json_object", "schema": schema}
data = requests.post(endpoint, headers={"Content-Type": "application/json"},
json=dict(messages=messages, response_format=response_format, **kwargs)).json()
if 'error' in data:
raise Exception(data['error']['message'])
content = data["choices"][0]["message"]["content"]
return type_adapter.validate_json(content) if type_adapter else content
else:
# This alternative branch uses Instructor + OpenAI client lib.
# Instructor support streamed iterable responses, retry & more.
# (see https://python.useinstructor.com/)
#! pip install instructor openai
import instructor, openai
client = instructor.patch(
openai.OpenAI(api_key="123", base_url="http://localhost:8080"),
mode=instructor.Mode.JSON_SCHEMA)
create_completion = client.chat.completions.create
if __name__ == '__main__':
class QAPair(BaseModel):
question: str
concise_answer: str
justification: str
class PyramidalSummary(BaseModel):
title: str
summary: str
question_answers: Annotated[List[QAPair], MinLen(2)]
sub_sections: Optional[Annotated[List['PyramidalSummary'], MinLen(2)]]
print("# Summary\n", create_completion(
model="...",
response_model=PyramidalSummary,
messages=[{
"role": "user",
"content": f"""
You are a highly efficient corporate document summarizer.
Create a pyramidal summary of an imaginary internal document about our company processes
(starting high-level, going down to each sub sections).
Keep questions short, and answers even shorter (trivia / quizz style).
"""
}]))

View file

@ -1,8 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import itertools
import json import json
import re import re
import sys import sys
from typing import Any, Dict, List, Set, Tuple, Union
# whitespace is constrained to a single space char to prevent model "running away" in # whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality? # whitespace. Also maybe improves generation quality?
@ -12,26 +14,54 @@ PRIMITIVE_RULES = {
'boolean': '("true" | "false") space', 'boolean': '("true" | "false") space',
'number': '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', 'number': '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
'integer': '("-"? ([0-9] | [1-9] [0-9]*)) space', 'integer': '("-"? ([0-9] | [1-9] [0-9]*)) space',
'value' : 'object | array | string | number | boolean',
'object' : '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
'array' : '"[" space ( value ("," space value)* )? "]" space',
'uuid' : '"\\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + ' "\\"" space',
'string': r''' "\"" ( 'string': r''' "\"" (
[^"\\] | [^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space ''', )* "\"" space''',
'null': '"null" space', 'null': '"null" space',
} }
OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value']
# TODO: support "uri", "email" string formats
DATE_RULES = {
'date' : '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )',
'time' : '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
'date-time': 'date "T" time',
'date-string': '"\\"" date "\\"" space',
'time-string': '"\\"" time "\\"" space',
'date-time-string': '"\\"" date-time "\\"" space',
}
RESERVED_NAMES = set(["root", *PRIMITIVE_RULES.keys(), *DATE_RULES.keys()])
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"'} GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')
DATE_PATTERN = '[0-9]{4}-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])'
TIME_PATTERN = '([01][0-9]|2[0-3])(:[0-5][0-9]){2}(\\.[0-9]{1,3})?(Z|[+-](([01][0-9]|2[0-3]):[0-5][0-9]))' # Cap millisecond precision w/ 3 digits
class SchemaConverter: class SchemaConverter:
def __init__(self, prop_order): def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
self._prop_order = prop_order self._prop_order = prop_order
self._allow_fetch = allow_fetch
self._dotall = dotall
self._raw_pattern = raw_pattern
self._rules = {'space': SPACE_RULE} self._rules = {'space': SPACE_RULE}
self._refs = {}
self._refs_being_resolved = set()
def _format_literal(self, literal): def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal) lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
) )
return f'"{escaped}"' return f'"{escaped}"'
@ -41,78 +71,420 @@ class SchemaConverter:
key = esc_name key = esc_name
else: else:
i = 0 i = 0
while f'{esc_name}{i}' in self._rules: while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
i += 1 i += 1
key = f'{esc_name}{i}' key = f'{esc_name}{i}'
self._rules[key] = rule self._rules[key] = rule
return key return key
def resolve_refs(self, schema: dict, url: str):
'''
Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries.
'''
def visit(n: dict):
if isinstance(n, list):
return [visit(x) for x in n]
elif isinstance(n, dict):
ref = n.get('$ref')
if ref is not None and ref not in self._refs:
if ref.startswith('https://'):
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
import requests
frag_split = ref.split('#')
base_url = frag_split[0]
target = self._refs.get(base_url)
if target is None:
target = self.resolve_refs(requests.get(ref).json(), base_url)
self._refs[base_url] = target
if len(frag_split) == 1 or frag_split[-1] == '':
return target
elif ref.startswith('#/'):
target = schema
ref = f'{url}{ref}'
n['$ref'] = ref
else:
raise ValueError(f'Unsupported ref {ref}')
for sel in ref.split('#')[-1].split('/')[1:]:
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
self._refs[ref] = target
else:
for v in n.values():
visit(v)
return n
return visit(schema)
def _generate_union_rule(self, name, alt_schemas):
return ' | '.join((
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
for i, alt_schema in enumerate(alt_schemas)
))
def _visit_pattern(self, pattern, name):
'''
Transforms a regular expression pattern into a GBNF rule.
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
we define sub-rules to keep the output lean.
'''
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1]
sub_rule_ids = {}
i = 0
length = len(pattern)
def to_rule(s: Tuple[str, bool]) -> str:
(txt, is_literal) = s
return "\"" + txt + "\"" if is_literal else txt
def transform() -> Tuple[str, bool]:
'''
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
'''
nonlocal i
nonlocal pattern
nonlocal sub_rule_ids
start = i
# For each component of this sequence, store its string representation and whether it's a literal.
# We only need a flat structure here to apply repetition operators to the last item, and
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
# (GBNF's syntax is luckily very close to regular expressions!)
seq: list[Tuple[str, bool]] = []
def get_dot():
if self._dotall:
rule = '[\\U00000000-\\U0010FFFF]'
else:
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]'
return self._add_rule(f'dot', rule)
def join_seq():
nonlocal seq
ret = []
for is_literal, g in itertools.groupby(seq, lambda x: x[1]):
if is_literal:
ret.append((''.join(x[0] for x in g), True))
else:
ret.extend(g)
if len(ret) == 1:
return ret[0]
return (' '.join(to_rule(x) for x in seq), False)
while i < length:
c = pattern[i]
if c == '.':
seq.append((get_dot(), False))
i += 1
elif c == '(':
i += 1
if i < length:
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
seq.append((f'({to_rule(transform())})', False))
elif c == ')':
i += 1
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
return join_seq()
elif c == '[':
square_brackets = c
i += 1
while i < length and pattern[i] != ']':
if pattern[i] == '\\':
square_brackets += pattern[i:i+2]
i += 2
else:
square_brackets += pattern[i]
i += 1
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
square_brackets += ']'
i += 1
seq.append((square_brackets, False))
elif c == '|':
seq.append(('|', False))
i += 1
elif c in ('*', '+', '?'):
seq[-1] = (to_rule(seq[-1]) + c, False)
i += 1
elif c == '{':
curly_brackets = c
i += 1
while i < length and pattern[i] != '}':
curly_brackets += pattern[i]
i += 1
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
curly_brackets += '}'
i += 1
nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
min_times = 0
max_times = None
try:
if len(nums) == 1:
min_times = int(nums[0])
max_times = min_times
else:
assert len(nums) == 2
min_times = int(nums[0]) if nums[0] else 0
max_times = int(nums[1]) if nums[1] else None
except ValueError:
raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
(sub, sub_is_literal) = seq[-1]
if min_times == 0 and max_times is None:
seq[-1] = (f'{sub}*', False)
elif min_times == 0 and max_times == 1:
seq[-1] = (f'{sub}?', False)
elif min_times == 1 and max_times is None:
seq[-1] = (f'{sub}+', False)
else:
if not sub_is_literal:
id = sub_rule_ids.get(sub)
if id is None:
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
sub_rule_ids[sub] = id
sub = id
seq[-1] = (
' '.join(
([f'"{sub[1:-1] * min_times}"'] if sub_is_literal else [sub] * min_times) +
([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*'])),
False
)
else:
literal = ''
while i < length:
if pattern[i] == '\\' and i < length - 1:
next = pattern[i + 1]
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
i += 1
literal += pattern[i]
i += 1
else:
literal += pattern[i:i+2]
i += 2
elif pattern[i] == '"' and not self._raw_pattern:
literal += '\\"'
i += 1
elif pattern[i] not in NON_LITERAL_SET and \
(i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
literal += pattern[i]
i += 1
else:
break
if literal:
seq.append((literal, True))
return join_seq()
return self._add_rule(
name,
to_rule(transform()) if self._raw_pattern \
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
def _resolve_ref(self, ref):
ref_name = ref.split('/')[-1]
if ref_name not in self._rules and ref not in self._refs_being_resolved:
self._refs_being_resolved.add(ref)
resolved = self._refs[ref]
ref_name = self.visit(resolved, ref_name)
self._refs_being_resolved.remove(ref)
return ref_name
def _generate_constant_rule(self, value):
return self._format_literal(json.dumps(value))
def visit(self, schema, name): def visit(self, schema, name):
schema_type = schema.get('type') schema_type = schema.get('type')
rule_name = name or 'root' schema_format = schema.get('format')
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
if 'oneOf' in schema or 'anyOf' in schema: if (ref := schema.get('$ref')) is not None:
rule = ' | '.join(( return self._add_rule(rule_name, self._resolve_ref(ref))
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}')
for i, alt_schema in enumerate(schema.get('oneOf') or schema['anyOf']) elif 'oneOf' in schema or 'anyOf' in schema:
)) return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
return self._add_rule(rule_name, rule)
elif isinstance(schema_type, list):
return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type]))
elif 'const' in schema: elif 'const' in schema:
return self._add_rule(rule_name, self._format_literal(schema['const'])) return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
elif 'enum' in schema: elif 'enum' in schema:
rule = ' | '.join((self._format_literal(v) for v in schema['enum'])) rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum']))
return self._add_rule(rule_name, rule) return self._add_rule(rule_name, rule)
elif schema_type == 'object' and 'properties' in schema: elif schema_type in (None, 'object') and \
# TODO: `required` keyword ('properties' in schema or \
prop_order = self._prop_order ('additionalProperties' in schema and schema['additionalProperties'] is not True)):
prop_pairs = sorted( required = set(schema.get('required', []))
schema['properties'].items(), properties = list(schema.get('properties', {}).items())
# sort by position in prop_order (if specified) then by key return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
elif schema_type in (None, 'object') and 'allOf' in schema:
required = set()
properties = []
hybrid_name = name
def add_component(comp_schema, is_required):
if (ref := comp_schema.get('$ref')) is not None:
comp_schema = self._refs[ref]
if 'properties' in comp_schema:
for prop_name, prop_schema in comp_schema['properties'].items():
properties.append((prop_name, prop_schema))
if is_required:
required.add(prop_name)
for t in schema['allOf']:
if 'anyOf' in t:
for tt in t['anyOf']:
add_component(tt, is_required=False)
else:
add_component(t, is_required=True)
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[]))
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
items = schema.get('items') or schema['prefixItems']
if isinstance(items, list):
return self._add_rule(
rule_name,
'"[" space ' +
' "," space '.join(
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
for i, item in enumerate(items)) +
' "]" space')
else:
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
list_item_operator = f'( "," space {item_rule_name} )'
successive_items = ""
min_items = schema.get("minItems", 0)
max_items = schema.get("maxItems")
if min_items > 0:
successive_items = list_item_operator * (min_items - 1)
min_items -= 1
if max_items is not None and max_items > min_items:
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
else:
successive_items += list_item_operator + "*"
if min_items == 0:
rule = f'"[" space ( {item_rule_name} {successive_items} )? "]" space'
else:
rule = f'"[" space {item_rule_name} {successive_items} "]" space'
return self._add_rule(rule_name, rule)
elif schema_type in (None, 'string') and 'pattern' in schema:
return self._visit_pattern(schema['pattern'], rule_name)
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
return self._add_rule(
'root' if rule_name == 'root' else schema_format,
PRIMITIVE_RULES['uuid']
) )
rule = '"{" space' elif schema_type in (None, 'string') and schema_format in DATE_RULES:
for i, (prop_name, prop_schema) in enumerate(prop_pairs): for t, r in DATE_RULES.items():
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') self._add_rule(t, r)
if i > 0: return schema_format + '-string'
rule += ' "," space'
rule += fr' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
rule += ' "}" space'
return self._add_rule(rule_name, rule) elif (schema_type == 'object') or (len(schema) == 0):
for n in OBJECT_RULE_NAMES:
elif schema_type == 'array' and 'items' in schema: self._add_rule(n, PRIMITIVE_RULES[n])
# TODO `prefixItems` keyword return self._add_rule(rule_name, 'object')
item_rule_name = self.visit(schema['items'], f'{name}{"-" if name else ""}item')
list_item_operator = f'("," space {item_rule_name})'
successive_items = ""
min_items = schema.get("minItems", 0)
if min_items > 0:
first_item = f"({item_rule_name})"
successive_items = list_item_operator * (min_items - 1)
min_items -= 1
else:
first_item = f"({item_rule_name})?"
max_items = schema.get("maxItems")
if max_items is not None and max_items > min_items:
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
else:
successive_items += list_item_operator + "*"
rule = f'"[" space {first_item} {successive_items} "]" space'
return self._add_rule(rule_name, rule)
else: else:
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return self._add_rule( return self._add_rule(
'root' if rule_name == 'root' else schema_type, 'root' if rule_name == 'root' else schema_type,
PRIMITIVE_RULES[schema_type] PRIMITIVE_RULES[schema_type]
) )
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
prop_order = self._prop_order
# sort by position in prop_order (if specified) then by original order
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
prop_kv_rule_names = {}
for prop_name, prop_schema in properties:
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
prop_kv_rule_names[prop_name] = self._add_rule(
f'{name}{"-" if name else ""}{prop_name}-kv',
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
)
required_props = [k for k in sorted_props if k in required]
optional_props = [k for k in sorted_props if k not in required]
if additional_properties == True or isinstance(additional_properties, dict):
sub_name = f'{name}{"-" if name else ""}additional'
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
prop_kv_rule_names["*"] = self._add_rule(
f'{sub_name}-kv',
self._add_rule('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
)
optional_props.append("*")
rule = '"{" space '
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
if optional_props:
rule += ' ('
if required_props:
rule += ' "," space ( '
def get_recursive_refs(ks, first_is_optional):
[k, *rest] = ks
kv_rule_name = prop_kv_rule_names[k]
if k == '*':
res = self._add_rule(
f'{name}{"-" if name else ""}additional-kvs',
f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*'
)
elif first_is_optional:
res = f'( "," space {kv_rule_name} )?'
else:
res = kv_rule_name
if len(rest) > 0:
res += ' ' + self._add_rule(
f'{name}{"-" if name else ""}{k}-rest',
get_recursive_refs(rest, first_is_optional=True)
)
return res
rule += ' | '.join(
get_recursive_refs(optional_props[i:], first_is_optional=False)
for i in range(len(optional_props))
)
if required_props:
rule += ' )'
rule += ' )?'
rule += ' "}" space'
return rule
def format_grammar(self): def format_grammar(self):
return '\n'.join((f'{name} ::= {rule}' for name, rule in self._rules.items())) return '\n'.join(
f'{name} ::= {rule}'
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
)
def main(args_in = None): def main(args_in = None):
@ -129,16 +501,47 @@ def main(args_in = None):
type=lambda s: s.split(','), type=lambda s: s.split(','),
help=''' help='''
comma-separated property names defining the order of precedence for object properties; comma-separated property names defining the order of precedence for object properties;
properties not specified here are given lower precedence than those that are, and are properties not specified here are given lower precedence than those that are, and
sorted alphabetically are kept in their original order from the schema. Required properties are always
given precedence over optional properties.
''' '''
) )
parser.add_argument(
'--allow-fetch',
action='store_true',
default=False,
help='Whether to allow fetching referenced schemas over HTTPS')
parser.add_argument(
'--dotall',
action='store_true',
default=False,
help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns')
parser.add_argument(
'--raw-pattern',
action='store_true',
default=False,
help='Treats string patterns as raw patterns w/o quotes (or quote escapes)')
parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)') parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)')
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
schema = json.load(sys.stdin if args.schema == '-' else open(args.schema)) if args.schema.startswith('https://'):
prop_order = {name: idx for idx, name in enumerate(args.prop_order)} url = args.schema
converter = SchemaConverter(prop_order) import requests
schema = requests.get(url).json()
elif args.schema == '-':
url = 'stdin'
schema = json.load(sys.stdin)
else:
url = f'file://{args.schema}'
with open(args.schema) as f:
schema = json.load(f)
converter = SchemaConverter(
prop_order={name: idx for idx, name in enumerate(args.prop_order)},
allow_fetch=args.allow_fetch,
dotall=args.dotall,
raw_pattern=args.raw_pattern)
schema = converter.resolve_refs(schema, url)
converter.visit(schema, '') converter.visit(schema, '')
print(converter.format_grammar()) print(converter.format_grammar())

View file

@ -249,6 +249,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
if (s == "q5_1") { if (s == "q5_1") {
return GGML_TYPE_Q5_1; return GGML_TYPE_Q5_1;
} }
if (s == "iq4_nl") {
return GGML_TYPE_IQ4_NL;
}
return GGML_TYPE_COUNT; return GGML_TYPE_COUNT;
} }

View file

@ -3,3 +3,21 @@ add_executable(${TARGET} lookup.cpp)
install(TARGETS ${TARGET} RUNTIME) install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11) target_compile_features(${TARGET} PRIVATE cxx_std_11)
set(TARGET lookup-create)
add_executable(${TARGET} lookup-create.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
set(TARGET lookup-merge)
add_executable(${TARGET} lookup-merge.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
set(TARGET lookup-stats)
add_executable(${TARGET} lookup-stats.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -0,0 +1,43 @@
#include "ggml.h"
#include "llama.h"
#include "common.h"
#include "ngram-cache.h"
#include <cstdint>
#include <fstream>
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>
int main(int argc, char ** argv){
gpt_params params;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
// init llama.cpp
llama_backend_init();
llama_numa_init(params.numa);
llama_model * model = NULL;
llama_context * ctx = NULL;
// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
GGML_ASSERT(model != nullptr);
// tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model);
std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
fprintf(stderr, "%s: tokenization done\n", __func__);
llama_ngram_cache ngram_cache;
llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
llama_ngram_cache_save(ngram_cache, params.lookup_cache_static);
}

View file

@ -0,0 +1,47 @@
#include "ggml.h"
#include "llama.h"
#include "common.h"
#include "ngram-cache.h"
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>
static void print_usage() {
fprintf(stderr, "Merges multiple lookup cache files into a single one.\n");
fprintf(stderr, "Usage: lookup-merge [--help] lookup_part_1.bin lookup_part_2.bin ... lookup_merged.bin\n");
}
int main(int argc, char ** argv){
if (argc < 3) {
print_usage();
exit(1);
}
std::vector<std::string> args;
args.resize(argc-1);
for (int i = 0; i < argc-1; ++i) {
args[i] = argv[i+1];
if (args[i] == "-h" || args[i] == "--help") {
print_usage();
exit(0);
}
}
fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str());
llama_ngram_cache ngram_cache_merged = llama_ngram_cache_load(args[0]);
for (size_t i = 1; i < args.size()-1; ++i) {
fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str());
llama_ngram_cache ngram_cache = llama_ngram_cache_load(args[i]);
llama_ngram_cache_merge(ngram_cache_merged, ngram_cache);
}
fprintf(stderr, "lookup-merge: saving file %s\n", args.back().c_str());
llama_ngram_cache_save(ngram_cache_merged, args.back());
}

View file

@ -0,0 +1,163 @@
#include "ggml.h"
#include "common.h"
#include "llama.h"
#include "log.h"
#include "ngram-cache.h"
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <string>
#include <vector>
#include <unordered_map>
int main(int argc, char ** argv){
gpt_params params;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
const int n_draft = params.n_draft;
// init llama.cpp
llama_backend_init();
llama_numa_init(params.numa);
llama_model * model = NULL;
llama_context * ctx = NULL;
// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_set_rng_seed(ctx, params.seed);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
// tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos tgt: %d\n", add_bos);
std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
llama_ngram_cache ngram_cache_context;
llama_ngram_cache ngram_cache_dynamic;
llama_ngram_cache ngram_cache_static;
int64_t t_draft_flat_us = 0;
int64_t t_draft_us = 0;
{
const int64_t t_start_draft_us = ggml_time_us();
if (!params.lookup_cache_static.empty()) {
try {
ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static);
} catch (std::ifstream::failure const &) {
fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
exit(1);
}
}
if (!params.lookup_cache_dynamic.empty()) {
try {
ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic);
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
}
t_draft_flat_us += ggml_time_us() - t_start_draft_us;
}
const int n_input = inp.size();
const int n_ctx = params.n_ctx;
int n_drafted = 0;
int n_accept = 0;
const int64_t t_start_ms = ggml_time_ms();
// Iterate over input tokens in chunks of size n_ctx.
// Each chunk is treated as if a sequential generation but with pre-determined tokens to ensure reproducibility.
for (int i_start = 0; i_start + n_ctx < n_input; i_start += n_ctx) {
const std::vector<llama_token> inp_slice(inp.begin() + i_start, inp.begin() + i_start + n_ctx);
std::vector<llama_token> pseudo_output;
pseudo_output.push_back(inp_slice[0]);
while ((int) pseudo_output.size() < n_ctx) {
// Simulate drafting and decoding from draft:
std::vector<llama_token> draft;
draft.push_back(pseudo_output.back());
{
const int64_t t_start_draft_us = ggml_time_us();
llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
t_draft_us += ggml_time_us() - t_start_draft_us;
}
n_drafted += draft.size() - 1;
for (size_t j = 1; j < draft.size() && (int) pseudo_output.size() < n_ctx; ++j) {
const llama_token ground_truth = inp_slice[pseudo_output.size()];
const llama_token drafted = draft[j];
if (ground_truth != drafted) {
break;
}
++n_accept;
pseudo_output.push_back(ground_truth);
{
const int64_t t_start_draft_us = ggml_time_us();
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
t_draft_us += ggml_time_us() - t_start_draft_us;
}
}
// After each simulated batch decoding simulate the sampling of a single token:
if ((int) pseudo_output.size() < n_ctx) {
pseudo_output.push_back(inp_slice[pseudo_output.size()]);
{
const int64_t t_start_draft_us = ggml_time_us();
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
t_draft_us += ggml_time_us() - t_start_draft_us;
}
}
draft.erase(draft.begin());
}
if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) {
const int64_t t_now_ms = ggml_time_ms();
const int64_t eta_ms = (n_input - i_start) * (t_now_ms - t_start_ms) / i_start;
const int64_t eta_min = eta_ms / (60*1000);
const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000;
LOG_TEE("lookup-stats: %d/%d done, ETA: %02" PRId64 ":%02" PRId64 "\n", i_start, n_input, eta_min, eta_s);
}
// After each chunk, update the dynamic ngram cache with the context ngram cache:
llama_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
ngram_cache_context.clear();
}
LOG_TEE("\n");
LOG_TEE("\n");
LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_predict = %d\n", n_input - n_input % n_ctx);
LOG_TEE("n_drafted = %d\n", n_drafted);
LOG_TEE("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3);
LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n",
t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us));
LOG_TEE("n_accept = %d\n", n_accept);
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
fprintf(stderr, "\n\n");
return 0;
}

View file

@ -1,12 +1,15 @@
#include "common.h"
#include "ggml.h" #include "ggml.h"
#include "llama.h" #include "llama.h"
#include "common.h"
#include "ngram-cache.h"
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
#include <cstdio> #include <cstdio>
#include <fstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map>
int main(int argc, char ** argv){ int main(int argc, char ** argv){
gpt_params params; gpt_params params;
@ -15,11 +18,7 @@ int main(int argc, char ** argv){
return 1; return 1;
} }
// max/min n-grams size to search for in prompt // max. number of additional tokens to draft if match is found
const int ngram_max = 4;
const int ngram_min = 1;
// length of the candidate / draft sequence, if match is found
const int n_draft = params.n_draft; const int n_draft = params.n_draft;
const bool dump_kv_cache = params.dump_kv_cache; const bool dump_kv_cache = params.dump_kv_cache;
@ -39,6 +38,8 @@ int main(int argc, char ** argv){
// load the model // load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_set_rng_seed(ctx, params.seed);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
// tokenize the prompt // tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model); const bool add_bos = llama_should_add_bos_token(model);
@ -47,6 +48,35 @@ int main(int argc, char ** argv){
std::vector<llama_token> inp; std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
llama_ngram_cache ngram_cache_context;
llama_ngram_cache ngram_cache_dynamic;
llama_ngram_cache ngram_cache_static;
int64_t t_draft_flat_us = 0;
int64_t t_draft_us = 0;
{
// Fill up context ngram cache with tokens from user input:
const int64_t t_start_draft_us = ggml_time_us();
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
if (!params.lookup_cache_static.empty()) {
try {
ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static);
} catch (std::ifstream::failure const &) {
fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
exit(1);
}
}
if (!params.lookup_cache_dynamic.empty()) {
try {
ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic);
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
}
t_draft_flat_us += ggml_time_us() - t_start_draft_us;
}
const int max_context_size = llama_n_ctx(ctx); const int max_context_size = llama_n_ctx(ctx);
const int max_tokens_list_size = max_context_size - 4; const int max_tokens_list_size = max_context_size - 4;
@ -76,8 +106,6 @@ int main(int argc, char ** argv){
int n_drafted = 0; int n_drafted = 0;
int n_accept = 0; int n_accept = 0;
int64_t t_draft_us = 0;
int n_past = inp.size(); int n_past = inp.size();
bool has_eos = false; bool has_eos = false;
@ -129,6 +157,12 @@ int main(int argc, char ** argv){
++n_past; ++n_past;
++i_dft; ++i_dft;
inp.push_back(id); inp.push_back(id);
{
// Update context ngram cache with the newly accepted token:
const int64_t t_start_draft_us = ggml_time_us();
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
t_draft_us += ggml_time_us() - t_start_draft_us;
}
if (params.use_color) { if (params.use_color) {
// color accepted draft token // color accepted draft token
@ -149,6 +183,12 @@ int main(int argc, char ** argv){
draft.clear(); draft.clear();
draft.push_back(id); draft.push_back(id);
inp.push_back(id); inp.push_back(id);
{
// Update context ngram cache with the newly accepted token:
const int64_t t_start_draft_us = ggml_time_us();
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
t_draft_us += ggml_time_us() - t_start_draft_us;
}
break; break;
} }
@ -163,44 +203,19 @@ int main(int argc, char ** argv){
llama_batch_clear(batch_tgt); llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
// generate n_pred tokens through prompt lookup // Draft already contains a single token sampled from the model:
auto prompt_lookup = [&]() -> void { GGML_ASSERT(draft.size() == 1);
const int inp_size = inp.size(); GGML_ASSERT(draft[0] == inp.back());
for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
const llama_token * ngram = &inp[inp_size - ngram_size];
for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
bool match = true;
for (int j = 0; j < ngram_size; ++j) {
if (inp[i + j] != ngram[j]) {
match = false;
break;
}
}
if (match) {
const int startIdx = i + ngram_size;
const int endIdx = startIdx + n_draft;
if (endIdx < inp_size) {
for (int j = startIdx; j < endIdx; ++j) {
LOG(" - draft candidate %d: %d\n", j, inp[j]);
draft.push_back(inp[j]);
llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true);
++n_drafted;
}
return;
}
}
}
}
return;
};
const int64_t t_start_draft_us = ggml_time_us(); const int64_t t_start_draft_us = ggml_time_us();
prompt_lookup(); llama_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
for (size_t i = 1; i < draft.size(); ++i) {
llama_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
}
t_draft_us += ggml_time_us() - t_start_draft_us; t_draft_us += ggml_time_us() - t_start_draft_us;
n_drafted += draft.size() - 1;
llama_decode(ctx, batch_tgt); llama_decode(ctx, batch_tgt);
++n_past; ++n_past;
@ -210,19 +225,24 @@ int main(int argc, char ** argv){
auto t_dec_end = ggml_time_us(); auto t_dec_end = ggml_time_us();
// Update dynamic ngram cache with context ngram cache and save it to disk:
llama_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
llama_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic);
LOG_TEE("\n\n"); LOG_TEE("\n\n");
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("n_draft = %d\n", n_draft); LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("n_drafted = %d\n", n_drafted); LOG_TEE("n_drafted = %d\n", n_drafted);
LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n", LOG_TEE("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3);
LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n",
t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us)); t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us));
LOG_TEE("n_accept = %d\n", n_accept); LOG_TEE("n_accept = %d\n", n_accept);
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
LOG_TEE("\ntarget:\n"); LOG_TEE("\ntarget:\n");
llama_print_timings(ctx); llama_print_timings(ctx);

View file

@ -189,6 +189,18 @@ static void prepare_imatrix(const std::string& imatrix_file,
} }
} }
static ggml_type parse_ggml_type(const char * arg) {
ggml_type result = GGML_TYPE_COUNT;
for (int j = 0; j < GGML_TYPE_COUNT; ++j) {
auto type = ggml_type(j);
const auto * name = ggml_type_name(type);
if (name && strcmp(arg, name) == 0) {
result = type; break;
}
}
return result;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
if (argc < 3) { if (argc < 3) {
usage(argv[0]); usage(argv[0]);
@ -203,6 +215,18 @@ int main(int argc, char ** argv) {
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
params.quantize_output_tensor = false; params.quantize_output_tensor = false;
} else if (strcmp(argv[arg_idx], "--output-tensor-type") == 0) {
if (arg_idx < argc-1) {
params.output_tensor_type = parse_ggml_type(argv[++arg_idx]);
} else {
usage(argv[0]);
}
} else if (strcmp(argv[arg_idx], "--token-embedding-type") == 0) {
if (arg_idx < argc-1) {
params.token_embedding_type = parse_ggml_type(argv[++arg_idx]);
} else {
usage(argv[0]);
}
} else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) { } else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) {
params.allow_requantize = true; params.allow_requantize = true;
} else if (strcmp(argv[arg_idx], "--pure") == 0) { } else if (strcmp(argv[arg_idx], "--pure") == 0) {

View file

@ -0,0 +1,20 @@
import json, subprocess, sys, os
assert len(sys.argv) >= 2
[_, pattern, *rest] = sys.argv
print(subprocess.check_output(
[
"python",
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"json-schema-to-grammar.py"),
*rest,
"-",
"--raw-pattern",
],
text=True,
input=json.dumps({
"type": "string",
"pattern": pattern,
}, indent=2)))

View file

@ -2,12 +2,16 @@ set(TARGET server)
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF) option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h) add_executable(${TARGET}
server.cpp
utils.hpp
httplib.h
)
install(TARGETS ${TARGET} RUNTIME) install(TARGETS ${TARGET} RUNTIME)
target_compile_definitions(${TARGET} PRIVATE target_compile_definitions(${TARGET} PRIVATE
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}> SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
) )
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common json-schema-to-grammar ${CMAKE_THREAD_LIBS_INIT})
if (LLAMA_SERVER_SSL) if (LLAMA_SERVER_SSL)
find_package(OpenSSL REQUIRED) find_package(OpenSSL REQUIRED)
target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto) target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto)

View file

@ -16,17 +16,20 @@ The project is under active development, and we are [looking for feedback and co
**Command line options:** **Command line options:**
- `--threads N`, `-t N`: Set the number of threads to use during generation. - `--threads N`, `-t N`: Set the number of threads to use during generation. Not used if model layers are offloaded to GPU. The server is using batching, this parameter is used only if one token is to be processed on CPU backend.
- `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. If not specified, the number of threads will be set to the number of threads used for generation. - `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. If not specified, the number of threads will be set to the number of threads used for generation. Not used if model layers are offloaded to GPU.
- `--threads-http N`: number of threads in the http server pool to process requests (default: `max(std::thread::hardware_concurrency() - 1, --parallel N + 2)`) - `--threads-http N`: number of threads in the http server pool to process requests (default: `max(std::thread::hardware_concurrency() - 1, --parallel N + 2)`)
- `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`). - `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`).
- `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (e.g https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf). - `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (default: unused).
- `-hfr REPO, --hf-repo REPO`: Hugging Face model repository (default: unused).
- `-hff FILE, --hf-file FILE`: Hugging Face model file (default: unused).
- `-a ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses. - `-a ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses.
- `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096. - `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096.
- `-ngl N`, `--n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance. - `-ngl N`, `--n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS. - `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS. - `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
- `-b N`, `--batch-size N`: Set the batch size for prompt processing. Default: `512`. - `-b N`, `--batch-size N`: Set the batch size for prompt processing. Default: `2048`.
- `-ub N`, `--ubatch-size N`: physical maximum batch size. Default: `512`.
- `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. Not recommended. - `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. Not recommended.
- `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped. - `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped.
- `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed. - `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed.
@ -57,7 +60,7 @@ see https://github.com/ggerganov/llama.cpp/issues/1437
- `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included. - `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included.
- `--metrics`: enable prometheus `/metrics` compatible endpoint (default: disabled) - `--metrics`: enable prometheus `/metrics` compatible endpoint (default: disabled)
- `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name (default: template taken from model's metadata). We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) - `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name (default: template taken from model's metadata). We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
- `--log-disable`: Output logs to stdout only, default: enabled. - `--log-disable`: Output logs to stdout only, not to `llama.log`. default: enabled.
- `--log-format FORMAT`: Define the log output to FORMAT: json or text (default: json) - `--log-format FORMAT`: Define the log output to FORMAT: json or text (default: json)
**If compiled with `LLAMA_SERVER_SSL=ON`** **If compiled with `LLAMA_SERVER_SSL=ON`**
@ -260,7 +263,7 @@ node index.js
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
`slot_id`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1) `id_slot`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1)
`cache_prompt`: Re-use previously cached prompt from the last request if possible. This may prevent re-caching the prompt from scratch. (default: false) `cache_prompt`: Re-use previously cached prompt from the last request if possible. This may prevent re-caching the prompt from scratch. (default: false)

View file

@ -26,8 +26,9 @@ const propOrder = grammarJsonSchemaPropOrder
let grammar = null let grammar = null
if (grammarJsonSchemaFile) { if (grammarJsonSchemaFile) {
const schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8')) let schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8'))
const converter = new SchemaConverter(propOrder) const converter = new SchemaConverter({prop_order: propOrder, allow_fetch: true})
schema = await converter.resolveRefs(schema, grammarJsonSchemaFile)
converter.visit(schema, '') converter.visit(schema, '')
grammar = converter.formatGrammar() grammar = converter.formatGrammar()
} }

View file

@ -483,4 +483,4 @@ unsigned char completion_js[] = {
0x20, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x20, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f,
0x73, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x3b, 0x0a, 0x7d, 0x0a 0x73, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x3b, 0x0a, 0x7d, 0x0a
}; };
unsigned int completion_js_len = 5796; size_t completion_js_len = 5796;

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -630,14 +630,16 @@
const grammarJsonSchemaPropOrder = signal('') const grammarJsonSchemaPropOrder = signal('')
const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value
const convertJSONSchemaGrammar = () => { const convertJSONSchemaGrammar = async () => {
try { try {
const schema = JSON.parse(params.value.grammar) let schema = JSON.parse(params.value.grammar)
const converter = new SchemaConverter( const converter = new SchemaConverter({
grammarJsonSchemaPropOrder.value prop_order: grammarJsonSchemaPropOrder.value
.split(',') .split(',')
.reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {}) .reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {}),
) allow_fetch: true,
})
schema = await converter.resolveRefs(schema, 'input')
converter.visit(schema, '') converter.visit(schema, '')
params.value = { params.value = {
...params.value, ...params.value,

File diff suppressed because one or more lines are too long

View file

@ -1,112 +1,538 @@
// WARNING: This file was ported from json-schema-to-grammar.py, please fix bugs / add features there first.
const SPACE_RULE = '" "?'; const SPACE_RULE = '" "?';
const PRIMITIVE_RULES = { const PRIMITIVE_RULES = {
boolean: '("true" | "false") space', boolean: '("true" | "false") space',
number: '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', number: '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
integer: '("-"? ([0-9] | [1-9] [0-9]*)) space', integer: '("-"? ([0-9] | [1-9] [0-9]*)) space',
value: 'object | array | string | number | boolean',
object: '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
array: '"[" space ( value ("," space value)* )? "]" space',
uuid: '"\\"" ' + [8, 4, 4, 4, 12].map(n => [...new Array(n)].map(_ => '[0-9a-fA-F]').join('')).join(' "-" ') + ' "\\"" space',
string: ` "\\"" ( string: ` "\\"" (
[^"\\\\] | [^"\\\\] |
"\\\\" (["\\\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) "\\\\" (["\\\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\\"" space`, )* "\\"" space`,
null: '"null" space', null: '"null" space',
}; };
const OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value'];
// TODO: support "uri", "email" string formats
const DATE_RULES = {
'date' : '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )',
'time' : '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
'date-time': 'date "T" time',
'date-string': '"\\"" date "\\"" space',
'time-string': '"\\"" time "\\"" space',
'date-time-string': '"\\"" date-time "\\"" space',
};
const RESERVED_NAMES = {'root': true, ...PRIMITIVE_RULES, ...DATE_RULES};
const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g; const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g;
const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"]/g; const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"]/g;
const GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"'}; const GRAMMAR_RANGE_LITERAL_ESCAPE_RE = /[\n\r"\]\-\\]/g;
const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]' };
const NON_LITERAL_SET = new Set('|.()[]{}*+?');
const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('[]()|{}*+?');
export class SchemaConverter { export class SchemaConverter {
constructor(propOrder) { constructor(options) {
this._propOrder = propOrder || {}; this._propOrder = options.prop_order || {};
this._rules = new Map(); this._allowFetch = options.allow_fetch || false;
this._rules.set('space', SPACE_RULE); this._dotall = options.dotall || false;
this._rules = {'space': SPACE_RULE};
this._refs = {};
this._refsBeingResolved = new Set();
} }
_formatLiteral(literal) { _formatLiteral(literal) {
const escaped = JSON.stringify(literal).replace( const escaped = literal.replace(
GRAMMAR_LITERAL_ESCAPE_RE, GRAMMAR_LITERAL_ESCAPE_RE,
m => GRAMMAR_LITERAL_ESCAPES[m] m => GRAMMAR_LITERAL_ESCAPES[m]
); );
return `"${escaped}"`; return `"${escaped}"`;
} }
_formatRangeChar(literal) {
return JSON.stringify(literal).slice(1, -1).replace(
GRAMMAR_RANGE_LITERAL_ESCAPE_RE,
m => GRAMMAR_LITERAL_ESCAPES[m]
);
}
_addRule(name, rule) { _addRule(name, rule) {
let escName = name.replace(INVALID_RULE_CHARS_RE, '-'); let escName = name.replace(INVALID_RULE_CHARS_RE, '-');
let key = escName; let key = escName;
if (this._rules.has(escName)) { if (escName in this._rules) {
if (this._rules.get(escName) === rule) { if (this._rules[escName] === rule) {
return key; return key;
} }
let i = 0; let i = 0;
while (this._rules.has(`${escName}${i}`)) { while ((`${escName}${i}` in this._rules) && (this._rules[`${escName}${i}`] !== rule)) {
i += 1; i += 1;
} }
key = `${escName}${i}`; key = `${escName}${i}`;
} }
this._rules.set(key, rule); this._rules[key] = rule;
return key; return key;
} }
async resolveRefs(schema, url) {
const visit = async (n) => {
if (Array.isArray(n)) {
return Promise.all(n.map(visit));
} else if (typeof n === 'object' && n !== null) {
let ref = n.$ref;
let target;
if (ref !== undefined && !this._refs[ref]) {
if (ref.startsWith('https://')) {
if (!this._allowFetch) {
throw new Error('Fetching remote schemas is not allowed (use --allow-fetch for force)');
}
const fetch = (await import('node-fetch')).default;
const fragSplit = ref.split('#');
const baseUrl = fragSplit[0];
target = this._refs[baseUrl];
if (!target) {
target = await this.resolveRefs(await fetch(ref).then(res => res.json()), baseUrl);
this._refs[baseUrl] = target;
}
if (fragSplit.length === 1 || fragSplit[fragSplit.length - 1] === '') {
return target;
}
} else if (ref.startsWith('#/')) {
target = schema;
ref = `${url}${ref}`;
n.$ref = ref;
} else {
throw new Error(`Unsupported ref ${ref}`);
}
const selectors = ref.split('#')[1].split('/').slice(1);
for (const sel of selectors) {
if (!target || !(sel in target)) {
throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`);
}
target = target[sel];
}
this._refs[ref] = target;
} else {
await Promise.all(Object.values(n).map(visit));
}
}
return n;
};
return visit(schema);
}
_generateUnionRule(name, altSchemas) {
return altSchemas
.map((altSchema, i) => this.visit(altSchema, `${name ?? ''}${name ? '-' : 'alternative-'}${i}`))
.join(' | ');
}
_visitPattern(pattern, name) {
if (!pattern.startsWith('^') || !pattern.endsWith('$')) {
throw new Error('Pattern must start with "^" and end with "$"');
}
pattern = pattern.slice(1, -1);
const subRuleIds = {};
let i = 0;
const length = pattern.length;
const getDot = () => {
let rule;
if (this._dotall) {
rule = '[\\U00000000-\\U0010FFFF]';
} else {
// Accept any character... except \n and \r line break chars (\x0A and \xOD)
rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]';
}
return this._addRule('dot', rule);
};
const toRule = ([s, isLiteral]) => isLiteral ? "\"" + s + "\"" : s;
const transform = () => {
const start = i;
// For each component of this sequence, store its string representation and whether it's a literal.
// We only need a flat structure here to apply repetition operators to the last item, and
// to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
// (GBNF's syntax is luckily very close to regular expressions!)
const seq = [];
const joinSeq = () => {
const ret = [];
for (const [isLiteral, g] of groupBy(seq, x => x[1])) {
if (isLiteral) {
ret.push([[...g].map(x => x[0]).join(''), true]);
} else {
ret.push(...g);
}
}
if (ret.length === 1) {
return ret[0];
}
return [ret.map(x => toRule(x)).join(' '), false];
};
while (i < length) {
const c = pattern[i];
if (c === '.') {
seq.push([getDot(), false]);
i += 1;
} else if (c === '(') {
i += 1;
if (i < length) {
if (pattern[i] === '?') {
throw new Error(`Unsupported pattern syntax "${pattern[i]}" at index ${i} of /${pattern}/`);
}
}
seq.push([`(${toRule(transform())})`, false]);
} else if (c === ')') {
i += 1;
if (start <= 0 || pattern[start - 1] !== '(') {
throw new Error(`Unbalanced parentheses; start = ${start}, i = ${i}, pattern = ${pattern}`);
}
return joinSeq();
} else if (c === '[') {
let squareBrackets = c;
i += 1;
while (i < length && pattern[i] !== ']') {
if (pattern[i] === '\\') {
squareBrackets += pattern.slice(i, i + 2);
i += 2;
} else {
squareBrackets += pattern[i];
i += 1;
}
}
if (i >= length) {
throw new Error(`Unbalanced square brackets; start = ${start}, i = ${i}, pattern = ${pattern}`);
}
squareBrackets += ']';
i += 1;
seq.push([squareBrackets, false]);
} else if (c === '|') {
seq.push(['|', false]);
i += 1;
} else if (c === '*' || c === '+' || c === '?') {
seq[seq.length - 1] = [toRule(seq[seq.length - 1]) + c, false];
i += 1;
} else if (c === '{') {
let curlyBrackets = c;
i += 1;
while (i < length && pattern[i] !== '}') {
curlyBrackets += pattern[i];
i += 1;
}
if (i >= length) {
throw new Error(`Unbalanced curly brackets; start = ${start}, i = ${i}, pattern = ${pattern}`);
}
curlyBrackets += '}';
i += 1;
const nums = curlyBrackets.slice(1, -1).split(',').map(s => s.trim());
let minTimes, maxTimes;
if (nums.length === 1) {
minTimes = parseInt(nums[0], 10);
maxTimes = minTimes;
} else {
if (nums.length !== 2) {
throw new Error(`Invalid quantifier ${curlyBrackets}`);
}
minTimes = nums[0] ? parseInt(nums[0], 10) : 0;
maxTimes = nums[1] ? parseInt(nums[1], 10) : Infinity;
}
let [sub, subIsLiteral] = seq[seq.length - 1];
if (minTimes === 0 && maxTimes === Infinity) {
seq[seq.length - 1] = [`${sub}*`, false];
} else if (minTimes === 0 && maxTimes === 1) {
seq[seq.length - 1] = [`${sub}?`, false];
} else if (minTimes === 1 && maxTimes === Infinity) {
seq[seq.length - 1] = [`${sub}+`, false];
} else {
if (!subIsLiteral) {
let id = subRuleIds[sub];
if (id === undefined) {
id = this._addRule(`${name}-${Object.keys(subRuleIds).length + 1}`, sub);
subRuleIds[sub] = id;
}
sub = id;
}
const repeatedSub = Array.from({ length: minTimes }, () => subIsLiteral ? `"${sub.slice(1, -1).repeat(minTimes)}"` : sub);
const optionalSub = maxTimes !== undefined ? Array.from({ length: maxTimes - minTimes }, () => `${sub}?`) : [`${sub}*`];
seq[seq.length - 1] = [repeatedSub.concat(optionalSub).join(' '), false];
}
} else {
let literal = '';
while (i < length) {
if (pattern[i] === '\\' && i < length - 1) {
const next = pattern[i + 1];
if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.has(next)) {
i += 1;
literal += pattern[i];
i += 1;
} else {
literal += pattern.slice(i, i + 2);
i += 2;
}
} else if (pattern[i] === '"') {
literal += '\\"';
i += 1;
} else if (!NON_LITERAL_SET.has(pattern[i]) &&
(i === length - 1 || literal === '' || pattern[i + 1] === '.' || !NON_LITERAL_SET.has(pattern[i+1]))) {
literal += pattern[i];
i += 1;
} else {
break;
}
}
if (literal !== '') {
seq.push([literal, true]);
}
}
}
return joinSeq();
};
return this._addRule(name, "\"\\\"\" " + toRule(transform()) + " \"\\\"\" space")
}
_resolveRef(ref) {
let refName = ref.split('/').pop();
if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) {
this._refsBeingResolved.add(ref);
const resolved = this._refs[ref];
refName = this.visit(resolved, refName);
this._refsBeingResolved.delete(ref);
}
return refName;
}
_generateConstantRule(value) {
return this._formatLiteral(JSON.stringify(value));
}
visit(schema, name) { visit(schema, name) {
const schemaType = schema.type; const schemaType = schema.type;
const ruleName = name || 'root'; const schemaFormat = schema.format;
const ruleName = name in RESERVED_NAMES ? name + '-' : name == '' ? 'root' : name;
if (schema.oneOf || schema.anyOf) { const ref = schema.$ref;
const rule = (schema.oneOf || schema.anyOf).map((altSchema, i) => if (ref !== undefined) {
this.visit(altSchema, `${name}${name ? "-" : ""}${i}`) return this._addRule(ruleName, this._resolveRef(ref));
).join(' | '); } else if (schema.oneOf || schema.anyOf) {
return this._addRule(ruleName, this._generateUnionRule(name, schema.oneOf || schema.anyOf));
return this._addRule(ruleName, rule); } else if (Array.isArray(schemaType)) {
return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({ type: t }))));
} else if ('const' in schema) { } else if ('const' in schema) {
return this._addRule(ruleName, this._formatLiteral(schema.const)); return this._addRule(ruleName, this._generateConstantRule(schema.const));
} else if ('enum' in schema) { } else if ('enum' in schema) {
const rule = schema.enum.map(v => this._formatLiteral(v)).join(' | '); const rule = schema.enum.map(v => this._generateConstantRule(v)).join(' | ');
return this._addRule(ruleName, rule); return this._addRule(ruleName, rule);
} else if (schemaType === 'object' && 'properties' in schema) { } else if ((schemaType === undefined || schemaType === 'object') &&
// TODO: `required` keyword (from python implementation) ('properties' in schema ||
const propOrder = this._propOrder; ('additionalProperties' in schema && schema.additionalProperties !== true))) {
const propPairs = Object.entries(schema.properties).sort((a, b) => { const required = new Set(schema.required || []);
// sort by position in prop_order (if specified) then by key const properties = Object.entries(schema.properties ?? {});
const orderA = typeof propOrder[a[0]] === 'number' ? propOrder[a[0]] : Infinity; return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
const orderB = typeof propOrder[b[0]] === 'number' ? propOrder[b[0]] : Infinity; } else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) {
return orderA - orderB || a[0].localeCompare(b[0]); const required = new Set();
}); const properties = [];
const addComponent = (compSchema, isRequired) => {
let rule = '"{" space'; const ref = compSchema.$ref;
propPairs.forEach(([propName, propSchema], i) => { if (ref !== undefined) {
const propRuleName = this.visit(propSchema, `${name}${name ? "-" : ""}${propName}`); compSchema = this._refs[ref];
if (i > 0) {
rule += ' "," space';
} }
rule += ` ${this._formatLiteral(propName)} space ":" space ${propRuleName}`;
});
rule += ' "}" space';
return this._addRule(ruleName, rule); if ('properties' in compSchema) {
} else if (schemaType === 'array' && 'items' in schema) { for (const [propName, propSchema] of Object.entries(compSchema.properties)) {
// TODO `prefixItems` keyword (from python implementation) properties.push([propName, propSchema]);
const itemRuleName = this.visit(schema.items, `${name}${name ? "-" : ""}item`); if (isRequired) {
const rule = `"[" space (${itemRuleName} ("," space ${itemRuleName})*)? "]" space`; required.add(propName);
return this._addRule(ruleName, rule); }
}
}
};
for (const t of schema.allOf) {
if ('anyOf' in t) {
for (const tt of t.anyOf) {
addComponent(tt, false);
}
} else {
addComponent(t, true);
}
}
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, /* additionalProperties= */ false));
} else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
const items = schema.items ?? schema.prefixItems;
if (Array.isArray(items)) {
return this._addRule(
ruleName,
'"[" space ' +
items.map((item, i) => this.visit(item, `${name ?? ''}${name ? '-' : ''}tuple-${i}`)).join(' "," space ') +
' "]" space'
);
} else {
const itemRuleName = this.visit(items, `${name ?? ''}${name ? '-' : ''}item`);
const listItemOperator = `( "," space ${itemRuleName} )`;
let successiveItems = '';
let minItems = schema.minItems || 0;
const maxItems = schema.maxItems;
if (minItems > 0) {
successiveItems = listItemOperator.repeat(minItems - 1);
minItems--;
}
if (maxItems !== undefined && maxItems > minItems) {
successiveItems += `${listItemOperator}?`.repeat(maxItems - minItems - 1);
} else {
successiveItems += `${listItemOperator}*`;
}
const rule = minItems === 0
? `"[" space ( ${itemRuleName} ${successiveItems} )? "]" space`
: `"[" space ${itemRuleName} ${successiveItems} "]" space`;
return this._addRule(ruleName, rule);
}
} else if ((schemaType === undefined || schemaType === 'string') && 'pattern' in schema) {
return this._visitPattern(schema.pattern, ruleName);
} else if ((schemaType === undefined || schemaType === 'string') && /^uuid[1-5]?$/.test(schema.format || '')) {
return this._addRule(
ruleName === 'root' ? 'root' : schemaFormat,
PRIMITIVE_RULES['uuid'])
} else if ((schemaType === undefined || schemaType === 'string') && schema.format in DATE_RULES) {
for (const [t, r] of Object.entries(DATE_RULES)) {
this._addRule(t, r);
}
return schemaFormat + '-string';
} else if ((schemaType === 'object') || (Object.keys(schema).length === 0)) {
for (const n of OBJECT_RULE_NAMES) {
this._addRule(n, PRIMITIVE_RULES[n]);
}
return this._addRule(ruleName, 'object');
} else { } else {
if (!PRIMITIVE_RULES[schemaType]) { if (!(schemaType in PRIMITIVE_RULES)) {
throw new Error(`Unrecognized schema: ${JSON.stringify(schema)}`); throw new Error(`Unrecognized schema: ${JSON.stringify(schema)}`);
} }
return this._addRule( // TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
ruleName === 'root' ? 'root' : schemaType, return this._addRule(ruleName === 'root' ? 'root' : schemaType, PRIMITIVE_RULES[schemaType]);
PRIMITIVE_RULES[schemaType] }
}
_buildObjectRule(properties, required, name, additionalProperties) {
const propOrder = this._propOrder;
// sort by position in prop_order (if specified) then by original order
const sortedProps = properties.map(([k]) => k).sort((a, b) => {
const orderA = propOrder[a] || Infinity;
const orderB = propOrder[b] || Infinity;
return orderA - orderB || properties.findIndex(([k]) => k === a) - properties.findIndex(([k]) => k === b);
});
const propKvRuleNames = {};
for (const [propName, propSchema] of properties) {
const propRuleName = this.visit(propSchema, `${name ?? ''}${name ? '-' : ''}${propName}`);
propKvRuleNames[propName] = this._addRule(
`${name ?? ''}${name ? '-' : ''}${propName}-kv`,
`${this._formatLiteral(JSON.stringify(propName))} space ":" space ${propRuleName}`
); );
} }
const requiredProps = sortedProps.filter(k => required.has(k));
const optionalProps = sortedProps.filter(k => !required.has(k));
if (typeof additionalProperties === 'object' || additionalProperties === true) {
const subName = `${name ?? ''}${name ? '-' : ''}additional`;
const valueRule = this.visit(additionalProperties === true ? {} : additionalProperties, `${subName}-value`);
propKvRuleNames['*'] = this._addRule(
`${subName}-kv`,
`${this._addRule('string', PRIMITIVE_RULES['string'])} ":" space ${valueRule}`);
optionalProps.push('*');
}
let rule = '"{" space ';
rule += requiredProps.map(k => propKvRuleNames[k]).join(' "," space ');
if (optionalProps.length > 0) {
rule += ' (';
if (requiredProps.length > 0) {
rule += ' "," space ( ';
}
const getRecursiveRefs = (ks, firstIsOptional) => {
const [k, ...rest] = ks;
const kvRuleName = propKvRuleNames[k];
let res;
if (k === '*') {
res = this._addRule(
`${name ?? ''}${name ? '-' : ''}additional-kvs`,
`${kvRuleName} ( "," space ` + kvRuleName + ` )*`
)
} else if (firstIsOptional) {
res = `( "," space ${kvRuleName} )?`;
} else {
res = kvRuleName;
}
if (rest.length > 0) {
res += ' ' + this._addRule(
`${name ?? ''}${name ? '-' : ''}${k}-rest`,
getRecursiveRefs(rest, true)
);
}
return res;
};
rule += optionalProps.map((_, i) => getRecursiveRefs(optionalProps.slice(i), false)).join(' | ');
if (requiredProps.length > 0) {
rule += ' )';
}
rule += ' )?';
}
rule += ' "}" space';
return rule;
} }
formatGrammar() { formatGrammar() {
let grammar = ''; let grammar = '';
this._rules.forEach((rule, name) => { for (const [name, rule] of Object.entries(this._rules).sort(([a], [b]) => a.localeCompare(b))) {
grammar += `${name} ::= ${rule}\n`; grammar += `${name} ::= ${rule}\n`;
}); }
return grammar; return grammar;
} }
} }
// Helper function to group elements by a key function
function* groupBy(iterable, keyFn) {
let lastKey = null;
let group = [];
for (const element of iterable) {
const key = keyFn(element);
if (lastKey !== null && key !== lastKey) {
yield [lastKey, group];
group = [];
}
group.push(element);
lastKey = key;
}
if (group.length > 0) {
yield [lastKey, group];
}
}

View file

@ -1,6 +1,7 @@
#include "utils.hpp" #include "utils.hpp"
#include "common.h" #include "common.h"
#include "json-schema-to-grammar.h"
#include "llama.h" #include "llama.h"
#include "grammar-parser.h" #include "grammar-parser.h"
@ -29,7 +30,7 @@
#include <signal.h> #include <signal.h>
#include <memory> #include <memory>
using json = nlohmann::json; using json = nlohmann::ordered_json;
bool server_verbose = false; bool server_verbose = false;
bool server_log_json = true; bool server_log_json = true;
@ -178,6 +179,7 @@ struct server_slot {
llama_token sampled; llama_token sampled;
struct llama_sampling_params sparams; struct llama_sampling_params sparams;
llama_sampling_context * ctx_sampling = nullptr; llama_sampling_context * ctx_sampling = nullptr;
json json_schema;
int32_t ga_i = 0; // group-attention state int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1; // group-attention factor int32_t ga_n = 1; // group-attention factor
@ -845,7 +847,17 @@ struct server_context {
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.seed = json_value(data, "seed", default_params.seed); slot.params.seed = json_value(data, "seed", default_params.seed);
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar); if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
slot.sparams.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) {
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
return false;
}
} else {
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
}
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
@ -1235,7 +1247,7 @@ struct server_context {
{"penalize_nl", slot.sparams.penalize_nl}, {"penalize_nl", slot.sparams.penalize_nl},
{"stop", slot.params.antiprompt}, {"stop", slot.params.antiprompt},
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
{"n_keep", params.n_keep}, {"n_keep", slot.params.n_keep},
{"ignore_eos", ignore_eos}, {"ignore_eos", ignore_eos},
{"stream", slot.params.stream}, {"stream", slot.params.stream},
{"logit_bias", slot.sparams.logit_bias}, {"logit_bias", slot.sparams.logit_bias},
@ -1746,7 +1758,7 @@ struct server_context {
} }
// process in chunks of params.n_batch // process in chunks of params.n_batch
int32_t n_batch = llama_n_batch(ctx); int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch // next, batch any pending prompts without exceeding n_batch
@ -2196,7 +2208,11 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" -m FNAME, --model FNAME\n"); printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str()); printf(" model path (default: %s)\n", params.model.c_str());
printf(" -mu MODEL_URL, --model-url MODEL_URL\n"); printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
printf(" model download url (default: %s)\n", params.model_url.c_str()); printf(" model download url (default: unused)\n");
printf(" -hfr REPO, --hf-repo REPO\n");
printf(" Hugging Face model repository (default: unused)\n");
printf(" -hff FILE, --hf-file FILE\n");
printf(" Hugging Face model file (default: unused)\n");
printf(" -a ALIAS, --alias ALIAS\n"); printf(" -a ALIAS, --alias ALIAS\n");
printf(" set an alias for the model, will be added as `model` field in completion response\n"); printf(" set an alias for the model, will be added as `model` field in completion response\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
@ -2213,7 +2229,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled)\n");
printf(" -spf FNAME, --system-prompt-file FNAME\n"); printf(" -spf FNAME, --system-prompt-file FNAME\n");
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n"); printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
printf(" -ctk TYPE, --cache-type-k TYPE\n"); printf(" -ctk TYPE, --cache-type-k TYPE\n");
@ -2325,6 +2341,18 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
break; break;
} }
params.model_url = argv[i]; params.model_url = argv[i];
} else if (arg == "-hfr" || arg == "--hf-repo") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.hf_repo = argv[i];
} else if (arg == "-hff" || arg == "--hf-file") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.hf_file = argv[i];
} else if (arg == "-a" || arg == "--alias") { } else if (arg == "-a" || arg == "--alias") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;

View file

@ -4,7 +4,8 @@ Feature: Parallel
Background: Server startup Background: Server startup
Given a server listening on localhost:8080 Given a server listening on localhost:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models
And a model file test-model-00001-of-00003.gguf
And 42 as server seed And 42 as server seed
And 128 as batch size And 128 as batch size
And 256 KV cache size And 256 KV cache size

View file

@ -37,6 +37,22 @@ Feature: Security
| llama.cpp | no | | llama.cpp | no |
| hackme | raised | | hackme | raised |
Scenario Outline: OAI Compatibility (invalid response formats)
Given a system prompt test
And a user prompt test
And a response format <response_format>
And a model test
And 2 max tokens to predict
And streaming is disabled
Given an OAI compatible chat completions request with raised api error
Examples: Prompts
| response_format |
| {"type": "sound"} |
| {"type": "json_object", "schema": 123} |
| {"type": "json_object", "schema": {"type": 123}} |
| {"type": "json_object", "schema": {"type": "hiccup"}} |
Scenario Outline: CORS Options Scenario Outline: CORS Options
Given a user api key llama.cpp Given a user api key llama.cpp

View file

@ -4,8 +4,8 @@ Feature: llama.cpp server
Background: Server startup Background: Server startup
Given a server listening on localhost:8080 Given a server listening on localhost:8080
And a model url https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories260K.gguf And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model file stories260K.gguf And a model file test-model.gguf
And a model alias tinyllama-2 And a model alias tinyllama-2
And 42 as server seed And 42 as server seed
# KV Cache corresponds to the total amount of tokens # KV Cache corresponds to the total amount of tokens
@ -70,6 +70,22 @@ Feature: llama.cpp server
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird\|Annabyear)+ | -1 | 64 | enabled | | | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird\|Annabyear)+ | -1 | 64 | enabled | |
Scenario Outline: OAI Compatibility w/ response format
Given a model test
And a system prompt test
And a user prompt test
And a response format <response_format>
And 10 max tokens to predict
Given an OAI compatible chat completions request with no api error
Then <n_predicted> tokens are predicted matching <re_content>
Examples: Prompts
| response_format | n_predicted | re_content |
| {"type": "json_object", "schema": {"const": "42"}} | 5 | "42" |
| {"type": "json_object", "schema": {"items": [{"type": "integer"}]}} | 10 | \[ -300 \] |
| {"type": "json_object"} | 10 | \{ " Jacky. |
Scenario: Tokenize / Detokenize Scenario: Tokenize / Detokenize
When tokenizing: When tokenizing:
""" """

View file

@ -16,7 +16,6 @@ import numpy as np
import openai import openai
from behave import step from behave import step
from behave.api.async_step import async_run_until_complete from behave.api.async_step import async_run_until_complete
from huggingface_hub import hf_hub_download
from prometheus_client import parser from prometheus_client import parser
@ -39,6 +38,8 @@ def step_server_config(context, server_fqdn, server_port):
context.model_alias = None context.model_alias = None
context.model_file = None context.model_file = None
context.model_hf_repo = None
context.model_hf_file = None
context.model_url = None context.model_url = None
context.n_batch = None context.n_batch = None
context.n_ubatch = None context.n_ubatch = None
@ -59,6 +60,7 @@ def step_server_config(context, server_fqdn, server_port):
context.seed = None context.seed = None
context.server_seed = None context.server_seed = None
context.user_api_key = None context.user_api_key = None
context.response_format = None
context.tasks_result = [] context.tasks_result = []
context.concurrent_tasks = [] context.concurrent_tasks = []
@ -67,9 +69,9 @@ def step_server_config(context, server_fqdn, server_port):
@step('a model file {hf_file} from HF repo {hf_repo}') @step('a model file {hf_file} from HF repo {hf_repo}')
def step_download_hf_model(context, hf_file, hf_repo): def step_download_hf_model(context, hf_file, hf_repo):
context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file) context.model_hf_repo = hf_repo
if context.debug: context.model_hf_file = hf_file
print(f"model file: {context.model_file}") context.model_file = os.path.basename(hf_file)
@step('a model file {model_file}') @step('a model file {model_file}')
@ -269,6 +271,11 @@ def step_max_tokens(context, max_tokens):
context.n_predict = max_tokens context.n_predict = max_tokens
@step('a response format {response_format}')
def step_response_format(context, response_format):
context.response_format = json.loads(response_format)
@step('streaming is {enable_streaming}') @step('streaming is {enable_streaming}')
def step_streaming(context, enable_streaming): def step_streaming(context, enable_streaming):
context.enable_streaming = enable_streaming == 'enabled' context.enable_streaming = enable_streaming == 'enabled'
@ -384,6 +391,9 @@ async def step_oai_chat_completions(context, api_error):
enable_streaming=context.enable_streaming enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None, if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
seed=await completions_seed(context), seed=await completions_seed(context),
user_api_key=context.user_api_key user_api_key=context.user_api_key
@ -443,6 +453,8 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'n_predict') else None, if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None, if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
seed=await completions_seed(context), seed=await completions_seed(context),
user_api_key=context.user_api_key user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None) if hasattr(context, 'user_api_key') else None)
@ -463,6 +475,8 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'n_predict') else None, if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None, if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
seed=context.seed seed=context.seed
if hasattr(context, 'seed') else if hasattr(context, 'seed') else
context.server_seed context.server_seed
@ -745,6 +759,7 @@ async def oai_chat_completions(user_prompt,
model=None, model=None,
n_predict=None, n_predict=None,
enable_streaming=None, enable_streaming=None,
response_format=None,
seed=None, seed=None,
user_api_key=None, user_api_key=None,
expect_api_error=None): expect_api_error=None):
@ -770,6 +785,8 @@ async def oai_chat_completions(user_prompt,
"stream": enable_streaming, "stream": enable_streaming,
"seed": seed "seed": seed
} }
if response_format is not None:
payload['response_format'] = response_format
completion_response = { completion_response = {
'content': '', 'content': '',
'timings': { 'timings': {
@ -830,6 +847,7 @@ async def oai_chat_completions(user_prompt,
model=model, model=model,
max_tokens=n_predict, max_tokens=n_predict,
stream=enable_streaming, stream=enable_streaming,
response_format=payload.get('response_format'),
seed=seed seed=seed
) )
except openai.error.AuthenticationError as e: except openai.error.AuthenticationError as e:
@ -1062,6 +1080,10 @@ def start_server_background(context):
server_args.extend(['--model', context.model_file]) server_args.extend(['--model', context.model_file])
if context.model_url: if context.model_url:
server_args.extend(['--model-url', context.model_url]) server_args.extend(['--model-url', context.model_url])
if context.model_hf_repo:
server_args.extend(['--hf-repo', context.model_hf_repo])
if context.model_hf_file:
server_args.extend(['--hf-file', context.model_hf_file])
if context.n_batch: if context.n_batch:
server_args.extend(['--batch-size', context.n_batch]) server_args.extend(['--batch-size', context.n_batch])
if context.n_ubatch: if context.n_ubatch:

View file

@ -12,7 +12,7 @@
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
using json = nlohmann::json; using json = nlohmann::ordered_json;
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type { enum error_type {
@ -95,8 +95,8 @@ static inline void server_log(const char *level, const char *function, int line,
const std::string str = ss.str(); const std::string str = ss.str();
printf("%.*s\n", (int)str.size(), str.data()); printf("%.*s\n", (int)str.size(), str.data());
fflush(stdout);
} }
fflush(stdout);
} }
// //
@ -373,10 +373,21 @@ static json oaicompat_completion_params_parse(
llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z); llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z);
llama_params["n_keep"] = json_value(body, "n_keep", 0); llama_params["n_keep"] = json_value(body, "n_keep", 0);
if (body.count("grammar") != 0) { if (body.contains("grammar")) {
llama_params["grammar"] = json_value(body, "grammar", json::object()); llama_params["grammar"] = json_value(body, "grammar", json::object());
} }
if (body.contains("response_format")) {
auto response_format = json_value(body, "response_format", json::object());
if (response_format.contains("type")) {
if (response_format["type"] == "json_object") {
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
} else {
throw std::runtime_error("response_format type not supported: " + response_format["type"].dump());
}
}
}
// Handle 'stop' field // Handle 'stop' field
if (body.contains("stop") && body["stop"].is_string()) { if (body.contains("stop") && body["stop"].is_string()) {
llama_params["stop"] = json::array({body["stop"].get<std::string>()}); llama_params["stop"] = json::array({body["stop"].get<std::string>()});

View file

@ -219,7 +219,8 @@ int main(int argc, char ** argv) {
if (params.sparams.temp > 0) { if (params.sparams.temp > 0) {
// stochastic verification // stochastic verification
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
llama_sample_softmax(ctx_tgt, &dist_tgt);
float p_tgt = 0, p_dft = 0; float p_tgt = 0, p_dft = 0;
// GGML_ASSERT(dist_tgt.size() == dist_dft.size()); // GGML_ASSERT(dist_tgt.size() == dist_dft.size());

28
examples/ts-type-to-grammar.sh Executable file
View file

@ -0,0 +1,28 @@
#!/bin/bash
#
# ./examples/ts-type-to-grammar.sh "{a:string,b:string,c?:string}"
# python examples/json-schema-to-grammar.py https://json.schemastore.org/tsconfig.json
#
set -euo pipefail
readonly type="$1"
# Create a temporary directory
TMPDIR=""
trap 'rm -fR "$TMPDIR"' EXIT
TMPDIR=$(mktemp -d)
DTS_FILE="$TMPDIR/type.d.ts"
SCHEMA_FILE="$TMPDIR/schema.json"
echo "export type MyType = $type" > "$DTS_FILE"
# This is a fork of typescript-json-schema, actively maintained as of March 2024:
# https://github.com/vega/ts-json-schema-generator
npx ts-json-schema-generator --unstable --no-top-ref --path "$DTS_FILE" --type MyType -e none > "$SCHEMA_FILE"
# Alternative, not actively maintained as of March 2024:
# https://github.com/YousefED/typescript-json-schema
# npx typescript-json-schema --defaultProps --required "$DTS_FILE" MyType | tee "$SCHEMA_FILE" >&2
./examples/json-schema-to-grammar.py "$SCHEMA_FILE"

View file

@ -771,7 +771,11 @@ GGML_CALL static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t
if (src_ctx->device == dst_ctx->device) { if (src_ctx->device == dst_ctx->device) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread)); CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread));
} else { } else {
#ifdef GGML_CUDA_NO_PEER_COPY
return false;
#else
CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread)); CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));
#endif
} }
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
return true; return true;
@ -6757,6 +6761,123 @@ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
} }
} }
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q5_0 * dsti = (block_q5_0 *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK5_0; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK5_0/2 + j]*id;
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
memcpy(dsti->qh, &qh, sizeof(qh));
}
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q5_1 * dsti = (block_q5_1 *) cdsti;
float min = xi[0];
float max = xi[0];
for (int j = 1; j < QK5_1; ++j) {
const float v = xi[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d;
dsti->dm.y = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (xi[0 + j] - min)*id;
const float x1 = (xi[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
memcpy(dsti->qh, &qh, sizeof(qh));
}
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_NL; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
dsti->qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = xi[0 + j]*xi[0 + j];
const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
}
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@ -8490,6 +8611,39 @@ static void ggml_cpy_f32_q4_1_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q5_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_0 == 0);
const int num_blocks = ne / QK5_0;
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_1 == 0);
const int num_blocks = ne / QK5_1;
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_NL == 0);
const int num_blocks = ne / QK4_NL;
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f16_f16_cuda( static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@ -9303,7 +9457,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
#ifdef GGML_CUDA_F16 #ifdef GGML_CUDA_F16
cuda_pool_alloc<half> src1_dfloat_a; ggml_cuda_pool_alloc<half> src1_dfloat_a(ctx.pool());
half * src1_dfloat = nullptr; // dfloat == half half * src1_dfloat = nullptr; // dfloat == half
bool src1_convert_f16 = bool src1_convert_f16 =
@ -10888,6 +11042,12 @@ static void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * s
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@ -11166,19 +11326,23 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
GGML_ASSERT(cuda_ctx_src->device == buf_ctx_src->device); GGML_ASSERT(cuda_ctx_src->device == buf_ctx_src->device);
GGML_ASSERT(cuda_ctx_dst->device == buf_ctx_dst->device); GGML_ASSERT(cuda_ctx_dst->device == buf_ctx_dst->device);
// copy on src stream
if (cuda_ctx_src->device == cuda_ctx_dst->device) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_dst->stream()));
} else {
#ifdef GGML_CUDA_NO_PEER_COPY
return false;
#else
CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
#endif
}
// record event on src stream
if (!cuda_ctx_src->copy_event) { if (!cuda_ctx_src->copy_event) {
ggml_cuda_set_device(cuda_ctx_src->device); ggml_cuda_set_device(cuda_ctx_src->device);
CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming)); CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
} }
// copy on src stream
if (cuda_ctx_src->device == cuda_ctx_dst->device) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_dst->stream()));
} else {
CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
}
// record event on src stream
CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream())); CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));
// wait on dst stream for the copy to complete // wait on dst stream for the copy to complete
@ -11304,6 +11468,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) { if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
return true; return true;
} }
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
return true;
}
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
return true; return true;
} }
@ -11365,6 +11538,9 @@ GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const
} }
static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend) { static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend) {
#ifdef GGML_CUDA_NO_PEER_COPY
return nullptr;
#else
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device); ggml_cuda_set_device(cuda_ctx->device);
@ -11376,6 +11552,7 @@ static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend)
/* .backend = */ backend, /* .backend = */ backend,
/* .context = */ event, /* .context = */ event,
}; };
#endif
} }
static void ggml_backend_cuda_event_free(ggml_backend_event_t event) { static void ggml_backend_cuda_event_free(ggml_backend_event_t event) {
@ -11481,7 +11658,7 @@ GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, si
} }
GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
if (getenv("GGML_CUDA_NO_PINNED") != nullptr) { if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
return false; return false;
} }
@ -11498,6 +11675,10 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size
} }
GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer) { GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
return;
}
cudaError_t err = cudaHostUnregister(buffer); cudaError_t err = cudaHostUnregister(buffer);
if (err != cudaSuccess) { if (err != cudaSuccess) {
// clear the error // clear the error

View file

@ -173,8 +173,9 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16, GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F32, GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_CONCAT,
@ -598,8 +599,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
@ -739,6 +741,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_IQ4_NL:
return true; return true;
default: default:
return false; return false;
@ -1387,6 +1392,14 @@ static enum ggml_status ggml_metal_graph_compute(
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
default: break;
}
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) { switch (src0->type) {
@ -1701,6 +1714,14 @@ static enum ggml_status ggml_metal_graph_compute(
ne20 % 32 == 0 && ne20 >= 64 && ne20 % 32 == 0 && ne20 >= 64 &&
ne11 > ne11_mm_min) { ne11 > ne11_mm_min) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
default: break;
}
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
switch (src2->type) { switch (src2->type) {
@ -2431,13 +2452,14 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
switch (dstt) { switch (dstt) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
//case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
//case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
default: GGML_ASSERT(false && "not implemented"); default: GGML_ASSERT(false && "not implemented");
}; };
} break; } break;

View file

@ -2388,6 +2388,242 @@ kernel void kernel_cpy_f32_q4_1(
} }
} }
kernel void kernel_cpy_f32_q5_0(
device const float * src0,
device void * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK5_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -16;
const float id = d ? 1.0f/d : 0.0f;
dst_data[i00/QK5_0].d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK5_0/2 + j]*id;
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst_data[i00/QK5_0].qh[j] = qh8[j];
}
}
}
kernel void kernel_cpy_f32_q5_1(
device const float * src0,
device void * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
float max = src[0];
float min = src[0];
for (int j = 1; j < QK5_1; j++) {
const float v = src[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dst_data[i00/QK5_1].d = d;
dst_data[i00/QK5_1].m = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst_data[i00/QK5_1].qh[j] = qh8[j];
}
}
}
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
kernel void kernel_cpy_f32_iq4_nl(
device const float * src0,
device void * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / kvalues_iq4nl_f[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl_f[xi0];
const float v1 = kvalues_iq4nl_f[xi1];
const float w0 = src[0 + j]*src[0 + j];
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
}
}
kernel void kernel_concat( kernel void kernel_concat(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
@ -4220,10 +4456,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
} }
} }
constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
void kernel_mul_mv_iq4_nl_f32_impl( void kernel_mul_mv_iq4_nl_f32_impl(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
@ -5528,18 +5760,18 @@ typedef void (mat_mm_t)(
threadgroup uchar *, threadgroup uchar *,
uint3, uint, uint); uint3, uint, uint);
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>; template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>; template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>; template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>; template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
@ -5588,18 +5820,18 @@ typedef void (mat_mm_id_t)(
threadgroup uchar *, threadgroup uchar *,
uint3, uint, uint); uint3, uint, uint);
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>; template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>; template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>; template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>; template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>; template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>; template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>; template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>; template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>; template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>; template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>; template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;

View file

@ -11705,9 +11705,8 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l, ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
float * scales, float * weight, uint8_t * L, float * scales, float * weight, uint8_t * L,
const int8_t * values, const int8_t * values,
const float * quant_weights) { const float * quant_weights,
const int ntry) {
const int ntry = 7;
float sigma2 = 0; float sigma2 = 0;
for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j]; for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
@ -11719,6 +11718,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
float max_scale = 0, amax_scale = 0; float max_scale = 0, amax_scale = 0;
for (int ib = 0; ib < super_block_size/block_size; ++ib) { for (int ib = 0; ib < super_block_size/block_size; ++ib) {
const float * xb = x + ib*block_size; const float * xb = x + ib*block_size;
uint8_t * Lb = L + ib*block_size;
if (quant_weights) { if (quant_weights) {
const float * qw = quant_weights + ib*block_size; const float * qw = quant_weights + ib*block_size;
for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
@ -11736,12 +11736,13 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
scales[ib] = 0; scales[ib] = 0;
continue; continue;
} }
float d = -max/values[0]; float d = ntry > 0 ? -max/values[0] : max/values[0];
float id = 1/d; float id = 1/d;
float sumqx = 0, sumq2 = 0; float sumqx = 0, sumq2 = 0;
for (int j = 0; j < block_size; ++j) { for (int j = 0; j < block_size; ++j) {
float al = id*xb[j]; float al = id*xb[j];
int l = best_index_int8(16, values, al); int l = best_index_int8(16, values, al);
Lb[j] = l;
float q = values[l]; float q = values[l];
float w = weight[j]; float w = weight[j];
sumqx += w*q*xb[j]; sumqx += w*q*xb[j];
@ -11796,9 +11797,11 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
} }
} else { } else {
dh[0] = GGML_FP32_TO_FP16(scales[0]); dh[0] = GGML_FP32_TO_FP16(scales[0]);
float id = scales[0] ? 1/scales[0] : 0; if (ntry > 0) {
for (int j = 0; j < super_block_size; ++j) { float id = scales[0] ? 1/scales[0] : 0;
L[j] = best_index_int8(16, values, id*x[j]); for (int j = 0; j < super_block_size; ++j) {
L[j] = best_index_int8(16, values, id*x[j]);
}
} }
} }
@ -11823,7 +11826,7 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
for (int ibl = 0; ibl < nblock; ++ibl) { for (int ibl = 0; ibl < nblock; ++ibl) {
const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL; const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l, quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
&scale, weight, L, kvalues_iq4nl, qw); &scale, weight, L, kvalues_iq4nl, qw, 7);
} }
src += n_per_row; src += n_per_row;
qrow += nblock*sizeof(block_iq4_nl); qrow += nblock*sizeof(block_iq4_nl);
@ -11832,14 +11835,23 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
} }
void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) { void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
assert(k % QK4_NL == 0); GGML_ASSERT(k%QK4_NL == 0);
block_iq4_nl * restrict y = vy; int nblock = k/QK4_NL;
quantize_row_iq4_nl_reference(x, y, k); uint8_t L[QK4_NL];
float weight[QK4_NL];
uint16_t unused_h;
uint8_t * unused_l = NULL;
float scale;
block_iq4_nl * iq4 = (block_iq4_nl *)vy;
for (int ibl = 0; ibl < nblock; ++ibl) {
quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
&scale, weight, L, kvalues_iq4nl, NULL, -1);
}
} }
void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) { void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
assert(k % QK4_NL == 0); assert(k % QK4_NL == 0);
quantize_iq4_nl(x, y, 1, k, NULL); quantize_row_iq4_nl(x, y, k);
} }
size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
@ -11857,7 +11869,7 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow
for (int ibl = 0; ibl < nblock; ++ibl) { for (int ibl = 0; ibl < nblock; ++ibl) {
const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL; const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l, quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,
scales, weight, L, kvalues_iq4nl, qw); scales, weight, L, kvalues_iq4nl, qw, 7);
} }
src += n_per_row; src += n_per_row;
qrow += nblock*sizeof(block_iq4_xs); qrow += nblock*sizeof(block_iq4_xs);

View file

@ -740,11 +740,7 @@ namespace dpct
sycl::queue &default_queue() sycl::queue &default_queue()
{ {
#ifdef DPCT_USM_LEVEL_NONE
return out_of_order_queue();
#else
return in_order_queue(); return in_order_queue();
#endif // DPCT_USM_LEVEL_NONE
} }
void queues_wait_and_throw() void queues_wait_and_throw()
@ -763,11 +759,7 @@ namespace dpct
sycl::queue *create_queue(bool enable_exception_handler = false) sycl::queue *create_queue(bool enable_exception_handler = false)
{ {
#ifdef DPCT_USM_LEVEL_NONE
return create_out_of_order_queue(enable_exception_handler);
#else
return create_in_order_queue(enable_exception_handler); return create_in_order_queue(enable_exception_handler);
#endif // DPCT_USM_LEVEL_NONE
} }
sycl::queue *create_queue(sycl::context context, sycl::device device, sycl::queue *create_queue(sycl::context context, sycl::device device,
@ -1075,11 +1067,6 @@ namespace dpct
static pointer_access_attribute get_pointer_attribute(sycl::queue &q, static pointer_access_attribute get_pointer_attribute(sycl::queue &q,
const void *ptr) const void *ptr)
{ {
#ifdef DPCT_USM_LEVEL_NONE
return mem_mgr::instance().is_device_ptr(ptr)
? pointer_access_attribute::device_only
: pointer_access_attribute::host_only;
#else
switch (sycl::get_pointer_type(ptr, q.get_context())) switch (sycl::get_pointer_type(ptr, q.get_context()))
{ {
case sycl::usm::alloc::unknown: case sycl::usm::alloc::unknown:
@ -1090,7 +1077,6 @@ namespace dpct
case sycl::usm::alloc::host: case sycl::usm::alloc::host:
return pointer_access_attribute::host_device; return pointer_access_attribute::host_device;
} }
#endif
} }
template <typename ArgT> template <typename ArgT>
@ -1273,11 +1259,7 @@ namespace dpct
static inline void *dpct_malloc(size_t size, sycl::queue &q) static inline void *dpct_malloc(size_t size, sycl::queue &q)
{ {
#ifdef DPCT_USM_LEVEL_NONE
return mem_mgr::instance().mem_alloc(size * sizeof(byte_t));
#else
return sycl::malloc_device(size, q.get_device(), q.get_context()); return sycl::malloc_device(size, q.get_device(), q.get_context());
#endif // DPCT_USM_LEVEL_NONE
} }
#define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F)) #define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F))
@ -1301,25 +1283,7 @@ namespace dpct
static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr, static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr,
valueT value, size_t size) valueT value, size_t size)
{ {
#ifdef DPCT_USM_LEVEL_NONE
auto &mm = mem_mgr::instance();
assert(mm.is_device_ptr(dev_ptr));
auto alloc = mm.translate_ptr(dev_ptr);
size_t offset = (valueT *)dev_ptr - (valueT *)alloc.alloc_ptr;
return q.submit([&](sycl::handler &cgh)
{
auto r = sycl::range<1>(size);
auto o = sycl::id<1>(offset);
auto new_buffer = alloc.buffer.reinterpret<valueT>(
sycl::range<1>(alloc.size / sizeof(valueT)));
sycl::accessor<valueT, 1, sycl::access_mode::write,
sycl::access::target::device>
acc(new_buffer, cgh, r, o);
cgh.fill(acc, value); });
#else
return q.fill(dev_ptr, value, size); return q.fill(dev_ptr, value, size);
#endif // DPCT_USM_LEVEL_NONE
} }
/** /**
@ -1413,72 +1377,8 @@ namespace dpct
{ {
if (!size) if (!size)
return sycl::event{}; return sycl::event{};
#ifdef DPCT_USM_LEVEL_NONE
auto &mm = mem_mgr::instance();
auto real_direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
switch (real_direction)
{
case host_to_host:
return q.submit([&](sycl::handler &cgh)
{
cgh.depends_on(dep_events);
cgh.host_task([=] { std::memcpy(to_ptr, from_ptr, size); }); });
case host_to_device:
{
auto alloc = mm.translate_ptr(to_ptr);
size_t offset = (byte_t *)to_ptr - alloc.alloc_ptr;
return q.submit([&](sycl::handler &cgh)
{
cgh.depends_on(dep_events);
auto r = sycl::range<1>(size);
auto o = sycl::id<1>(offset);
sycl::accessor<byte_t, 1, sycl::access_mode::write,
sycl::access::target::device>
acc(alloc.buffer, cgh, r, o);
cgh.copy(from_ptr, acc); });
}
case device_to_host:
{
auto alloc = mm.translate_ptr(from_ptr);
size_t offset = (byte_t *)from_ptr - alloc.alloc_ptr;
return q.submit([&](sycl::handler &cgh)
{
cgh.depends_on(dep_events);
auto r = sycl::range<1>(size);
auto o = sycl::id<1>(offset);
sycl::accessor<byte_t, 1, sycl::access_mode::read,
sycl::access::target::device>
acc(alloc.buffer, cgh, r, o);
cgh.copy(acc, to_ptr); });
}
case device_to_device:
{
auto to_alloc = mm.translate_ptr(to_ptr);
auto from_alloc = mm.translate_ptr(from_ptr);
size_t to_offset = (byte_t *)to_ptr - to_alloc.alloc_ptr;
size_t from_offset = (byte_t *)from_ptr - from_alloc.alloc_ptr;
return q.submit([&](sycl::handler &cgh)
{
cgh.depends_on(dep_events);
auto r = sycl::range<1>(size);
auto to_o = sycl::id<1>(to_offset);
auto from_o = sycl::id<1>(from_offset);
sycl::accessor<byte_t, 1, sycl::access_mode::write,
sycl::access::target::device>
to_acc(to_alloc.buffer, cgh, r, to_o);
sycl::accessor<byte_t, 1, sycl::access_mode::read,
sycl::access::target::device>
from_acc(from_alloc.buffer, cgh, r, from_o);
cgh.copy(from_acc, to_acc); });
}
default:
throw std::runtime_error("dpct_memcpy: invalid direction value");
}
#else
return q.memcpy(to_ptr, from_ptr, size, dep_events); return q.memcpy(to_ptr, from_ptr, size, dep_events);
GGML_UNUSED(direction); GGML_UNUSED(direction);
#endif // DPCT_USM_LEVEL_NONE
} }
// Get actual copy range and make sure it will not exceed range. // Get actual copy range and make sure it will not exceed range.
@ -1618,45 +1518,15 @@ namespace dpct
break; break;
} }
case device_to_device: case device_to_device:
#ifdef DPCT_USM_LEVEL_NONE event_list.push_back(q.submit([&](sycl::handler &cgh){
{ cgh.depends_on(dep_events);
auto &mm = mem_mgr::instance(); cgh.parallel_for<class dpct_memcpy_3d_detail>(
auto to_alloc = mm.translate_ptr(to_surface); size,
auto from_alloc = mm.translate_ptr(from_surface); [=](sycl::id<3> id) {
size_t to_offset = (byte_t *)to_surface - to_alloc.alloc_ptr; to_surface[get_offset(id, to_slice, to_range.get(0))] =
size_t from_offset = (byte_t *)from_surface - from_alloc.alloc_ptr; from_surface[get_offset(id, from_slice, from_range.get(0))];
event_list.push_back(q.submit([&](sycl::handler &cgh) }); }));
{ break;
cgh.depends_on(dep_events);
auto to_o = sycl::id<1>(to_offset);
auto from_o = sycl::id<1>(from_offset);
sycl::accessor<byte_t, 1, sycl::access_mode::write,
sycl::access::target::device>
to_acc(to_alloc.buffer, cgh,
get_copy_range(size, to_slice, to_range.get(0)), to_o);
sycl::accessor<byte_t, 1, sycl::access_mode::read,
sycl::access::target::device>
from_acc(from_alloc.buffer, cgh,
get_copy_range(size, from_slice, from_range.get(0)), from_o);
cgh.parallel_for<class dpct_memcpy_3d_detail_usmnone>(
size,
[=](sycl::id<3> id) {
to_acc[get_offset(id, to_slice, to_range.get(0))] =
from_acc[get_offset(id, from_slice, from_range.get(0))];
}); }));
}
#else
event_list.push_back(q.submit([&](sycl::handler &cgh)
{
cgh.depends_on(dep_events);
cgh.parallel_for<class dpct_memcpy_3d_detail>(
size,
[=](sycl::id<3> id) {
to_surface[get_offset(id, to_slice, to_range.get(0))] =
from_surface[get_offset(id, from_slice, from_range.get(0))];
}); }));
#endif
break;
default: default:
throw std::runtime_error("dpct_memcpy: invalid direction value"); throw std::runtime_error("dpct_memcpy: invalid direction value");
} }
@ -1754,11 +1624,7 @@ namespace dpct
{ {
if (ptr) if (ptr)
{ {
#ifdef DPCT_USM_LEVEL_NONE
detail::mem_mgr::instance().mem_free(ptr);
#else
sycl::free(ptr, q.get_context()); sycl::free(ptr, q.get_context());
#endif // DPCT_USM_LEVEL_NONE
} }
} }
@ -1766,11 +1632,7 @@ namespace dpct
inline auto get_memory(const void *x) inline auto get_memory(const void *x)
{ {
T *new_x = reinterpret_cast<T *>(const_cast<void *>(x)); T *new_x = reinterpret_cast<T *>(const_cast<void *>(x));
#ifdef DPCT_USM_LEVEL_NONE
return dpct::get_buffer<std::remove_cv_t<T>>(new_x);
#else
return new_x; return new_x;
#endif
} }
template <typename T> template <typename T>
@ -2222,72 +2084,8 @@ namespace dpct
{ {
if (!size) if (!size)
return sycl::event{}; return sycl::event{};
#ifdef DPCT_USM_LEVEL_NONE
auto &mm = mem_mgr::instance();
auto real_direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
switch (real_direction)
{
case host_to_host:
return q.submit([&](sycl::handler &cgh)
{
cgh.depends_on(dep_events);
cgh.host_task([=] { std::memcpy(to_ptr, from_ptr, size); }); });
case host_to_device:
{
auto alloc = mm.translate_ptr(to_ptr);
size_t offset = (byte_t *)to_ptr - alloc.alloc_ptr;
return q.submit([&](sycl::handler &cgh)
{
cgh.depends_on(dep_events);
auto r = sycl::range<1>(size);
auto o = sycl::id<1>(offset);
sycl::accessor<byte_t, 1, sycl::access_mode::write,
sycl::access::target::device>
acc(alloc.buffer, cgh, r, o);
cgh.copy(from_ptr, acc); });
}
case device_to_host:
{
auto alloc = mm.translate_ptr(from_ptr);
size_t offset = (byte_t *)from_ptr - alloc.alloc_ptr;
return q.submit([&](sycl::handler &cgh)
{
cgh.depends_on(dep_events);
auto r = sycl::range<1>(size);
auto o = sycl::id<1>(offset);
sycl::accessor<byte_t, 1, sycl::access_mode::read,
sycl::access::target::device>
acc(alloc.buffer, cgh, r, o);
cgh.copy(acc, to_ptr); });
}
case device_to_device:
{
auto to_alloc = mm.translate_ptr(to_ptr);
auto from_alloc = mm.translate_ptr(from_ptr);
size_t to_offset = (byte_t *)to_ptr - to_alloc.alloc_ptr;
size_t from_offset = (byte_t *)from_ptr - from_alloc.alloc_ptr;
return q.submit([&](sycl::handler &cgh)
{
cgh.depends_on(dep_events);
auto r = sycl::range<1>(size);
auto to_o = sycl::id<1>(to_offset);
auto from_o = sycl::id<1>(from_offset);
sycl::accessor<byte_t, 1, sycl::access_mode::write,
sycl::access::target::device>
to_acc(to_alloc.buffer, cgh, r, to_o);
sycl::accessor<byte_t, 1, sycl::access_mode::read,
sycl::access::target::device>
from_acc(from_alloc.buffer, cgh, r, from_o);
cgh.copy(from_acc, to_acc); });
}
default:
throw std::runtime_error("dpct_memcpy: invalid direction value");
}
#else
return q.memcpy(to_ptr, from_ptr, size, dep_events); return q.memcpy(to_ptr, from_ptr, size, dep_events);
GGML_UNUSED(direction); GGML_UNUSED(direction);
#endif // DPCT_USM_LEVEL_NONE
} }
// Get actual copy range and make sure it will not exceed range. // Get actual copy range and make sure it will not exceed range.
@ -2427,34 +2225,6 @@ namespace dpct
break; break;
} }
case device_to_device: case device_to_device:
#ifdef DPCT_USM_LEVEL_NONE
{
auto &mm = mem_mgr::instance();
auto to_alloc = mm.translate_ptr(to_surface);
auto from_alloc = mm.translate_ptr(from_surface);
size_t to_offset = (byte_t *)to_surface - to_alloc.alloc_ptr;
size_t from_offset = (byte_t *)from_surface - from_alloc.alloc_ptr;
event_list.push_back(q.submit([&](sycl::handler &cgh)
{
cgh.depends_on(dep_events);
auto to_o = sycl::id<1>(to_offset);
auto from_o = sycl::id<1>(from_offset);
sycl::accessor<byte_t, 1, sycl::access_mode::write,
sycl::access::target::device>
to_acc(to_alloc.buffer, cgh,
get_copy_range(size, to_slice, to_range.get(0)), to_o);
sycl::accessor<byte_t, 1, sycl::access_mode::read,
sycl::access::target::device>
from_acc(from_alloc.buffer, cgh,
get_copy_range(size, from_slice, from_range.get(0)), from_o);
cgh.parallel_for<class dpct_memcpy_3d_detail_usmnone>(
size,
[=](sycl::id<3> id) {
to_acc[get_offset(id, to_slice, to_range.get(0))] =
from_acc[get_offset(id, from_slice, from_range.get(0))];
}); }));
}
#else
event_list.push_back(q.submit([&](sycl::handler &cgh) event_list.push_back(q.submit([&](sycl::handler &cgh)
{ {
cgh.depends_on(dep_events); cgh.depends_on(dep_events);
@ -2464,7 +2234,6 @@ namespace dpct
to_surface[get_offset(id, to_slice, to_range.get(0))] = to_surface[get_offset(id, to_slice, to_range.get(0))] =
from_surface[get_offset(id, from_slice, from_range.get(0))]; from_surface[get_offset(id, from_slice, from_range.get(0))];
}); })); }); }));
#endif
break; break;
default: default:
throw std::runtime_error("dpct_memcpy: invalid direction value"); throw std::runtime_error("dpct_memcpy: invalid direction value");
@ -2655,9 +2424,6 @@ namespace dpct
void *c[], library_data_t c_type, int ldc, void *c[], library_data_t c_type, int ldc,
int batch_size, library_data_t scaling_type) int batch_size, library_data_t scaling_type)
{ {
#ifdef DPCT_USM_LEVEL_NONE
throw std::runtime_error("this API is unsupported when USM level is none");
#else
if (scaling_type == library_data_t::real_float && if (scaling_type == library_data_t::real_float &&
c_type == library_data_t::complex_float) c_type == library_data_t::complex_float)
{ {
@ -2792,7 +2558,6 @@ namespace dpct
default: default:
throw std::runtime_error("the combination of data type is unsupported"); throw std::runtime_error("the combination of data type is unsupported");
} }
#endif
} }
/// Computes a batch of matrix-matrix product with general matrices. /// Computes a batch of matrix-matrix product with general matrices.
@ -3131,24 +2896,9 @@ namespace dpct
template <size_t D = Dimension> template <size_t D = Dimension>
typename std::enable_if<D == 1, T>::type &operator[](size_t index) { typename std::enable_if<D == 1, T>::type &operator[](size_t index) {
init(); init();
#ifdef DPCT_USM_LEVEL_NONE
return dpct::get_buffer<typename std::enable_if<D == 1, T>::type>(
_device_ptr)
.template get_access<sycl::access_mode::read_write>()[index];
#else
return _device_ptr[index]; return _device_ptr[index];
#endif // DPCT_USM_LEVEL_NONE
} }
#ifdef DPCT_USM_LEVEL_NONE
/// Get sycl::accessor for the device memory object when usm is not used.
accessor_t get_access(sycl::handler &cgh) {
return get_buffer(_device_ptr)
.template reinterpret<T, Dimension>(_range)
.template get_access<detail::memory_traits<Memory, T>::mode,
detail::memory_traits<Memory, T>::target>(cgh);
}
#else
/// Get dpct::accessor with dimension info for the device memory object /// Get dpct::accessor with dimension info for the device memory object
/// when usm is used and dimension is greater than 1. /// when usm is used and dimension is greater than 1.
template <size_t D = Dimension> template <size_t D = Dimension>
@ -3156,7 +2906,6 @@ namespace dpct
get_access(sycl::handler &cgh) { get_access(sycl::handler &cgh) {
return dpct_accessor_t((T *)_device_ptr, _range); return dpct_accessor_t((T *)_device_ptr, _range);
} }
#endif // DPCT_USM_LEVEL_NONE
private: private:
device_memory(value_t *memory_ptr, size_t size) device_memory(value_t *memory_ptr, size_t size)
@ -3201,15 +2950,6 @@ namespace dpct
/// Default constructor /// Default constructor
device_memory() : base(1) {} device_memory() : base(1) {}
#ifdef DPCT_USM_LEVEL_NONE
/// Get sycl::accessor for the device memory object when usm is not used.
accessor_t get_access(sycl::handler &cgh) {
auto buf = get_buffer(base::get_ptr())
.template reinterpret<T, 1>(sycl::range<1>(1));
return accessor_t(buf, cgh);
}
#endif // DPCT_USM_LEVEL_NONE
}; };
} // namespace detail } // namespace detail
@ -13181,7 +12921,7 @@ int get_work_group_size(int user_device_id) {
return prop.get_max_work_group_size(); return prop.get_max_work_group_size();
} }
void ggml_init_sycl() try { static void ggml_init_sycl() try {
static bool initialized = false; static bool initialized = false;
if (!initialized) { if (!initialized) {
@ -16677,6 +16417,7 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
}; };
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device_index) { ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device_index) {
ggml_init_sycl();
if (device_index>=g_device_count or device_index<0) { if (device_index>=g_device_count or device_index<0) {
printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n", printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
device_index, g_device_count-1); device_index, g_device_count-1);
@ -17046,6 +16787,7 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface
}; };
GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) { GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
ggml_init_sycl();
// FIXME: this is not thread safe // FIXME: this is not thread safe
static std::map<std::array<float, GGML_SYCL_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map; static std::map<std::array<float, GGML_SYCL_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
@ -17379,6 +17121,13 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
UNUSED(backend); UNUSED(backend);
} }
GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
const int min_batch_size = 32;
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
GGML_UNUSED(backend);
}
static ggml_backend_i ggml_backend_sycl_interface = { static ggml_backend_i ggml_backend_sycl_interface = {
/* .get_name = */ ggml_backend_sycl_name, /* .get_name = */ ggml_backend_sycl_name,
/* .free = */ ggml_backend_sycl_free, /* .free = */ ggml_backend_sycl_free,
@ -17392,7 +17141,7 @@ static ggml_backend_i ggml_backend_sycl_interface = {
/* .graph_plan_compute = */ NULL, /* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_sycl_graph_compute, /* .graph_compute = */ ggml_backend_sycl_graph_compute,
/* .supports_op = */ ggml_backend_sycl_supports_op, /* .supports_op = */ ggml_backend_sycl_supports_op,
/* .offload_op = */ NULL, /* .offload_op = */ ggml_backend_sycl_offload_op,
/* .event_new = */ NULL, /* .event_new = */ NULL,
/* .event_free = */ NULL, /* .event_free = */ NULL,
/* .event_record = */ NULL, /* .event_record = */ NULL,
@ -17406,7 +17155,7 @@ static ggml_guid_t ggml_backend_sycl_guid() {
} }
GGML_CALL ggml_backend_t ggml_backend_sycl_init(int device) { GGML_CALL ggml_backend_t ggml_backend_sycl_init(int device) {
ggml_init_sycl(); // TODO: remove from ggml.c ggml_init_sycl();
check_allow_gpu_index(device); check_allow_gpu_index(device);

View file

@ -16,16 +16,22 @@ extern "C" {
#define GGML_SYCL_MAX_DEVICES 48 #define GGML_SYCL_MAX_DEVICES 48
#define GGML_SYCL_NAME "SYCL" #define GGML_SYCL_NAME "SYCL"
GGML_API void ggml_init_sycl(void); // backend API
GGML_API bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
GGML_API ggml_backend_t ggml_backend_sycl_init(int device); GGML_API ggml_backend_t ggml_backend_sycl_init(int device);
// devide buffer
GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device); GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
// split tensor buffer that splits matrices by rows across multiple devices
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void); GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
GGML_API void ggml_backend_sycl_print_sycl_devices(void); GGML_API void ggml_backend_sycl_print_sycl_devices(void);
GGML_API GGML_CALL void ggml_sycl_get_gpu_list(int *id_list, int max_len); GGML_API GGML_CALL void ggml_sycl_get_gpu_list(int *id_list, int max_len);
GGML_API GGML_CALL void ggml_sycl_get_device_description(int device, char *description, size_t description_size); GGML_API GGML_CALL void ggml_sycl_get_device_description(int device, char *description, size_t description_size);
GGML_API GGML_CALL int ggml_backend_sycl_get_device_count(); GGML_API GGML_CALL int ggml_backend_sycl_get_device_count();
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
GGML_API GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total); GGML_API GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
GGML_API GGML_CALL int ggml_backend_sycl_get_device_index(int device_id); GGML_API GGML_CALL int ggml_backend_sycl_get_device_index(int device_id);
@ -34,6 +40,10 @@ GGML_API GGML_CALL int ggml_backend_sycl_get_device_index(int device_id);
GGML_API GGML_CALL int ggml_backend_sycl_get_device_id(int device_index); GGML_API GGML_CALL int ggml_backend_sycl_get_device_id(int device_index);
GGML_API GGML_CALL void ggml_backend_sycl_set_single_device_mode(int main_gpu_id); GGML_API GGML_CALL void ggml_backend_sycl_set_single_device_mode(int main_gpu_id);
GGML_API GGML_CALL void ggml_backend_sycl_set_mul_device_mode(); GGML_API GGML_CALL void ggml_backend_sycl_set_mul_device_mode();
// SYCL doesn't support registering host memory, keep here for reference
// GGML_API GGML_CALL bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
// GGML_API GGML_CALL void ggml_backend_sycl_unregister_host_buffer(void * buffer);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

76
ggml.c
View file

@ -3,6 +3,7 @@
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml-quants.h" #include "ggml-quants.h"
#include "ggml.h"
#if defined(_MSC_VER) || defined(__MINGW32__) #if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW #include <malloc.h> // using malloc.h with MSC/MINGW
@ -43,6 +44,10 @@
#if defined(_WIN32) #if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h> #include <windows.h>
typedef volatile LONG atomic_int; typedef volatile LONG atomic_int;
@ -286,8 +291,6 @@ inline static void * ggml_calloc(size_t num, size_t size) {
#include "ggml-opencl.h" #include "ggml-opencl.h"
#elif defined(GGML_USE_VULKAN) #elif defined(GGML_USE_VULKAN)
#include "ggml-vulkan.h" #include "ggml-vulkan.h"
#elif defined(GGML_USE_SYCL)
#include "ggml-sycl.h"
#endif #endif
// floating point type used to accumulate sums // floating point type used to accumulate sums
@ -430,6 +433,57 @@ int64_t ggml_cycles_per_ms(void) {
#define ggml_perf_cycles_per_ms() 0 #define ggml_perf_cycles_per_ms() 0
#endif #endif
//
// cross-platform UTF-8 file paths
//
#ifdef _WIN32
static wchar_t * ggml_mbstowcs(const char * mbs) {
int wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, NULL, 0);
if (!wlen) {
errno = EINVAL;
return NULL;
}
wchar_t * wbuf = GGML_MALLOC(wlen * sizeof(wchar_t));
wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, wbuf, wlen);
if (!wlen) {
GGML_FREE(wbuf);
errno = EINVAL;
return NULL;
}
return wbuf;
}
#endif
FILE * ggml_fopen(const char * fname, const char * mode) {
#ifdef _WIN32
FILE * file = NULL;
// convert fname (UTF-8)
wchar_t * wfname = ggml_mbstowcs(fname);
if (wfname) {
// convert mode (ANSI)
wchar_t * wmode = GGML_MALLOC(strlen(mode) + 1);
wchar_t * wmode_p = wmode;
do {
*wmode_p++ = (wchar_t)*mode;
} while (*mode++);
// open file
file = _wfopen(wfname, wmode);
GGML_FREE(wfname);
GGML_FREE(wmode);
}
return file;
#else
return fopen(fname, mode);
#endif
}
// //
// cache line // cache line
// //
@ -2642,8 +2696,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
ggml_cl_init(); ggml_cl_init();
#elif defined(GGML_USE_VULKAN) #elif defined(GGML_USE_VULKAN)
ggml_vk_init_cpu_assist(); ggml_vk_init_cpu_assist();
#elif defined(GGML_USE_SYCL)
ggml_init_sycl();
#endif #endif
ggml_setup_op_has_task_pass(); ggml_setup_op_has_task_pass();
@ -16059,12 +16111,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_TYPE_CPU); GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_TYPE_CPU);
#endif // GGML_USE_VULKAN #endif // GGML_USE_VULKAN
#ifdef GGML_USE_SYCL
bool skip_cpu = ggml_sycl_compute_forward(params, tensor);
if (skip_cpu) {
return;
}
#endif // GGML_USE_SYCL
switch (tensor->op) { switch (tensor->op) {
case GGML_OP_DUP: case GGML_OP_DUP:
{ {
@ -18739,7 +18785,7 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
// write binary data // write binary data
{ {
FILE * fout = fopen(fname, "wb"); FILE * fout = ggml_fopen(fname, "wb");
if (!fout) { if (!fout) {
fprintf(stderr, "%s: failed to open %s\n", __func__, fname); fprintf(stderr, "%s: failed to open %s\n", __func__, fname);
@ -18877,7 +18923,7 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context *
// read file into data // read file into data
{ {
FILE * fin = fopen(fname, "rb"); FILE * fin = ggml_fopen(fname, "rb");
if (!fin) { if (!fin) {
fprintf(stderr, "%s: failed to open %s\n", __func__, fname); fprintf(stderr, "%s: failed to open %s\n", __func__, fname);
return result; return result;
@ -19213,7 +19259,7 @@ static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node,
void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
char color[16]; char color[16];
FILE * fp = fopen(filename, "w"); FILE * fp = ggml_fopen(filename, "w");
GGML_ASSERT(fp); GGML_ASSERT(fp);
fprintf(fp, "digraph G {\n"); fprintf(fp, "digraph G {\n");
@ -20531,7 +20577,7 @@ struct gguf_context * gguf_init_empty(void) {
} }
struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
FILE * file = fopen(fname, "rb"); FILE * file = ggml_fopen(fname, "rb");
if (!file) { if (!file) {
return NULL; return NULL;
} }
@ -21486,7 +21532,7 @@ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf *
} }
void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
FILE * file = fopen(fname, "wb"); FILE * file = ggml_fopen(fname, "wb");
if (!file) { if (!file) {
GGML_ASSERT(false && "failed to open file for writing"); GGML_ASSERT(false && "failed to open file for writing");
} }

8
ggml.h
View file

@ -214,9 +214,10 @@
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif #endif
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h> #include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#define GGML_FILE_MAGIC 0x67676d6c // "ggml" #define GGML_FILE_MAGIC 0x67676d6c // "ggml"
#define GGML_FILE_VERSION 1 #define GGML_FILE_VERSION 1
@ -708,6 +709,9 @@ extern "C" {
GGML_API void ggml_print_backtrace(void); GGML_API void ggml_print_backtrace(void);
// accepts a UTF-8 path, even on Windows
GGML_API FILE * ggml_fopen(const char * fname, const char * mode);
GGML_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems GGML_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node

View file

@ -100,6 +100,7 @@ class MODEL_ARCH(IntEnum):
LLAMA = auto() LLAMA = auto()
FALCON = auto() FALCON = auto()
BAICHUAN = auto() BAICHUAN = auto()
GROK = auto()
GPT2 = auto() GPT2 = auto()
GPTJ = auto() GPTJ = auto()
GPTNEOX = auto() GPTNEOX = auto()
@ -167,6 +168,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama", MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon", MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.BAICHUAN: "baichuan", MODEL_ARCH.BAICHUAN: "baichuan",
MODEL_ARCH.GROK: "grok",
MODEL_ARCH.GPT2: "gpt2", MODEL_ARCH.GPT2: "gpt2",
MODEL_ARCH.GPTJ: "gptj", MODEL_ARCH.GPTJ: "gptj",
MODEL_ARCH.GPTNEOX: "gptneox", MODEL_ARCH.GPTNEOX: "gptneox",
@ -251,6 +253,28 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP, MODEL_TENSOR.FFN_UP_EXP,
], ],
MODEL_ARCH.GROK: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.GPTNEOX: [ MODEL_ARCH.GPTNEOX: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View file

@ -23,6 +23,7 @@ class TensorNameMap:
"model.embedding", # mamba-qbert "model.embedding", # mamba-qbert
"backbone.embedding", # mamba "backbone.embedding", # mamba
"backbone.embeddings", # mamba-hf "backbone.embeddings", # mamba-hf
"transformer.in_out_embed", # Grok
), ),
# Token type embeddings # Token type embeddings
@ -66,6 +67,7 @@ class TensorNameMap:
"lm_head.ln", # phi2 "lm_head.ln", # phi2
"model.norm_f", # mamba-qbert "model.norm_f", # mamba-qbert
"backbone.norm_f", # mamba "backbone.norm_f", # mamba
"transformer.rms_norm", # Grok
), ),
# Rope frequencies # Rope frequencies
@ -93,6 +95,7 @@ class TensorNameMap:
"model.layers.{bid}.attention_norm", # internlm2 "model.layers.{bid}.attention_norm", # internlm2
"model.layers.{bid}.norm", # mamba-qbert "model.layers.{bid}.norm", # mamba-qbert
"backbone.layers.{bid}.norm", # mamba "backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
), ),
# Attention norm 2 # Attention norm 2
@ -116,32 +119,35 @@ class TensorNameMap:
# Attention query # Attention query
MODEL_TENSOR.ATTN_Q: ( MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf "model.layers.{bid}.self_attn.q_proj", # llama-hf
"layers.{bid}.attention.wq", # llama-pth "layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert "encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j "transformer.h.{bid}.attn.q_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.q_proj", # plamo "model.layers.layers.{bid}.self_attn.q_proj", # plamo
"model.layers.{bid}.attention.wq" # internlm2 "model.layers.{bid}.attention.wq", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
), ),
# Attention key # Attention key
MODEL_TENSOR.ATTN_K: ( MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf "model.layers.{bid}.self_attn.k_proj", # llama-hf
"layers.{bid}.attention.wk", # llama-pth "layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert "encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j "transformer.h.{bid}.attn.k_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.k_proj", # plamo "model.layers.layers.{bid}.self_attn.k_proj", # plamo
"model.layers.{bid}.attention.wk" # internlm2 "model.layers.{bid}.attention.wk", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
), ),
# Attention value # Attention value
MODEL_TENSOR.ATTN_V: ( MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf "model.layers.{bid}.self_attn.v_proj", # llama-hf
"layers.{bid}.attention.wv", # llama-pth "layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert "encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j "transformer.h.{bid}.attn.v_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.v_proj", # plamo "model.layers.layers.{bid}.self_attn.v_proj", # plamo
"model.layers.{bid}.attention.wv" # internlm2 "model.layers.{bid}.attention.wv", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
), ),
# Attention output # Attention output
@ -162,12 +168,14 @@ class TensorNameMap:
"model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.{bid}.attention.wo", # internlm2 "model.layers.{bid}.attention.wo", # internlm2
"encoder.layers.{bid}.attn.out_proj", # nomic-bert "encoder.layers.{bid}.attn.out_proj", # nomic-bert
"transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
), ),
# Attention output norm # Attention output norm
MODEL_TENSOR.ATTN_OUT_NORM: ( MODEL_TENSOR.ATTN_OUT_NORM: (
"encoder.layer.{bid}.attention.output.LayerNorm", # bert "encoder.layer.{bid}.attention.output.LayerNorm", # bert
"encoder.layers.{bid}.norm1", # nomic-bert "encoder.layers.{bid}.norm1", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
), ),
# Rotary embeddings # Rotary embeddings
@ -190,11 +198,13 @@ class TensorNameMap:
"model.layers.{bid}.ln2", # yi "model.layers.{bid}.ln2", # yi
"h.{bid}.ln_2", # gpt2 "h.{bid}.ln_2", # gpt2
"model.layers.{bid}.ffn_norm", # internlm2 "model.layers.{bid}.ffn_norm", # internlm2
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
), ),
MODEL_TENSOR.FFN_GATE_INP: ( MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral "layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral "model.layers.{bid}.block_sparse_moe.gate", # mixtral
"transformer.decoder_layer.{bid}.router" # Grok
), ),
# Feed-forward up # Feed-forward up
@ -223,6 +233,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_UP_EXP: ( MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral "layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
"transformer.decoder_layer.{bid}.moe.{xid}.linear_v", # Grok
), ),
# AWQ-activation gate # AWQ-activation gate
@ -243,6 +254,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_GATE_EXP: ( MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral "layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral "model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral
"transformer.decoder_layer.{bid}.moe.{xid}.linear" # Grok
), ),
# Feed-forward down # Feed-forward down
@ -270,6 +282,8 @@ class TensorNameMap:
MODEL_TENSOR.FFN_DOWN_EXP: ( MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral "layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral "model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral
"transformer.decoder_layer.{bid}.moe.{xid}.linear_1", # Grok
), ),
MODEL_TENSOR.ATTN_Q_NORM: ( MODEL_TENSOR.ATTN_Q_NORM: (
@ -287,8 +301,9 @@ class TensorNameMap:
), ),
MODEL_TENSOR.LAYER_OUT_NORM: ( MODEL_TENSOR.LAYER_OUT_NORM: (
"encoder.layer.{bid}.output.LayerNorm", # bert "encoder.layer.{bid}.output.LayerNorm", # bert
"encoder.layers.{bid}.norm2", # nomic-bert "encoder.layers.{bid}.norm2", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
), ),
MODEL_TENSOR.SSM_IN: ( MODEL_TENSOR.SSM_IN: (

871
llama.cpp

File diff suppressed because it is too large Load diff

26
llama.h
View file

@ -275,13 +275,15 @@ extern "C" {
// model quantization parameters // model quantization parameters
typedef struct llama_model_quantize_params { typedef struct llama_model_quantize_params {
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
enum llama_ftype ftype; // quantize to this llama_ftype enum llama_ftype ftype; // quantize to this llama_ftype
bool allow_requantize; // allow quantizing non-f32/f16 tensors enum ggml_type output_tensor_type; // output tensor type
bool quantize_output_tensor; // quantize output.weight enum ggml_type token_embedding_type; // itoken embeddings tensor type
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored bool allow_requantize; // allow quantizing non-f32/f16 tensors
bool pure; // quantize all tensors to the default type bool quantize_output_tensor; // quantize output.weight
void * imatrix; // pointer to importance matrix data bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
bool pure; // quantize all tensors to the default type
void * imatrix; // pointer to importance matrix data
} llama_model_quantize_params; } llama_model_quantize_params;
// grammar types // grammar types
@ -960,6 +962,16 @@ extern "C" {
int32_t n_past, int32_t n_past,
int32_t n_predict); int32_t n_predict);
/// @details Build a split GGUF final path for this chunk.
/// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
// Returns the split_path length.
LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count);
/// @details Extract the path prefix from the split_path if and only if the split_no and split_count match.
/// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0"
// Returns the split_prefix length.
LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
// Performance information // Performance information
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);

BIN
retrieval Executable file

Binary file not shown.

10
scripts/get-wikitext-103.sh Executable file
View file

@ -0,0 +1,10 @@
#!/bin/bash
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip
echo "Usage:"
echo ""
echo " ./perplexity -m model.gguf -f wiki.test.raw [other params]"
echo ""
exit 0

View file

@ -1,66 +1,75 @@
function(llama_build_executable source) # Builds and runs a test source file.
get_filename_component(TEST_TARGET ${source} NAME_WE) # Optional args:
# - NAME: name of the executable & test target (defaults to the source file name without extension)
# - LABEL: label for the test (defaults to main)
# - ARGS: arguments to pass to the test executable
# - WORKING_DIRECTORY
function(llama_test source)
include(CMakeParseArguments)
set(options)
set(oneValueArgs NAME LABEL WORKING_DIRECTORY)
set(multiValueArgs ARGS)
cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if (NOT DEFINED LLAMA_TEST_LABEL)
set(LLAMA_TEST_LABEL "main")
endif()
if (NOT DEFINED LLAMA_TEST_WORKING_DIRECTORY)
set(LLAMA_TEST_WORKING_DIRECTORY .)
endif()
if (DEFINED LLAMA_TEST_NAME)
set(TEST_TARGET ${LLAMA_TEST_NAME})
else()
get_filename_component(TEST_TARGET ${source} NAME_WE)
endif()
add_executable(${TEST_TARGET} ${source} get-model.cpp) add_executable(${TEST_TARGET} ${source} get-model.cpp)
install(TARGETS ${TEST_TARGET} RUNTIME) install(TARGETS ${TEST_TARGET} RUNTIME)
target_link_libraries(${TEST_TARGET} PRIVATE common) target_link_libraries(${TEST_TARGET} PRIVATE common json-schema-to-grammar)
add_test(
NAME ${TEST_TARGET}
WORKING_DIRECTORY ${LLAMA_TEST_WORKING_DIRECTORY}
COMMAND $<TARGET_FILE:${TEST_TARGET}>
${LLAMA_TEST_ARGS})
set_property(TEST ${TEST_TARGET} PROPERTY LABELS ${LLAMA_TEST_LABEL})
endfunction() endfunction()
function(llama_test_executable name source) # llama_test(test-double-float.cpp) # SLOW
get_filename_component(TEST_TARGET ${source} NAME_WE) llama_test(test-quantize-fns.cpp)
add_test(NAME ${name} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN}) llama_test(test-quantize-perf.cpp)
set_property(TEST ${name} PROPERTY LABELS "main") llama_test(test-sampling.cpp)
endfunction() llama_test(test-chat-template.cpp)
function(llama_build_and_test_executable source) llama_test(test-tokenizer-0-llama.cpp NAME test-tokenizer-0-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
llama_build_and_test_executable_with_label(${source} "main") llama_test(test-tokenizer-0-falcon.cpp NAME test-tokenizer-0-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
endfunction()
function(llama_build_and_test_executable_with_label source label) llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
get_filename_component(TEST_TARGET ${source} NAME_WE) llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
add_executable(${TEST_TARGET} ${source} get-model.cpp)
install(TARGETS ${TEST_TARGET} RUNTIME)
target_link_libraries(${TEST_TARGET} PRIVATE common)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
set_property(TEST ${TEST_TARGET} PROPERTY LABELS ${label})
endfunction()
# llama_build_and_test_executable(test-double-float.cpp) # SLOW llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
llama_build_and_test_executable(test-quantize-fns.cpp) llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-aquila ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
llama_build_and_test_executable(test-quantize-perf.cpp) llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
llama_build_and_test_executable(test-sampling.cpp) llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-stablelm-3b-4e1t ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-stablelm-3b-4e1t.gguf)
llama_build_and_test_executable(test-chat-template.cpp) llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-gpt-neox ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-gpt2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt2.gguf)
#llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-bloom ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bloom.gguf) # BIG
llama_build_executable(test-tokenizer-0-llama.cpp) llama_test(test-grammar-parser.cpp)
llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) llama_test(test-llama-grammar.cpp)
llama_test(test-grad0.cpp)
# llama_test(test-opt.cpp) # SLOW
llama_test(test-backend-ops.cpp)
llama_build_executable(test-tokenizer-0-falcon.cpp) llama_test(test-rope.cpp)
llama_test_executable (test-tokenizer-0-falcon test-tokenizer-0-falcon.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
llama_build_executable(test-tokenizer-1-llama.cpp) llama_test(test-model-load-cancel.cpp LABEL "model")
llama_test_executable (test-tokenizer-1-llama test-tokenizer-1-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) llama_test(test-autorelease.cpp LABEL "model")
llama_test_executable (test-tokenizer-1-baichuan test-tokenizer-1-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
llama_build_executable(test-tokenizer-1-bpe.cpp) llama_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
llama_test_executable (test-tokenizer-1-falcon test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf) target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server)
llama_test_executable (test-tokenizer-1-aquila test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
llama_test_executable (test-tokenizer-1-mpt test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
llama_test_executable (test-tokenizer-1-stablelm-3b-4e1t test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-stablelm-3b-4e1t.gguf)
llama_test_executable (test-tokenizer-1-gpt-neox test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
llama_test_executable (test-tokenizer-1-refact test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
llama_test_executable (test-tokenizer-1-starcoder test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
llama_test_executable (test-tokenizer-1-gpt2 test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt2.gguf)
# llama_test_executable (test-tokenizer-1-bloom test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bloom.gguf) # BIG
llama_build_and_test_executable(test-grammar-parser.cpp)
llama_build_and_test_executable(test-llama-grammar.cpp)
llama_build_and_test_executable(test-grad0.cpp)
# llama_build_and_test_executable(test-opt.cpp) # SLOW
llama_build_and_test_executable(test-backend-ops.cpp)
llama_build_and_test_executable(test-rope.cpp)
llama_build_and_test_executable_with_label(test-model-load-cancel.cpp "model")
llama_build_and_test_executable_with_label(test-autorelease.cpp "model")
# dummy executable - not installed # dummy executable - not installed
get_filename_component(TEST_TARGET test-c.c NAME_WE) get_filename_component(TEST_TARGET test-c.c NAME_WE)

View file

@ -0,0 +1,10 @@
import { readFileSync } from "fs"
import { SchemaConverter } from "../examples/server/public/json-schema-to-grammar.mjs"
const [, , file] = process.argv
const url = `file://${file}`
let schema = JSON.parse(readFileSync(file, "utf8"));
const converter = new SchemaConverter({})
schema = await converter.resolveRefs(schema, url)
converter.visit(schema, '')
console.log(converter.formatGrammar())

View file

@ -2091,6 +2091,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
} }
} }
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 128, { 8, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 128, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 64, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 64, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
for (ggml_type type_a : all_types) { for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {2, 4, 8}) { for (int n_mats : {2, 4, 8}) {

View file

@ -0,0 +1,842 @@
#ifdef NDEBUG
#undef NDEBUG
#endif
#include <fstream>
#include <sstream>
#include <regex>
#include "json-schema-to-grammar.h"
#include "grammar-parser.h"
static std::string trim(const std::string & source) {
std::string s(source);
s.erase(0,s.find_first_not_of(" \n\r\t"));
s.erase(s.find_last_not_of(" \n\r\t")+1);
return std::regex_replace(s, std::regex("(^|\n)[ \t]+"), "$1");
}
enum TestCaseStatus {
SUCCESS,
FAILURE
};
struct TestCase {
TestCaseStatus expected_status;
std::string name;
std::string schema;
std::string expected_grammar;
void _print_failure_header() const {
fprintf(stderr, "#\n# Test '%s' failed.\n#\n%s\n", name.c_str(), schema.c_str());
}
void verify(const std::string & actual_grammar) const {
if (trim(actual_grammar) != trim(expected_grammar)) {
_print_failure_header();
fprintf(stderr, "# EXPECTED:\n%s\n# ACTUAL:\n%s\n", expected_grammar.c_str(), actual_grammar.c_str());
assert(false);
}
}
void verify_expectation_parseable() const {
try {
auto state = grammar_parser::parse(expected_grammar.c_str());
if (state.symbol_ids.find("root") == state.symbol_ids.end()) {
throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar);
}
} catch (const std::runtime_error & ex) {
_print_failure_header();
fprintf(stderr, "# GRAMMAR ERROR: %s\n", ex.what());
assert(false);
}
}
void verify_status(TestCaseStatus status) const {
if (status != expected_status) {
_print_failure_header();
fprintf(stderr, "# EXPECTED STATUS: %s\n", expected_status == SUCCESS ? "SUCCESS" : "FAILURE");
fprintf(stderr, "# ACTUAL STATUS: %s\n", status == SUCCESS ? "SUCCESS" : "FAILURE");
assert(false);
}
}
};
static void write(const std::string & file, const std::string & content) {
std::ofstream f;
f.open(file.c_str());
f << content.c_str();
f.close();
}
static std::string read(const std::string & file) {
std::ostringstream actuals;
actuals << std::ifstream(file.c_str()).rdbuf();
return actuals.str();
}
static void test_all(const std::string & lang, std::function<void(const TestCase &)> runner) {
fprintf(stderr, "#\n# Testing JSON schema conversion (%s)\n#\n", lang.c_str());
auto test = [&](const TestCase & tc) {
fprintf(stderr, "- %s%s\n", tc.name.c_str(), tc.expected_status == FAILURE ? " (failure expected)" : "");
runner(tc);
};
test({
FAILURE,
"unknown type",
R"""({
"type": "kaboom"
})""",
""
});
test({
FAILURE,
"invalid type",
R"""({
"type": 123
})""",
""
});
test({
SUCCESS,
"empty schema (object)",
"{}",
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
null ::= "null" space
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
value ::= object | array | string | number | boolean
)"""
});
test({
SUCCESS,
"exotic formats",
R"""({
"items": [
{ "format": "date" },
{ "format": "uuid" },
{ "format": "time" },
{ "format": "date-time" }
]
})""",
R"""(
date ::= [0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )
date-string ::= "\"" date "\"" space
date-time ::= date "T" time
date-time-string ::= "\"" date-time "\"" space
root ::= "[" space date-string "," space uuid "," space time-string "," space date-time-string "]" space
space ::= " "?
time ::= ([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )
time-string ::= "\"" time "\"" space
uuid ::= "\"" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "-" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "-" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "-" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "-" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "\"" space
)"""
});
test({
SUCCESS,
"string",
R"""({
"type": "string"
})""",
R"""(
root ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
space ::= " "?
)"""
});
test({
SUCCESS,
"boolean",
R"""({
"type": "boolean"
})""",
R"""(
root ::= ("true" | "false") space
space ::= " "?
)"""
});
test({
SUCCESS,
"integer",
R"""({
"type": "integer"
})""",
R"""(
root ::= ("-"? ([0-9] | [1-9] [0-9]*)) space
space ::= " "?
)"""
});
test({
SUCCESS,
"string const",
R"""({
"const": "foo"
})""",
R"""(
root ::= "\"foo\""
space ::= " "?
)"""
});
test({
SUCCESS,
"non-string const",
R"""({
"const": 123
})""",
R"""(
root ::= "123"
space ::= " "?
)"""
});
test({
SUCCESS,
"non-string enum",
R"""({
"enum": ["red", "amber", "green", null, 42, ["foo"]]
})""",
R"""(
root ::= "\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]"
space ::= " "?
)"""
});
test({
SUCCESS,
"tuple1",
R"""({
"prefixItems": [{ "type": "string" }]
})""",
R"""(
root ::= "[" space string "]" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"tuple2",
R"""({
"prefixItems": [{ "type": "string" }, { "type": "number" }]
})""",
R"""(
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "[" space string "," space number "]" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"number",
R"""({
"type": "number"
})""",
R"""(
root ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
space ::= " "?
)"""
});
test({
SUCCESS,
"minItems",
R"""({
"items": {
"type": "boolean"
},
"minItems": 2
})""",
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space boolean ( "," space boolean )( "," space boolean )* "]" space
space ::= " "?
)"""
});
test({
SUCCESS,
"maxItems 1",
R"""({
"items": {
"type": "boolean"
},
"maxItems": 1
})""",
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space ( boolean )? "]" space
space ::= " "?
)"""
});
test({
SUCCESS,
"maxItems 2",
R"""({
"items": {
"type": "boolean"
},
"maxItems": 2
})""",
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space ( boolean ( "," space boolean )? )? "]" space
space ::= " "?
)"""
});
test({
SUCCESS,
"min + maxItems",
R"""({
"items": {
"type": ["number", "integer"]
},
"minItems": 3,
"maxItems": 5
})""",
R"""(
integer ::= ("-"? ([0-9] | [1-9] [0-9]*)) space
item ::= number | integer
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "[" space item ( "," space item )( "," space item )( "," space item )?( "," space item )? "]" space
space ::= " "?
)"""
});
test({
SUCCESS,
"simple regexp",
R"""({
"type": "string",
"pattern": "^abc?d*efg+(hij)?kl$"
})""",
R"""(
root ::= "\"" "ab" "c"? "d"* "ef" "g"+ ("hij")? "kl" "\"" space
space ::= " "?
)"""
});
test({
SUCCESS,
"regexp escapes",
R"""({
"type": "string",
"pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
})""",
R"""(
root ::= "\"" "[]{}()|+*?" "\"" space
space ::= " "?
)"""
});
test({
SUCCESS,
"regexp quote",
R"""({
"type": "string",
"pattern": "^\"$"
})""",
R"""(
root ::= "\"" "\"" "\"" space
space ::= " "?
)"""
});
test({
SUCCESS,
"regexp",
R"""({
"type": "string",
"pattern": "^(\\([0-9]{1,3}\\))?[0-9]{3}-[0-9]{4} and...$"
})""",
R"""(
dot ::= [\U00000000-\x09\x0B\x0C\x0E-\U0010FFFF]
root ::= "\"" ("(" root-1 root-1? root-1? ")")? root-1 root-1 root-1 "-" root-1 root-1 root-1 root-1 " and" dot dot dot "\"" space
root-1 ::= [0-9]
space ::= " "?
)"""
});
test({
SUCCESS,
"required props in original order",
R"""({
"type": "object",
"properties": {
"b": {"type": "string"},
"c": {"type": "string"},
"a": {"type": "string"}
},
"required": [
"a",
"b",
"c"
],
"additionalProperties": false,
"definitions": {}
})""",
R"""(
a-kv ::= "\"a\"" space ":" space string
b-kv ::= "\"b\"" space ":" space string
c-kv ::= "\"c\"" space ":" space string
root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"1 optional prop",
R"""({
"properties": {
"a": {
"type": "string"
}
},
"additionalProperties": false
})""",
R"""(
a-kv ::= "\"a\"" space ":" space string
root ::= "{" space (a-kv )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"N optional props",
R"""({
"properties": {
"a": {"type": "string"},
"b": {"type": "string"},
"c": {"type": "string"}
},
"additionalProperties": false
})""",
R"""(
a-kv ::= "\"a\"" space ":" space string
a-rest ::= ( "," space b-kv )? b-rest
b-kv ::= "\"b\"" space ":" space string
b-rest ::= ( "," space c-kv )?
c-kv ::= "\"c\"" space ":" space string
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"required + optional props each in original order",
R"""({
"properties": {
"b": {"type": "string"},
"a": {"type": "string"},
"d": {"type": "string"},
"c": {"type": "string"}
},
"required": ["a", "b"],
"additionalProperties": false
})""",
R"""(
a-kv ::= "\"a\"" space ":" space string
b-kv ::= "\"b\"" space ":" space string
c-kv ::= "\"c\"" space ":" space string
d-kv ::= "\"d\"" space ":" space string
d-rest ::= ( "," space c-kv )?
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"additional props",
R"""({
"type": "object",
"additionalProperties": {"type": "array", "items": {"type": "number"}}
})""",
R"""(
additional-kv ::= string ":" space additional-value
additional-kvs ::= additional-kv ( "," space additional-kv )*
additional-value ::= "[" space ( number ( "," space number )* )? "]" space
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space (additional-kvs )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"additional props (true)",
R"""({
"type": "object",
"additionalProperties": true
})""",
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
null ::= "null" space
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
value ::= object | array | string | number | boolean
)"""
});
test({
SUCCESS,
"additional props (implicit)",
R"""({
"type": "object"
})""",
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
null ::= "null" space
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
value ::= object | array | string | number | boolean
)"""
});
test({
SUCCESS,
"empty w/o additional props",
R"""({
"type": "object",
"additionalProperties": false
})""",
R"""(
root ::= "{" space "}" space
space ::= " "?
)"""
});
test({
SUCCESS,
"required + additional props",
R"""({
"type": "object",
"properties": {
"a": {"type": "number"}
},
"required": ["a"],
"additionalProperties": {"type": "string"}
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
additional-kv ::= string ":" space string
additional-kvs ::= additional-kv ( "," space additional-kv )*
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space a-kv ( "," space ( additional-kvs ) )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"optional + additional props",
R"""({
"type": "object",
"properties": {
"a": {"type": "number"}
},
"additionalProperties": {"type": "number"}
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
a-rest ::= additional-kvs
additional-kv ::= string ":" space number
additional-kvs ::= additional-kv ( "," space additional-kv )*
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space (a-kv a-rest | additional-kvs )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"required + optional + additional props",
R"""({
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"}
},
"required": ["a"],
"additionalProperties": {"type": "number"}
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
additional-kv ::= string ":" space number
additional-kvs ::= additional-kv ( "," space additional-kv )*
b-kv ::= "\"b\"" space ":" space number
b-rest ::= additional-kvs
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space a-kv ( "," space ( b-kv b-rest | additional-kvs ) )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"top-level $ref",
R"""({
"$ref": "#/definitions/MyType",
"definitions": {
"MyType": {
"type": "object",
"properties": {
"a": {
"type": "string"
}
},
"required": [
"a"
],
"additionalProperties": false
}
}
})""",
R"""(
MyType ::= "{" space MyType-a-kv "}" space
MyType-a-kv ::= "\"a\"" space ":" space string
root ::= MyType
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"anyOf",
R"""({
"anyOf": [
{"$ref": "#/definitions/foo"},
{"$ref": "#/definitions/bar"}
],
"definitions": {
"foo": {
"properties": {"a": {"type": "number"}}
},
"bar": {
"properties": {"b": {"type": "number"}}
}
},
"type": "object"
})""",
R"""(
alternative-0 ::= foo
alternative-1 ::= bar
bar ::= "{" space (bar-b-kv )? "}" space
bar-b-kv ::= "\"b\"" space ":" space number
foo ::= "{" space (foo-a-kv )? "}" space
foo-a-kv ::= "\"a\"" space ":" space number
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= alternative-0 | alternative-1
space ::= " "?
)"""
});
test({
SUCCESS,
"mix of allOf, anyOf and $ref (similar to https://json.schemastore.org/tsconfig.json)",
R"""({
"allOf": [
{"$ref": "#/definitions/foo"},
{"$ref": "#/definitions/bar"},
{
"anyOf": [
{"$ref": "#/definitions/baz"},
{"$ref": "#/definitions/bam"}
]
}
],
"definitions": {
"foo": {
"properties": {"a": {"type": "number"}}
},
"bar": {
"properties": {"b": {"type": "number"}}
},
"bam": {
"properties": {"c": {"type": "number"}}
},
"baz": {
"properties": {"d": {"type": "number"}}
}
},
"type": "object"
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
b-kv ::= "\"b\"" space ":" space number
c-kv ::= "\"c\"" space ":" space number
d-kv ::= "\"d\"" space ":" space number
d-rest ::= ( "," space c-kv )?
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
space ::= " "?
)"""
});
test({
SUCCESS,
"conflicting names",
R"""({
"type": "object",
"properties": {
"number": {
"type": "object",
"properties": {
"number": {
"type": "object",
"properties": {
"root": {
"type": "number"
}
},
"required": [
"root"
],
"additionalProperties": false
}
},
"required": [
"number"
],
"additionalProperties": false
}
},
"required": [
"number"
],
"additionalProperties": false,
"definitions": {}
})""",
R"""(
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
number- ::= "{" space number-number-kv "}" space
number-kv ::= "\"number\"" space ":" space number-
number-number ::= "{" space number-number-root-kv "}" space
number-number-kv ::= "\"number\"" space ":" space number-number
number-number-root-kv ::= "\"root\"" space ":" space number
root ::= "{" space number-kv "}" space
space ::= " "?
)"""
});
}
int main() {
fprintf(stderr, "LLAMA_NODE_AVAILABLE = %s\n", getenv("LLAMA_NODE_AVAILABLE") ? "true" : "false");
fprintf(stderr, "LLAMA_PYTHON_AVAILABLE = %s\n", getenv("LLAMA_PYTHON_AVAILABLE") ? "true" : "false");
test_all("C++", [](const TestCase & tc) {
try {
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema)));
tc.verify_status(SUCCESS);
} catch (const std::runtime_error & ex) {
fprintf(stderr, "Error: %s\n", ex.what());
tc.verify_status(FAILURE);
}
});
if (getenv("LLAMA_PYTHON_AVAILABLE") || (std::system("python --version") == 0)) {
test_all("Python", [](const TestCase & tc) {
write("test-json-schema-input.tmp", tc.schema);
tc.verify_status(std::system(
"python ./examples/json-schema-to-grammar.py test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
tc.verify(read("test-grammar-output.tmp"));
});
} else {
fprintf(stderr, "\033[33mWARNING: Python not found, skipping Python JSON schema -> grammar tests.\n\033[0m");
}
if (getenv("LLAMA_NODE_AVAILABLE") || (std::system("node --version") == 0)) {
test_all("JavaScript", [](const TestCase & tc) {
write("test-json-schema-input.tmp", tc.schema);
tc.verify_status(std::system(
"node ./tests/run-json-schema-to-grammar.mjs test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
tc.verify(read("test-grammar-output.tmp"));
});
} else {
fprintf(stderr, "\033[33mWARNING: Node not found, skipping JavaScript JSON schema -> grammar tests.\n\033[0m");
}
test_all("Check Expectations Validity", [](const TestCase & tc) {
if (tc.expected_status == SUCCESS) {
tc.verify_expectation_parseable();
}
});
}