Merge branch 'ggerganov:master' into master

This commit is contained in:
haopeng 2024-11-27 19:50:29 +08:00 committed by GitHub
commit 2c96bd2466
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
174 changed files with 6984 additions and 5254 deletions

161
.clang-format Normal file
View file

@ -0,0 +1,161 @@
---
Language: Cpp
AlignAfterOpenBracket: Align
AlignArrayOfStructures: Left
AlignConsecutiveAssignments: AcrossComments
AlignConsecutiveBitFields: AcrossComments
AlignConsecutiveDeclarations: AcrossComments
AlignConsecutiveMacros: AcrossComments
# AlignConsecutiveShortCaseStatements: AcrossComments
AlignEscapedNewlines: Left # LeftWithLastLine
AlignOperands: Align
AlignTrailingComments:
Kind: Always
OverEmptyLines: 1
AllowAllArgumentsOnNextLine: true
AllowAllParametersOfDeclarationOnNextLine: false
# AllowBreakBeforeNoexceptSpecifier: OnlyWithParen
AllowShortBlocksOnASingleLine: Never
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Inline
AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: Inline
AllowShortLoopsOnASingleLine: false
AlwaysBreakBeforeMultilineStrings: true
BinPackArguments: true
BinPackParameters: true # OnePerLine
BitFieldColonSpacing: Both
BreakBeforeBraces: Custom # Attach
BraceWrapping:
AfterCaseLabel: true
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
AfterExternBlock: false
BeforeCatch: false
BeforeElse: false
BeforeLambdaBody: false
BeforeWhile: false
IndentBraces: false
SplitEmptyFunction: false
SplitEmptyRecord: false
SplitEmptyNamespace: false
# BreakAdjacentStringLiterals: true
BreakAfterAttributes: Never
BreakBeforeBinaryOperators: None
BreakBeforeInlineASMColon: OnlyMultiline
BreakBeforeTernaryOperators: false
# BreakBinaryOperations: Never
BreakConstructorInitializers: AfterColon
# BreakFunctionDefinitionParameters: false
BreakInheritanceList: AfterComma
BreakStringLiterals: true
# BreakTemplateDeclarations: Yes
ColumnLimit: 120
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: false
DerivePointerAlignment: false
DisableFormat: false
EmptyLineBeforeAccessModifier: Leave
EmptyLineAfterAccessModifier: Never
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
IncludeBlocks: Regroup
IncludeCategories:
- Regex: '^<.*\.h>'
Priority: 1
SortPriority: 0
- Regex: '^<.*'
Priority: 2
SortPriority: 0
- Regex: '.*'
Priority: 3
SortPriority: 0
IncludeIsMainRegex: '([-_](test|unittest))?$'
IncludeIsMainSourceRegex: ''
IndentAccessModifiers: false
IndentCaseBlocks: true
IndentCaseLabels: true
IndentExternBlock: NoIndent
IndentGotoLabels: false
IndentPPDirectives: AfterHash
IndentWidth: 4
IndentWrappedFunctionNames: false
InsertBraces: true # NOTE: may lead to incorrect formatting
InsertNewlineAtEOF: true
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
LambdaBodyIndentation: Signature
LineEnding: LF
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBinPackProtocolList: Auto
ObjCBlockIndentWidth: 4
ObjCSpaceAfterProperty: true
ObjCSpaceBeforeProtocolList: true
PPIndentWidth: -1
PackConstructorInitializers: CurrentLine
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Middle
QualifierAlignment: Left
#QualifierOrder: ['static', 'inline', 'friend', 'constexpr', 'const', 'volatile', 'type', 'restrict']
RawStringFormats:
- Language: Cpp
Delimiters:
- cc
- CC
- cpp
- Cpp
- CPP
- 'c++'
- 'C++'
CanonicalDelimiter: ''
ReferenceAlignment: Middle
ReflowComments: false # IndentOnly
SeparateDefinitionBlocks: Always
SortIncludes: CaseInsensitive
SortUsingDeclarations: LexicographicNumeric
SpaceAfterCStyleCast: true
SpaceAfterLogicalNot: false
SpaceAfterTemplateKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: false
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SpaceInEmptyBlock: false
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 2
SpacesInAngles: Never
SpacesInContainerLiterals: true
SpacesInLineCommentPrefix:
Minimum: 1
Maximum: -1
SpacesInParentheses: false
SpacesInSquareBrackets: false
SpaceBeforeSquareBrackets: false
Standard: c++17
TabWidth: 4
UseTab: Never
WhitespaceSensitiveMacros: ['STRINGIZE']
...

View file

@ -6,6 +6,9 @@ ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_V
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
# MUSA architecture to build for (defaults to all supported archs)
ARG MUSA_DOCKER_ARCH=default
RUN apt-get update && \
apt-get install -y build-essential cmake python3 python3-pip git libcurl4-openssl-dev libgomp1
@ -19,7 +22,11 @@ WORKDIR /app
COPY . .
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
# Use the default MUSA archs if not specified
RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \
fi && \
cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
cmake --build build --config Release -j$(nproc) && \
cp build/bin/* .

View file

@ -8,6 +8,9 @@ ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
# MUSA architecture to build for (defaults to all supported archs)
ARG MUSA_DOCKER_ARCH=default
RUN apt-get update && \
apt-get install -y build-essential git cmake
@ -15,7 +18,11 @@ WORKDIR /app
COPY . .
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
# Use the default MUSA archs if not specified
RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \
fi && \
cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
cmake --build build --config Release --target llama-cli -j$(nproc) && \
mkdir -p /app/lib && \
find build -name "*.so" -exec cp {} /app/lib \;

View file

@ -8,6 +8,9 @@ ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
# MUSA architecture to build for (defaults to all supported archs)
ARG MUSA_DOCKER_ARCH=default
RUN apt-get update && \
apt-get install -y build-essential git cmake libcurl4-openssl-dev
@ -15,7 +18,11 @@ WORKDIR /app
COPY . .
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
# Use the default MUSA archs if not specified
RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \
fi && \
cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
cmake --build build --config Release --target llama-server -j$(nproc) && \
mkdir -p /app/lib && \
find build -name "*.so" -exec cp {} /app/lib \;

View file

@ -34,7 +34,7 @@ let
# server tests
openai
behave
pytest
prometheus-client
];
in

View file

@ -1,50 +0,0 @@
name: Low Severity Bugs
description: Used to report low severity bugs in llama.cpp (e.g. cosmetic issues, non critical UI glitches)
title: "Bug: "
labels: ["bug-unconfirmed", "low severity"]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report!
Please include information about your system, the steps to reproduce the bug,
and the version of llama.cpp that you are using.
If possible, please provide a minimal code example that reproduces the bug.
- type: textarea
id: what-happened
attributes:
label: What happened?
description: Also tell us, what did you expect to happen?
placeholder: Tell us what you see!
validations:
required: true
- type: textarea
id: version
attributes:
label: Name and Version
description: Which executable and which version of our software are you running? (use `--version` to get a version string)
placeholder: |
$./llama-cli --version
version: 2999 (42b4109e)
built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
validations:
required: true
- type: dropdown
id: operating-system
attributes:
label: What operating system are you seeing the problem on?
multiple: true
options:
- Linux
- Mac
- Windows
- BSD
- Other? (Please let us know in description)
validations:
required: false
- type: textarea
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
render: shell

View file

@ -0,0 +1,77 @@
name: Bug (compilation)
description: Something goes wrong when trying to compile llama.cpp.
title: "Compile bug: "
labels: ["bug-unconfirmed", "compilation"]
body:
- type: markdown
attributes:
value: >
Thanks for taking the time to fill out this bug report!
This issue template is intended for bug reports where the compilation of llama.cpp fails.
Before opening an issue, please confirm that the compilation still fails with `-DGGML_CCACHE=OFF`.
If the compilation succeeds with ccache disabled you should be able to permanently fix the issue
by clearing `~/.cache/ccache` (on Linux).
- type: textarea
id: commit
attributes:
label: Git commit
description: Which commit are you trying to compile?
placeholder: |
$git rev-parse HEAD
84a07a17b1b08cf2b9747c633a2372782848a27f
validations:
required: true
- type: dropdown
id: operating-system
attributes:
label: Operating systems
description: Which operating systems do you know to be affected?
multiple: true
options:
- Linux
- Mac
- Windows
- BSD
- Other? (Please let us know in description)
validations:
required: true
- type: dropdown
id: backends
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan]
multiple: true
validations:
required: true
- type: textarea
id: info
attributes:
label: Problem description & steps to reproduce
description: >
Please give us a summary of the problem and tell us how to reproduce it.
If you can narrow down the bug to specific compile flags, that information would be very much appreciated by us.
placeholder: >
I'm trying to compile llama.cpp with CUDA support on a fresh install of Ubuntu and get error XY.
Here are the exact commands that I used: ...
validations:
required: true
- type: textarea
id: first_bad_commit
attributes:
label: First Bad Commit
description: >
If the bug was not present on an earlier version: when did it start appearing?
If possible, please do a git bisect and identify the exact commit that introduced the bug.
validations:
required: false
- type: textarea
id: logs
attributes:
label: Relevant log output
description: >
Please copy and paste any relevant log output, including the command that you entered and any generated text.
This will be automatically formatted into code, so no need for backticks.
render: shell
validations:
required: true

View file

@ -0,0 +1,101 @@
name: Bug (model use)
description: Something goes wrong when using a model (in general, not specific to a single llama.cpp module).
title: "Eval bug: "
labels: ["bug-unconfirmed", "model evaluation"]
body:
- type: markdown
attributes:
value: >
Thanks for taking the time to fill out this bug report!
This issue template is intended for bug reports where the model evaluation results
(i.e. the generated text) are incorrect or llama.cpp crashes during model evaluation.
If you encountered the issue while using an external UI (e.g. ollama),
please reproduce your issue using one of the examples/binaries in this repository.
The `llama-cli` binary can be used for simple and reproducible model inference.
- type: textarea
id: version
attributes:
label: Name and Version
description: Which version of our software are you running? (use `--version` to get a version string)
placeholder: |
$./llama-cli --version
version: 2999 (42b4109e)
built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
validations:
required: true
- type: dropdown
id: operating-system
attributes:
label: Operating systems
description: Which operating systems do you know to be affected?
multiple: true
options:
- Linux
- Mac
- Windows
- BSD
- Other? (Please let us know in description)
validations:
required: true
- type: dropdown
id: backends
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan]
multiple: true
validations:
required: true
- type: textarea
id: hardware
attributes:
label: Hardware
description: Which CPUs/GPUs are you using?
placeholder: >
e.g. Ryzen 5950X + 2x RTX 4090
validations:
required: true
- type: textarea
id: model
attributes:
label: Models
description: >
Which model(s) at which quantization were you using when encountering the bug?
If you downloaded a GGUF file off of Huggingface, please provide a link.
placeholder: >
e.g. Meta LLaMA 3.1 Instruct 8b q4_K_M
validations:
required: false
- type: textarea
id: info
attributes:
label: Problem description & steps to reproduce
description: >
Please give us a summary of the problem and tell us how to reproduce it.
If you can narrow down the bug to specific hardware, compile flags, or command line arguments,
that information would be very much appreciated by us.
placeholder: >
e.g. when I run llama-cli with -ngl 99 I get garbled outputs.
When I use -ngl 0 it works correctly.
Here are the exact commands that I used: ...
validations:
required: true
- type: textarea
id: first_bad_commit
attributes:
label: First Bad Commit
description: >
If the bug was not present on an earlier version: when did it start appearing?
If possible, please do a git bisect and identify the exact commit that introduced the bug.
validations:
required: false
- type: textarea
id: logs
attributes:
label: Relevant log output
description: >
Please copy and paste any relevant log output, including the command that you entered and any generated text.
This will be automatically formatted into code, so no need for backticks.
render: shell
validations:
required: true

81
.github/ISSUE_TEMPLATE/019-bug-misc.yml vendored Normal file
View file

@ -0,0 +1,81 @@
name: Bug (misc.)
description: Something is not working the way it should (and it's not covered by any of the above cases).
title: "Misc. bug: "
labels: ["bug-unconfirmed"]
body:
- type: markdown
attributes:
value: >
Thanks for taking the time to fill out this bug report!
This issue template is intended for miscellaneous bugs that don't fit into any other category.
If you encountered the issue while using an external UI (e.g. ollama),
please reproduce your issue using one of the examples/binaries in this repository.
- type: textarea
id: version
attributes:
label: Name and Version
description: Which version of our software is affected? (You can use `--version` to get a version string.)
placeholder: |
$./llama-cli --version
version: 2999 (42b4109e)
built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
validations:
required: true
- type: dropdown
id: operating-system
attributes:
label: Operating systems
description: Which operating systems do you know to be affected?
multiple: true
options:
- Linux
- Mac
- Windows
- BSD
- Other? (Please let us know in description)
validations:
required: false
- type: dropdown
id: module
attributes:
label: Which llama.cpp modules do you know to be affected?
multiple: true
options:
- Documentation/Github
- libllama (core library)
- llama-cli
- llama-server
- llama-bench
- llama-quantize
- Python/Bash scripts
- Test code
- Other (Please specify in the next section)
validations:
required: false
- type: textarea
id: info
attributes:
label: Problem description & steps to reproduce
description: >
Please give us a summary of the problem and tell us how to reproduce it (if applicable).
validations:
required: true
- type: textarea
id: first_bad_commit
attributes:
label: First Bad Commit
description: >
If the bug was not present on an earlier version and it's not trivial to track down: when did it start appearing?
If possible, please do a git bisect and identify the exact commit that introduced the bug.
validations:
required: false
- type: textarea
id: logs
attributes:
label: Relevant log output
description: >
If applicable, please copy and paste any relevant log output, including the command that you entered and any generated text.
This will be automatically formatted into code, so no need for backticks.
render: shell
validations:
required: false

View file

@ -1,50 +0,0 @@
name: Medium Severity Bug
description: Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but generally still useable)
title: "Bug: "
labels: ["bug-unconfirmed", "medium severity"]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report!
Please include information about your system, the steps to reproduce the bug,
and the version of llama.cpp that you are using.
If possible, please provide a minimal code example that reproduces the bug.
- type: textarea
id: what-happened
attributes:
label: What happened?
description: Also tell us, what did you expect to happen?
placeholder: Tell us what you see!
validations:
required: true
- type: textarea
id: version
attributes:
label: Name and Version
description: Which executable and which version of our software are you running? (use `--version` to get a version string)
placeholder: |
$./llama-cli --version
version: 2999 (42b4109e)
built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
validations:
required: true
- type: dropdown
id: operating-system
attributes:
label: What operating system are you seeing the problem on?
multiple: true
options:
- Linux
- Mac
- Windows
- BSD
- Other? (Please let us know in description)
validations:
required: false
- type: textarea
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
render: shell

View file

@ -1,5 +1,5 @@
name: Enhancement
description: Used to request enhancements for llama.cpp
description: Used to request enhancements for llama.cpp.
title: "Feature Request: "
labels: ["enhancement"]
body:

View file

@ -1,50 +0,0 @@
name: High Severity Bug
description: Used to report high severity bugs in llama.cpp (e.g. Malfunctioning features hindering important common workflow)
title: "Bug: "
labels: ["bug-unconfirmed", "high severity"]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report!
Please include information about your system, the steps to reproduce the bug,
and the version of llama.cpp that you are using.
If possible, please provide a minimal code example that reproduces the bug.
- type: textarea
id: what-happened
attributes:
label: What happened?
description: Also tell us, what did you expect to happen?
placeholder: Tell us what you see!
validations:
required: true
- type: textarea
id: version
attributes:
label: Name and Version
description: Which executable and which version of our software are you running? (use `--version` to get a version string)
placeholder: |
$./llama-cli --version
version: 2999 (42b4109e)
built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
validations:
required: true
- type: dropdown
id: operating-system
attributes:
label: What operating system are you seeing the problem on?
multiple: true
options:
- Linux
- Mac
- Windows
- BSD
- Other? (Please let us know in description)
validations:
required: false
- type: textarea
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
render: shell

View file

@ -1,5 +1,5 @@
name: Research
description: Track new technical research area
description: Track new technical research area.
title: "Research: "
labels: ["research 🔬"]
body:

View file

@ -1,50 +0,0 @@
name: Critical Severity Bug
description: Used to report critical severity bugs in llama.cpp (e.g. Crashing, Corrupted, Dataloss)
title: "Bug: "
labels: ["bug-unconfirmed", "critical severity"]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report!
Please include information about your system, the steps to reproduce the bug,
and the version of llama.cpp that you are using.
If possible, please provide a minimal code example that reproduces the bug.
- type: textarea
id: what-happened
attributes:
label: What happened?
description: Also tell us, what did you expect to happen?
placeholder: Tell us what you see!
validations:
required: true
- type: textarea
id: version
attributes:
label: Name and Version
description: Which executable and which version of our software are you running? (use `--version` to get a version string)
placeholder: |
$./llama-cli --version
version: 2999 (42b4109e)
built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
validations:
required: true
- type: dropdown
id: operating-system
attributes:
label: What operating system are you seeing the problem on?
multiple: true
options:
- Linux
- Mac
- Windows
- BSD
- Other? (Please let us know in description)
validations:
required: false
- type: textarea
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
render: shell

View file

@ -1,5 +1,5 @@
name: Refactor (Maintainers)
description: Used to track refactoring opportunities
description: Used to track refactoring opportunities.
title: "Refactor: "
labels: ["refactor"]
body:

15
.github/labeler.yml vendored
View file

@ -3,19 +3,18 @@ Kompute:
- changed-files:
- any-glob-to-any-file:
- ggml/include/ggml-kompute.h
- ggml/src/ggml-kompute.cpp
- ggml/src/ggml-kompute/**
- README-kompute.md
Apple Metal:
- changed-files:
- any-glob-to-any-file:
- ggml/include/ggml-metal.h
- ggml/src/ggml-metal.cpp
- ggml/src/ggml-metal/**
- README-metal.md
SYCL:
- changed-files:
- any-glob-to-any-file:
- ggml/include/ggml-sycl.h
- ggml/src/ggml-sycl.cpp
- ggml/src/ggml-sycl/**
- docs/backend/SYCL.md
- examples/sycl/**
@ -27,8 +26,8 @@ Nvidia GPU:
Vulkan:
- changed-files:
- any-glob-to-any-file:
- ggml/ggml_vk_generate_shaders.py
- ggml/src/ggml-vulkan*
- ggml/include/ggml-vulkan.h
- ggml/src/ggml-vulkan/**
documentation:
- changed-files:
- any-glob-to-any-file:
@ -75,11 +74,7 @@ server:
ggml:
- changed-files:
- any-glob-to-any-file:
- ggml/include/ggml*.h
- ggml/src/ggml*.c
- ggml/src/ggml*.cpp
- ggml/src/ggml*.h
- ggml-cuda/**
- ggml/**
nix:
- changed-files:
- any-glob-to-any-file:

View file

@ -728,7 +728,7 @@ jobs:
cmake --build build --config ${{ matrix.build }} -j $(nproc)
windows-latest-cmake:
runs-on: windows-2019
runs-on: windows-latest
env:
OPENBLAS_VERSION: 0.3.23
@ -871,37 +871,115 @@ jobs:
path: llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip
name: llama-bin-win-${{ matrix.build }}.zip
windows-latest-cmake-cuda:
ubuntu-latest-cmake-cuda:
runs-on: ubuntu-latest
container: nvidia/cuda:12.6.2-devel-ubuntu24.04
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Install dependencies
env:
DEBIAN_FRONTEND: noninteractive
run: |
apt update
apt install -y cmake build-essential ninja-build libgomp1 git
- name: Build with CMake
run: |
cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=89-real -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined -DLLAMA_FATAL_WARNINGS=ON
cmake --build build
windows-2019-cmake-cuda:
runs-on: windows-2019
strategy:
matrix:
cuda: ['12.2.0', '11.7.1']
cuda: ['12.4', '11.7']
build: ['cuda']
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Install CUDA toolkit
id: cuda-toolkit
uses: Jimver/cuda-toolkit@v0.2.15
- name: Install Cuda Toolkit 11.7
if: ${{ matrix.cuda == '11.7' }}
run: |
mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7"
choco install unzip -y
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-11.7.99-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-11.7.99-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-11.7.99-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-11.7.4.6-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-11.7.91-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-11.7.91-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-11.7.101-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-11.7.91-archive.zip"
unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7"
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_cudart-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvcc-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvrtc-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\libcublas-windows-x86_64-11.7.4.6-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvtx-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\visual_studio_integration-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvprof-windows-x86_64-11.7.101-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_cccl-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y
echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
echo "CUDA_PATH_V11_7=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
- name: Install Cuda Toolkit 12.4
if: ${{ matrix.cuda == '12.4' }}
run: |
mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4"
choco install unzip -y
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.4.127-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.4.131-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.4.127-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.4.5.8-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.4.127-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.4.127-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.4.127-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.4.127-archive.zip"
curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.4.127-archive.zip"
unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4"
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cudart-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvcc-windows-x86_64-12.4.131-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvrtc-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libcublas-windows-x86_64-12.4.5.8-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvtx-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_profiler_api-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\visual_studio_integration-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvprof-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y
xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cccl-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y
echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
echo "CUDA_PATH_V12_4=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
- name: Install ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
cuda: ${{ matrix.cuda }}
method: 'network'
sub-packages: '["nvcc", "cudart", "cublas", "cublas_dev", "thrust", "visual_studio_integration"]'
key: ${{ github.job }}-${{ matrix.cuda }}-${{ matrix.build }}
- name: Install Ninja
id: install_ninja
run: |
choco install ninja
- name: Build
id: cmake_build
shell: cmd
run: |
mkdir build
cd build
cmake .. -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DBUILD_SHARED_LIBS=ON -DGGML_RPC=ON
cmake --build . --config Release -j $((${env:NUMBER_OF_PROCESSORS} - 1)) -t ggml
cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS}
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
cmake -S . -B build -G "Ninja Multi-Config" -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DBUILD_SHARED_LIBS=ON -DGGML_RPC=ON
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
cmake --build build --config Release
- name: Determine tag name
id: tag
@ -930,10 +1008,12 @@ jobs:
name: llama-bin-win-cu${{ matrix.cuda }}-x64.zip
- name: Copy and pack Cuda runtime
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
run: |
echo "Cuda install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}"
echo "Cuda install location: ${{ env.CUDA_PATH }}"
$dst='.\build\bin\cudart\'
robocopy "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
robocopy "${{env.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
robocopy "${{env.CUDA_PATH}}\lib" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll
7z a cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip $dst\*
- name: Upload Cuda runtime
@ -952,7 +1032,7 @@ jobs:
env:
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/b380d914-366b-4b77-a74a-05e3c38b3514/intel-oneapi-base-toolkit-2025.0.0.882_offline.exe
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel
ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI"
steps:
- name: Clone
@ -962,7 +1042,8 @@ jobs:
fetch-depth: 0
- name: Install
run: scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL
run: |
scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL
- name: Build
id: cmake_build
@ -981,25 +1062,33 @@ jobs:
echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
fi
- name: Pack artifacts
- name: Build the release package
id: pack_artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: |
echo "cp oneAPI running time dll files in ${{ env.ONEAPI_ROOT }} to ./build/bin"
cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_sycl_blas.4.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_sycl_blas.5.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_core.2.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_tbb_thread.2.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/pi_win_proxy_loader.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/pi_level_zero.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/sycl7.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_level_zero.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_opencl.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_loader.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_win_proxy_loader.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/sycl8.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/svml_dispmd.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libmmd.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libiomp5md.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/dnnl/latest/bin/dnnl.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/tbb/latest/bin/tbb12.dll" ./build/bin
echo "cp oneAPI running time dll files to ./build/bin done"
7z a llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip ./build/bin/*
- name: Upload artifacts
- name: Upload the release package
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: actions/upload-artifact@v4
with:
@ -1164,7 +1253,7 @@ jobs:
- macOS-latest-make
- macOS-latest-cmake
- windows-latest-cmake
- windows-latest-cmake-cuda
- windows-2019-cmake-cuda
- windows-latest-cmake-hip-release
- macOS-latest-cmake-arm64
- macOS-latest-cmake-x64

View file

@ -10,12 +10,10 @@
name: Publish Docker image
on:
#pull_request:
push:
branches:
- master
paths: ['.github/workflows/docker.yml', '.devops/*.Dockerfile', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal']
workflow_dispatch: # allows manual triggering, useful for debugging
workflow_dispatch: # allows manual triggering
schedule:
# Rebuild daily rather than on every push because it is expensive
- cron: '12 4 * * *'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
@ -29,7 +27,6 @@ permissions:
jobs:
push_to_registry:
name: Push Docker image to Docker Hub
#if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
env:
@ -117,7 +114,7 @@ jobs:
swap-storage: true
- name: Build and push Docker image (tagged + versioned)
if: github.event_name == 'push'
if: ${{ github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' }}
uses: docker/build-push-action@v6
with:
context: .

View file

@ -1,72 +0,0 @@
name: Nix aarch64 builds
on:
workflow_dispatch: # allows manual triggering
schedule:
# Rebuild daily rather than on every push because QEMU is expensive (e.g.
# 1.5h instead of minutes with the cold cache).
#
# randint(0, 59), randint(0, 23)
- cron: '26 12 * * *'
# But also rebuild if we touched any of the Nix expressions:
push:
branches:
- master
paths: ['**/*.nix', 'flake.lock']
pull_request:
types: [opened, synchronize, reopened]
paths: ['**/*.nix', 'flake.lock']
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
# Fine-grant permission
# https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token
permissions:
# https://github.com/DeterminateSystems/nix-installer-action?tab=readme-ov-file#with-flakehub
id-token: write
contents: read
jobs:
nix-build-aarch64:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install QEMU
# Copy-paste from https://github.com/orgs/community/discussions/8305#discussioncomment-5888654
run: |
sudo apt-get update
sudo apt-get install -y qemu-user-static qemu-system-aarch64
sudo usermod -a -G kvm $USER
- name: Install Nix
uses: DeterminateSystems/nix-installer-action@v9
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
extra-conf: |
extra-platforms = aarch64-linux
extra-system-features = nixos-test kvm
extra-substituters = https://llama-cpp.cachix.org https://cuda-maintainers.cachix.org
extra-trusted-public-keys = llama-cpp.cachix.org-1:H75X+w83wUKTIPSO1KWy9ADUrzThyGs8P5tmAbkWhQc= cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=
- uses: DeterminateSystems/magic-nix-cache-action@v2
with:
upstream-cache: https://${{ matrix.cachixName }}.cachix.org
- name: Set-up cachix to push the results to
uses: cachix/cachix-action@v13
with:
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
name: llama-cpp
- name: Show all output paths
run: >
nix run github:nix-community/nix-eval-jobs
-- --gc-roots-dir gcroot
--flake
".#packages.aarch64-linux"
- name: Build
run: >
nix run github:Mic92/nix-fast-build
-- --skip-cached --no-nom
--systems aarch64-linux
--flake
".#checks.aarch64-linux"

View file

@ -1,79 +0,0 @@
name: Nix CI
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
pull_request:
types: [opened, synchronize, reopened]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
# Fine-grant permission
# https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token
permissions:
# https://github.com/DeterminateSystems/nix-installer-action?tab=readme-ov-file#with-flakehub
id-token: write
contents: read
jobs:
nix-eval:
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest, macos-latest ]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Nix
uses: DeterminateSystems/nix-installer-action@v9
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
extra-conf: |
extra-substituters = https://llama-cpp.cachix.org https://cuda-maintainers.cachix.org
extra-trusted-public-keys = llama-cpp.cachix.org-1:H75X+w83wUKTIPSO1KWy9ADUrzThyGs8P5tmAbkWhQc= cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=
- uses: DeterminateSystems/magic-nix-cache-action@v2
with:
upstream-cache: https://${{ matrix.cachixName }}.cachix.org
- name: List all flake outputs
run: nix flake show --all-systems
- name: Show all output paths
run: >
nix run github:nix-community/nix-eval-jobs
-- --gc-roots-dir gcroot
--flake
".#packages.$(nix eval --raw --impure --expr builtins.currentSystem)"
nix-build:
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest, macos-latest ]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Nix
uses: DeterminateSystems/nix-installer-action@v9
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
extra-conf: |
extra-substituters = https://llama-cpp.cachix.org https://cuda-maintainers.cachix.org
extra-trusted-public-keys = llama-cpp.cachix.org-1:H75X+w83wUKTIPSO1KWy9ADUrzThyGs8P5tmAbkWhQc= cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=
- uses: DeterminateSystems/magic-nix-cache-action@v2
with:
upstream-cache: https://${{ matrix.cachixName }}.cachix.org
- name: Set-up cachix to push the results to
uses: cachix/cachix-action@v13
with:
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
name: llama-cpp
- name: Build
run: >
nix run github:Mic92/nix-fast-build
-- --skip-cached --no-nom
--flake
".#checks.$(nix eval --raw --impure --expr builtins.currentSystem)"

View file

@ -1,22 +0,0 @@
name: update-flake-lock
on:
workflow_dispatch:
schedule:
- cron: '0 0 * * 0' # runs weekly on Sunday at 00:00
jobs:
lockfile:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Nix
uses: DeterminateSystems/nix-installer-action@main
- name: Update flake.lock
uses: DeterminateSystems/update-flake-lock@main
with:
pr-title: "nix: update flake.lock"
pr-labels: |
nix
pr-reviewers: philiptaron,SomeoneSerge
token: ${{ secrets.FLAKE_TOKEN }}

View file

@ -1,36 +0,0 @@
# Make the flake discoverable on https://flakestry.dev and https://flakehub.com/flakes
name: "Publish a flake to flakestry & flakehub"
on:
push:
tags:
- "*"
workflow_dispatch:
inputs:
tag:
description: "The existing tag to publish"
type: "string"
required: true
jobs:
flakestry-publish:
runs-on: ubuntu-latest
permissions:
id-token: "write"
contents: "read"
steps:
- uses: flakestry/flakestry-publish@main
with:
version: "${{ inputs.tag || github.ref_name }}"
flakehub-publish:
runs-on: "ubuntu-latest"
permissions:
id-token: "write"
contents: "read"
steps:
- uses: "actions/checkout@v4"
with:
ref: "${{ (inputs.tag != null) && format('refs/tags/{0}', inputs.tag) || '' }}"
- uses: "DeterminateSystems/nix-installer-action@main"
- uses: "DeterminateSystems/flakehub-push@main"
with:
visibility: "public"
tag: "${{ inputs.tag }}"

View file

@ -1,6 +1,13 @@
name: flake8 Lint
on: [push, pull_request]
on:
push:
branches:
- master
paths: ['.github/workflows/python-lint.yml', '**/*.py']
pull_request:
types: [opened, synchronize, reopened]
paths: ['.github/workflows/python-lint.yml', '**/*.py']
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}

View file

@ -122,14 +122,14 @@ jobs:
id: server_integration_tests
run: |
cd examples/server/tests
PORT=8888 ./tests.sh
./tests.sh
- name: Slow tests
id: server_integration_tests_slow
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
run: |
cd examples/server/tests
PORT=8888 ./tests.sh --stop --no-skipped --no-capture --tags slow
SLOW_TESTS=1 ./tests.sh
server-windows:
@ -180,11 +180,12 @@ jobs:
run: |
cd examples/server/tests
$env:PYTHONIOENCODING = ":replace"
behave.exe --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp
pytest -v -x
- name: Slow tests
id: server_integration_tests_slow
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
run: |
cd examples/server/tests
behave.exe --stop --no-skipped --no-capture --tags slow
$env:SLOW_TESTS = "1"
pytest -v -x

View file

@ -82,6 +82,7 @@ option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
# Required for relocatable CMake package
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake)
# override ggml options
set(GGML_SANITIZE_THREAD ${LLAMA_SANITIZE_THREAD})
@ -163,8 +164,11 @@ if (GGML_TARGET_DEFINES)
list(APPEND GGML_TRANSIENT_DEFINES ${GGML_TARGET_DEFINES})
endif()
get_target_property(GGML_LINK_LIBRARIES ggml LINK_LIBRARIES)
set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/llama.h)
# all public headers
set(LLAMA_PUBLIC_HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/include/llama.h
${CMAKE_CURRENT_SOURCE_DIR}/include/llama-cpp.h)
set_target_properties(llama PROPERTIES PUBLIC_HEADER "${LLAMA_PUBLIC_HEADERS}")
install(TARGETS llama LIBRARY PUBLIC_HEADER)
configure_package_config_file(

View file

@ -34,6 +34,7 @@ BUILD_TARGETS = \
llama-server \
llama-simple \
llama-simple-chat \
llama-run \
llama-speculative \
llama-tokenize \
llama-vdot \
@ -251,7 +252,7 @@ endif
#
# keep standard at C11 and C++11
MK_CPPFLAGS = -Iggml/include -Iggml/src -Iinclude -Isrc -Icommon
MK_CPPFLAGS = -Iggml/include -Iggml/src -Iinclude -Isrc -Icommon -DGGML_USE_CPU
MK_CFLAGS = -std=c11 -fPIC
MK_CXXFLAGS = -std=c++11 -fPIC
MK_NVCCFLAGS = -std=c++11
@ -290,6 +291,7 @@ endif
# some memory allocation are available on Linux through GNU extensions in libc
ifeq ($(UNAME_S),Linux)
MK_CPPFLAGS += -D_GNU_SOURCE
MK_LDFLAGS += -ldl
endif
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
@ -750,7 +752,7 @@ vulkan-shaders-gen: ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
endif # GGML_VULKAN
ifdef GGML_HIPBLAS
ifdef GGML_HIP
ifeq ($(wildcard /opt/rocm),)
ROCM_PATH ?= /usr
AMDGPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
@ -805,7 +807,7 @@ ggml/src/ggml-cuda/%.o: \
ggml/src/ggml-common.h \
ggml/src/ggml-cuda/common.cuh
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
endif # GGML_HIPBLAS
endif # GGML_HIP
ifdef GGML_MUSA
ifeq ($(wildcard /opt/musa),)
@ -813,7 +815,7 @@ ifdef GGML_MUSA
else
MUSA_PATH ?= /opt/musa
endif
MTGPU_TARGETS ?= mp_21 mp_22
MUSA_ARCHITECTURES ?= 21;22
MK_CPPFLAGS += -DGGML_USE_MUSA -DGGML_USE_CUDA
MK_LDFLAGS += -L$(MUSA_PATH)/lib -Wl,-rpath=$(MUSA_PATH)/lib
@ -832,7 +834,8 @@ ifdef GGML_MUSA
CXX := $(MUSA_PATH)/bin/clang++
MCC := $(CCACHE) $(MUSA_PATH)/bin/mcc
MUSAFLAGS += $(addprefix --cuda-gpu-arch=, $(MTGPU_TARGETS))
MUSAFLAGS = -x musa -mtgpu
MUSAFLAGS += $(foreach arch,$(subst ;, ,$(MUSA_ARCHITECTURES)),--cuda-gpu-arch=mp_$(arch))
ifdef GGML_CUDA_FORCE_MMQ
MUSAFLAGS += -DGGML_CUDA_FORCE_MMQ
@ -876,14 +879,14 @@ ggml/src/ggml-cuda/ggml-cuda.o: \
ggml/src/ggml-backend-impl.h \
ggml/src/ggml-common.h \
$(wildcard ggml/src/ggml-cuda/*.cuh)
$(MCC) $(CXXFLAGS) $(MUSAFLAGS) -x musa -mtgpu -c -o $@ $<
$(MCC) $(CXXFLAGS) $(MUSAFLAGS) -c -o $@ $<
ggml/src/ggml-cuda/%.o: \
ggml/src/ggml-cuda/%.cu \
ggml/include/ggml.h \
ggml/src/ggml-common.h \
ggml/src/ggml-cuda/common.cuh
$(MCC) $(CXXFLAGS) $(MUSAFLAGS) -x musa -mtgpu -c -o $@ $<
$(MCC) $(CXXFLAGS) $(MUSAFLAGS) -c -o $@ $<
endif # GGML_MUSA
ifdef GGML_METAL
@ -966,6 +969,7 @@ OBJ_COMMON = \
$(DIR_COMMON)/console.o \
$(DIR_COMMON)/ngram-cache.o \
$(DIR_COMMON)/sampling.o \
$(DIR_COMMON)/speculative.o \
$(DIR_COMMON)/build-info.o \
$(DIR_COMMON)/json-schema-to-grammar.o
@ -1165,6 +1169,11 @@ llama-infill: examples/infill/infill.cpp \
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
llama-run: examples/run/run.cpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
llama-simple: examples/simple/simple.cpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

View file

@ -43,7 +43,8 @@ linkerSettings.append(.linkedFramework("Accelerate"))
cSettings.append(
contentsOf: [
.define("GGML_USE_ACCELERATE"),
.define("GGML_USE_METAL")
.define("GGML_USE_METAL"),
.define("GGML_USE_CPU")
]
)
#endif

View file

@ -79,6 +79,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [SEA-LION](https://huggingface.co/models?search=sea-lion)
- [x] [GritLM-7B](https://huggingface.co/GritLM/GritLM-7B) + [GritLM-8x7B](https://huggingface.co/GritLM/GritLM-8x7B)
- [x] [OLMo](https://allenai.org/olmo)
- [x] [OLMo 2](https://allenai.org/olmo)
- [x] [OLMoE](https://huggingface.co/allenai/OLMoE-1B-7B-0924)
- [x] [Granite models](https://huggingface.co/collections/ibm-granite/granite-code-models-6624c5cec322e4c148c8b330)
- [x] [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) + [Pythia](https://github.com/EleutherAI/pythia)

33
cmake/common.cmake Normal file
View file

@ -0,0 +1,33 @@
function(llama_add_compile_flags)
if (LLAMA_FATAL_WARNINGS)
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
list(APPEND C_FLAGS -Werror)
list(APPEND CXX_FLAGS -Werror)
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
add_compile_options(/WX)
endif()
endif()
if (LLAMA_ALL_WARNINGS)
if (NOT MSVC)
list(APPEND C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes
-Werror=implicit-int -Werror=implicit-function-declaration)
list(APPEND CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn)
list(APPEND WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
list(APPEND C_FLAGS ${WARNING_FLAGS})
list(APPEND CXX_FLAGS ${WARNING_FLAGS})
ggml_get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION})
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${C_FLAGS};${GF_C_FLAGS}>"
"$<$<COMPILE_LANGUAGE:CXX>:${CXX_FLAGS};${GF_CXX_FLAGS}>")
else()
# todo : msvc
set(C_FLAGS "" PARENT_SCOPE)
set(CXX_FLAGS "" PARENT_SCOPE)
endif()
endif()
endfunction()

View file

@ -3,12 +3,60 @@ set(LLAMA_BUILD_COMMIT @LLAMA_BUILD_COMMIT@)
set(LLAMA_BUILD_NUMBER @LLAMA_BUILD_NUMBER@)
set(LLAMA_SHARED_LIB @BUILD_SHARED_LIBS@)
set(GGML_STATIC @GGML_STATIC@)
set(GGML_NATIVE @GGML_NATIVE@)
set(GGML_LTO @GGML_LTO@)
set(GGML_CCACHE @GGML_CCACHE@)
set(GGML_AVX @GGML_AVX@)
set(GGML_AVX2 @GGML_AVX2@)
set(GGML_AVX512 @GGML_AVX512@)
set(GGML_AVX512_VBMI @GGML_AVX512_VBMI@)
set(GGML_AVX512_VNNI @GGML_AVX512_VNNI@)
set(GGML_AVX512_BF16 @GGML_AVX512_BF16@)
set(GGML_AMX_TILE @GGML_AMX_TILE@)
set(GGML_AMX_INT8 @GGML_AMX_INT8@)
set(GGML_AMX_BF16 @GGML_AMX_BF16@)
set(GGML_FMA @GGML_FMA@)
set(GGML_LASX @GGML_LASX@)
set(GGML_LSX @GGML_LSX@)
set(GGML_RVV @GGML_RVV@)
set(GGML_SVE @GGML_SVE@)
set(GGML_ACCELERATE @GGML_ACCELERATE@)
set(GGML_OPENMP @GGML_OPENMP@)
set(GGML_CPU_HBM @GGML_CPU_HBM@)
set(GGML_BLAS_VENDOR @GGML_BLAS_VENDOR@)
set(GGML_CUDA_FORCE_MMQ @GGML_CUDA_FORCE_MMQ@)
set(GGML_CUDA_FORCE_CUBLAS @GGML_CUDA_FORCE_CUBLAS@)
set(GGML_CUDA_F16 @GGML_CUDA_F16@)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE @GGML_CUDA_PEER_MAX_BATCH_SIZE@)
set(GGML_CUDA_NO_PEER_COPY @GGML_CUDA_NO_PEER_COPY@)
set(GGML_CUDA_NO_VMM @GGML_CUDA_NO_VMM@)
set(GGML_CUDA_FA_ALL_QUANTS @GGML_CUDA_FA_ALL_QUANTS@)
set(GGML_CUDA_GRAPHS @GGML_CUDA_GRAPHS@)
set(GGML_HIP_UMA @GGML_HIP_UMA@)
set(GGML_VULKAN_CHECK_RESULTS @GGML_VULKAN_CHECK_RESULTS@)
set(GGML_VULKAN_DEBUG @GGML_VULKAN_DEBUG@)
set(GGML_VULKAN_MEMORY_DEBUG @GGML_VULKAN_MEMORY_DEBUG@)
set(GGML_VULKAN_VALIDATE @GGML_VULKAN_VALIDATE@)
set(GGML_OPENMP @GGML_OPENMP@)
set(GGML_VULKAN_DEBUG @GGML_VULKAN_DEBUG@)
set(GGML_VULKAN_MEMORY_DEBUG @GGML_VULKAN_MEMORY_DEBUG@)
set(GGML_VULKAN_SHADER_DEBUG_INFO @GGML_VULKAN_SHADER_DEBUG_INFO@)
set(GGML_VULKAN_PERF @GGML_VULKAN_PERF@)
set(GGML_VULKAN_VALIDATE @GGML_VULKAN_VALIDATE@)
set(GGML_VULKAN_RUN_TESTS @GGML_VULKAN_RUN_TESTS@)
set(GGML_METAL_USE_BF16 @GGML_METAL_USE_BF16@)
set(GGML_METAL_NDEBUG @GGML_METAL_NDEBUG@)
set(GGML_METAL_SHADER_DEBUG @GGML_METAL_SHADER_DEBUG@)
set(GGML_METAL_EMBED_LIBRARY @GGML_METAL_EMBED_LIBRARY@)
set(GGML_METAL_MACOSX_VERSION_MIN @GGML_METAL_MACOSX_VERSION_MIN@)
set(GGML_METAL_STD @GGML_METAL_STD@)
set(GGML_SYCL_F16 @GGML_SYCL_F16@)
set(GGML_SYCL_TARGET @GGML_SYCL_TARGET@)
set(GGML_SYCL_DEVICE_ARCH @GGML_SYCL_DEVICE_ARCH@)
@PACKAGE_INIT@
@ -20,6 +68,7 @@ find_package(Threads REQUIRED)
set(_llama_transient_defines "@GGML_TRANSIENT_DEFINES@")
set(_llama_link_deps "")
set(_llama_link_opts "")
foreach(_ggml_lib ggml ggml-base)
string(REPLACE "-" "_" _ggml_lib_var "${_ggml_lib}_LIBRARY")
find_library(${_ggml_lib_var} ${_ggml_lib}
@ -49,41 +98,63 @@ foreach(backend amx blas cann cpu cuda hip kompute metal musa rpc sycl vulkan)
endif()
endforeach()
if (APPLE AND GGML_ACCELERATE)
find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
endif()
if (NOT LLAMA_SHARED_LIB)
if (APPLE AND GGML_ACCELERATE)
find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
list(APPEND _llama_link_deps ${ACCELERATE_FRAMEWORK})
endif()
if (GGML_BLAS)
find_package(BLAS REQUIRED)
endif()
if (GGML_OPENMP)
find_package(OpenMP REQUIRED)
list(APPEND _llama_link_deps OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
endif()
if (GGML_CUDA)
find_package(CUDAToolkit REQUIRED)
endif()
if (GGML_CPU_HBM)
find_library(memkind memkind REQUIRED)
list(APPEND _llama_link_deps memkind)
endif()
if (GGML_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
endif()
if (GGML_BLAS)
find_package(BLAS REQUIRED)
list(APPEND _llama_link_deps ${BLAS_LIBRARIES})
list(APPEND _llama_link_opts ${BLAS_LINKER_FLAGS})
endif()
if (GGML_VULKAN)
find_package(Vulkan REQUIRED)
endif()
if (GGML_CUDA)
find_package(CUDAToolkit REQUIRED)
endif()
if (GGML_HIP)
find_package(hip REQUIRED)
find_package(hipblas REQUIRED)
find_package(rocblas REQUIRED)
endif()
if (GGML_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
list(APPEND _llama_link_deps ${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
endif()
if (GGML_SYCL)
find_package(IntelSYCL REQUIRED)
find_package(MKL REQUIRED)
endif()
if (GGML_VULKAN)
find_package(Vulkan REQUIRED)
list(APPEND _llama_link_deps Vulkan::Vulkan)
endif()
if (GGML_OPENMP)
find_package(OpenMP REQUIRED)
if (GGML_HIP)
find_package(hip REQUIRED)
find_package(hipblas REQUIRED)
find_package(rocblas REQUIRED)
list(APPEND _llama_link_deps hip::host roc::rocblas roc::hipblas)
endif()
if (GGML_SYCL)
find_package(DNNL)
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
list(APPEND _llama_link_deps DNNL::dnnl)
endif()
if (WIN32)
find_package(IntelSYCL REQUIRED)
find_package(MKL REQUIRED)
list(APPEND _llama_link_deps IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
endif()
endif()
endif()
find_library(llama_LIBRARY llama
@ -97,6 +168,7 @@ set_target_properties(llama
PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${LLAMA_INCLUDE_DIR}"
INTERFACE_LINK_LIBRARIES "${_llama_link_deps}"
INTERFACE_LINK_OPTIONS "${_llama_link_opts}"
INTERFACE_COMPILE_DEFINITIONS "${_llama_transient_defines}"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
IMPORTED_LOCATION "${llama_LIBRARY}"

View file

@ -2,6 +2,8 @@
find_package(Threads REQUIRED)
llama_add_compile_flags()
# Build info header
#
@ -66,6 +68,8 @@ add_library(${TARGET} STATIC
ngram-cache.h
sampling.cpp
sampling.h
speculative.cpp
speculative.h
)
if (BUILD_SHARED_LIBS)

View file

@ -233,10 +233,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
}
}
postprocess_cpu_params(params.cpuparams, nullptr);
postprocess_cpu_params(params.cpuparams, nullptr);
postprocess_cpu_params(params.cpuparams_batch, &params.cpuparams);
postprocess_cpu_params(params.draft_cpuparams, &params.cpuparams);
postprocess_cpu_params(params.draft_cpuparams_batch, &params.cpuparams_batch);
postprocess_cpu_params(params.speculative.cpuparams, &params.cpuparams);
postprocess_cpu_params(params.speculative.cpuparams_batch, &params.cpuparams_batch);
if (params.prompt_cache_all && (params.interactive || params.interactive_first)) {
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
@ -251,7 +252,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
for (auto & antiprompt : params.antiprompt) {
string_process_escapes(antiprompt);
}
for (auto & seq_breaker : params.sparams.dry_sequence_breakers) {
for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
string_process_escapes(seq_breaker);
}
}
@ -297,6 +298,27 @@ static void common_params_print_usage(common_params_context & ctx_arg) {
print_options(specific_options);
}
static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & value) {
std::vector<ggml_backend_dev_t> devices;
auto dev_names = string_split<std::string>(value, ',');
if (dev_names.empty()) {
throw std::invalid_argument("no devices specified");
}
if (dev_names.size() == 1 && dev_names[0] == "none") {
devices.push_back(nullptr);
} else {
for (const auto & device : dev_names) {
auto * dev = ggml_backend_dev_by_name(device.c_str());
if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
throw std::invalid_argument(string_format("invalid device: %s", device.c_str()));
}
devices.push_back(dev);
}
devices.push_back(nullptr);
}
return devices;
}
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
auto ctx_arg = common_params_parser_init(params, ex, print_usage);
const common_params params_org = ctx_arg.params; // the example can modify the default params
@ -323,13 +345,16 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
}
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
// load dynamic backends
ggml_backend_load_all();
common_params_context ctx_arg(params);
ctx_arg.print_usage = print_usage;
ctx_arg.ex = ex;
std::string sampler_type_chars;
std::string sampler_type_names;
for (const auto & sampler : params.sparams.samplers) {
for (const auto & sampler : params.sampling.samplers) {
sampler_type_chars += common_sampler_type_to_chr(sampler);
sampler_type_names += common_sampler_type_to_str(sampler) + ";";
}
@ -407,26 +432,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
));
add_opt(common_arg(
{"-td", "--threads-draft"}, "N",
"number of threads to use during generation (default: same as --threads)",
[](common_params & params, int value) {
params.draft_cpuparams.n_threads = value;
if (params.draft_cpuparams.n_threads <= 0) {
params.draft_cpuparams.n_threads = std::thread::hardware_concurrency();
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-tbd", "--threads-batch-draft"}, "N",
"number of threads to use during batch and prompt processing (default: same as --threads-draft)",
[](common_params & params, int value) {
params.draft_cpuparams_batch.n_threads = value;
if (params.draft_cpuparams_batch.n_threads <= 0) {
params.draft_cpuparams_batch.n_threads = std::thread::hardware_concurrency();
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-C", "--cpu-mask"}, "M",
"CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")",
@ -515,108 +520,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.cpuparams_batch.poll = value;
}
));
add_opt(common_arg(
{"-Cd", "--cpu-mask-draft"}, "M",
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
[](common_params & params, const std::string & mask) {
params.draft_cpuparams.mask_valid = true;
if (!parse_cpu_mask(mask, params.draft_cpuparams.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Crd", "--cpu-range-draft"}, "lo-hi",
"Ranges of CPUs for affinity. Complements --cpu-mask-draft",
[](common_params & params, const std::string & range) {
params.draft_cpuparams.mask_valid = true;
if (!parse_cpu_range(range, params.draft_cpuparams.cpumask)) {
throw std::invalid_argument("invalid range");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--cpu-strict-draft"}, "<0|1>",
"Use strict CPU placement for draft model (default: same as --cpu-strict)",
[](common_params & params, int value) {
params.draft_cpuparams.strict_cpu = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--prio-draft"}, "N",
string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.draft_cpuparams.priority),
[](common_params & params, int prio) {
if (prio < 0 || prio > 3) {
throw std::invalid_argument("invalid value");
}
params.draft_cpuparams.priority = (enum ggml_sched_priority) prio;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--poll-draft"}, "<0|1>",
"Use polling to wait for draft model work (default: same as --poll])",
[](common_params & params, int value) {
params.draft_cpuparams.poll = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Cbd", "--cpu-mask-batch-draft"}, "M",
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
[](common_params & params, const std::string & mask) {
params.draft_cpuparams_batch.mask_valid = true;
if (!parse_cpu_mask(mask, params.draft_cpuparams_batch.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Crbd", "--cpu-range-batch-draft"}, "lo-hi",
"Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)",
[](common_params & params, const std::string & range) {
params.draft_cpuparams_batch.mask_valid = true;
if (!parse_cpu_range(range, params.draft_cpuparams_batch.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--cpu-strict-batch-draft"}, "<0|1>",
"Use strict CPU placement for draft model (default: --cpu-strict-draft)",
[](common_params & params, int value) {
params.draft_cpuparams_batch.strict_cpu = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--prio-batch-draft"}, "N",
string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.draft_cpuparams_batch.priority),
[](common_params & params, int prio) {
if (prio < 0 || prio > 3) {
throw std::invalid_argument("invalid value");
}
params.draft_cpuparams_batch.priority = (enum ggml_sched_priority) prio;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--poll-batch-draft"}, "<0|1>",
"Use polling to wait for draft model work (default: --poll-draft)",
[](common_params & params, int value) {
params.draft_cpuparams_batch.poll = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--draft"}, "N",
string_format("number of tokens to draft for speculative decoding (default: %d)", params.n_draft),
[](common_params & params, int value) {
params.n_draft = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
add_opt(common_arg(
{"-ps", "--p-split"}, "N",
string_format("speculative decoding split probability (default: %.1f)", (double)params.p_split),
[](common_params & params, const std::string & value) {
params.p_split = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-lcs", "--lookup-cache-static"}, "FNAME",
"path to static lookup cache to use for lookup decoding (not updated by generation)",
@ -701,7 +604,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"),
[](common_params & params) {
params.no_perf = true;
params.sparams.no_perf = true;
params.sampling.no_perf = true;
}
).set_env("LLAMA_ARG_NO_PERF"));
add_opt(common_arg(
@ -883,155 +786,155 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
[](common_params & params, const std::string & value) {
const auto sampler_names = string_split<std::string>(value, ';');
params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
}
).set_sparam());
add_opt(common_arg(
{"-s", "--seed"}, "SEED",
string_format("RNG seed (default: %d, use random seed for %d)", params.sparams.seed, LLAMA_DEFAULT_SEED),
string_format("RNG seed (default: %d, use random seed for %d)", params.sampling.seed, LLAMA_DEFAULT_SEED),
[](common_params & params, const std::string & value) {
params.sparams.seed = std::stoul(value);
params.sampling.seed = std::stoul(value);
}
).set_sparam());
add_opt(common_arg(
{"--sampling-seq"}, "SEQUENCE",
string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()),
[](common_params & params, const std::string & value) {
params.sparams.samplers = common_sampler_types_from_chars(value);
params.sampling.samplers = common_sampler_types_from_chars(value);
}
).set_sparam());
add_opt(common_arg(
{"--ignore-eos"},
"ignore end of stream token and continue generating (implies --logit-bias EOS-inf)",
[](common_params & params) {
params.sparams.ignore_eos = true;
params.sampling.ignore_eos = true;
}
).set_sparam());
add_opt(common_arg(
{"--penalize-nl"},
string_format("penalize newline tokens (default: %s)", params.sparams.penalize_nl ? "true" : "false"),
string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"),
[](common_params & params) {
params.sparams.penalize_nl = true;
params.sampling.penalize_nl = true;
}
).set_sparam());
add_opt(common_arg(
{"--temp"}, "N",
string_format("temperature (default: %.1f)", (double)params.sparams.temp),
string_format("temperature (default: %.1f)", (double)params.sampling.temp),
[](common_params & params, const std::string & value) {
params.sparams.temp = std::stof(value);
params.sparams.temp = std::max(params.sparams.temp, 0.0f);
params.sampling.temp = std::stof(value);
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
}
).set_sparam());
add_opt(common_arg(
{"--top-k"}, "N",
string_format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k),
string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
[](common_params & params, int value) {
params.sparams.top_k = value;
params.sampling.top_k = value;
}
).set_sparam());
add_opt(common_arg(
{"--top-p"}, "N",
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sparams.top_p),
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
[](common_params & params, const std::string & value) {
params.sparams.top_p = std::stof(value);
params.sampling.top_p = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--min-p"}, "N",
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sparams.min_p),
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
[](common_params & params, const std::string & value) {
params.sparams.min_p = std::stof(value);
params.sampling.min_p = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--xtc-probability"}, "N",
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
[](common_params & params, const std::string & value) {
params.sparams.xtc_probability = std::stof(value);
params.sampling.xtc_probability = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--xtc-threshold"}, "N",
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sparams.xtc_threshold),
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
[](common_params & params, const std::string & value) {
params.sparams.xtc_threshold = std::stof(value);
params.sampling.xtc_threshold = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--typical"}, "N",
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p),
[](common_params & params, const std::string & value) {
params.sparams.typ_p = std::stof(value);
params.sampling.typ_p = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--repeat-last-n"}, "N",
string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sparams.penalty_last_n),
string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n),
[](common_params & params, int value) {
params.sparams.penalty_last_n = value;
params.sparams.n_prev = std::max(params.sparams.n_prev, params.sparams.penalty_last_n);
params.sampling.penalty_last_n = value;
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
}
).set_sparam());
add_opt(common_arg(
{"--repeat-penalty"}, "N",
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sparams.penalty_repeat),
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
[](common_params & params, const std::string & value) {
params.sparams.penalty_repeat = std::stof(value);
params.sampling.penalty_repeat = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--presence-penalty"}, "N",
string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_present),
string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_present),
[](common_params & params, const std::string & value) {
params.sparams.penalty_present = std::stof(value);
params.sampling.penalty_present = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--frequency-penalty"}, "N",
string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_freq),
string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_freq),
[](common_params & params, const std::string & value) {
params.sparams.penalty_freq = std::stof(value);
params.sampling.penalty_freq = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dry-multiplier"}, "N",
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier),
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sampling.dry_multiplier),
[](common_params & params, const std::string & value) {
params.sparams.dry_multiplier = std::stof(value);
params.sampling.dry_multiplier = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dry-base"}, "N",
string_format("set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base),
string_format("set DRY sampling base value (default: %.2f)", (double)params.sampling.dry_base),
[](common_params & params, const std::string & value) {
float potential_base = std::stof(value);
if (potential_base >= 1.0f)
{
params.sparams.dry_base = potential_base;
params.sampling.dry_base = potential_base;
}
}
).set_sparam());
add_opt(common_arg(
{"--dry-allowed-length"}, "N",
string_format("set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length),
string_format("set allowed length for DRY sampling (default: %d)", params.sampling.dry_allowed_length),
[](common_params & params, int value) {
params.sparams.dry_allowed_length = value;
params.sampling.dry_allowed_length = value;
}
).set_sparam());
add_opt(common_arg(
{"--dry-penalty-last-n"}, "N",
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n),
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n),
[](common_params & params, int value) {
params.sparams.dry_penalty_last_n = value;
params.sampling.dry_penalty_last_n = value;
}
).set_sparam());
add_opt(common_arg(
{"--dry-sequence-breaker"}, "STRING",
string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n",
params.sparams.dry_sequence_breakers.empty() ? "none" :
std::accumulate(std::next(params.sparams.dry_sequence_breakers.begin()),
params.sparams.dry_sequence_breakers.end(),
std::string("'") + (params.sparams.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sparams.dry_sequence_breakers[0]) + "'",
params.sampling.dry_sequence_breakers.empty() ? "none" :
std::accumulate(std::next(params.sampling.dry_sequence_breakers.begin()),
params.sampling.dry_sequence_breakers.end(),
std::string("'") + (params.sampling.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sampling.dry_sequence_breakers[0]) + "'",
[](const std::string& a, const std::string& b) {
std::string formatted_b = (b == "\n") ? "\\n" : b;
return a + ", '" + formatted_b + "'";
@ -1040,51 +943,51 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
static bool defaults_cleared = false;
if (!defaults_cleared) {
params.sparams.dry_sequence_breakers.clear();
params.sampling.dry_sequence_breakers.clear();
defaults_cleared = true;
}
if (value == "none") {
params.sparams.dry_sequence_breakers.clear();
params.sampling.dry_sequence_breakers.clear();
} else {
params.sparams.dry_sequence_breakers.emplace_back(value);
params.sampling.dry_sequence_breakers.emplace_back(value);
}
}
).set_sparam());
add_opt(common_arg(
{"--dynatemp-range"}, "N",
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
[](common_params & params, const std::string & value) {
params.sparams.dynatemp_range = std::stof(value);
params.sampling.dynatemp_range = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dynatemp-exp"}, "N",
string_format("dynamic temperature exponent (default: %.1f)", (double)params.sparams.dynatemp_exponent),
string_format("dynamic temperature exponent (default: %.1f)", (double)params.sampling.dynatemp_exponent),
[](common_params & params, const std::string & value) {
params.sparams.dynatemp_exponent = std::stof(value);
params.sampling.dynatemp_exponent = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--mirostat"}, "N",
string_format("use Mirostat sampling.\nTop K, Nucleus and Locally Typical samplers are ignored if used.\n"
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat),
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
[](common_params & params, int value) {
params.sparams.mirostat = value;
params.sampling.mirostat = value;
}
).set_sparam());
add_opt(common_arg(
{"--mirostat-lr"}, "N",
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sparams.mirostat_eta),
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
[](common_params & params, const std::string & value) {
params.sparams.mirostat_eta = std::stof(value);
params.sampling.mirostat_eta = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--mirostat-ent"}, "N",
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sparams.mirostat_tau),
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
[](common_params & params, const std::string & value) {
params.sparams.mirostat_tau = std::stof(value);
params.sampling.mirostat_tau = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
@ -1100,7 +1003,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
try {
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
params.sparams.logit_bias.push_back({key, bias});
params.sampling.logit_bias.push_back({key, bias});
} else {
throw std::invalid_argument("invalid input format");
}
@ -1111,9 +1014,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--grammar"}, "GRAMMAR",
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sparams.grammar.c_str()),
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),
[](common_params & params, const std::string & value) {
params.sparams.grammar = value;
params.sampling.grammar = value;
}
).set_sparam());
add_opt(common_arg(
@ -1127,7 +1030,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(params.sparams.grammar)
std::back_inserter(params.sampling.grammar)
);
}
).set_sparam());
@ -1135,7 +1038,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-j", "--json-schema"}, "SCHEMA",
"JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
[](common_params & params, const std::string & value) {
params.sparams.grammar = json_schema_to_grammar(json::parse(value));
params.sampling.grammar = json_schema_to_grammar(json::parse(value));
}
).set_sparam());
add_opt(common_arg(
@ -1433,6 +1336,30 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
else { throw std::invalid_argument("invalid value"); }
}
).set_env("LLAMA_ARG_NUMA"));
add_opt(common_arg(
{"-dev", "--device"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading (none = don't offload)\n"
"use --list-devices to see a list of available devices",
[](common_params & params, const std::string & value) {
params.devices = parse_device_list(value);
}
).set_env("LLAMA_ARG_DEVICE"));
add_opt(common_arg(
{"--list-devices"},
"print list of available devices and exit",
[](common_params &) {
printf("Available devices:\n");
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
}
}
exit(0);
}
));
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",
@ -1444,17 +1371,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_env("LLAMA_ARG_N_GPU_LAYERS"));
add_opt(common_arg(
{"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
"number of layers to store in VRAM for the draft model",
[](common_params & params, int value) {
params.n_gpu_layers_draft = value;
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-sm", "--split-mode"}, "{none,layer,row}",
"how to split the model across multiple GPUs, one of:\n"
@ -1468,10 +1384,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} else if (arg_next == "layer") {
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
} else if (arg_next == "row") {
#ifdef GGML_USE_SYCL
fprintf(stderr, "warning: The split mode value:[row] is not supported by llama.cpp with SYCL. It's developing.\nExit!\n");
exit(1);
#endif // GGML_USE_SYCL
params.split_mode = LLAMA_SPLIT_MODE_ROW;
} else {
throw std::invalid_argument("invalid value");
@ -1593,13 +1505,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model = value;
}
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL"));
add_opt(common_arg(
{"-md", "--model-draft"}, "FNAME",
"draft model for speculative decoding (default: unused)",
[](common_params & params, const std::string & value) {
params.model_draft = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-mu", "--model-url"}, "MODEL_URL",
"model download url (default: unused)",
@ -2037,5 +1942,176 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_LOG_TIMESTAMPS"));
// speculative parameters
add_opt(common_arg(
{"-td", "--threads-draft"}, "N",
"number of threads to use during generation (default: same as --threads)",
[](common_params & params, int value) {
params.speculative.cpuparams.n_threads = value;
if (params.speculative.cpuparams.n_threads <= 0) {
params.speculative.cpuparams.n_threads = std::thread::hardware_concurrency();
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-tbd", "--threads-batch-draft"}, "N",
"number of threads to use during batch and prompt processing (default: same as --threads-draft)",
[](common_params & params, int value) {
params.speculative.cpuparams_batch.n_threads = value;
if (params.speculative.cpuparams_batch.n_threads <= 0) {
params.speculative.cpuparams_batch.n_threads = std::thread::hardware_concurrency();
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Cd", "--cpu-mask-draft"}, "M",
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
[](common_params & params, const std::string & mask) {
params.speculative.cpuparams.mask_valid = true;
if (!parse_cpu_mask(mask, params.speculative.cpuparams.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Crd", "--cpu-range-draft"}, "lo-hi",
"Ranges of CPUs for affinity. Complements --cpu-mask-draft",
[](common_params & params, const std::string & range) {
params.speculative.cpuparams.mask_valid = true;
if (!parse_cpu_range(range, params.speculative.cpuparams.cpumask)) {
throw std::invalid_argument("invalid range");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--cpu-strict-draft"}, "<0|1>",
"Use strict CPU placement for draft model (default: same as --cpu-strict)",
[](common_params & params, int value) {
params.speculative.cpuparams.strict_cpu = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--prio-draft"}, "N",
string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams.priority),
[](common_params & params, int prio) {
if (prio < 0 || prio > 3) {
throw std::invalid_argument("invalid value");
}
params.speculative.cpuparams.priority = (enum ggml_sched_priority) prio;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--poll-draft"}, "<0|1>",
"Use polling to wait for draft model work (default: same as --poll])",
[](common_params & params, int value) {
params.speculative.cpuparams.poll = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Cbd", "--cpu-mask-batch-draft"}, "M",
"Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
[](common_params & params, const std::string & mask) {
params.speculative.cpuparams_batch.mask_valid = true;
if (!parse_cpu_mask(mask, params.speculative.cpuparams_batch.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"-Crbd", "--cpu-range-batch-draft"}, "lo-hi",
"Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)",
[](common_params & params, const std::string & range) {
params.speculative.cpuparams_batch.mask_valid = true;
if (!parse_cpu_range(range, params.speculative.cpuparams_batch.cpumask)) {
throw std::invalid_argument("invalid cpumask");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--cpu-strict-batch-draft"}, "<0|1>",
"Use strict CPU placement for draft model (default: --cpu-strict-draft)",
[](common_params & params, int value) {
params.speculative.cpuparams_batch.strict_cpu = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--prio-batch-draft"}, "N",
string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams_batch.priority),
[](common_params & params, int prio) {
if (prio < 0 || prio > 3) {
throw std::invalid_argument("invalid value");
}
params.speculative.cpuparams_batch.priority = (enum ggml_sched_priority) prio;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--poll-batch-draft"}, "<0|1>",
"Use polling to wait for draft model work (default: --poll-draft)",
[](common_params & params, int value) {
params.speculative.cpuparams_batch.poll = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--draft-max", "--draft", "--draft-n"}, "N",
string_format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max),
[](common_params & params, int value) {
params.speculative.n_max = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--draft-min", "--draft-n-min"}, "N",
string_format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.speculative.n_min),
[](common_params & params, int value) {
params.speculative.n_min = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--draft-p-split"}, "P",
string_format("speculative decoding split probability (default: %.1f)", (double)params.speculative.p_split),
[](common_params & params, const std::string & value) {
params.speculative.p_split = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
{"--draft-p-min"}, "P",
string_format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min),
[](common_params & params, const std::string & value) {
params.speculative.p_min = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-cd", "--ctx-size-draft"}, "N",
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx),
[](common_params & params, int value) {
params.speculative.n_ctx = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-devd", "--device-draft"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"
"use --list-devices to see a list of available devices",
[](common_params & params, const std::string & value) {
params.speculative.devices = parse_device_list(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
"number of layers to store in VRAM for the draft model",
[](common_params & params, int value) {
params.speculative.n_gpu_layers = value;
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-md", "--model-draft"}, "FNAME",
"draft model for speculative decoding (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.model = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
return ctx_arg;
}

View file

@ -536,12 +536,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());
buf << "\n" << std::to_string(i)
<< ":token '" << detokenized << "'"
<< ":pos " << std::to_string(batch.pos[i])
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
<< ":logits " << std::to_string(batch.logits[i]);
buf << "\n" << std::to_string(i)
<< ", token '" << detokenized << "'"
<< ", pos " << std::to_string(batch.pos[i])
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ", logits " << std::to_string(batch.logits[i]);
}
buf << " ]";
@ -925,9 +925,9 @@ struct common_init_result common_init_from_params(common_params & params) {
common_lora_adapters_apply(lctx, iparams.lora_adapters);
}
if (params.sparams.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sparams.ignore_eos = false;
params.sampling.ignore_eos = false;
}
if (params.warmup) {
@ -979,9 +979,12 @@ void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_l
}
}
struct llama_model_params common_model_params_to_llama(const common_params & params) {
struct llama_model_params common_model_params_to_llama(common_params & params) {
auto mparams = llama_model_default_params();
if (!params.devices.empty()) {
mparams.devices = params.devices.data();
}
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}
@ -1490,6 +1493,66 @@ void common_batch_add(
batch.n_tokens++;
}
//
// Token utils
//
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
//
// Vocab utils
//

View file

@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info {
struct llama_lora_adapter * adapter;
};
using llama_tokens = std::vector<llama_token>;
// build info
extern int LLAMA_BUILD_NUMBER;
extern char const * LLAMA_COMMIT;
@ -101,8 +103,8 @@ enum dimre_method {
DIMRE_METHOD_MEAN,
};
// sampler parameters
struct common_sampler_params {
// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
int32_t n_prev = 64; // number of previous tokens to remember
@ -153,21 +155,30 @@ struct common_sampler_params {
std::string print() const;
};
struct common_params_speculative {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.9f; // minimum speculative decoding probability (greedy)
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
std::string model = ""; // draft model for speculative decoding // NOLINT
};
struct common_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
int32_t grp_attn_n = 1; // group-attention factor
int32_t grp_attn_w = 512; // group-attention width
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
@ -180,25 +191,29 @@ struct common_params {
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = 0.1f; // KV cache defragmentation threshold
// offload params
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
struct cpu_params draft_cpuparams;
struct cpu_params draft_cpuparams_batch;
ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
struct common_sampler_params sparams;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
std::string model = ""; // model path // NOLINT
std::string model_draft = ""; // draft model for speculative decoding // NOLINT
std::string model_alias = "unknown"; // model alias // NOLINT
std::string model_url = ""; // model url to download // NOLINT
std::string hf_token = ""; // HF token // NOLINT
@ -451,7 +466,7 @@ struct common_init_result {
struct common_init_result common_init_from_params(common_params & params);
struct llama_model_params common_model_params_to_llama (const common_params & params);
struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
@ -461,7 +476,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f
// clear LoRA adapters from context, then apply new list of adapters
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
//
// Batch utils
//
void common_batch_clear(struct llama_batch & batch);
@ -472,6 +489,16 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Token utils
//
// longest common prefix
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
// longet common subsequence
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
//
// Vocab utils
//

View file

@ -99,7 +99,7 @@ struct ring_buffer {
};
struct common_sampler {
common_sampler_params params;
common_params_sampling params;
struct llama_sampler * grmr;
struct llama_sampler * chain;
@ -125,7 +125,7 @@ struct common_sampler {
}
};
std::string common_sampler_params::print() const {
std::string common_params_sampling::print() const {
char result[1024];
snprintf(result, sizeof(result),
@ -141,7 +141,7 @@ std::string common_sampler_params::print() const {
return std::string(result);
}
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) {
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
lparams.no_perf = params.no_perf;
@ -320,6 +320,45 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return cur_p.data[cur_p.selected].id;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
result.reserve(idxs.size());
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
if (draft[i] != id) {
break;
}
}
if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
}
return result;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}

View file

@ -36,7 +36,7 @@ struct common_sampler;
// llama_sampler API overloads
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params);
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
void common_sampler_free(struct common_sampler * gsmpl);
@ -60,6 +60,27 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
// generalized version of common_sampler_sample
//
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
//
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
//
// is equivalent to
//
// common_sampler_sample(gsmpl, ctx, idx);
// common_sampler_accept(gsmpl, token, true);
//
// requires: idxs.size() == draft.size() + 1
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
// helpers

270
common/speculative.cpp Normal file
View file

@ -0,0 +1,270 @@
#include "speculative.h"
#include "log.h"
#include "common.h"
#include "sampling.h"
#include <cstring>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct common_speculative {
struct llama_context * ctx;
struct common_sampler * smpl;
llama_batch batch;
llama_tokens prompt;
};
struct common_speculative * common_speculative_init(
struct llama_context * ctx_dft) {
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt = */ {},
};
// TODO: optimize or pass from outside?
#if 0
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 40;
params.top_p = 0.9;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_INFILL,
};
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
#else
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 10;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
};
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
#endif
return result;
}
void common_speculative_free(struct common_speculative * spec) {
common_sampler_free(spec->smpl);
llama_batch_free(spec->batch);
delete spec;
}
bool common_speculative_are_compatible(
const struct llama_context * ctx_tgt,
const struct llama_context * ctx_dft) {
const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
const struct llama_model * model_dft = llama_get_model(ctx_dft);
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(model_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
return false;
}
if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft)) {
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_tgt), llama_add_bos_token(model_tgt), llama_token_eos(model_tgt), llama_add_eos_token(model_tgt));
LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_dft), llama_add_bos_token(model_dft), llama_token_eos(model_dft), llama_add_eos_token(model_dft));
return false;
}
{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
__func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false;
}
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_ERR("%s: draft model vocab must match target model to use speculation but "
"token %d content differs - target '%s', draft '%s'\n", __func__, i,
common_token_to_piece(ctx_tgt, i).c_str(),
common_token_to_piece(ctx_dft, i).c_str());
return false;
}
}
}
return true;
}
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt_tgt,
llama_token id_last) {
auto & batch = spec->batch;
auto & ctx = spec->ctx;
auto & smpl = spec->smpl;
auto & prompt = spec->prompt;
int reuse_i = 0;
int reuse_n = 0;
const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt_tgt.size() &&
i + cur < (int) prompt.size() &&
prompt_tgt[i_start + cur] == prompt[i + cur]) {
cur++;
}
if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
reuse_i = i;
reuse_n = cur;
}
}
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
llama_tokens result;
result.reserve(params.n_draft);
if (reuse_n == 0) {
llama_kv_cache_clear(ctx);
prompt.clear();
} else {
// this happens when a previous draft has been discarded (for example, due to being too small), but the
// target model agreed with it. in this case, we simply pass back the previous results to save compute
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
result.push_back(prompt[i]);
if (params.n_draft <= (int) result.size()) {
break;
}
}
return result;
}
if (reuse_i > 0) {
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
}
if (reuse_n < (int) prompt.size()) {
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
prompt.erase(prompt.begin() + reuse_n, prompt.end());
}
}
// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
prompt.push_back(prompt_tgt[i]);
}
// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
llama_decode(ctx, batch);
}
const llama_pos n_past = prompt.size();
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);
prompt.push_back(id_last);
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
llama_decode(ctx, batch);
common_sampler_reset(smpl);
// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);
common_sampler_sample(smpl, ctx, 0, true);
const auto * cur_p = common_sampler_get_candidates(smpl);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
}
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
break;
}
common_sampler_accept(smpl, id, true);
result.push_back(id);
if (params.n_draft <= (int) result.size()) {
break;
}
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
// evaluate the drafted tokens on the draft model
llama_decode(ctx, batch);
prompt.push_back(id);
}
return result;
}

28
common/speculative.h Normal file
View file

@ -0,0 +1,28 @@
#pragma once
#include "llama.h"
#include "common.h"
struct common_speculative;
struct common_speculative_params {
int n_draft = 16; // max drafted tokens
int n_reuse = 256;
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
};
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
void common_speculative_free(struct common_speculative * spec);
bool common_speculative_are_compatible(
const struct llama_context * ctx_tgt,
const struct llama_context * ctx_dft);
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);

View file

@ -2707,7 +2707,7 @@ class XLMRobertaModel(BertModel):
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_add_space_prefix(add_prefix)
self.gguf_writer.add_token_type_count(1)
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces)
if precompiled_charsmap:
self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap)
@ -3040,9 +3040,9 @@ class OlmoModel(Model):
return [(self.map_tensor_name(name), data_torch)]
@Model.register("Olmo1124ForCausalLM")
class Olmo1124Model(Model):
model_arch = gguf.MODEL_ARCH.OLMO_1124
@Model.register("Olmo2ForCausalLM")
class Olmo2Model(Model):
model_arch = gguf.MODEL_ARCH.OLMO2
@Model.register("OlmoeForCausalLM")

View file

@ -221,7 +221,7 @@ You can download it from your Linux distro's package manager or from here: [ROCm
- Using `make`:
```bash
make GGML_HIPBLAS=1
make GGML_HIP=1
```
- Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU):
```bash
@ -249,7 +249,7 @@ You can download it from your Linux distro's package manager or from here: [ROCm
- Using `make` (example for target gfx1030, build with 16 CPU threads):
```bash
make -j16 GGML_HIPBLAS=1 GGML_HIP_UMA=1 AMDGPU_TARGETS=gfx1030
make -j16 GGML_HIP=1 GGML_HIP_UMA=1 AMDGPU_TARGETS=gfx1030
```
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU):

View file

@ -6,19 +6,20 @@ find_package(Threads REQUIRED)
# ...
# flags
llama_add_compile_flags()
# examples
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
if (EMSCRIPTEN)
else()
add_subdirectory(cvector-generator)
add_subdirectory(batched-bench)
add_subdirectory(batched)
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(embedding)
add_subdirectory(eval-callback)
add_subdirectory(export-lora)
add_subdirectory(gbnf-validator)
add_subdirectory(gguf-hash)
add_subdirectory(gguf-split)
@ -27,28 +28,36 @@ else()
add_subdirectory(imatrix)
add_subdirectory(infill)
add_subdirectory(llama-bench)
add_subdirectory(llava)
add_subdirectory(lookahead)
add_subdirectory(lookup)
add_subdirectory(main)
add_subdirectory(parallel)
add_subdirectory(passkey)
add_subdirectory(perplexity)
add_subdirectory(quantize-stats)
add_subdirectory(quantize)
add_subdirectory(retrieval)
if (GGML_RPC)
add_subdirectory(rpc)
endif()
if (LLAMA_BUILD_SERVER)
add_subdirectory(server)
endif()
if (GGML_SYCL)
add_subdirectory(sycl)
add_subdirectory(server)
endif()
add_subdirectory(save-load-state)
add_subdirectory(run)
add_subdirectory(simple)
add_subdirectory(simple-chat)
add_subdirectory(speculative)
add_subdirectory(speculative-simple)
add_subdirectory(tokenize)
if (NOT GGML_BACKEND_DL)
# these examples use the backends directly and cannot be built with dynamic loading
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(cvector-generator)
add_subdirectory(export-lora)
add_subdirectory(quantize-stats)
add_subdirectory(llava)
if (GGML_RPC)
add_subdirectory(rpc)
endif()
if (GGML_SYCL)
add_subdirectory(sycl)
endif()
endif()
endif()

View file

@ -68,10 +68,10 @@ int main(int argc, char ** argv) {
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed));
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
if (ctx == NULL) {
LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);

View file

@ -5,5 +5,6 @@ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
set(TEST_TARGET test-eval-callback)
add_test(NAME ${TEST_TARGET} COMMAND llama-eval-callback --hf-repo ggml-org/models --hf-file tinyllamas/stories260K.gguf --model stories260K.gguf --prompt hello --seed 42 -ngl 0)
add_test(NAME ${TEST_TARGET}
COMMAND llama-eval-callback --hf-repo ggml-org/models --hf-file tinyllamas/stories260K.gguf --model stories260K.gguf --prompt hello --seed 42 -ngl 0)
set_property(TEST ${TEST_TARGET} PROPERTY LABELS eval-callback curl)

View file

@ -4,10 +4,17 @@ install(TARGETS ${TARGET} RUNTIME)
# clibs dependencies
include_directories(deps/)
add_library(xxhash OBJECT deps/xxhash/xxhash.c deps/xxhash/xxhash.h)
target_link_libraries(${TARGET} PRIVATE xxhash)
add_library(sha1 OBJECT deps/sha1/sha1.c deps/sha1/sha1.h)
target_link_libraries(${TARGET} PRIVATE sha1)
if (NOT MSVC)
# disable warnings in 3rd party code
target_compile_options(sha1 PRIVATE -w)
endif()
add_library(sha256 OBJECT deps/sha256/sha256.c deps/sha256/sha256.h)
target_link_libraries(${TARGET} PRIVATE sha256)

View file

@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
common_init();
auto & sparams = params.sparams;
auto & sparams = params.sampling;
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });

File diff suppressed because it is too large Load diff

View file

@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
LOG("\n");
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams);
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
if (!smpl) {
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
exit(1);

View file

@ -237,7 +237,7 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
LOG_INF("\n");
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams);
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
return smpl;
}

View file

@ -115,7 +115,7 @@ int main(int argc, char ** argv) {
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
// target model sampling context
struct common_sampler * smpl = common_sampler_init(model, params.sparams);
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
// verification n-grams
std::vector<ngram_data> ngrams_cur(G);

View file

@ -21,7 +21,7 @@ int main(int argc, char ** argv){
common_init();
const int n_draft = params.n_draft;
const int n_draft = params.speculative.n_max;
// init llama.cpp
llama_backend_init();
@ -40,6 +40,7 @@ int main(int argc, char ** argv){
common_ngram_cache ngram_cache_context;
common_ngram_cache ngram_cache_dynamic;
common_ngram_cache ngram_cache_static;
int64_t t_draft_flat_us = 0;
int64_t t_draft_us = 0;

View file

@ -22,7 +22,7 @@ int main(int argc, char ** argv){
common_init();
// max. number of additional tokens to draft if match is found
const int n_draft = params.n_draft;
const int n_draft = params.speculative.n_max;
const bool dump_kv_cache = params.dump_kv_cache;
@ -102,7 +102,7 @@ int main(int argc, char ** argv){
bool has_eos = false;
struct common_sampler * smpl = common_sampler_init(model, params.sparams);
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
std::vector<llama_token> draft;

View file

@ -100,7 +100,7 @@ int main(int argc, char ** argv) {
common_init();
auto & sparams = params.sparams;
auto & sparams = params.sampling;
// save choice to use color for later
// (note for later: this is a slightly awkward choice)
@ -165,6 +165,10 @@ int main(int argc, char ** argv) {
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
auto * ggml_threadpool_new_fn = (decltype(ggml_threadpool_new) *) ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_new");
auto * ggml_threadpool_free_fn = (decltype(ggml_threadpool_free) *) ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_free");
struct ggml_threadpool_params tpp_batch =
ggml_threadpool_params_from_cpu_params(params.cpuparams_batch);
struct ggml_threadpool_params tpp =
@ -174,7 +178,7 @@ int main(int argc, char ** argv) {
struct ggml_threadpool * threadpool_batch = NULL;
if (!ggml_threadpool_params_match(&tpp, &tpp_batch)) {
threadpool_batch = ggml_threadpool_new(&tpp_batch);
threadpool_batch = ggml_threadpool_new_fn(&tpp_batch);
if (!threadpool_batch) {
LOG_ERR("%s: batch threadpool create failed : n_threads %d\n", __func__, tpp_batch.n_threads);
return 1;
@ -184,7 +188,7 @@ int main(int argc, char ** argv) {
tpp.paused = true;
}
struct ggml_threadpool * threadpool = ggml_threadpool_new(&tpp);
struct ggml_threadpool * threadpool = ggml_threadpool_new_fn(&tpp);
if (!threadpool) {
LOG_ERR("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
return 1;
@ -890,8 +894,8 @@ int main(int argc, char ** argv) {
llama_backend_free();
ggml_threadpool_free(threadpool);
ggml_threadpool_free(threadpool_batch);
ggml_threadpool_free_fn(threadpool);
ggml_threadpool_free_fn(threadpool_batch);
return 0;
}

View file

@ -160,7 +160,7 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
client.smpl = common_sampler_init(model, params.sparams);
client.smpl = common_sampler_init(model, params.sampling);
}
std::vector<llama_token> tokens_system;

View file

@ -282,8 +282,8 @@ int main(int argc, char ** argv) {
return a.second > b.second;
});
LOG("Top %d similar chunks:\n", params.sparams.top_k);
for (int i = 0; i < std::min(params.sparams.top_k, (int) chunks.size()); i++) {
LOG("Top %d similar chunks:\n", params.sampling.top_k);
for (int i = 0; i < std::min(params.sampling.top_k, (int) chunks.size()); i++) {
LOG("filename: %s\n", chunks[similarities[i].first].filename.c_str());
LOG("filepos: %lld\n", (long long int) chunks[similarities[i].first].filepos);
LOG("similarity: %f\n", similarities[i].second);

View file

@ -0,0 +1,5 @@
set(TARGET llama-run)
add_executable(${TARGET} run.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

7
examples/run/README.md Normal file
View file

@ -0,0 +1,7 @@
# llama.cpp/example/run
The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models.
```bash
./llama-run Meta-Llama-3.1-8B-Instruct.gguf
...

409
examples/run/run.cpp Normal file
View file

@ -0,0 +1,409 @@
#if defined(_WIN32)
#include <windows.h>
#else
#include <unistd.h>
#endif
#include <climits>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include "llama-cpp.h"
typedef std::unique_ptr<char[]> char_array_ptr;
struct Argument {
std::string flag;
std::string help_text;
};
struct Options {
std::string model_path, prompt_non_interactive;
int ngl = 99;
int n_ctx = 2048;
};
class ArgumentParser {
public:
ArgumentParser(const char * program_name) : program_name(program_name) {}
void add_argument(const std::string & flag, std::string & var, const std::string & help_text = "") {
string_args[flag] = &var;
arguments.push_back({flag, help_text});
}
void add_argument(const std::string & flag, int & var, const std::string & help_text = "") {
int_args[flag] = &var;
arguments.push_back({flag, help_text});
}
int parse(int argc, const char ** argv) {
for (int i = 1; i < argc; ++i) {
std::string arg = argv[i];
if (string_args.count(arg)) {
if (i + 1 < argc) {
*string_args[arg] = argv[++i];
} else {
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
print_usage();
return 1;
}
} else if (int_args.count(arg)) {
if (i + 1 < argc) {
if (parse_int_arg(argv[++i], *int_args[arg]) != 0) {
fprintf(stderr, "error: invalid value for %s: %s\n", arg.c_str(), argv[i]);
print_usage();
return 1;
}
} else {
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
print_usage();
return 1;
}
} else {
fprintf(stderr, "error: unrecognized argument %s\n", arg.c_str());
print_usage();
return 1;
}
}
if (string_args["-m"]->empty()) {
fprintf(stderr, "error: -m is required\n");
print_usage();
return 1;
}
return 0;
}
private:
const char * program_name;
std::unordered_map<std::string, std::string *> string_args;
std::unordered_map<std::string, int *> int_args;
std::vector<Argument> arguments;
int parse_int_arg(const char * arg, int & value) {
char * end;
const long val = std::strtol(arg, &end, 10);
if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) {
value = static_cast<int>(val);
return 0;
}
return 1;
}
void print_usage() const {
printf("\nUsage:\n");
printf(" %s [OPTIONS]\n\n", program_name);
printf("Options:\n");
for (const auto & arg : arguments) {
printf(" %-10s %s\n", arg.flag.c_str(), arg.help_text.c_str());
}
printf("\n");
}
};
class LlamaData {
public:
llama_model_ptr model;
llama_sampler_ptr sampler;
llama_context_ptr context;
std::vector<llama_chat_message> messages;
int init(const Options & opt) {
model = initialize_model(opt.model_path, opt.ngl);
if (!model) {
return 1;
}
context = initialize_context(model, opt.n_ctx);
if (!context) {
return 1;
}
sampler = initialize_sampler();
return 0;
}
private:
// Initializes the model and returns a unique pointer to it
llama_model_ptr initialize_model(const std::string & model_path, const int ngl) {
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = ngl;
llama_model_ptr model(llama_load_model_from_file(model_path.c_str(), model_params));
if (!model) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
}
return model;
}
// Initializes the context with the specified parameters
llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) {
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = n_ctx;
llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params));
if (!context) {
fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__);
}
return context;
}
// Initializes and configures the sampler
llama_sampler_ptr initialize_sampler() {
llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params()));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
return sampler;
}
};
// Add a message to `messages` and store its content in `owned_content`
static void add_message(const char * role, const std::string & text, LlamaData & llama_data,
std::vector<char_array_ptr> & owned_content) {
char_array_ptr content(new char[text.size() + 1]);
std::strcpy(content.get(), text.c_str());
llama_data.messages.push_back({role, content.get()});
owned_content.push_back(std::move(content));
}
// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(const LlamaData & llama_data, std::vector<char> & formatted, const bool append) {
int result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(),
llama_data.messages.size(), append, formatted.data(), formatted.size());
if (result > static_cast<int>(formatted.size())) {
formatted.resize(result);
result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(),
llama_data.messages.size(), append, formatted.data(), formatted.size());
}
return result;
}
// Function to tokenize the prompt
static int tokenize_prompt(const llama_model_ptr & model, const std::string & prompt,
std::vector<llama_token> & prompt_tokens) {
const int n_prompt_tokens = -llama_tokenize(model.get(), prompt.c_str(), prompt.size(), NULL, 0, true, true);
prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
true) < 0) {
GGML_ABORT("failed to tokenize the prompt\n");
}
return n_prompt_tokens;
}
// Check if we have enough space in the context to evaluate this batch
static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
const int n_ctx = llama_n_ctx(ctx.get());
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get());
if (n_ctx_used + batch.n_tokens > n_ctx) {
printf("\033[0m\n");
fprintf(stderr, "context size exceeded\n");
return 1;
}
return 0;
}
// convert the token to a string
static int convert_token_to_string(const llama_model_ptr & model, const llama_token token_id, std::string & piece) {
char buf[256];
int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true);
if (n < 0) {
GGML_ABORT("failed to convert token to piece\n");
}
piece = std::string(buf, n);
return 0;
}
static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) {
printf("%s", piece.c_str());
fflush(stdout);
response += piece;
}
// helper function to evaluate a prompt and generate a response
static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) {
std::vector<llama_token> prompt_tokens;
const int n_prompt_tokens = tokenize_prompt(llama_data.model, prompt, prompt_tokens);
if (n_prompt_tokens < 0) {
return 1;
}
// prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
llama_token new_token_id;
while (true) {
check_context_size(llama_data.context, batch);
if (llama_decode(llama_data.context.get(), batch)) {
GGML_ABORT("failed to decode\n");
}
// sample the next token, check is it an end of generation?
new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1);
if (llama_token_is_eog(llama_data.model.get(), new_token_id)) {
break;
}
std::string piece;
if (convert_token_to_string(llama_data.model, new_token_id, piece)) {
return 1;
}
print_word_and_concatenate_to_response(piece, response);
// prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1);
}
return 0;
}
static int parse_arguments(const int argc, const char ** argv, Options & opt) {
ArgumentParser parser(argv[0]);
parser.add_argument("-m", opt.model_path, "model");
parser.add_argument("-p", opt.prompt_non_interactive, "prompt");
parser.add_argument("-c", opt.n_ctx, "context_size");
parser.add_argument("-ngl", opt.ngl, "n_gpu_layers");
if (parser.parse(argc, argv)) {
return 1;
}
return 0;
}
static int read_user_input(std::string & user) {
std::getline(std::cin, user);
return user.empty(); // Indicate an error or empty input
}
// Function to generate a response based on the prompt
static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response) {
// Set response color
printf("\033[33m");
if (generate(llama_data, prompt, response)) {
fprintf(stderr, "failed to generate response\n");
return 1;
}
// End response with color reset and newline
printf("\n\033[0m");
return 0;
}
// Helper function to apply the chat template and handle errors
static int apply_chat_template_with_error_handling(const LlamaData & llama_data, std::vector<char> & formatted,
const bool is_user_input, int & output_length) {
const int new_len = apply_chat_template(llama_data, formatted, is_user_input);
if (new_len < 0) {
fprintf(stderr, "failed to apply the chat template\n");
return -1;
}
output_length = new_len;
return 0;
}
// Helper function to handle user input
static bool handle_user_input(std::string & user_input, const std::string & prompt_non_interactive) {
if (!prompt_non_interactive.empty()) {
user_input = prompt_non_interactive;
return true; // No need for interactive input
}
printf("\033[32m> \033[0m");
return !read_user_input(user_input); // Returns false if input ends the loop
}
// Function to tokenize the prompt
static int chat_loop(LlamaData & llama_data, std::string & prompt_non_interactive) {
std::vector<char_array_ptr> owned_content;
std::vector<char> fmtted(llama_n_ctx(llama_data.context.get()));
int prev_len = 0;
while (true) {
// Get user input
std::string user_input;
if (!handle_user_input(user_input, prompt_non_interactive)) {
break;
}
add_message("user", prompt_non_interactive.empty() ? user_input : prompt_non_interactive, llama_data,
owned_content);
int new_len;
if (apply_chat_template_with_error_handling(llama_data, fmtted, true, new_len) < 0) {
return 1;
}
std::string prompt(fmtted.begin() + prev_len, fmtted.begin() + new_len);
std::string response;
if (generate_response(llama_data, prompt, response)) {
return 1;
}
}
return 0;
}
static void log_callback(const enum ggml_log_level level, const char * text, void *) {
if (level == GGML_LOG_LEVEL_ERROR) {
fprintf(stderr, "%s", text);
}
}
static bool is_stdin_a_terminal() {
#if defined(_WIN32)
HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE);
DWORD mode;
return GetConsoleMode(hStdin, &mode);
#else
return isatty(STDIN_FILENO);
#endif
}
static std::string read_pipe_data() {
std::ostringstream result;
result << std::cin.rdbuf(); // Read all data from std::cin
return result.str();
}
int main(int argc, const char ** argv) {
Options opt;
if (parse_arguments(argc, argv, opt)) {
return 1;
}
if (!is_stdin_a_terminal()) {
if (!opt.prompt_non_interactive.empty()) {
opt.prompt_non_interactive += "\n\n";
}
opt.prompt_non_interactive += read_pipe_data();
}
llama_log_set(log_callback, nullptr);
LlamaData llama_data;
if (llama_data.init(opt)) {
return 1;
}
if (chat_loop(llama_data, opt.prompt_non_interactive)) {
return 1;
}
return 0;
}

View file

@ -9,7 +9,7 @@ int main(int argc, char ** argv) {
common_params params;
params.prompt = "The quick brown fox";
params.sparams.seed = 1234;
params.sampling.seed = 1234;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
@ -42,7 +42,7 @@ int main(int argc, char ** argv) {
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed));
// tokenize prompt
auto tokens = common_tokenize(ctx, params.prompt, true);
@ -106,7 +106,7 @@ int main(int argc, char ** argv) {
llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sampling.seed));
printf("\nsecond run: %s", params.prompt.c_str());
@ -169,7 +169,7 @@ int main(int argc, char ** argv) {
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sampling.seed));
printf("\nsingle seq run: %s", params.prompt.c_str());

View file

@ -412,7 +412,7 @@ node index.js
`id_slot`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot. Default: `-1`
`cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `false`
`cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `true`
`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values.

View file

@ -81,7 +81,13 @@
<path d="M14.5 3a1 1 0 0 1-1 1H13v9a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2V4h-.5a1 1 0 0 1-1-1V2a1 1 0 0 1 1-1H6a1 1 0 0 1 1-1h2a1 1 0 0 1 1 1h3.5a1 1 0 0 1 1 1zM4.118 4 4 4.059V13a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1V4.059L11.882 4zM2.5 3h11V2h-11z"/>
</svg>
</button>
<button v-if="messages.length > 0" class="btn mr-1" @click="downloadConv(viewingConvId)" :disabled="isGenerating">
<!-- download conversation button -->
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-download" viewBox="0 0 16 16">
<path d="M.5 9.9a.5.5 0 0 1 .5.5v2.5a1 1 0 0 0 1 1h12a1 1 0 0 0 1-1v-2.5a.5.5 0 0 1 1 0v2.5a2 2 0 0 1-2 2H2a2 2 0 0 1-2-2v-2.5a.5.5 0 0 1 .5-.5"/>
<path d="M7.646 11.854a.5.5 0 0 0 .708 0l3-3a.5.5 0 0 0-.708-.708L8.5 10.293V1.5a.5.5 0 0 0-1 0v8.793L5.354 8.146a.5.5 0 1 0-.708.708z"/>
</svg>
</button>
<button class="btn" @click="showConfigDialog = true" :disabled="isGenerating">
<!-- edit config button -->
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-gear" viewBox="0 0 16 16">
@ -526,6 +532,23 @@
this.fetchMessages();
}
},
downloadConv(convId) {
const conversation = StorageUtils.getOneConversation(convId);
if (!conversation) {
alert('Conversation not found.');
return;
}
const conversationJson = JSON.stringify(conversation, null, 2);
const blob = new Blob([conversationJson], { type: 'application/json' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `conversation_${convId}.json`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
},
async sendMessage() {
if (!this.inputMsg) return;
const currConvId = this.viewingConvId;

View file

@ -2,10 +2,11 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "sampling.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "log.h"
#include "sampling.h"
#include "speculative.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
@ -110,7 +111,7 @@ struct server_static_file {
struct slot_params {
bool stream = true;
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@ -121,12 +122,21 @@ struct slot_params {
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<std::string> antiprompt;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
};
struct server_slot {
int id;
int id_task = -1;
llama_batch batch_spec;
llama_context * ctx_dft = nullptr;
common_speculative * spec = nullptr;
// the index relative to completion multi-task request
size_t index = 0;
@ -175,7 +185,6 @@ struct server_slot {
// sampling
json json_schema;
struct common_sampler_params sparams;
struct common_sampler * smpl = nullptr;
llama_token sampled;
@ -212,7 +221,7 @@ struct server_slot {
generated_token_probs.clear();
}
bool has_budget(common_params &global_params) {
bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
}
@ -232,6 +241,10 @@ struct server_slot {
return state != SLOT_STATE_IDLE;
}
bool can_speculate() const {
return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
}
void add_token(const completion_token_output & token) {
if (!is_processing()) {
SLT_WRN(*this, "%s", "slot is not processing\n");
@ -591,11 +604,14 @@ struct server_response {
};
struct server_context {
common_params params_base;
llama_model * model = nullptr;
llama_context * ctx = nullptr;
std::vector<common_lora_adapter_container> loras;
common_params params;
llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
llama_batch batch = {};
@ -628,27 +644,41 @@ struct server_context {
model = nullptr;
}
if (model_dft) {
llama_free_model(model_dft);
model_dft = nullptr;
}
// Clear any sampling context
for (server_slot & slot : slots) {
if (slot.smpl != nullptr) {
common_sampler_free(slot.smpl);
}
common_sampler_free(slot.smpl);
slot.smpl = nullptr;
llama_free(slot.ctx_dft);
slot.ctx_dft = nullptr;
common_speculative_free(slot.spec);
slot.spec = nullptr;
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
}
bool load_model(const common_params & params_) {
params = params_;
bool load_model(const common_params & params) {
SRV_INF("loading model '%s'\n", params.model.c_str());
common_init_result llama_init = common_init_from_params(params);
params_base = params;
common_init_result llama_init = common_init_from_params(params_base);
model = llama_init.model;
ctx = llama_init.context;
loras = llama_init.lora_adapters;
if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
return false;
}
@ -657,6 +687,41 @@ struct server_context {
add_bos_token = llama_add_bos_token(model);
has_eos_token = !llama_add_eos_token(model);
if (!params_base.speculative.model.empty()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
auto params_dft = params_base;
params_dft.devices = params_base.speculative.devices;
params_dft.model = params_base.speculative.model;
params_dft.n_ctx = params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
common_init_result llama_init_dft = common_init_from_params(params_dft);
model_dft = llama_init_dft.model;
if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str());
return false;
}
if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str());
llama_free (llama_init_dft.context);
llama_free_model(llama_init_dft.model);
return false;
}
cparams_dft = common_context_params_to_llama(params_base);
cparams_dft.n_batch = llama_n_ctx(llama_init_dft.context);
// the context is not needed - we will create one for each slot
llama_free(llama_init_dft.context);
}
return true;
}
@ -674,20 +739,36 @@ struct server_context {
}
void init() {
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
SRV_INF("initializing slots, n_slots = %d\n", params.n_parallel);
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
for (int i = 0; i < params.n_parallel; i++) {
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
slot.id = i;
slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict;
slot.n_predict = params_base.n_predict;
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
return;
}
slot.spec = common_speculative_init(slot.ctx_dft);
if (slot.spec == nullptr) {
SRV_ERR("%s", "failed to create speculator\n");
return;
}
}
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
slot.sparams = params.sparams;
slot.params.sampling = params_base.sampling;
slot.callback_on_release = [this](int) {
queue_tasks.pop_deferred_task();
@ -707,7 +788,7 @@ struct server_context {
const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed
batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
metrics.init();
@ -743,7 +824,7 @@ struct server_context {
}
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
// fraction of the common subsequence length compared to the current slot's prompt length
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
@ -786,9 +867,11 @@ struct server_context {
}
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
slot_params default_params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
auto default_sparams = params.sparams;
slot_params defaults;
defaults.sampling = params_base.sampling;
defaults.speculative = params_base.speculative;
const auto & data = task.data;
if (data.count("__oaicompat") != 0) {
@ -799,42 +882,48 @@ struct server_context {
slot.oaicompat_model = "";
}
slot.params.stream = json_value(data, "stream", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
slot.params.stream = json_value(data, "stream", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
slot.params.n_indent = json_value(data, "n_indent", defaults.n_indent);
slot.params.n_keep = json_value(data, "n_keep", defaults.n_keep);
slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard);
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
if (slot.sparams.dry_base < 1.0f)
{
slot.sparams.dry_base = default_sparams.dry_base;
slot.params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
slot.params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
slot.params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
slot.params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
slot.params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
slot.params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
slot.params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
slot.params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
slot.params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
slot.params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
slot.params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
slot.params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
slot.params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
slot.params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
slot.params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
slot.params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
slot.params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
slot.params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
slot.params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
slot.params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
if (slot.params.sampling.dry_base < 1.0f) {
slot.params.sampling.dry_base = defaults.sampling.dry_base;
}
// sequence breakers for DRY
@ -843,8 +932,8 @@ struct server_context {
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
if (data.contains("dry_sequence_breakers")) {
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
if (slot.sparams.dry_sequence_breakers.empty()) {
slot.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
if (slot.params.sampling.dry_sequence_breakers.empty()) {
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
return false;
}
@ -858,14 +947,14 @@ struct server_context {
}
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
slot.sparams.grammar = json_schema_to_grammar(schema);
auto schema = json_value(data, "json_schema", json::object());
slot.params.sampling.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) {
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
return false;
}
} else {
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
slot.params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
}
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
@ -875,10 +964,10 @@ struct server_context {
}
{
slot.sparams.logit_bias.clear();
slot.params.sampling.logit_bias.clear();
if (json_value(data, "ignore_eos", false) && has_eos_token) {
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
}
const auto & logit_bias = data.find("logit_bias");
@ -899,12 +988,12 @@ struct server_context {
if (el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab) {
slot.sparams.logit_bias.push_back({tok, bias});
slot.params.sampling.logit_bias.push_back({tok, bias});
}
} else if (el[0].is_string()) {
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks) {
slot.sparams.logit_bias.push_back({tok, bias});
slot.params.sampling.logit_bias.push_back({tok, bias});
}
}
}
@ -935,16 +1024,16 @@ struct server_context {
sampler_names.emplace_back(name);
}
}
slot.sparams.samplers = common_sampler_types_from_names(sampler_names, false);
slot.params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
} else if (samplers->is_string()){
std::string sampler_string;
for (const auto & name : *samplers) {
sampler_string += name;
}
slot.sparams.samplers = common_sampler_types_from_chars(sampler_string);
slot.params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
}
} else {
slot.sparams.samplers = default_sparams.samplers;
slot.params.sampling.samplers = defaults.sampling.samplers;
}
}
@ -953,7 +1042,7 @@ struct server_context {
common_sampler_free(slot.smpl);
}
slot.smpl = common_sampler_init(model, slot.sparams);
slot.smpl = common_sampler_init(model, slot.params.sampling);
if (slot.smpl == nullptr) {
// for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
@ -961,6 +1050,12 @@ struct server_context {
}
}
if (slot.ctx_dft) {
llama_batch_free(slot.batch_spec);
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
}
slot.state = SLOT_STATE_STARTED;
SLT_INF(slot, "%s", "processing task\n");
@ -978,7 +1073,7 @@ struct server_context {
bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
slot.sampled = result.tok;
// search stop word and delete it
@ -1043,7 +1138,7 @@ struct server_context {
}
// check the limits
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) {
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
slot.stopped_limit = true;
slot.has_next_token = false;
@ -1136,50 +1231,54 @@ struct server_context {
json get_formated_generation(const server_slot & slot) const {
std::vector<std::string> samplers;
samplers.reserve(slot.sparams.samplers.size());
for (const auto & sampler : slot.sparams.samplers) {
samplers.reserve(slot.params.sampling.samplers.size());
for (const auto & sampler : slot.params.sampling.samplers) {
samplers.emplace_back(common_sampler_type_to_str(sampler));
}
return json {
{"n_ctx", slot.n_ctx},
{"n_predict", slot.n_predict}, // Server configured n_predict
{"model", params.model_alias},
{"seed", slot.sparams.seed},
{"model", params_base.model_alias},
{"seed", slot.params.sampling.seed},
{"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
{"temperature", slot.sparams.temp},
{"dynatemp_range", slot.sparams.dynatemp_range},
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
{"top_k", slot.sparams.top_k},
{"top_p", slot.sparams.top_p},
{"min_p", slot.sparams.min_p},
{"xtc_probability", slot.sparams.xtc_probability},
{"xtc_threshold", slot.sparams.xtc_threshold},
{"typical_p", slot.sparams.typ_p},
{"repeat_last_n", slot.sparams.penalty_last_n},
{"repeat_penalty", slot.sparams.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present},
{"frequency_penalty", slot.sparams.penalty_freq},
{"dry_multiplier", slot.sparams.dry_multiplier},
{"dry_base", slot.sparams.dry_base},
{"dry_allowed_length", slot.sparams.dry_allowed_length},
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},
{"penalize_nl", slot.sparams.penalize_nl},
{"temperature", slot.params.sampling.temp},
{"dynatemp_range", slot.params.sampling.dynatemp_range},
{"dynatemp_exponent", slot.params.sampling.dynatemp_exponent},
{"top_k", slot.params.sampling.top_k},
{"top_p", slot.params.sampling.top_p},
{"min_p", slot.params.sampling.min_p},
{"xtc_probability", slot.params.sampling.xtc_probability},
{"xtc_threshold", slot.params.sampling.xtc_threshold},
{"typical_p", slot.params.sampling.typ_p},
{"repeat_last_n", slot.params.sampling.penalty_last_n},
{"repeat_penalty", slot.params.sampling.penalty_repeat},
{"presence_penalty", slot.params.sampling.penalty_present},
{"frequency_penalty", slot.params.sampling.penalty_freq},
{"dry_multiplier", slot.params.sampling.dry_multiplier},
{"dry_base", slot.params.sampling.dry_base},
{"dry_allowed_length", slot.params.sampling.dry_allowed_length},
{"dry_penalty_last_n", slot.params.sampling.dry_penalty_last_n},
{"dry_sequence_breakers", slot.params.sampling.dry_sequence_breakers},
{"mirostat", slot.params.sampling.mirostat},
{"mirostat_tau", slot.params.sampling.mirostat_tau},
{"mirostat_eta", slot.params.sampling.mirostat_eta},
{"penalize_nl", slot.params.sampling.penalize_nl},
{"stop", slot.params.antiprompt},
{"max_tokens", slot.params.n_predict}, // User configured n_predict
{"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard},
{"ignore_eos", slot.sparams.ignore_eos},
{"ignore_eos", slot.params.sampling.ignore_eos},
{"stream", slot.params.stream},
//{"logit_bias", slot.sparams.logit_bias},
{"n_probs", slot.sparams.n_probs},
{"min_keep", slot.sparams.min_keep},
{"grammar", slot.sparams.grammar},
//{"logit_bias", slot.params.sampling.logit_bias},
{"n_probs", slot.params.sampling.n_probs},
{"min_keep", slot.params.sampling.min_keep},
{"grammar", slot.params.sampling.grammar},
{"samplers", samplers},
{"speculative", slot.can_speculate()},
{"speculative.n_max", slot.params.speculative.n_max},
{"speculative.n_min", slot.params.speculative.n_min},
{"speculative.p_min", slot.params.speculative.p_min},
};
}
@ -1216,7 +1315,7 @@ struct server_context {
{"index", slot.index},
};
if (slot.sparams.n_probs > 0) {
if (slot.params.sampling.n_probs > 0) {
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
@ -1249,7 +1348,7 @@ struct server_context {
{"content", !slot.params.stream ? slot.generated_text : ""},
{"id_slot", slot.id},
{"stop", true},
{"model", params.model_alias},
{"model", params_base.model_alias},
{"tokens_predicted", slot.n_decoded},
{"tokens_evaluated", slot.n_prompt_tokens},
{"generation_settings", get_formated_generation(slot)},
@ -1265,7 +1364,7 @@ struct server_context {
{"index", slot.index},
};
if (slot.sparams.n_probs > 0) {
if (slot.params.sampling.n_probs > 0) {
std::vector<completion_token_output> probs;
if (!slot.params.stream && slot.stopped_word) {
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
@ -1422,10 +1521,10 @@ struct server_context {
data.at("input_prefix"),
data.at("input_suffix"),
data.at("input_extra"),
params.n_batch,
params.n_predict,
params_base.n_batch,
params_base.n_predict,
slots[0].n_ctx, // TODO: there should be a better way
params.spm_infill,
params_base.spm_infill,
tokenized_prompts[i]
);
create_task(data, tokens);
@ -1798,7 +1897,7 @@ struct server_context {
// TODO: simplify and improve
for (server_slot & slot : slots) {
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
if (!params.ctx_shift) {
if (!params_base.ctx_shift) {
// this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token()
slot.release();
@ -1864,7 +1963,7 @@ struct server_context {
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch
if (params.cont_batching || batch.n_tokens == 0) {
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
@ -1917,7 +2016,7 @@ struct server_context {
continue;
}
} else {
if (!params.ctx_shift) {
if (!params_base.ctx_shift) {
// if context shift is disabled, we make sure prompt size is smaller than KV size
// TODO: there should be a separate parameter that control prompt truncation
// context shift should be applied only during the generation phase
@ -1960,14 +2059,14 @@ struct server_context {
if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params.n_cache_reuse > 0) {
if (params_base.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache
size_t head_p = slot.n_past; // current prompt
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
while (head_c < slot.cache_tokens.size() &&
head_p < prompt_tokens.size()) {
@ -1980,7 +2079,7 @@ struct server_context {
n_match++;
}
if (n_match >= (size_t) params.n_cache_reuse) {
if (n_match >= (size_t) params_base.n_cache_reuse) {
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
//for (size_t i = head_p; i < head_p + n_match; i++) {
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
@ -2168,8 +2267,9 @@ struct server_context {
continue; // continue loop of slots
}
completion_token_output result;
const llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
slot.i_batch = -1;
common_sampler_accept(slot.smpl, id, true);
@ -2180,14 +2280,15 @@ struct server_context {
metrics.on_prompt_eval(slot);
}
completion_token_output result;
result.tok = id;
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
result.probs.push_back({
cur_p->data[i].id,
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
}
@ -2197,9 +2298,67 @@ struct server_context {
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
continue;
}
}
// do speculative decoding
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) {
continue;
}
slot.i_batch = -1;
llama_token id = slot.sampled;
struct common_speculative_params params_spec;
params_spec.n_draft = slot.params.speculative.n_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
continue;
}
// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
}
llama_decode(ctx, slot.batch_spec);
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
slot.n_past += ids.size();
slot.n_decoded += ids.size();
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}
SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
}
}
@ -2697,7 +2856,7 @@ int main(int argc, char ** argv) {
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
json data = {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params.n_parallel },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "chat_template", llama_get_chat_template(ctx_server.model) },
};
@ -2705,7 +2864,7 @@ int main(int argc, char ** argv) {
};
const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params.endpoint_props) {
if (!ctx_server.params_base.endpoint_props) {
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@ -2718,7 +2877,7 @@ int main(int argc, char ** argv) {
};
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
if (ctx_server.params.embedding) {
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@ -2824,7 +2983,7 @@ int main(int argc, char ** argv) {
// TODO: maybe merge this function with "handle_completions_generic"
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@ -3001,7 +3160,7 @@ int main(int argc, char ** argv) {
};
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params.reranking || ctx_server.params.embedding) {
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
return;
}

View file

@ -1 +1,2 @@
.venv
tmp

View file

@ -1,19 +1,9 @@
# Server tests
Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development)
and [behave](https://behave.readthedocs.io/en/latest/):
* [issues.feature](./features/issues.feature) Pending issues scenario
* [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests
* [security.feature](./features/security.feature) Security, CORS and API Key
* [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc...
Python based server tests scenario using [pytest](https://docs.pytest.org/en/stable/).
Tests target GitHub workflows job runners with 4 vCPU.
Requests are
using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html)
based http client.
Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail.
To mitigate it, you can increase values in `n_predict`, `kv_size`.
@ -39,26 +29,19 @@ It's possible to override some scenario steps values with environment variables:
|--------------------------|------------------------------------------------------------------------------------------------|
| `PORT` | `context.server_port` to set the listening port of the server during scenario, default: `8080` |
| `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/llama-server` |
| `DEBUG` | "ON" to enable steps and server verbose mode `--verbose` |
| `DEBUG` | to enable steps and server verbose mode `--verbose` |
| `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` |
### Run @bug, @wip or @wrong_usage annotated scenario
Feature or Scenario must be annotated with `@llama.cpp` to be included in the default scope.
- `@bug` annotation aims to link a scenario with a GitHub issue.
- `@wrong_usage` are meant to show user issue that are actually an expected behavior
- `@wip` to focus on a scenario working in progress
- `@slow` heavy test, disabled by default
To run a scenario annotated with `@bug`, start:
To run slow tests:
```shell
DEBUG=ON ./tests.sh --no-skipped --tags bug --stop
SLOW_TESTS=1 ./tests.sh
```
After changing logic in `steps.py`, ensure that `@bug` and `@wrong_usage` scenario are updated.
To run with stdout/stderr display in real time (verbose output, but useful for debugging):
```shell
./tests.sh --no-skipped --tags bug,wrong_usage || echo "should failed but compile"
DEBUG=1 ./tests.sh -s -v -x
```
To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html)

View file

@ -0,0 +1,15 @@
import pytest
from utils import *
# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test
@pytest.fixture(autouse=True)
def stop_server_after_each_test():
# do nothing before each test
yield
# stop all servers after each test
instances = set(
server_instances
) # copy the set to prevent 'Set changed size during iteration'
for server in instances:
server.stop()

View file

@ -1,66 +0,0 @@
@llama.cpp
@ctx_shift
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model file test-model.gguf
And a model alias tinyllama-2
And BOS token is 1
And 42 as server seed
And 256 KV cache size
And 32 as batch size
And 2 slots
# the prompt is 301 tokens
# the slot context is 256/2 = 128 tokens
# the prompt is truncated to keep the last 109 tokens
# 64 tokens are generated thanks to shifting the context when it gets full
Scenario: Inference with context shift
And 64 server max tokens to predict
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And a completion request with no api error
Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
And the completion is truncated
And 109 prompt tokens are processed
Scenario Outline: Inference without context shift
And <n_predict> server max tokens to predict
And disable context shifting
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Hi how are you
"""
And a completion request with no api error
Then <n_token_output> tokens are predicted matching twind|Anna
And the completion is <truncated> truncated
And 8 prompt tokens are processed
Examples:
| n_predict | n_token_output | truncated |
| 64 | 64 | not |
| -1 | 120 | |
Scenario: Inference without context shift (expected error: prompt too long)
And disable context shifting
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And a completion request with 400 api error

View file

@ -1,113 +0,0 @@
@llama.cpp
@embeddings
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
And a model url https://huggingface.co/ggml-org/models/resolve/main/bert-bge-small/ggml-model-f16.gguf
And a model file bert-bge-small.gguf
And a model alias bert-bge-small
And 42 as server seed
And 2 slots
# the bert-bge-small model has context size of 512
# since the generated prompts are as big as the batch size, we need to set the batch size to <= 512
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20
And 128 as batch size
And 128 as ubatch size
And 512 KV cache size
And enable embeddings endpoint
Then the server is starting
Then the server is healthy
Scenario: Embedding
When embeddings are computed for:
"""
What is the capital of Bulgaria ?
"""
Then embeddings are generated
Scenario: Embedding (error: prompt too long)
When embeddings are computed for:
"""
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And embeddings request with 500 api error
Scenario: OAI Embeddings compatibility
Given a model bert-bge-small
When an OAI compatible embeddings computation request for:
"""
What is the capital of Spain ?
"""
Then embeddings are generated
Scenario: OAI Embeddings compatibility with multiple inputs
Given a model bert-bge-small
Given a prompt:
"""
In which country Paris is located ?
"""
And a prompt:
"""
Is Madrid the capital of Spain ?
"""
When an OAI compatible embeddings computation request for multiple inputs
Then embeddings are generated
Scenario: Multi users embeddings
Given a prompt:
"""
Write a very long story about AI.
"""
And a prompt:
"""
Write another very long music lyrics.
"""
And a prompt:
"""
Write a very long poem.
"""
And a prompt:
"""
Write a very long joke.
"""
Given concurrent embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated
Scenario: Multi users OAI compatibility embeddings
Given a prompt:
"""
In which country Paris is located ?
"""
And a prompt:
"""
Is Madrid the capital of Spain ?
"""
And a prompt:
"""
What is the biggest US city ?
"""
And a prompt:
"""
What is the capital of Bulgaria ?
"""
And a model bert-bge-small
Given concurrent OAI embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated
Scenario: All embeddings should be the same
Given 10 fixed prompts
And a model bert-bge-small
Given concurrent OAI embedding requests
Then all embeddings are the same

View file

@ -1,71 +0,0 @@
import os
import signal
import socket
import sys
import time
import traceback
from contextlib import closing
from subprocess import TimeoutExpired
def before_scenario(context, scenario):
context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
if context.debug:
print("DEBUG=ON")
print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m")
port = 8080
if 'PORT' in os.environ:
port = int(os.environ['PORT'])
if is_server_listening("localhost", port):
assert False, "Server already started"
def after_scenario(context, scenario):
try:
if 'server_process' not in context or context.server_process is None:
return
if scenario.status == "failed":
if 'GITHUB_ACTIONS' in os.environ:
print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n")
if os.path.isfile('llama.log'):
with closing(open('llama.log', 'r')) as f:
for line in f:
print(line)
if not is_server_listening(context.server_fqdn, context.server_port):
print("\x1b[33;101mERROR: Server stopped listening\x1b[0m")
if context.server_process.poll() is not None:
assert False, f"Server not running pid={context.server_process.pid} ..."
server_graceful_shutdown(context) # SIGINT
try:
context.server_process.wait(0.5)
except TimeoutExpired:
print(f"server still alive after 500ms, force-killing pid={context.server_process.pid} ...")
context.server_process.kill() # SIGKILL
context.server_process.wait()
while is_server_listening(context.server_fqdn, context.server_port):
time.sleep(0.1)
except Exception:
print("ignoring error in after_scenario:")
traceback.print_exc(file=sys.stdout)
def server_graceful_shutdown(context):
print(f"shutting down server pid={context.server_process.pid} ...")
if os.name == 'nt':
interrupt = signal.CTRL_C_EVENT
else:
interrupt = signal.SIGINT
context.server_process.send_signal(interrupt)
def is_server_listening(server_fqdn, server_port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
result = sock.connect_ex((server_fqdn, server_port))
_is_server_listening = result == 0
if _is_server_listening:
print(f"server is listening on {server_fqdn}:{server_port}...")
return _is_server_listening

View file

@ -1,36 +0,0 @@
@llama.cpp
@infill
Feature: llama.cpp server
# The current model is made by adding FIM tokens to the existing stories260K
# We may want to use a better model in the future, maybe something like SmolLM 360M
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K-infill.gguf from HF repo ggml-org/models
And a model file test-model-infill.gguf
And a model alias tinyllama-infill
And 42 as server seed
And 1024 as batch size
And 1024 as ubatch size
And 2048 KV cache size
And 64 max tokens to predict
And 0.0 temperature
Then the server is starting
Then the server is healthy
Scenario: Infill without input_extra
Given a prompt "Complete this"
And an infill input extra none none
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
And an infill input suffix "}\n"
And an infill request with no api error
Then 64 tokens are predicted matching One|day|she|saw|big|scary|bird
Scenario: Infill with input_extra
Given a prompt "Complete this"
And an infill input extra "llama.h" "LLAMA_API int32_t llama_n_threads();\n"
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
And an infill input suffix "}\n"
And an infill request with no api error
Then 64 tokens are predicted matching cuts|Jimmy|mom|came|into|the|room"

View file

@ -1,5 +0,0 @@
# List of ongoing issues
# run with: DEBUG=ON ./tests.sh --no-skipped --tags bug
@bug
Feature: Issues
# No confirmed issue at the moment

View file

@ -1,36 +0,0 @@
@llama.cpp
@lora
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
And a model url https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/stories15M_MOE-F16.gguf
And a model file stories15M_MOE-F16.gguf
And a model alias stories15M_MOE
And a lora adapter file from https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf
And 42 as server seed
And 1024 as batch size
And 1024 as ubatch size
And 2048 KV cache size
And 64 max tokens to predict
And 0.0 temperature
Then the server is starting
Then the server is healthy
Scenario: Completion LoRA disabled
Given switch off lora adapter 0
Given a prompt:
"""
Look in thy glass
"""
And a completion request with no api error
Then 64 tokens are predicted matching little|girl|three|years|old
Scenario: Completion LoRA enabled
Given switch on lora adapter 0
Given a prompt:
"""
Look in thy glass
"""
And a completion request with no api error
Then 64 tokens are predicted matching eye|love|glass|sun

View file

@ -1,131 +0,0 @@
@llama.cpp
@parallel
Feature: Parallel
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models
And a model file test-model-00001-of-00003.gguf
And 42 as server seed
And 128 as batch size
And 256 KV cache size
And 2 slots
And continuous batching
Then the server is starting
Then the server is healthy
Scenario Outline: Multi users completion
Given a prompt:
"""
Write a very long story about AI.
"""
And a prompt:
"""
Write another very long music lyrics.
"""
And <n_predict> max tokens to predict
Given concurrent completion requests
Then the server is busy
Then the server is idle
And all slots are idle
Then all prompts are predicted with <n_predict> tokens
Examples:
| n_predict |
| 128 |
Scenario Outline: Multi users OAI completions compatibility
Given a system prompt You are a writer.
And a model tinyllama-2
Given a prompt:
"""
Write a very long book.
"""
And a prompt:
"""
Write another a poem.
"""
And <n_predict> max tokens to predict
And streaming is <streaming>
Given concurrent OAI completions requests
Then the server is busy
Then the server is idle
Then all prompts are predicted with <n_predict> tokens
Examples:
| streaming | n_predict |
| disabled | 128 |
| enabled | 64 |
Scenario Outline: Multi users OAI completions compatibility no v1
Given a system prompt You are a writer.
And a model tinyllama-2
Given a prompt:
"""
Write a very long book.
"""
And a prompt:
"""
Write another a poem.
"""
And <n_predict> max tokens to predict
And streaming is <streaming>
Given concurrent OAI completions requests no v1
Then the server is busy
Then the server is idle
Then all prompts are predicted with <n_predict> tokens
Examples:
| streaming | n_predict |
| disabled | 128 |
| enabled | 64 |
Scenario Outline: Multi users with number of prompts exceeding number of slots
Given a system prompt You are a writer.
And a model tinyllama-2
Given a prompt:
"""
Write a very long book.
"""
And a prompt:
"""
Write another a poem.
"""
And a prompt:
"""
What is LLM?
"""
And a prompt:
"""
The sky is blue and I love it.
"""
And <n_predict> max tokens to predict
And streaming is <streaming>
Given concurrent OAI completions requests
Then the server is busy
Then the server is idle
Then all prompts are predicted with <n_predict> tokens
Examples:
| streaming | n_predict |
| disabled | 128 |
| enabled | 64 |
Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969
Given a prompt:
"""
Write a very long story about AI.
"""
And a prompt:
"""
Write another very long music lyrics.
"""
And a prompt:
"""
Write a very long poem.
"""
And a prompt:
"""
Write a very long joke.
"""
And 128 max tokens to predict
Given concurrent completion requests
Then the server is busy
Then the server is idle
Then all prompts are predicted

View file

@ -1,56 +0,0 @@
# run with: ./tests.sh --no-skipped --tags passkey
@passkey
@slow
Feature: Passkey / Self-extend with context shift
Background: Server startup
Given a server listening on localhost:8080
# Generates a long text of junk and inserts a secret passkey number inside it.
# Then we query the LLM for the secret passkey.
# see #3856 and #4810
Scenario Outline: Passkey
Given a model file <hf_file> from HF repo <hf_repo>
And <n_batch> as batch size
And <n_junk> as number of junk
And <n_predicted> server max tokens to predict
And 42 as seed
And 0.0 temperature
And <n_ctx> KV cache size
And 1 slots
And <n_ga> group attention factor to extend context size through self-extend
And <n_ga_w> group attention width to extend context size through self-extend
# Can be override with N_GPU_LAYERS
And <ngl> GPU offloaded layers
Then the server is starting
# Higher timeout because the model may need to be downloaded from the internet
Then the server is healthy with timeout 120 seconds
Given available models
Then model 0 is trained on <n_ctx_train> tokens context
Given a prefix prompt:
"""
here is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.
"""
And a passkey prompt template:
"""
The pass key is <passkey> Remember it. <passkey> is the pass key.
"""
And a junk suffix prompt:
"""
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
"""
And a suffix prompt:
"""
What is the pass key? The pass key is
"""
Given a "<passkey>" passkey challenge prompt with the passkey inserted every <i_pos> junk
And a completion request with no api error
Then <n_predicted> tokens are predicted matching <re_content>
Examples:
| hf_repo | hf_file | n_ctx_train | ngl | n_ctx | n_batch | n_ga | n_ga_w | n_junk | i_pos | passkey | n_predicted | re_content |
| TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 4 | 512 | 250 | 50 | 42 | 1 | 42 |
| TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 2 | 512 | 250 | 50 | 42 | 1 | \b((?!42)\w)+\b |
#| TheBloke/Llama-2-7B-GGUF | llama-2-7b.Q2_K.gguf | 4096 | 3 | 16384 | 512 | 4 | 512 | 500 | 300 | 1234 | 5 | 1234 |
#| TheBloke/Mixtral-8x7B-v0.1-GGUF | mixtral-8x7b-v0.1.Q2_K.gguf | 32768 | 2 | 16384 | 512 | 4 | 512 | 500 | 100 | 0987 | 5 | 0
# 987 |

View file

@ -1,42 +0,0 @@
@llama.cpp
@rerank
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
And a model url https://huggingface.co/ggml-org/models/resolve/main/jina-reranker-v1-tiny-en/ggml-model-f16.gguf
And a model file jina-reranker-v1-tiny-en.gguf
And a model alias jina-reranker-v1-tiny-en
And 42 as server seed
And 2 slots
And 512 as batch size
And 512 as ubatch size
And 512 KV cache size
And enable reranking endpoint
Then the server is starting
Then the server is healthy
Scenario: Rerank
Given a rerank query:
"""
Machine learning is
"""
And a rerank document:
"""
A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.
"""
And a rerank document:
"""
Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.
"""
And a rerank document:
"""
Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.
"""
And a rerank document:
"""
Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine.
"""
When reranking request
Then reranking results are returned
Then reranking highest score is index 2 and lowest score is index 3

View file

@ -1,118 +0,0 @@
@llama.cpp
@results
Feature: Results
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models
And a model file test-model-00001-of-00003.gguf
And 128 as batch size
And 1024 KV cache size
And 128 max tokens to predict
And continuous batching
Scenario Outline: consistent results with same seed
Given <n_slots> slots
And 1.0 temperature
Then the server is starting
Then the server is healthy
Given 4 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 42
Given concurrent completion requests
Then the server is busy
Then the server is idle
And all slots are idle
Then all predictions are equal
Examples:
| n_slots |
| 1 |
# FIXME: unified KV cache nondeterminism
# | 2 |
Scenario Outline: different results with different seed
Given <n_slots> slots
And 1.0 temperature
Then the server is starting
Then the server is healthy
Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 42
Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 43
Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 44
Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 45
Given concurrent completion requests
Then the server is busy
Then the server is idle
And all slots are idle
Then all predictions are different
Examples:
| n_slots |
| 1 |
| 2 |
Scenario Outline: consistent results with same seed and varying batch size
Given 4 slots
And <temp> temperature
# And 0 as draft
Then the server is starting
Then the server is healthy
Given 1 prompts "Write a very long story about AI." with seed 42
And concurrent completion requests
# Then the server is busy # Not all slots will be utilized.
Then the server is idle
And all slots are idle
Given <n_parallel> prompts "Write a very long story about AI." with seed 42
And concurrent completion requests
# Then the server is busy # Not all slots will be utilized.
Then the server is idle
And all slots are idle
Then all predictions are equal
Examples:
| n_parallel | temp |
| 1 | 0.0 |
| 1 | 1.0 |
# FIXME: unified KV cache nondeterminism
# See https://github.com/ggerganov/whisper.cpp/issues/1941#issuecomment-1986923227
# and https://github.com/ggerganov/llama.cpp/pull/6122#discussion_r1531405574
# and https://github.com/ggerganov/llama.cpp/pull/7347 .
# | 2 | 0.0 |
# | 4 | 0.0 |
# | 2 | 1.0 |
# | 4 | 1.0 |
Scenario Outline: consistent token probs with same seed and prompt
Given <n_slots> slots
And <n_kv> KV cache size
And 1.0 temperature
And <n_predict> max tokens to predict
Then the server is starting
Then the server is healthy
Given 1 prompts "The meaning of life is" with seed 42
And concurrent completion requests
# Then the server is busy # Not all slots will be utilized.
Then the server is idle
And all slots are idle
Given <n_parallel> prompts "The meaning of life is" with seed 42
And concurrent completion requests
# Then the server is busy # Not all slots will be utilized.
Then the server is idle
And all slots are idle
Then all token probabilities are equal
Examples:
| n_slots | n_kv | n_predict | n_parallel |
| 4 | 1024 | 1 | 1 |
# FIXME: unified KV cache nondeterminism
# See https://github.com/ggerganov/whisper.cpp/issues/1941#issuecomment-1986923227
# and https://github.com/ggerganov/llama.cpp/pull/6122#discussion_r1531405574
# and https://github.com/ggerganov/llama.cpp/pull/7347 .
# | 4 | 1024 | 1 | 4 |
# | 4 | 1024 | 100 | 1 |
# This test still fails even the above patches; the first token probabilities are already different.
# | 4 | 1024 | 100 | 4 |

View file

@ -1,68 +0,0 @@
@llama.cpp
@security
Feature: Security
Background: Server startup with an api key defined
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a server api key THIS_IS_THE_KEY
Then the server is starting
Then the server is healthy
Scenario Outline: Completion with some user api key
Given a prompt test
And a user api key <api_key>
And 4 max tokens to predict
And a completion request with <api_error> api error
Examples: Prompts
| api_key | api_error |
| THIS_IS_THE_KEY | no |
| THIS_IS_THE_KEY | no |
| hackeme | raised |
| | raised |
Scenario Outline: OAI Compatibility
Given a system prompt test
And a user prompt test
And a model test
And 2 max tokens to predict
And streaming is disabled
And a user api key <api_key>
Given an OAI compatible chat completions request with <api_error> api error
Examples: Prompts
| api_key | api_error |
| THIS_IS_THE_KEY | no |
| THIS_IS_THE_KEY | no |
| hackme | raised |
Scenario Outline: OAI Compatibility (invalid response formats)
Given a system prompt test
And a user prompt test
And a response format <response_format>
And a model test
And 2 max tokens to predict
And streaming is disabled
Given an OAI compatible chat completions request with raised api error
Examples: Prompts
| response_format |
| {"type": "sound"} |
| {"type": "json_object", "schema": 123} |
| {"type": "json_object", "schema": {"type": 123}} |
| {"type": "json_object", "schema": {"type": "hiccup"}} |
Scenario Outline: CORS Options
Given a user api key THIS_IS_THE_KEY
When an OPTIONS request is sent from <origin>
Then CORS header <cors_header> is set to <cors_header_value>
Examples: Headers
| origin | cors_header | cors_header_value |
| localhost | Access-Control-Allow-Origin | localhost |
| web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr |
| origin | Access-Control-Allow-Credentials | true |
| web.mydomain.fr | Access-Control-Allow-Methods | GET, POST |
| web.mydomain.fr | Access-Control-Allow-Headers | * |

View file

@ -1,120 +0,0 @@
@llama.cpp
@server
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model file test-model.gguf
And a model alias tinyllama-2
And BOS token is 1
And 42 as server seed
# KV Cache corresponds to the total amount of tokens
# that can be stored across all independent sequences: #4130
# see --ctx-size and #5568
And 256 KV cache size
And 32 as batch size
And 2 slots
And 64 server max tokens to predict
And prometheus compatible metrics exposed
Then the server is starting
Then the server is healthy
Scenario: Health
Then the server is ready
And all slots are idle
Scenario Outline: Completion
Given a prompt <prompt>
And <n_predict> max tokens to predict
And a completion request with no api error
Then <n_predicted> tokens are predicted matching <re_content>
And the completion is <truncated> truncated
And <n_prompt> prompt tokens are processed
And prometheus metrics are exposed
And metric llamacpp:tokens_predicted is <n_predicted>
Examples: Prompts
| prompt | n_predict | re_content | n_prompt | n_predicted | truncated |
| I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not |
| Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids\|Anna\|forest)+ | 46 | 64 | not |
Scenario: Completion prompt truncated
Given a prompt:
"""
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And a completion request with no api error
Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
And the completion is truncated
And 109 prompt tokens are processed
Scenario Outline: OAI Compatibility
Given a model <model>
And a system prompt <system_prompt>
And a user prompt <user_prompt>
And <max_tokens> max tokens to predict
And streaming is <enable_streaming>
Given an OAI compatible chat completions request with no api error
Then <n_predicted> tokens are predicted matching <re_content>
And <n_prompt> prompt tokens are processed
And the completion is <truncated> truncated
Examples: Prompts
| model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated |
| llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird\|Annabyear)+ | -1 | 64 | enabled | |
Scenario Outline: OAI Compatibility w/ response format
Given a model test
And a system prompt test
And a user prompt test
And a response format <response_format>
And 10 max tokens to predict
Given an OAI compatible chat completions request with no api error
Then <n_predicted> tokens are predicted matching <re_content>
Examples: Prompts
| response_format | n_predicted | re_content |
| {"type": "json_object", "schema": {"const": "42"}} | 6 | "42" |
| {"type": "json_object", "schema": {"items": [{"type": "integer"}]}} | 10 | \[ -300 \] |
| {"type": "json_object"} | 10 | \{ " Jacky. |
Scenario: Tokenize / Detokenize
When tokenizing:
"""
What is the capital of France ?
"""
Then tokens can be detokenized
And tokens do not begin with BOS
Scenario: Tokenize w/ BOS
Given adding special tokens
When tokenizing:
"""
What is the capital of Germany?
"""
Then tokens begin with BOS
Given first token is removed
Then tokens can be detokenized
Scenario: Tokenize with pieces
When tokenizing with pieces:
"""
What is the capital of Germany?
"""
Then tokens are given with pieces
Scenario: Models available
Given available models
Then 1 models are supported
Then model 0 is identified by tinyllama-2
Then model 0 is trained on 128 tokens context

View file

@ -1,58 +0,0 @@
@llama.cpp
@slotsave
Feature: llama.cpp server slot management
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And prompt caching is enabled
And 2 slots
And . as slot save path
And 2048 KV cache size
And 42 as server seed
And 24 max tokens to predict
Then the server is starting
Then the server is healthy
Scenario: Save and Restore Slot
# First prompt in slot 1 should be fully processed
Given a user prompt "What is the capital of France?"
And using slot id 1
And a completion request with no api error
Then 24 tokens are predicted matching (Lily|cake)
And 22 prompt tokens are processed
When the slot 1 is saved with filename "slot1.bin"
Then the server responds with status code 200
# Since we have cache, this should only process the last tokens
Given a user prompt "What is the capital of Germany?"
And a completion request with no api error
Then 24 tokens are predicted matching (Thank|special)
And 7 prompt tokens are processed
# Loading the original cache into slot 0,
# we should only be processing 1 prompt token and get the same output
When the slot 0 is restored with filename "slot1.bin"
Then the server responds with status code 200
Given a user prompt "What is the capital of France?"
And using slot id 0
And a completion request with no api error
Then 24 tokens are predicted matching (Lily|cake)
And 1 prompt tokens are processed
# For verification that slot 1 was not corrupted during slot 0 load, same thing
Given a user prompt "What is the capital of Germany?"
And using slot id 1
And a completion request with no api error
Then 24 tokens are predicted matching (Thank|special)
And 1 prompt tokens are processed
Scenario: Erase Slot
Given a user prompt "What is the capital of France?"
And using slot id 1
And a completion request with no api error
Then 24 tokens are predicted matching (Lily|cake)
And 22 prompt tokens are processed
When the slot 1 is erased
Then the server responds with status code 200
Given a user prompt "What is the capital of France?"
And a completion request with no api error
Then 24 tokens are predicted matching (Lily|cake)
And 22 prompt tokens are processed

File diff suppressed because it is too large Load diff

View file

@ -1,25 +0,0 @@
# run with: ./tests.sh --no-skipped --tags wrong_usage
@wrong_usage
Feature: Wrong usage of llama.cpp server
#3969 The user must always set --n-predict option
# to cap the number of tokens any completion request can generate
# or pass n_predict/max_tokens in the request.
Scenario: Infinite loop
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And 42 as server seed
And 2048 KV cache size
# Uncomment below to fix the issue
#And 64 server max tokens to predict
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Go to: infinite loop
"""
# Uncomment below to fix the issue
#And 128 max tokens to predict
Given concurrent completion requests
Then the server is idle
Then all prompts are predicted

View file

@ -1,5 +1,5 @@
aiohttp~=3.9.3
behave~=1.2.6
pytest~=8.3.3
huggingface_hub~=0.23.2
numpy~=1.26.4
openai~=1.30.3

View file

@ -4,8 +4,7 @@ set -eu
if [ $# -lt 1 ]
then
# Start @llama.cpp scenario
behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp
pytest -v -x
else
behave "$@"
pytest "$@"
fi

View file

@ -0,0 +1,34 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_server_start_simple():
global server
server.start()
res = server.make_request("GET", "/health")
assert res.status_code == 200
def test_server_props():
global server
server.start()
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert res.body["total_slots"] == server.n_slots
def test_server_models():
global server
server.start()
res = server.make_request("GET", "/models")
assert res.status_code == 200
assert len(res.body["data"]) == 1
assert res.body["data"][0]["id"] == server.model_alias

View file

@ -0,0 +1,129 @@
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
[
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
})
assert res.status_code == 200
assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"])
if truncated:
assert choice["finish_reason"] == "length"
else:
assert choice["finish_reason"] == "stop"
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
[
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
]
)
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
global server
server.start()
res = server.make_stream_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"stream": True,
})
content = ""
for data in res:
choice = data["choices"][0]
if choice["finish_reason"] in ["stop", "length"]:
assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted
assert "content" not in choice["delta"]
assert match_regex(re_content, content)
# FIXME: not sure why this is incorrect in stream mode
# if truncated:
# assert choice["finish_reason"] == "length"
# else:
# assert choice["finish_reason"] == "stop"
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"]
def test_chat_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=8,
seed=42,
temperature=0.8,
)
print(res)
assert res.choices[0].finish_reason == "stop"
assert res.choices[0].message.content is not None
assert match_regex("(Suddenly)+", res.choices[0].message.content)
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
({"type": "json_object"}, 10, "(\\{|John)+"),
({"type": "sound"}, 0, None),
# invalid response format (expected to fail)
({"type": "json_object", "schema": 123}, 0, None),
({"type": "json_object", "schema": {"type": 123}}, 0, None),
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
])
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"response_format": response_format,
})
if re_content is not None:
assert res.status_code == 200
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"])
else:
assert res.status_code != 200
assert "error" in res.body

View file

@ -0,0 +1,223 @@
import pytest
import time
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
])
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": prompt,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == n_prompt
assert res.body["timings"]["predicted_n"] == n_predicted
assert res.body["truncated"] == truncated
assert match_regex(re_content, res.body["content"])
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
])
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
global server
server.start()
res = server.make_stream_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": prompt,
"stream": True,
})
content = ""
for data in res:
if data["stop"]:
assert data["timings"]["prompt_n"] == n_prompt
assert data["timings"]["predicted_n"] == n_predicted
assert data["truncated"] == truncated
assert match_regex(re_content, content)
else:
content += data["content"]
@pytest.mark.parametrize("n_slots", [1, 2])
def test_consistent_result_same_seed(n_slots: int):
global server
server.n_slots = n_slots
server.start()
last_res = None
for _ in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
last_res = res
@pytest.mark.parametrize("n_slots", [1, 2])
def test_different_result_different_seed(n_slots: int):
global server
server.n_slots = n_slots
server.start()
last_res = None
for seed in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": seed,
"temperature": 1.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] != last_res.body["content"]
last_res = res
@pytest.mark.parametrize("n_batch", [16, 32])
@pytest.mark.parametrize("temperature", [0.0, 1.0])
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
global server
server.n_batch = n_batch
server.start()
last_res = None
for _ in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": temperature,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
last_res = res
@pytest.mark.skip(reason="This test fails on linux, need to be fixed")
def test_cache_vs_nocache_prompt():
global server
server.start()
res_cache = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": True,
})
res_no_cache = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
})
assert res_cache.body["content"] == res_no_cache.body["content"]
def test_completion_with_tokens_input():
global server
server.temperature = 0.0
server.start()
prompt_str = "I believe the meaning of life is"
res = server.make_request("POST", "/tokenize", data={
"content": prompt_str,
"add_special": True,
})
assert res.status_code == 200
tokens = res.body["tokens"]
# single completion
res = server.make_request("POST", "/completion", data={
"prompt": tokens,
})
assert res.status_code == 200
assert type(res.body["content"]) == str
# batch completion
res = server.make_request("POST", "/completion", data={
"prompt": [tokens, tokens],
})
assert res.status_code == 200
assert type(res.body) == list
assert len(res.body) == 2
assert res.body[0]["content"] == res.body[1]["content"]
# mixed string and tokens
res = server.make_request("POST", "/completion", data={
"prompt": [tokens, prompt_str],
})
assert res.status_code == 200
assert type(res.body) == list
assert len(res.body) == 2
assert res.body[0]["content"] == res.body[1]["content"]
# mixed string and tokens in one sequence
res = server.make_request("POST", "/completion", data={
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
})
assert res.status_code == 200
assert type(res.body["content"]) == str
@pytest.mark.parametrize("n_slots,n_requests", [
(1, 3),
(2, 2),
(2, 4),
(4, 2), # some slots must be idle
(4, 6),
])
def test_completion_parallel_slots(n_slots: int, n_requests: int):
global server
server.n_slots = n_slots
server.temperature = 0.0
server.start()
PROMPTS = [
("Write a very long book.", "(very|special|big)+"),
("Write another a poem.", "(small|house)+"),
("What is LLM?", "(Dad|said)+"),
("The sky is blue and I love it.", "(climb|leaf)+"),
("Write another very long music lyrics.", "(friends|step|sky)+"),
("Write a very long joke.", "(cat|Whiskers)+"),
]
def check_slots_status():
should_all_slots_busy = n_requests >= n_slots
time.sleep(0.1)
res = server.make_request("GET", "/slots")
n_busy = sum([1 for slot in res.body if slot["is_processing"]])
if should_all_slots_busy:
assert n_busy == n_slots
else:
assert n_busy <= n_slots
tasks = []
for i in range(n_requests):
prompt, re_content = PROMPTS[i % len(PROMPTS)]
tasks.append((server.make_request, ("POST", "/completion", {
"prompt": prompt,
"seed": 42,
"temperature": 1.0,
})))
tasks.append((check_slots_status, ()))
results = parallel_function_calls(tasks)
# check results
for i in range(n_requests):
prompt, re_content = PROMPTS[i % len(PROMPTS)]
res = results[i]
assert res.status_code == 200
assert type(res.body["content"]) == str
assert len(res.body["content"]) > 10
# FIXME: the result is not deterministic when using other slot than slot 0
# assert match_regex(re_content, res.body["content"])

View file

@ -0,0 +1,67 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
LONG_TEXT = """
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
""".strip()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.n_ctx = 256
server.n_slots = 2
def test_ctx_shift_enabled():
# the prompt is 301 tokens
# the slot context is 256/2 = 128 tokens
# the prompt is truncated to keep the last 109 tokens
# 64 tokens are generated thanks to shifting the context when it gets full
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": LONG_TEXT,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == 109
assert res.body["timings"]["predicted_n"] == 64
assert res.body["truncated"] is True
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
(64, 64, False),
(-1, 120, True),
])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server
server.disable_ctx_shift = True
server.n_predict = -1
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": "Hi how are you",
})
assert res.status_code == 200
assert res.body["timings"]["predicted_n"] == n_token_output
assert res.body["truncated"] == truncated
def test_ctx_shift_disabled_long_prompt():
global server
server.disable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": LONG_TEXT,
})
assert res.status_code != 200
assert "error" in res.body
assert "exceeds the available context size" in res.body["error"]["message"]

View file

@ -0,0 +1,99 @@
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.bert_bge_small()
EPSILON = 1e-3
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.bert_bge_small()
def test_embedding_single():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
assert len(res.body['data']) == 1
assert 'embedding' in res.body['data'][0]
assert len(res.body['data'][0]['embedding']) > 1
# make sure embedding vector is normalized
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
def test_embedding_multiple():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
"This is a test",
"This is another test",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 4
for d in res.body['data']:
assert 'embedding' in d
assert len(d['embedding']) > 1
def test_embedding_openai_library_single():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
assert len(res.data) == 1
assert len(res.data[0].embedding) > 1
def test_embedding_openai_library_multiple():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.embeddings.create(model="text-embedding-3-small", input=[
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
"This is a test",
"This is another test",
])
assert len(res.data) == 4
for d in res.data:
assert len(d.embedding) > 1
def test_embedding_error_prompt_too_long():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": "This is a test " * 512,
})
assert res.status_code != 200
assert "too large" in res.body["error"]["message"]
def test_same_prompt_give_same_result():
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 5
for i in range(1, len(res.body['data'])):
v0 = res.body['data'][0]['embedding']
vi = res.body['data'][i]['embedding']
for x, y in zip(v0, vi):
assert abs(x - y) < EPSILON

View file

@ -0,0 +1,35 @@
import pytest
from utils import *
server = ServerPreset.tinyllama_infill()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama_infill()
def test_infill_without_input_extra():
global server
server.start()
res = server.make_request("POST", "/infill", data={
"prompt": "Complete this",
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
def test_infill_with_input_extra():
global server
server.start()
res = server.make_request("POST", "/infill", data={
"prompt": "Complete this",
"input_extra": [{
"filename": "llama.h",
"text": "LLAMA_API int32_t llama_n_threads();\n"
}],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])

View file

@ -0,0 +1,42 @@
import pytest
import os
from utils import *
server = ServerPreset.stories15m_moe()
LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.stories15m_moe()
# download lora file if needed
file_name = LORA_FILE_URL.split('/').pop()
lora_file = f'../../../{file_name}'
if not os.path.exists(lora_file):
print(f"Downloading {LORA_FILE_URL} to {lora_file}")
with open(lora_file, 'wb') as f:
f.write(requests.get(LORA_FILE_URL).content)
print(f"Done downloading lora file")
server.lora_files = [lora_file]
@pytest.mark.parametrize("scale,re_content", [
# without applying lora, the model should behave like a bedtime story generator
(0.0, "(little|girl|three|years|old)+"),
# with lora, the model should behave like a Shakespearean text generator
(1.0, "(eye|love|glass|sun)+"),
])
def test_lora(scale: float, re_content: str):
global server
server.start()
res_lora_control = server.make_request("POST", "/lora-adapters", data=[
{"id": 0, "scale": scale}
])
assert res_lora_control.status_code == 200
res = server.make_request("POST", "/completion", data={
"prompt": "Look in thy glass",
})
assert res.status_code == 200
assert match_regex(re_content, res.body["content"])

View file

@ -0,0 +1,38 @@
import pytest
from utils import *
server = ServerPreset.jina_reranker_tiny()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.jina_reranker_tiny()
def test_rerank():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"documents": [
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]
})
assert res.status_code == 200
assert len(res.body["results"]) == 4
most_relevant = res.body["results"][0]
least_relevant = res.body["results"][0]
for doc in res.body["results"]:
if doc["relevance_score"] > most_relevant["relevance_score"]:
most_relevant = doc
if doc["relevance_score"] < least_relevant["relevance_score"]:
least_relevant = doc
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
assert most_relevant["index"] == 2
assert least_relevant["index"] == 3

View file

@ -0,0 +1,83 @@
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllama2()
TEST_API_KEY = "sk-this-is-the-secret-key"
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.api_key = TEST_API_KEY
@pytest.mark.parametrize("endpoint", ["/health", "/models"])
def test_access_public_endpoint(endpoint: str):
global server
server.start()
res = server.make_request("GET", endpoint)
assert res.status_code == 200
assert "error" not in res.body
@pytest.mark.parametrize("api_key", [None, "invalid-key"])
def test_incorrect_api_key(api_key: str):
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"Authorization": f"Bearer {api_key}" if api_key else None,
})
assert res.status_code == 401
assert "error" in res.body
assert res.body["error"]["type"] == "authentication_error"
def test_correct_api_key():
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"Authorization": f"Bearer {TEST_API_KEY}",
})
assert res.status_code == 200
assert "error" not in res.body
assert "content" in res.body
def test_openai_library_correct_api_key():
global server
server.start()
client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a chatbot."},
{"role": "user", "content": "What is the meaning of life?"},
],
)
assert len(res.choices) == 1
@pytest.mark.parametrize("origin,cors_header,cors_header_value", [
("localhost", "Access-Control-Allow-Origin", "localhost"),
("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"),
("origin", "Access-Control-Allow-Credentials", "true"),
("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"),
("web.mydomain.fr", "Access-Control-Allow-Headers", "*"),
])
def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
global server
server.start()
res = server.make_request("OPTIONS", "/completions", headers={
"Origin": origin,
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Authorization",
})
assert res.status_code == 200
assert cors_header in res.headers
assert res.headers[cors_header] == cors_header_value

View file

@ -0,0 +1,98 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.slot_save_path = "./tmp"
server.temperature = 0.0
def test_slot_save_restore():
global server
server.start()
# First prompt in slot 1 should be fully processed
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
# Save state of slot 1
res = server.make_request("POST", "/slots/1?action=save", data={
"filename": "slot1.bin",
})
assert res.status_code == 200
assert res.body["n_saved"] == 84
# Since we have cache, this should only process the last tokens
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of Germany?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
# Loading the saved cache into slot 0
res = server.make_request("POST", "/slots/0?action=restore", data={
"filename": "slot1.bin",
})
assert res.status_code == 200
assert res.body["n_restored"] == 84
# Since we have cache, slot 0 should only process the last tokens
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of Germany?",
"id_slot": 0,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
# For verification that slot 1 was not corrupted during slot 0 load, same thing should work
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of Germany?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 1
def test_slot_erase():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
# erase slot 1
res = server.make_request("POST", "/slots/1?action=erase")
assert res.status_code == 200
# re-run the same prompt, it should process all tokens again
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed

View file

@ -0,0 +1,59 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_tokenize_detokenize():
global server
server.start()
# tokenize
content = "What is the capital of France ?"
res_tok = server.make_request("POST", "/tokenize", data={
"content": content
})
assert res_tok.status_code == 200
assert len(res_tok.body["tokens"]) > 5
# detokenize
res_detok = server.make_request("POST", "/detokenize", data={
"tokens": res_tok.body["tokens"],
})
assert res_detok.status_code == 200
assert res_detok.body["content"].strip() == content
def test_tokenize_with_bos():
global server
server.start()
# tokenize
content = "What is the capital of France ?"
bosId = 1
res_tok = server.make_request("POST", "/tokenize", data={
"content": content,
"add_special": True,
})
assert res_tok.status_code == 200
assert res_tok.body["tokens"][0] == bosId
def test_tokenize_with_pieces():
global server
server.start()
# tokenize
content = "This is a test string with unicode 媽 and emoji 🤗"
res_tok = server.make_request("POST", "/tokenize", data={
"content": content,
"with_pieces": True,
})
assert res_tok.status_code == 200
for token in res_tok.body["tokens"]:
assert "id" in token
assert token["id"] > 0
assert "piece" in token
assert len(token["piece"]) > 0

View file

@ -0,0 +1,377 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# type: ignore[reportUnusedImport]
import subprocess
import os
import re
import json
import sys
import threading
import requests
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import (
Any,
Callable,
ContextManager,
Iterable,
Iterator,
List,
Literal,
Tuple,
Set,
)
from re import RegexFlag
class ServerResponse:
headers: dict
status_code: int
body: dict | Any
class ServerProcess:
# default options
debug: bool = False
server_port: int = 8080
server_host: str = "127.0.0.1"
model_hf_repo: str = "ggml-org/models"
model_hf_file: str = "tinyllamas/stories260K.gguf"
model_alias: str = "tinyllama-2"
temperature: float = 0.8
seed: int = 42
# custom options
model_alias: str | None = None
model_url: str | None = None
model_file: str | None = None
n_threads: int | None = None
n_gpu_layer: int | None = None
n_batch: int | None = None
n_ubatch: int | None = None
n_ctx: int | None = None
n_ga: int | None = None
n_ga_w: int | None = None
n_predict: int | None = None
n_prompts: int | None = 0
slot_save_path: str | None = None
id_slot: int | None = None
cache_prompt: bool | None = None
n_slots: int | None = None
server_continuous_batching: bool | None = False
server_embeddings: bool | None = False
server_reranking: bool | None = False
server_metrics: bool | None = False
draft: int | None = None
api_key: str | None = None
response_format: str | None = None
lora_files: List[str] | None = None
disable_ctx_shift: int | None = False
# session variables
process: subprocess.Popen | None = None
def __init__(self):
if "N_GPU_LAYERS" in os.environ:
self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"])
if "DEBUG" in os.environ:
self.debug = True
if "PORT" in os.environ:
self.server_port = int(os.environ["PORT"])
def start(self, timeout_seconds: int = 10) -> None:
if "LLAMA_SERVER_BIN_PATH" in os.environ:
server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
elif os.name == "nt":
server_path = "../../../build/bin/Release/llama-server.exe"
else:
server_path = "../../../build/bin/llama-server"
server_args = [
"--slots", # requires to get slot status via /slots endpoint
"--host",
self.server_host,
"--port",
self.server_port,
"--temp",
self.temperature,
"--seed",
self.seed,
]
if self.model_file:
server_args.extend(["--model", self.model_file])
if self.model_url:
server_args.extend(["--model-url", self.model_url])
if self.model_hf_repo:
server_args.extend(["--hf-repo", self.model_hf_repo])
if self.model_hf_file:
server_args.extend(["--hf-file", self.model_hf_file])
if self.n_batch:
server_args.extend(["--batch-size", self.n_batch])
if self.n_ubatch:
server_args.extend(["--ubatch-size", self.n_ubatch])
if self.n_threads:
server_args.extend(["--threads", self.n_threads])
if self.n_gpu_layer:
server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
if self.draft is not None:
server_args.extend(["--draft", self.draft])
if self.server_continuous_batching:
server_args.append("--cont-batching")
if self.server_embeddings:
server_args.append("--embedding")
if self.server_reranking:
server_args.append("--reranking")
if self.server_metrics:
server_args.append("--metrics")
if self.model_alias:
server_args.extend(["--alias", self.model_alias])
if self.n_ctx:
server_args.extend(["--ctx-size", self.n_ctx])
if self.n_slots:
server_args.extend(["--parallel", self.n_slots])
if self.n_predict:
server_args.extend(["--n-predict", self.n_predict])
if self.slot_save_path:
server_args.extend(["--slot-save-path", self.slot_save_path])
if self.n_ga:
server_args.extend(["--grp-attn-n", self.n_ga])
if self.n_ga_w:
server_args.extend(["--grp-attn-w", self.n_ga_w])
if self.debug:
server_args.append("--verbose")
if self.lora_files:
for lora_file in self.lora_files:
server_args.extend(["--lora", lora_file])
if self.disable_ctx_shift:
server_args.extend(["--no-context-shift"])
if self.api_key:
server_args.extend(["--api-key", self.api_key])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
flags = 0
if "nt" == os.name:
flags |= subprocess.DETACHED_PROCESS
flags |= subprocess.CREATE_NEW_PROCESS_GROUP
flags |= subprocess.CREATE_NO_WINDOW
self.process = subprocess.Popen(
[str(arg) for arg in [server_path, *server_args]],
creationflags=flags,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env={**os.environ, "LLAMA_CACHE": "tmp"},
)
server_instances.add(self)
def server_log(in_stream, out_stream):
for line in iter(in_stream.readline, b""):
print(line.decode("utf-8"), end="", file=out_stream)
thread_stdout = threading.Thread(
target=server_log, args=(self.process.stdout, sys.stdout), daemon=True
)
thread_stdout.start()
thread_stderr = threading.Thread(
target=server_log, args=(self.process.stderr, sys.stderr), daemon=True
)
thread_stderr.start()
print(f"server pid={self.process.pid}, pytest pid={os.getpid()}")
# wait for server to start
start_time = time.time()
while time.time() - start_time < timeout_seconds:
try:
response = self.make_request("GET", "/slots", headers={
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
})
if response.status_code == 200:
self.ready = True
return # server is ready
except Exception as e:
pass
print(f"Waiting for server to start...")
time.sleep(0.5)
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
def stop(self) -> None:
server_instances.remove(self)
if self.process:
print(f"Stopping server with pid={self.process.pid}")
self.process.kill()
self.process = None
def make_request(
self,
method: str,
path: str,
data: dict | Any | None = None,
headers: dict | None = None,
) -> ServerResponse:
url = f"http://{self.server_host}:{self.server_port}{path}"
parse_body = False
if method == "GET":
response = requests.get(url, headers=headers)
parse_body = True
elif method == "POST":
response = requests.post(url, headers=headers, json=data)
parse_body = True
elif method == "OPTIONS":
response = requests.options(url, headers=headers)
else:
raise ValueError(f"Unimplemented method: {method}")
result = ServerResponse()
result.headers = dict(response.headers)
result.status_code = response.status_code
result.body = response.json() if parse_body else None
print("Response from server", result.body)
return result
def make_stream_request(
self,
method: str,
path: str,
data: dict | None = None,
headers: dict | None = None,
) -> Iterator[dict]:
url = f"http://{self.server_host}:{self.server_port}{path}"
if method == "POST":
response = requests.post(url, headers=headers, json=data, stream=True)
else:
raise ValueError(f"Unimplemented method: {method}")
for line_bytes in response.iter_lines():
line = line_bytes.decode("utf-8")
if '[DONE]' in line:
break
elif line.startswith('data: '):
data = json.loads(line[6:])
print("Partial response from server", data)
yield data
server_instances: Set[ServerProcess] = set()
class ServerPreset:
@staticmethod
def tinyllama2() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K.gguf"
server.model_alias = "tinyllama-2"
server.n_ctx = 256
server.n_batch = 32
server.n_slots = 2
server.n_predict = 64
server.seed = 42
return server
@staticmethod
def bert_bge_small() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
server.model_alias = "bert-bge-small"
server.n_ctx = 512
server.n_batch = 128
server.n_ubatch = 128
server.n_slots = 2
server.seed = 42
server.server_embeddings = True
return server
@staticmethod
def tinyllama_infill() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
server.model_alias = "tinyllama-infill"
server.n_ctx = 2048
server.n_batch = 1024
server.n_slots = 1
server.n_predict = 64
server.temperature = 0.0
server.seed = 42
return server
@staticmethod
def stories15m_moe() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/stories15M_MOE"
server.model_hf_file = "stories15M_MOE-F16.gguf"
server.model_alias = "stories15m-moe"
server.n_ctx = 2048
server.n_batch = 1024
server.n_slots = 1
server.n_predict = 64
server.temperature = 0.0
server.seed = 42
return server
@staticmethod
def jina_reranker_tiny() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
server.model_alias = "jina-reranker"
server.model_file = "./tmp/jina-reranker-v1-tiny-en.gguf"
server.n_ctx = 512
server.n_batch = 512
server.n_slots = 1
server.seed = 42
server.server_reranking = True
return server
def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
"""
Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
Example usage:
results = parallel_function_calls([
(func1, (arg1, arg2)),
(func2, (arg3, arg4)),
])
"""
results = [None] * len(function_list)
exceptions = []
def worker(index, func, args):
try:
result = func(*args)
results[index] = result
except Exception as e:
exceptions.append((index, str(e)))
with ThreadPoolExecutor() as executor:
futures = []
for i, (func, args) in enumerate(function_list):
future = executor.submit(worker, i, func, args)
futures.append(future)
# Wait for all futures to complete
for future in as_completed(futures):
pass
# Check if there were any exceptions
if exceptions:
print("Exceptions occurred:")
for index, error in exceptions:
print(f"Function at index {index}: {error}")
return results
def match_regex(regex: str, text: str) -> bool:
return (
re.compile(
regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
).search(text)
is not None
)

View file

@ -24,7 +24,6 @@
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
using json = nlohmann::ordered_json;
using llama_tokens = std::vector<llama_token>;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
@ -439,62 +438,6 @@ static std::string gen_chatcmplid() {
// other common utils
//
static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
static bool ends_with(const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}

View file

@ -62,6 +62,9 @@ int main(int argc, char ** argv) {
}
}, nullptr);
// load dynamic backends
ggml_backend_load_all();
// initialize the model
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = ngl;

View file

@ -74,6 +74,10 @@ int main(int argc, char ** argv) {
}
}
// load dynamic backends
ggml_backend_load_all();
// initialize the model
llama_model_params model_params = llama_model_default_params();

View file

@ -0,0 +1,5 @@
set(TARGET llama-speculative-simple)
add_executable(${TARGET} speculative-simple.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -0,0 +1,12 @@
# llama.cpp/examples/speculative-simple
Demonstration of basic greedy speculative decoding
```bash
./bin/llama-speculative-simple \
-m ../models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \
-md ../models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \
-f test.txt -c 0 -ngl 99 --color \
--sampling-seq k --top-k 1 -fa --temp 0.0 \
-ngld 99 --draft-max 16 --draft-min 5 --draft-p-min 0.9
```

View file

@ -0,0 +1,265 @@
#include "arg.h"
#include "common.h"
#include "sampling.h"
#include "speculative.h"
#include "log.h"
#include "llama.h"
#include <cstdio>
#include <cstring>
#include <string>
#include <vector>
int main(int argc, char ** argv) {
common_params params;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
return 1;
}
if (params.n_predict < -1) {
LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
return 1;
}
common_init();
if (params.speculative.model.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
// init llama.cpp
llama_backend_init();
llama_numa_init(params.numa);
llama_model * model_tgt = NULL;
llama_model * model_dft = NULL;
llama_context * ctx_tgt = NULL;
llama_context * ctx_dft = NULL;
// load the target model
common_init_result llama_init_tgt = common_init_from_params(params);
model_tgt = llama_init_tgt.model;
ctx_tgt = llama_init_tgt.context;
// load the draft model
params.devices = params.speculative.devices;
params.model = params.speculative.model;
params.n_ctx = params.speculative.n_ctx;
params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
params.n_gpu_layers = params.speculative.n_gpu_layers;
if (params.speculative.cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
}
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
common_init_result llama_init_dft = common_init_from_params(params);
model_dft = llama_init_dft.model;
ctx_dft = llama_init_dft.context;
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
return 1;
}
// Tokenize the prompt
std::vector<llama_token> inp;
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) {
LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
return 1;
}
if (llama_n_batch(ctx_tgt) < (uint32_t) inp.size()) {
LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt));
return 1;
}
LOG("\n\n");
for (auto id : inp) {
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
}
// how many tokens to draft each time
int n_draft = params.speculative.n_max;
int n_draft_min = params.speculative.n_min;
float p_min = params.speculative.p_min;
int n_predict = 0;
int n_drafted = 0;
int n_accept = 0;
// used to determine end of generation
bool has_eos = false;
// ================================================
// everything until here is standard initialization
// the relevant stuff for speculative decoding starts here
const auto t_enc_start = ggml_time_us();
// target model sampling context
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
// eval the prompt
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
// note: keep the last token separate!
llama_token id_last = inp.back();
// all tokens currently in the target context
llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
int n_past = inp.size() - 1;
// init the speculator
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft;
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
params_spec.p_min = p_min;
struct common_speculative * spec = common_speculative_init(ctx_dft);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
const auto t_enc_end = ggml_time_us();
const auto t_dec_start = ggml_time_us();
while (true) {
// optionally, generate draft tokens that can be appended to the target batch
//
// this is the most important part of the speculation. the more probable tokens that are provided here
// the better the performance will be. in theory, this computation can be performed asynchronously and even
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
// from a cache or lookup tables.
//
llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
// always have a token to evaluate from before - id_last
common_batch_clear(batch_tgt);
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{
// do not waste time on small drafts
if (draft.size() < (size_t) n_draft_min) {
draft.clear();
}
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
}
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
}
// sample from the full target batch and return the accepted tokens based on the target sampler
//
// for each token to be accepted, the sampler would have to sample that same token
// in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
// available logits from the batch and sample the next token until we run out of logits or the sampler
// disagrees with the draft
//
const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
//LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
n_past += ids.size() - 1;
n_drafted += draft.size(); // note: we ignore the discarded small drafts
n_accept += ids.size() - 1;
n_predict += ids.size();
// process the accepted tokens and update contexts
//
// this is the standard token post-processing that we normally do
// in this case, we do it for a group of accepted tokens at once
//
for (size_t i = 0; i < ids.size(); ++i) {
prompt_tgt.push_back(id_last);
id_last = ids[i];
if (llama_token_is_eog(model_tgt, id_last)) {
has_eos = true;
break;
}
const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
if (params.use_color && i + 1 < ids.size()) {
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
} else {
LOG("%s", token_str.c_str());
}
}
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
}
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
break;
}
}
auto t_dec_end = ggml_time_us();
const int n_input = inp.size();
LOG("\n\n");
LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
LOG_INF("\n");
LOG_INF("n_draft = %d\n", n_draft);
LOG_INF("n_predict = %d\n", n_predict);
LOG_INF("n_drafted = %d\n", n_drafted);
LOG_INF("n_accept = %d\n", n_accept);
LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
LOG_INF("\n");
LOG_INF("draft:\n\n");
llama_perf_context_print(ctx_dft);
LOG_INF("\n");
LOG_INF("target:\n\n");
common_perf_print(ctx_tgt, smpl);
common_sampler_free(smpl);
common_speculative_free(spec);
llama_free(ctx_tgt);
llama_free_model(model_tgt);
llama_free(ctx_dft);
llama_free_model(model_dft);
llama_backend_free();
LOG("\n\n");
return 0;
}

View file

@ -12,7 +12,7 @@
#include <string>
#include <vector>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft {
@ -33,7 +33,7 @@ int main(int argc, char ** argv) {
common_params params;
// needed to get candidate probs even for temp <= 0.0
params.sparams.n_probs = 128;
params.sampling.n_probs = 128;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
return 1;
@ -46,7 +46,7 @@ int main(int argc, char ** argv) {
common_init();
if (params.model_draft.empty()) {
if (params.speculative.model.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
@ -55,9 +55,9 @@ int main(int argc, char ** argv) {
const int n_seq_dft = params.n_parallel;
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
const float p_split = params.p_split;
const float p_draft_split = params.speculative.p_split;
std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed);
std::default_random_engine rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sampling.seed);
std::uniform_real_distribution<> u_dist;
// init llama.cpp
@ -76,13 +76,14 @@ int main(int argc, char ** argv) {
ctx_tgt = llama_init_tgt.context;
// load the draft model
params.model = params.model_draft;
params.n_gpu_layers = params.n_gpu_layers_draft;
if (params.draft_cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.draft_cpuparams.n_threads;
params.devices = params.speculative.devices;
params.model = params.speculative.model;
params.n_gpu_layers = params.speculative.n_gpu_layers;
if (params.speculative.cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
}
params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
common_init_result llama_init_dft = common_init_from_params(params);
model_dft = llama_init_dft.model;
ctx_dft = llama_init_dft.context;
@ -170,7 +171,7 @@ int main(int argc, char ** argv) {
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
// how many tokens to draft each time
int n_draft = params.n_draft;
int n_draft = params.speculative.n_max;
int n_predict = 0;
int n_drafted = 0;
@ -183,14 +184,14 @@ int main(int argc, char ** argv) {
bool has_eos = false;
// target model sampling context (reuse the llama_context's sampling instance)
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
// draft sequence data
std::vector<seq_draft> drafts(n_seq_dft);
for (int s = 0; s < n_seq_dft; ++s) {
// allocate llama_sampler for each draft sequence
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
}
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
@ -230,7 +231,7 @@ int main(int argc, char ** argv) {
// for stochastic sampling, attempt to match the token with the drafted tokens
{
bool accept = false;
if (params.sparams.temp > 0) {
if (params.sampling.temp > 0) {
// stochastic verification
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
@ -494,7 +495,7 @@ int main(int argc, char ** argv) {
// attempt to split the branch if the probability is high enough
for (int f = 1; f < 8; ++f) {
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) {
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);

Some files were not shown because too many files have changed in this diff Show more