Merge branch 'master' into cuda-releases

This commit is contained in:
Olivier Chafik 2025-02-04 15:57:59 +00:00
commit 178ad4e8c9
87 changed files with 5073 additions and 1310 deletions

View file

@ -59,16 +59,14 @@ jobs:
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake .. \
cmake -B build \
-DCMAKE_BUILD_RPATH="@loader_path" \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_CURL=ON \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DGGML_RPC=ON
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
- name: Test
id: cmake_test
@ -199,13 +197,11 @@ jobs:
- name: Build
id: cmake_build
run: |
mkdir build
cd build
cmake .. \
cmake -B build \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_CURL=ON \
-DGGML_RPC=ON
cmake --build . --config Release -j $(nproc)
cmake --build build --config Release -j $(nproc)
- name: Test
id: cmake_test
@ -283,26 +279,52 @@ jobs:
id: cmake_build
if: ${{ matrix.sanitizer != 'THREAD' }}
run: |
mkdir build
cd build
cmake .. \
cmake -B build \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
cmake --build . --config ${{ matrix.build_type }} -j $(nproc)
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
- name: Build (no OpenMP)
id: cmake_build_no_openmp
if: ${{ matrix.sanitizer == 'THREAD' }}
run: |
mkdir build
cd build
cmake .. \
cmake -B build \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DGGML_OPENMP=OFF
cmake --build . --config ${{ matrix.build_type }} -j $(nproc)
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
- name: Test
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900
ubuntu-latest-llguidance:
runs-on: ubuntu-latest
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get install build-essential
- name: Build
id: cmake_build
run: |
mkdir build
cd build
cmake .. \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_LLGUIDANCE=ON
cmake --build . --config Release -j $(nproc)
- name: Test
id: cmake_test
@ -335,11 +357,9 @@ jobs:
- name: Build
id: cmake_build
run: |
mkdir build
cd build
cmake .. \
cmake -B build \
-DGGML_RPC=ON
cmake --build . --config Release -j $(nproc)
cmake --build build --config Release -j $(nproc)
- name: Test
id: cmake_test
@ -372,11 +392,9 @@ jobs:
- name: Build
id: cmake_build
run: |
mkdir build
cd build
cmake .. \
cmake -B build \
-DGGML_VULKAN=ON
cmake --build . --config Release -j $(nproc)
cmake --build build --config Release -j $(nproc)
- name: Test
id: cmake_test
@ -493,13 +511,11 @@ jobs:
id: cmake_build
run: |
source /opt/intel/oneapi/setvars.sh
mkdir build
cd build
cmake .. \
cmake -B build \
-DGGML_SYCL=ON \
-DCMAKE_C_COMPILER=icx \
-DCMAKE_CXX_COMPILER=icpx
cmake --build . --config Release -j $(nproc)
cmake --build build --config Release -j $(nproc)
ubuntu-22-cmake-sycl-fp16:
runs-on: ubuntu-22.04
@ -543,14 +559,12 @@ jobs:
id: cmake_build
run: |
source /opt/intel/oneapi/setvars.sh
mkdir build
cd build
cmake .. \
cmake -B build \
-DGGML_SYCL=ON \
-DCMAKE_C_COMPILER=icx \
-DCMAKE_CXX_COMPILER=icpx \
-DGGML_SYCL_F16=ON
cmake --build . --config Release -j $(nproc)
cmake --build build --config Release -j $(nproc)
macOS-latest-cmake-ios:
runs-on: macos-latest
@ -576,9 +590,7 @@ jobs:
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -G Xcode .. \
cmake -B build -G Xcode \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DLLAMA_BUILD_EXAMPLES=OFF \
@ -587,7 +599,7 @@ jobs:
-DCMAKE_SYSTEM_NAME=iOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
macOS-latest-cmake-tvos:
runs-on: macos-latest
@ -613,9 +625,7 @@ jobs:
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -G Xcode .. \
cmake -B build -G Xcode \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DLLAMA_BUILD_EXAMPLES=OFF \
@ -624,7 +634,7 @@ jobs:
-DCMAKE_SYSTEM_NAME=tvOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
macOS-latest-swift:
runs-on: macos-latest
@ -654,17 +664,15 @@ jobs:
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -G Xcode .. \
cmake -B build -G Xcode \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64"
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
sudo cmake --install . --config Release
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
sudo cmake --install build --config Release
- name: xcodebuild for swift package
id: xcodebuild
@ -689,6 +697,7 @@ jobs:
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: windows-msys2
variant: sccache
evict-old-files: 1d
- name: Setup ${{ matrix.sys }}
@ -763,6 +772,7 @@ jobs:
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: windows-latest-cmake-${{ matrix.build }}
variant: sccache
evict-old-files: 1d
- name: Clone Kompute submodule
@ -804,21 +814,19 @@ jobs:
run: |
git clone https://github.com/KhronosGroup/OpenCL-Headers
cd OpenCL-Headers
mkdir build && cd build
cmake .. `
cmake -B build `
-DBUILD_TESTING=OFF `
-DOPENCL_HEADERS_BUILD_TESTING=OFF `
-DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF `
-DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release"
cmake --build . --target install
cmake --build build --target install
git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader
cd OpenCL-ICD-Loader
mkdir build-arm64-release && cd build-arm64-release
cmake .. `
cmake -B build-arm64-release `
-A arm64 `
-DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" `
-DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release"
cmake --build . --target install --config release
cmake --build build-arm64-release --target install --config release
- name: Build
id: cmake_build
@ -1026,6 +1034,7 @@ jobs:
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: ${{ github.job }}-${{ matrix.cuda }}-${{ matrix.build }}
variant: sccache
evict-old-files: 1d
- name: Install Cuda Toolkit 11.7
@ -1167,6 +1176,7 @@ jobs:
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: windows-latest-cmake-sycl
variant: sccache
evict-old-files: 1d
- name: Install
@ -1355,9 +1365,7 @@ jobs:
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -G Xcode .. \
cmake -B build -G Xcode \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DLLAMA_BUILD_EXAMPLES=OFF \
@ -1366,8 +1374,8 @@ jobs:
-DCMAKE_SYSTEM_NAME=iOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
sudo cmake --install . --config Release
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
sudo cmake --install build --config Release
- name: xcodebuild for swift package
id: xcodebuild

View file

@ -17,7 +17,7 @@ jobs:
steps:
- uses: actions/stale@v5
with:
exempt-issue-labels: "refactor,help wanted,good first issue,research,bug"
exempt-issue-labels: "refactor,help wanted,good first issue,research,bug,roadmap"
days-before-issue-stale: 30
days-before-issue-close: 14
stale-issue-label: "stale"

83
AUTHORS
View file

@ -1,4 +1,4 @@
# date: Thu Nov 28 20:46:15 EET 2024
# date: Tue Feb 4 13:04:05 EET 2025
# this file is auto-generated by scripts/gen-authors.sh
0cc4m <picard12@live.de>
@ -20,6 +20,8 @@ Adithya Balaji <adithya.b94@gmail.com>
AdithyanI <adithyan.i4internet@gmail.com>
Adrian <smith.adriane@gmail.com>
Adrian Hesketh <a-h@users.noreply.github.com>
Adrien Gallouët <adrien@gallouet.fr>
Adrien Gallouët <angt@huggingface.co>
Ahmad Tameem <113388789+Tameem-10xE@users.noreply.github.com>
Ahmet Zeer <ahmed.zeer@std.yildiz.edu.tr>
AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com>
@ -55,6 +57,7 @@ Ananta Bastola <anantarajbastola@gmail.com>
Anas Ahouzi <112881240+aahouzi@users.noreply.github.com>
András Salamon <ott2@users.noreply.github.com>
Andreas (Andi) Kunar <andreask@msn.com>
Andreas Kieslinger <47689530+aendk@users.noreply.github.com>
Andrei <abetlen@gmail.com>
Andrew Canis <andrew.canis@gmail.com>
Andrew Downing <andrew2085@gmail.com>
@ -91,13 +94,17 @@ Ben Siraphob <bensiraphob@gmail.com>
Ben Williams <ben@719ben.com>
Benjamin Findley <39356821+Kartoffelsaft@users.noreply.github.com>
Benjamin Lecaillon <84293038+blecaillon@users.noreply.github.com>
Benson Wong <mostlygeek@gmail.com>
Bernat Vadell <hounter.caza@gmail.com>
Bernhard M. Wiedemann <githubbmwprimary@lsmod.de>
Bert Wagner <github@bertwagner.com>
Billel Mokeddem <billel.mokeddem.ml@gmail.com>
Bingan <70050083+binganao@users.noreply.github.com>
Bjarke Viksøe <164612031+bviksoe@users.noreply.github.com>
Bodo Graumann <mail@bodograumann.de>
Bono Lv <lvscar@users.noreply.github.com>
Borislav Stanimirov <b.stanimirov@abv.bg>
Borislav Stanimirov <b@ibob.bg>
Branden Butler <bwtbutler@hotmail.com>
Brandon Squizzato <35474886+bsquizz@users.noreply.github.com>
Brian <mofosyne@gmail.com>
@ -117,6 +124,7 @@ Casey Primozic <casey@cprimozic.net>
Casey Primozic <me@ameo.link>
CausalLM <148736309+CausalLM@users.noreply.github.com>
Cebtenzzre <cebtenzzre@gmail.com>
CentricStorm <CentricStorm@users.noreply.github.com>
Chad Brewbaker <crb002@gmail.com>
Changyeon Kim <cyzero.kim@samsung.com>
Chao Jiang <jc19chaoj@zoho.com>
@ -131,12 +139,15 @@ Chris Kuehl <ckuehl@ckuehl.me>
Christian Demsar <christian@github.email.demsar.us>
Christian Demsar <crasm@git.vczf.us>
Christian Falch <875252+chrfalch@users.noreply.github.com>
Christian Kastner <ckk@kvr.at>
Christian Kögler <ck3d@gmx.de>
Christian Köhnenkamp <cvk5@me.com>
Christian Zhou-Zheng <59622928+christianazinn@users.noreply.github.com>
Christopher Nielsen <62156882+mascguy@users.noreply.github.com>
Clark Saben <76020733+csaben@users.noreply.github.com>
Clint Herron <hanclinto@gmail.com>
Conrad Kramer <conrad@conradkramer.com>
Corentin REGAL <corentin.regal@gmail.com>
CrispStrobe <154636388+CrispStrobe@users.noreply.github.com>
Csaba Kecskemeti <csaba.kecskemeti@gmail.com>
Cuong Trinh Manh <nguoithichkhampha@gmail.com>
@ -176,6 +187,7 @@ Dibakar Gope <dibakar.gope@arm.com>
Didzis Gosko <didzis@users.noreply.github.com>
Diego Devesa <slarengh@gmail.com>
Diogo Teles Sant'Anna <diogoteles@google.com>
Djip007 <3705339+Djip007@users.noreply.github.com>
Djip007 <djip.perois@free.fr>
Don Mahurin <dmahurin@users.noreply.github.com>
DooWoong Lee (David) <manics99@naver.com>
@ -193,6 +205,7 @@ Edward Taylor <edeetee@gmail.com>
Elaine <elaine.zosa@gmail.com>
Elbios <141279586+Elbios@users.noreply.github.com>
Elton Kola <eltonkola@gmail.com>
Emreerdog <34742675+Emreerdog@users.noreply.github.com>
Engininja2 <139037756+Engininja2@users.noreply.github.com>
Equim <sayaka@ekyu.moe>
Eric Curtin <ecurtin@redhat.com>
@ -233,6 +246,7 @@ Fred Douglas <43351173+fredlas@users.noreply.github.com>
Frederik Vogel <Schaltfehler@users.noreply.github.com>
Gabe Goodhart <gabe.l.hart@gmail.com>
Gabe Goodhart <ghart@us.ibm.com>
Gaetan Bisson <gaetan@fenua.org>
GainLee <perfecter.gen@gmail.com>
Galunid <karolek1231456@gmail.com>
Gary Linscott <glinscott@gmail.com>
@ -249,6 +263,7 @@ Guillaume "Vermeille" Sanchez <Guillaume.V.Sanchez@gmail.com>
Guillaume Wenzek <gwenzek@users.noreply.github.com>
Guoliang Hua <32868157+nbcsm@users.noreply.github.com>
Guoteng <32697156+SolenoidWGT@users.noreply.github.com>
Guspan Tanadi <36249910+guspan-tanadi@users.noreply.github.com>
Gustavo Rocha Dias <91472747+gustrd@users.noreply.github.com>
Haggai Nuchi <h.nuchi@gmail.com>
Halalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com>
@ -259,11 +274,13 @@ Haoxiang Fei <tonyfettes@tonyfettes.com>
Harald Fernengel <harald.fernengel@here.com>
Hatsune Miku <129688334+at8u@users.noreply.github.com>
HatsuneMikuUwU33 <173229399+HatsuneMikuUwU33@users.noreply.github.com>
Haus1 <haus.xda@gmail.com>
Henk Poley <HenkPoley@gmail.com>
Henri Vasserman <henv@hot.ee>
Henrik Forstén <henrik.forsten@gmail.com>
Herman Semenov <GermanAizek@yandex.ru>
Hesen Peng <hesen.peng@gmail.com>
HimariO <dsfhe49854@gmail.com>
Hoang Nguyen <hugo53@users.noreply.github.com>
Hong Bo PENG <penghb@cn.ibm.com>
Hongyu Ouyang <96765450+casavaca@users.noreply.github.com>
@ -280,6 +297,7 @@ Icecream95 <the.real.icecream95@gmail.com>
Ido S <ido.pluto@gmail.com>
IgnacioFDM <ignaciofdm@gmail.com>
Igor Okulist <okigan@gmail.com>
Ihar Hrachyshka <ihrachys@redhat.com>
Ikko Eltociear Ashimine <eltociear@gmail.com>
Ilya Kurdyukov <59548320+ilyakurdyukov@users.noreply.github.com>
Ionoclast Laboratories <brigham@ionoclast.com>
@ -289,12 +307,14 @@ Ivan <nekotekina@gmail.com>
Ivan Filipov <159561759+vanaka11@users.noreply.github.com>
Ivan Komarov <Ivan.Komarov@dfyz.info>
Ivan Stepanov <ivanstepanovftw@gmail.com>
JFLFY2255 <JFLFY2255@163.com>
JH23X <165871467+JH23X@users.noreply.github.com>
Jack Mousseau <jack@software.inc>
Jack Mousseau <jmousseau@users.noreply.github.com>
JackJollimore <130917767+JackJollimore@users.noreply.github.com>
Jaeden Amero <jaeden@patater.com>
Jaemin Son <woalsdnd@gmail.com>
Jafar Uruç <jafar.uruc@gmail.com>
Jag Chadha <jagtesh@gmail.com>
Jakub N <jakubniemczyk97@gmail.com>
James A Capozzoli <157492257+jac-jim@users.noreply.github.com>
@ -315,6 +335,7 @@ Jeffrey Morgan <jmorganca@gmail.com>
Jeffrey Quesnelle <emozilla@nousresearch.com>
Jeroen Mostert <jeroen.mostert@cm.com>
Jesse Jojo Johnson <williamsaintgeorge@gmail.com>
Jett Janiak <jettjaniak@gmail.com>
Jeximo <jeximo@gmail.com>
Jhen-Jie Hong <iainst0409@gmail.com>
Jiahao Li <liplus17@163.com>
@ -343,6 +364,7 @@ Josh Ramer <josh.ramer@icloud.com>
Joyce <joycebrum@google.com>
Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
Judd <foldl@users.noreply.github.com>
Juk Armstrong <69222624+jukofyork@users.noreply.github.com>
Julius Arkenberg <arki05@users.noreply.github.com>
Jun Hee Yoo <contact.jhyoo@gmail.com>
Jun Jie <71215065+junnjiee16@users.noreply.github.com>
@ -357,6 +379,7 @@ Justine Tunney <jtunney@mozilla.com>
Juuso Alasuutari <juuso.alasuutari@gmail.com>
KASR <karim.asrih@gmail.com>
Kamil Tomšík <info@tomsik.cz>
Karol Kontny <82021046+kkontny@users.noreply.github.com>
Karsten Weiss <knweiss@gmail.com>
Karthick <j.karthic2004@gmail.com>
Karthik Kumar Viswanathan <195178+guilt@users.noreply.github.com>
@ -376,6 +399,7 @@ Kolen Cheung <ickc@users.noreply.github.com>
Konstantin Herud <konstantin.herud@denkbares.com>
Konstantin Zhuravlyov <konstantin.zhuravlyov@amd.com>
Kunshang Ji <kunshang.ji@intel.com>
Kyle Bruene <KyleBruene@users.noreply.github.com>
Kyle Liang <liangmanlai@gmail.com>
Kyle Mistele <kyle@mistele.com>
Kylin <56434533+KyL0N@users.noreply.github.com>
@ -394,6 +418,7 @@ Liu Jia <jia3.liu@intel.com>
LoganDark <github@logandark.mozmail.com>
Loïc Carrère <loic.carrere@gmail.com>
LostRuins <39025047+LostRuins@users.noreply.github.com>
LostRuins Concedo <39025047+LostRuins@users.noreply.github.com>
Luciano <lucianostrika44@gmail.com>
Luo Tian <lt@basecity.com>
Lyle Dean <dean@lyle.dev>
@ -423,6 +448,7 @@ MasterYi1024 <39848311+MasterYi1024@users.noreply.github.com>
Mateusz Charytoniuk <mateusz.charytoniuk@protonmail.com>
Matheus C. França <matheus-catarino@hotmail.com>
Matheus Gabriel Alves Silva <matheusgasource@gmail.com>
Mathieu Baudier <mbaudier@argeo.org>
Mathieu Geli <mathieu.geli@gmail.com>
Mathieu Nayrolles <MathieuNls@users.noreply.github.com>
Mathijs Henquet <mathijs.henquet@gmail.com>
@ -444,6 +470,7 @@ Meng, Hengyu <hengyu.meng@intel.com>
Mengqing Cao <cmq0113@163.com>
Merrick Christensen <merrick.christensen@gmail.com>
Michael Coppola <m18coppola@gmail.com>
Michael Engel <mengel@redhat.com>
Michael Francis <edude03@gmail.com>
Michael Hueschen <m@mhueschen.dev>
Michael Kesper <mkesper@schokokeks.org>
@ -452,7 +479,9 @@ Michael Podvitskiy <podvitskiymichael@gmail.com>
Michael Potter <NanoTekGuy@Gmail.com>
Michael de Gans <michael.john.degans@gmail.com>
Michaël de Vries <vriesdemichael@gmail.com>
Michał Moskal <michal@moskal.me>
Michał Tuszyński <srgtuszy@gmail.com>
Michelle Tan <41475767+MichelleTanPY@users.noreply.github.com>
Mihai <mihai.chirculescu@yahoo.com>
Mike <ytianhui2004@gmail.com>
Mikko Juola <mikjuo@gmail.com>
@ -477,6 +506,7 @@ Neo Zhang <14088817+arthw@users.noreply.github.com>
Neo Zhang <zhang.jianyu@outlook.com>
Neo Zhang Jianyu <jianyu.zhang@intel.com>
Neuman Vong <neuman.vong@gmail.com>
NeverLucky <92274250+nvrxq@users.noreply.github.com>
Nexes the Old <124105151+Nexesenex@users.noreply.github.com>
Nexesenex <124105151+Nexesenex@users.noreply.github.com>
Niall Coates <1349685+Niall-@users.noreply.github.com>
@ -484,11 +514,15 @@ Nicholai Tukanov <nicholaitukanov@gmail.com>
Nico Bosshard <nico@bosshome.ch>
Nicolai Weitkemper <kontakt@nicolaiweitkemper.de>
Nicolás Pérez <nicolas_perez@brown.edu>
Nicolò Scipione <nicolo.scipione@codeplay.com>
Nigel Bosch <pnigelb@gmail.com>
Nikita Sarychev <42014488+sARY77@users.noreply.github.com>
Niklas Korz <niklas@niklaskorz.de>
NikolaiLyssogor <59844691+NikolaiLyssogor@users.noreply.github.com>
Nikolaos Pothitos <pothitos@di.uoa.gr>
Nikolas <127742645+nneubacher@users.noreply.github.com>
Nindaleth <Nindaleth@users.noreply.github.com>
Nuno <rare-magma@posteo.eu>
OSecret <135510162+OLSecret@users.noreply.github.com>
Oleksandr Nikitin <oleksandr@tvori.info>
Oleksii Maryshchenko <oleksii.maryshchenko@gmail.com>
@ -504,6 +538,7 @@ Pavel Zloi <github.com@drteam.rocks>
Pavol Rusnak <pavol@rusnak.io>
Paweł Wodnicki <151604+32bitmicro@users.noreply.github.com>
Pedro Cuenca <pedro@huggingface.co>
Peter <peter277@users.noreply.github.com>
Peter Sugihara <peter@campsh.com>
Phil H <5756783+phiharri@users.noreply.github.com>
Philip Taron <philip.taron@gmail.com>
@ -529,9 +564,12 @@ Rand Xie <randxiexyy29@gmail.com>
Randall Fitzgerald <randall@dasaku.net>
Random Fly <renfei8@live.cn>
Reinforce-II <fate@eastal.com>
Rémy Oudompheng <oudomphe@phare.normalesup.org>
Ren Xuancheng <jklj077@users.noreply.github.com>
Rene Leonhardt <65483435+reneleonhardt@users.noreply.github.com>
Reza Kakhki <rezakakhki.de@gmail.com>
RhinoDevel <RhinoDevel@users.noreply.github.com>
Riccardo Orlando <Riccorl@users.noreply.github.com>
Riceball LEE <snowyu.lee@gmail.com>
Rich Dougherty <rich@rd.nz>
Richard Kiss <him@richardkiss.com>
@ -544,6 +582,8 @@ Riley Stewart <ristew@users.noreply.github.com>
Rinne <AsakusaRinne@gmail.com>
Rinne <liu_yaohui1998@126.com>
Robert Brisita <986796+rbrisita@users.noreply.github.com>
Robert Collins <roberto.tomas.cuentas@gmail.com>
Robert Ormandi <52251610+ormandi@users.noreply.github.com>
Robert Sung-wook Shin <edp1096@users.noreply.github.com>
Robey Holderith <robey@flaminglunchbox.net>
Robyn <robyngraf@users.noreply.github.com>
@ -559,7 +599,9 @@ Roni <sulpher@gmx.net>
Ronny Brendel <ronnybrendel@gmail.com>
Ronsor <ronsor@ronsor.pw>
Rowan Hart <rowanbhart@gmail.com>
Ruan <47767371+ruanych@users.noreply.github.com>
Ruchira Hasaranga <ruchira66@gmail.com>
Rudi Servo <rudiservo@gmail.com>
Ruixin Huang <18860020911@163.com>
Rune <43761327+Rune-AI@users.noreply.github.com>
RunningLeon <maningsheng@sensetime.com>
@ -623,12 +665,14 @@ Steven Roussey <sroussey@gmail.com>
Steward Garcia <57494570+FSSRepo@users.noreply.github.com>
StrangeBytesDev <141275258+StrangeBytesDev@users.noreply.github.com>
Suaj Carrot <72162667+SuajCarrot@users.noreply.github.com>
Sukriti Sharma <Ssukriti@users.noreply.github.com>
SuperUserNameMan <yoann@terminajones.com>
Sutou Kouhei <kou@cozmixng.org>
Tai Duc Nguyen <taiducnguyen.drexel@gmail.com>
Taikono-Himazin <kazu@po.harenet.ne.jp>
Tameem <113388789+AhmadTameem@users.noreply.github.com>
Tamotsu Takahashi <ttakah+github@gmail.com>
Tei Home <taiteitonghome@proton.me>
Thái Hoàng Tâm <75922889+RoyalHeart@users.noreply.github.com>
Thatcher Chamberlin <j.thatcher.c@gmail.com>
Theia Vogel <theia@vgel.me>
@ -640,6 +684,7 @@ Tim Miller <drasticactions@users.noreply.github.com>
Tim Wang <overocean@gmail.com>
Timmy Knight <r2d2fish@gmail.com>
Timothy Cronin <40186632+4imothy@users.noreply.github.com>
Ting Lou <louting@189.cn>
Ting Lou <ting.lou@gmail.com>
Ting Sun <suntcrick@gmail.com>
Tobias Lütke <tobi@shopify.com>
@ -661,6 +706,7 @@ Uzo Nweke <uzoechi@gmail.com>
Vaibhav Srivastav <vaibhavs10@gmail.com>
Val Kharitonov <mail@kharvd.com>
Valentin Konovalov <valle.ketsujin@gmail.com>
Valentin Mamedov <45292985+Inf1delis@users.noreply.github.com>
Valentyn Bezshapkin <61702053+valentynbez@users.noreply.github.com>
Vali Malinoiu <0x4139@gmail.com>
Victor Nogueira <felladrin@gmail.com>
@ -673,13 +719,17 @@ Vladimir Malyutin <first-leon@yandex.ru>
Vladimir Zorin <vladimir@deviant.guru>
VoidIsVoid <343750470@qq.com>
Volodymyr Vitvitskyi <72226+signalpillar@users.noreply.github.com>
Wang Qin <37098874+wangqin0@users.noreply.github.com>
Wang Ran (汪然) <wangr@smail.nju.edu.cn>
WangHaoranRobin <56047610+WangHaoranRobin@users.noreply.github.com>
Weird Constructor <weirdconstructor@gmail.com>
Welby Seely <welbyseely@gmail.com>
Wentai Zhang <rchardx@gmail.com>
WillCorticesAI <150854901+WillCorticesAI@users.noreply.github.com>
William Tambellini <william.tambellini@gmail.com>
William Tambellini <wtambellini@sdl.com>
Willy Tarreau <w@1wt.eu>
Woof Dog <197125663+woof-dog@users.noreply.github.com>
Wouter <9594229+DifferentialityDevelopment@users.noreply.github.com>
Wu Jian Ping <wujjpp@hotmail.com>
Wu Jian Ping <wujp@greatld.com>
@ -692,6 +742,7 @@ Xie Yanbo <xieyanbo@gmail.com>
Xingchen Song(宋星辰) <xingchensong1996@163.com>
Xinpeng Dou <81913537+Dou-Git@users.noreply.github.com>
Xuan Son Nguyen <thichthat@gmail.com>
Xuan-Son Nguyen <thichthat@gmail.com>
Yaiko <elyaiko@hotmail.com>
Yann Follet <131855179+YannFollet@users.noreply.github.com>
Yaroslav <yaroslav.yashin@me.com>
@ -702,7 +753,9 @@ Yoshi Suhara <y.suhara@gmail.com>
Yoshi Suhara <ysuhara@nvidia.com>
Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Yueh-Po Peng <94939112+y10ab1@users.noreply.github.com>
Yüg <eugeniosegalaweb@gmail.com>
Yui <dev@sleepyyui.com>
Yun Dou <dixyes@gmail.com>
Yuri Khrustalev <ykhrustalev@users.noreply.github.com>
Yusuf Kağan Hanoğlu <hanoglu@yahoo.com>
Yuval Peled <31162840+Yuval-Peled@users.noreply.github.com>
@ -714,18 +767,23 @@ Zhang Peiyuan <a1286225768@gmail.com>
Zheng.Deng <32841220+dengzheng-cloud@users.noreply.github.com>
Zhenwei Jin <109658203+kylo5aby@users.noreply.github.com>
Zhiyuan Li <lizhiyuan@uniartisan.com>
Zhiyuan Li <uniartisan2017@gmail.com>
ZhouYuChen <zhouyuchen@naver.com>
Ziad Ben Hadj-Alouane <zied.benhadjalouane@gmail.com>
Ziang Wu <97337387+ZiangWu-77@users.noreply.github.com>
Zsapi <martin1.zsapka@gmail.com>
a-n-n-a-l-e-e <150648636+a-n-n-a-l-e-e@users.noreply.github.com>
a3sh <38979186+A3shTnT@users.noreply.github.com>
adel boussaken <netdur@gmail.com>
afrideva <95653597+afrideva@users.noreply.github.com>
ag2s20150909 <19373730+ag2s20150909@users.noreply.github.com>
agray3 <agray3@users.noreply.github.com>
akawrykow <142945436+akawrykow@users.noreply.github.com>
alek3y <44779186+alek3y@users.noreply.github.com>
alexpinel <93524949+alexpinel@users.noreply.github.com>
alonfaraj <alonfaraj@gmail.com>
alwqx <kenan3015@gmail.com>
amd-dwang <dong.wang@amd.com>
amd-lalithnc <lalithnc@amd.com>
amritahs-ibm <amritahs@linux.vnet.ibm.com>
andrijdavid <david@geek.mg>
@ -737,6 +795,7 @@ arch-btw <57669023+arch-btw@users.noreply.github.com>
arcrank <arcrank@gmail.com>
ardfork <134447697+ardfork@users.noreply.github.com>
arlo-phoenix <140345165+arlo-phoenix@users.noreply.github.com>
aryantandon01 <80969509+aryantandon01@users.noreply.github.com>
at8u <129688334+at8u@users.noreply.github.com>
automaticcat <daogiatuank54@gmail.com>
awatuna <23447591+awatuna@users.noreply.github.com>
@ -751,12 +810,14 @@ bryanSwk <93190252+bryanSwk@users.noreply.github.com>
bsilvereagle <bsilvereagle@users.noreply.github.com>
bssrdf <merlintiger@hotmail.com>
byte-6174 <88070277+byte-6174@users.noreply.github.com>
cduk <19917266+cduk@users.noreply.github.com>
cebtenzzre <cebtenzzre@gmail.com>
chaihahaha <chai836275709@gmail.com>
chiranko <96988916+chiranko@users.noreply.github.com>
clibdev <52199778+clibdev@users.noreply.github.com>
clyang <clyang@clyang.net>
cocktailpeanut <121128867+cocktailpeanut@users.noreply.github.com>
codezjx <code.zjx@gmail.com>
coezbek <c.oezbek@gmail.com>
comex <comexk@gmail.com>
compilade <113953597+compilade@users.noreply.github.com>
@ -780,14 +841,17 @@ drbh <david.richard.holtz@gmail.com>
ds5t5 <145942675+ds5t5@users.noreply.github.com>
dylan <canardleteer@users.noreply.github.com>
eastriver <lee@eastriver.dev>
ebraminio <ebrahim@gnu.org>
ebraminio <ebraminio@gmail.com>
eiery <19350831+eiery@users.noreply.github.com>
eric8607242 <e0928021388@gmail.com>
fairydreaming <166155368+fairydreaming@users.noreply.github.com>
fengerhu1 <2748250768@qq.com>
fj-y-saito <85871716+fj-y-saito@users.noreply.github.com>
fraxy-v <65565042+fraxy-v@users.noreply.github.com>
github-actions[bot] <github-actions[bot]@users.noreply.github.com>
gliptic <gliptic@users.noreply.github.com>
gn64 <yukikaze.jp@gmail.com>
goerch <jhr.walter@t-online.de>
grahameth <96447521+grahameth@users.noreply.github.com>
gtygo <gtydoit@gmail.com>
@ -812,10 +876,12 @@ icppWorld <124377669+icppWorld@users.noreply.github.com>
igarnier <igarnier@protonmail.com>
intelmatt <61025942+intelmatt@users.noreply.github.com>
iohub <rickyang.pro@gmail.com>
issixx <46835150+issixx@users.noreply.github.com>
jacobi petrucciani <8117202+jpetrucciani@users.noreply.github.com>
jaime-m-p <167997752+jaime-m-p@users.noreply.github.com>
jameswu2014 <545426914@qq.com>
jdomke <28772296+jdomke@users.noreply.github.com>
jiahao su <damow890@gmail.com>
jiez <373447296@qq.com>
jneem <joeneeman@gmail.com>
joecryptotoo <80373433+joecryptotoo@users.noreply.github.com>
@ -828,6 +894,7 @@ junchao-loongson <68935141+junchao-loongson@users.noreply.github.com>
jwj7140 <32943891+jwj7140@users.noreply.github.com>
k.h.lai <adrian.k.h.lai@outlook.com>
kaizau <kaizau@users.noreply.github.com>
kallewoof <kalle.alm@gmail.com>
kalomaze <66376113+kalomaze@users.noreply.github.com>
kang <tpdns9032100@gmail.com>
katsu560 <118887472+katsu560@users.noreply.github.com>
@ -835,6 +902,7 @@ kchro3 <62481661+kchro3@users.noreply.github.com>
khimaros <me@khimaros.com>
kiltyj <kiltyj@gmail.com>
klosax <131523366+klosax@users.noreply.github.com>
krystiancha <krystian@krystianch.com>
kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
kunnis <kunnis@users.noreply.github.com>
kuronekosaiko <EvanChanJ@163.com>
@ -847,6 +915,8 @@ ldwang <ftgreat@163.com>
le.chang <cljs118@126.com>
leejet <leejet714@gmail.com>
leo-pony <nengjunma@outlook.com>
lexasub <lexakopp2212@gmail.com>
lhez <quic_lih@quicinc.com>
limitedAtonement <limitedAtonement@users.noreply.github.com>
liuwei-git <14815172+liuwei-git@users.noreply.github.com>
lon <114724657+longregen@users.noreply.github.com>
@ -855,10 +925,13 @@ ltoniazzi <61414566+ltoniazzi@users.noreply.github.com>
luoyu-intel <yu.luo@intel.com>
m3ndax <adrian.goessl@outlook.com>
maddes8cht <55592906+maddes8cht@users.noreply.github.com>
mahorozte <41834471+mahorozte@users.noreply.github.com>
makomk <makosoft@googlemail.com>
manikbhandari <mbbhandarimanik2@gmail.com>
maor-ps <154728172+maor-ps@users.noreply.github.com>
mashdragon <122402293+mashdragon@users.noreply.github.com>
matiaslin <45382001+matiaslin@users.noreply.github.com>
matt23654 <matthew.webber@protonmail.com>
matteo <matteogeniaccio@yahoo.it>
mdrokz <mohammadmunshi@gmail.com>
mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com>
@ -868,6 +941,7 @@ mmyjona <jonathan.gonse@gmail.com>
momonga <115213907+mmnga@users.noreply.github.com>
momonga <146910567+mmngays@users.noreply.github.com>
moritzbrantner <31051084+moritzbrantner@users.noreply.github.com>
musoles <135031143+musoles@users.noreply.github.com>
mzcu <milos.cubrilo@gmail.com>
nanahi <130121847+na-na-hi@users.noreply.github.com>
ngc92 <7938269+ngc92@users.noreply.github.com>
@ -885,6 +959,7 @@ oobabooga <112222186+oobabooga@users.noreply.github.com>
opparco <parco.opaai@gmail.com>
ostix360 <55257054+ostix360@users.noreply.github.com>
pculliton <phillipculliton@gmail.com>
peidaqi <peidaqi@gmail.com>
pengxin99 <pengxin.yuan@intel.com>
perserk <perserk@gmail.com>
piDack <104877312+piDack@users.noreply.github.com>
@ -892,10 +967,12 @@ pmysl <piotr.myslinski@outlook.com>
postmasters <namnguyen@google.com>
pudepiedj <pudepiedj@gmail.com>
qingfengfenga <41416092+qingfengfenga@users.noreply.github.com>
qingy1337 <qxli2@students.everettcc.edu>
qouoq <qouoq@fastmail.com>
qunash <anzoria@gmail.com>
rabidcopy <rabidcopy@yahoo.com>
rankaiyx <rankaiyx@rankaiyx.com>
redbeard <bharrington@alticon.net>
rhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com>
rhuddleston <ryan.huddleston@percona.com>
rimoliga <53384203+rimoliga@users.noreply.github.com>
@ -912,6 +989,7 @@ sjxx <63994076+ylsdamxssjxxdd@users.noreply.github.com>
slaren <2141330+slaren@users.noreply.github.com>
slaren <slarengh@gmail.com>
snadampal <87143774+snadampal@users.noreply.github.com>
someone13574 <81528246+someone13574@users.noreply.github.com>
standby24x7 <standby24x7@gmail.com>
staviq <staviq@gmail.com>
stduhpf <stephduh@live.fr>
@ -931,6 +1009,7 @@ uint256_t <konndennsa@gmail.com>
uint256_t <maekawatoshiki1017@gmail.com>
unbounded <haakon@likedan.net>
uvos <devnull@uvos.xyz>
uvos <philipp@uvos.xyz>
valiray <133289098+valiray@users.noreply.github.com>
vb <vaibhavs10@gmail.com>
vik <vikhyatk@gmail.com>
@ -951,6 +1030,7 @@ xaedes <xaedes@googlemail.com>
xctan <axunlei@gmail.com>
xloem <0xloem@gmail.com>
yangli2 <yangli2@gmail.com>
ymcki <84055651+ymcki@users.noreply.github.com>
yuiseki <yuiseki@gmail.com>
yuri@FreeBSD <yurivict@users.noreply.github.com>
zakkor <edward.partenie@gmail.com>
@ -963,4 +1043,5 @@ zrm <trustiosity.zrm@gmail.com>
杨朱 · Kiki <baofa.fan@daocloud.io>
源文雨 <41315874+fumiama@users.noreply.github.com>
蕭澧邦 <45505768+shou692199@users.noreply.github.com>
谢乃闻 <sienaiwun@users.noreply.github.com>
Нияз Гарифзянов <112617865+garrnizon@users.noreply.github.com>

View file

@ -80,6 +80,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
# 3rd party libs
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
# Required for relocatable CMake package
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)

View file

@ -596,7 +596,7 @@ ifdef GGML_RPC
OBJ_GGML_EXT += ggml/src/ggml-rpc.o
endif # GGML_RPC
OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-wmma*.cu))
OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-mma*.cu))
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu))
ifdef GGML_CUDA_FA_ALL_QUANTS

View file

@ -96,7 +96,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [Bitnet b1.58 models](https://huggingface.co/1bitLLM)
- [x] [Flan T5](https://huggingface.co/models?search=flan-t5)
- [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca)
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) + [GLMEdge-1.5b](https://huggingface.co/THUDM/glm-edge-1.5b-chat) + [GLMEdge-4b](https://huggingface.co/THUDM/glm-edge-4b-chat)
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
@ -117,6 +117,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [Mini CPM](https://huggingface.co/models?search=MiniCPM)
- [x] [Moondream](https://huggingface.co/vikhyatk/moondream2)
- [x] [Bunny](https://github.com/BAAI-DCAI/Bunny)
- [x] [GLM-EDGE](https://huggingface.co/models?search=glm-edge)
- [x] [Qwen2-VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)
</details>
@ -135,6 +136,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- Rust (more features): [edgenai/llama_cpp-rs](https://github.com/edgenai/llama_cpp-rs)
- Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp)
- Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs)
- Rust (automated build from crates.io): [ShelbyJenkins/llm_client](https://github.com/ShelbyJenkins/llm_client)
- C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp)
- C#/VB.NET (more features - community license): [LM-Kit.NET](https://docs.lm-kit.com/lm-kit-net/index.html)
- Scala 3: [donderom/llm4s](https://github.com/donderom/llm4s)

View file

@ -65,6 +65,7 @@ add_library(${TARGET} STATIC
console.h
json-schema-to-grammar.cpp
json.hpp
llguidance.cpp
log.cpp
log.h
minja.hpp
@ -91,6 +92,33 @@ if (LLAMA_CURL)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})
endif ()
if (LLAMA_LLGUIDANCE)
include(ExternalProject)
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
ExternalProject_Add(llguidance_ext
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
# v0.6.12:
GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09
PREFIX ${CMAKE_BINARY_DIR}/llguidance
SOURCE_DIR ${LLGUIDANCE_SRC}
BUILD_IN_SOURCE TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND cargo build --release
INSTALL_COMMAND ""
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/libllguidance.a ${LLGUIDANCE_PATH}/llguidance.h
UPDATE_COMMAND ""
)
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE)
add_library(llguidance STATIC IMPORTED)
set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/libllguidance.a)
add_dependencies(llguidance llguidance_ext)
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance)
endif ()
target_include_directories(${TARGET} PUBLIC .)
target_compile_features (${TARGET} PUBLIC cxx_std_17)
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)

View file

@ -16,6 +16,7 @@ std::string common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
default:
throw std::runtime_error("Unknown chat format");
}
@ -317,6 +318,79 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
}
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
schemas.push_back({
{"type", "object"},
{"properties", {
{"tool_call_id", {
{"type", "string"},
// Command-R's template expects an integer string.
{"pattern", "^[0-9]{1,10}$"},
}},
{"tool_name", {
{"type", "string"},
{"const", function["name"]},
}},
{"parameters", function["parameters"]},
}},
{"required", json::array({"tool_call_id", "tool_name", "parameters"})},
});
});
auto schema = json {
{"type", "array"},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!inputs.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
}, grammar_options);
data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
data.preserved_tokens = {
"<|START_RESPONSE|>",
"<|END_RESPONSE|>",
"<|START_THINKING|>",
"<|END_THINKING|>",
"<|END_ACTION|>",
};
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
return data;
}
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
static std::regex response_regex("<\\|START_RESPONSE\\|>([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
std::smatch match;
common_chat_msg result;
result.role = "assistant";
if (std::regex_match(input, match, response_regex)) {
result.content = match[1].str();
} else if (std::regex_match(input, match, thought_action_regex)) {
result.tool_plan = match[1].str();
auto actions_str = match[2].str();
auto actions = json::parse(actions_str);
for (const auto & action : actions) {
result.tool_calls.push_back({
/* .name = */ action["tool_name"],
/* .arguments = */ action["parameters"].dump(),
/* .id = */ action["tool_call_id"],
});
}
} else {
LOG_ERR("Failed to parse command_r output");
result.content = input;
}
return result;
}
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
@ -462,6 +536,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n```json\\n\" " + args_rule + " \"```<tool▁call▁end>\""));
});
data.grammar_triggers.push_back({"<tool▁calls▁begin>", /* .at_start = */ false});
data.preserved_tokens = {
"<tool▁sep>",
"<tool▁call▁end>",
};
builder.add_rule("root", "\"<tool▁calls▁begin>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
}, grammar_options);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
@ -704,8 +782,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
// Not really a trigger but need to print this special token to get a successful parse.
data.grammar_triggers.push_back({"</tool_call>", /* .at_start = */ false});
data.preserved_tokens = { "</tool_call>" };
}, grammar_options);
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
@ -822,6 +899,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_params_init_mistral_nemo(tmpl, inputs);
}
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
return common_chat_params_init_command_r7b(tmpl, inputs);
}
return common_chat_params_init_generic(tmpl, inputs);
}
@ -855,6 +935,8 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
return common_chat_parse_hermes_2_pro(input);
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
return common_chat_parse_firefunction_v2(input);
case COMMON_CHAT_FORMAT_COMMAND_R7B:
return common_chat_parse_command_r7b(input);
default:
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
}

View file

@ -32,6 +32,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
@ -42,6 +43,7 @@ struct common_chat_params {
std::string grammar;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};

View file

@ -1869,11 +1869,19 @@ std::string common_chat_format_example(const common_chat_template & tmpl, bool u
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
}
#define CHATML_TEMPLATE_SRC \
"{%- for message in messages -%}\n" \
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
"{%- endfor -%}\n" \
"{%- if add_generation_prompt -%}\n" \
" {{- '<|im_start|>assistant\n' -}}\n" \
"{%- endif -%}"
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
{
auto vocab = llama_model_get_vocab(model);
std::string default_template_src = chat_template_override;
std::string template_tool_use_src = chat_template_override;
std::string default_template_src;
std::string template_tool_use_src;
bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
auto str = llama_model_chat_template(model, /* name */ nullptr);
@ -1886,21 +1894,21 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model
template_tool_use_src = str;
has_explicit_template = true;
}
} else {
default_template_src = chat_template_override;
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
} else {
default_template_src = R"(
{%- for message in messages -%}
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- "<|im_start|>assistant\n" -}}
{%- endif -%}
)";
default_template_src = CHATML_TEMPLATE_SRC;
}
}
std::string token_bos;
std::string token_eos;
// TODO: update logic that adds BOS and EOS tokens to the tokenized prompt, in favour of the template.
#if 0
auto vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
@ -1912,15 +1920,25 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model
return common_token_to_piece(vocab, token, true);
}
};
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
return {
has_explicit_template,
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
template_tool_use_src.empty()
? nullptr
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos)
};
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
#endif
try {
return {
has_explicit_template,
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
template_tool_use_src.empty()
? nullptr
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
};
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what());
return {
has_explicit_template,
std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
nullptr,
};
}
}
//

View file

@ -4,6 +4,7 @@
#include "llama-cpp.h"
#include <set>
#include <string>
#include <vector>
#include <sstream>
@ -163,6 +164,7 @@ struct common_params_sampling {
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
std::set<llama_token> preserved_tokens;
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
@ -621,6 +623,7 @@ struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_tool_call> tool_calls;
std::string tool_plan = "";
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

View file

@ -991,7 +991,14 @@ public:
}
};
std::string json_schema_to_grammar(const json & schema) {
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
#ifdef LLAMA_USE_LLGUIDANCE
if (!force_gbnf) {
return "%llguidance {}\nstart: %json " + schema.dump();
}
#else
(void)force_gbnf;
#endif // LLAMA_USE_LLGUIDANCE
return build_grammar([&](const common_grammar_builder & callbacks) {
auto copy = schema;
callbacks.resolve_refs(copy);

View file

@ -5,7 +5,8 @@
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
bool force_gbnf = false);
struct common_grammar_builder {
std::function<std::string(const std::string &, const std::string &)> add_rule;

270
common/llguidance.cpp Normal file
View file

@ -0,0 +1,270 @@
#include "sampling.h"
#include "log.h"
#ifdef LLAMA_USE_LLGUIDANCE
# include "llguidance.h"
# include <cmath>
struct llama_sampler_llg {
const llama_vocab * vocab;
std::string grammar_kind;
std::string grammar_data;
LlgTokenizer * tokenizer;
LlgConstraint * grammar;
LlgMaskResult llg_res;
bool has_llg_res;
};
static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
const char * grammar_data) {
LlgConstraintInit cinit;
llg_constraint_init_set_defaults(&cinit, tokenizer);
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
if (log_level && *log_level) {
cinit.log_stderr_level = atoi(log_level);
}
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
if (llg_get_error(c)) {
LOG_ERR("llg error: %s\n", llg_get_error(c));
llg_free_constraint(c);
return nullptr;
}
return c;
}
static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
return "llguidance";
}
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
LlgCommitResult res;
llg_commit_token(ctx->grammar, token, &res);
ctx->has_llg_res = false;
}
}
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
if (!ctx->has_llg_res) {
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
ctx->has_llg_res = true;
} else {
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
llg_free_constraint(ctx->grammar);
ctx->grammar = nullptr;
}
}
if (ctx->has_llg_res) {
if (ctx->llg_res.is_stop) {
for (size_t i = 0; i < cur_p->size; ++i) {
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
cur_p->data[i].logit = -INFINITY;
}
}
} else {
const uint32_t * mask = ctx->llg_res.sample_mask;
for (size_t i = 0; i < cur_p->size; ++i) {
auto token = cur_p->data[i].id;
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
cur_p->data[i].logit = -INFINITY;
}
}
}
}
}
}
static void llama_sampler_llg_reset(llama_sampler * smpl) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (!ctx->grammar) {
return;
}
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
llg_free_constraint(ctx->grammar);
ctx->grammar = grammar_new;
ctx->has_llg_res = false;
}
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
// copy the state
{
auto * result_ctx = (llama_sampler_llg *) result->ctx;
if (ctx->grammar) {
result_ctx->grammar_kind = ctx->grammar_kind;
result_ctx->grammar_data = ctx->grammar_data;
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
}
}
return result;
}
static void llama_sampler_llg_free(llama_sampler * smpl) {
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
llg_free_constraint(ctx->grammar);
llg_free_tokenizer(ctx->tokenizer);
}
delete ctx;
}
static llama_sampler_i llama_sampler_llg_i = {
/* .name = */ llama_sampler_llg_name,
/* .accept = */ llama_sampler_llg_accept_impl,
/* .apply = */ llama_sampler_llg_apply,
/* .reset = */ llama_sampler_llg_reset,
/* .clone = */ llama_sampler_llg_clone,
/* .free = */ llama_sampler_llg_free,
};
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
uint32_t * output_tokens, size_t output_tokens_len) {
const llama_vocab * vocab = (const llama_vocab *) user_data;
int r = 0;
try {
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
true);
} catch (const std::exception & e) {
GGML_ABORT("llama_tokenize failed: %s\n", e.what());
}
if (r < 0) {
return -r;
}
return r;
}
static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
// TODO store the tokenizer in the vocab somehow
static const llama_vocab * vocab_cache;
static LlgTokenizer * tokenizer_cache;
if (vocab_cache == vocab) {
return llg_clone_tokenizer(tokenizer_cache);
}
auto tok_eos = llama_vocab_eot(vocab);
if (tok_eos == LLAMA_TOKEN_NULL) {
tok_eos = llama_vocab_eos(vocab);
}
size_t vocab_size = llama_vocab_n_tokens(vocab);
auto token_lens = new uint32_t[vocab_size];
// we typically have ~7 bytes per token; let's go on the safe side here
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
auto token_bytes = new uint8_t[token_bytes_size];
size_t offset = 0;
for (size_t i = 0; i < vocab_size; i++) {
size_t max_token = 1024;
if (token_bytes_size - offset < max_token) {
GGML_ABORT("token_bytes buffer too small\n");
}
llama_token token = i;
auto dp = (char *) token_bytes + offset;
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
if (size < 0) {
GGML_ABORT("llama_detokenize failed\n");
}
if (size == 0) {
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
if (size < 0) {
GGML_ABORT("llama_detokenize failed\n");
}
if (size != 0) {
*dp = '\xff'; // special token prefix marker
size += 1;
}
}
token_lens[i] = size;
offset += size;
}
LlgTokenizerInit tinit = {
/* .vocab_size = */ (uint32_t) vocab_size,
/* .tok_eos = */ (uint32_t) tok_eos,
/* .token_lens = */ token_lens,
/* .token_bytes = */ token_bytes,
/* .tokenizer_json = */ nullptr,
/* .tokenize_assumes_string = */ true,
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
/* .use_approximate_greedy_tokenize_fn = */ false,
/* .tokenize_user_data = */ vocab,
};
char error_buffer[1024];
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
delete[] token_bytes;
delete[] token_lens;
if (tokenizer == nullptr) {
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
return tokenizer;
}
if (tokenizer_cache) {
llg_free_tokenizer(tokenizer_cache);
}
vocab_cache = vocab;
tokenizer_cache = tokenizer;
return llg_clone_tokenizer(tokenizer_cache);
}
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
const char * grammar_data) {
auto * ctx = new llama_sampler_llg;
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
*ctx = {
/* .vocab = */ vocab,
/* .grammar_kind = */ grammar_kind,
/* .grammar_data = */ grammar_data,
/* .tokenizer = */ tokenizer,
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
/* .llg_res = */ {},
/* .has_llg_res = */ false,
};
} else {
*ctx = {
/* .vocab = */ vocab,
/* .grammar_kind = */ {},
/* .grammar_data = */ {},
/* .tokenizer = */ nullptr,
/* .grammar = */ nullptr,
/* .llg_res = */ {},
/* .has_llg_res = */ false,
};
}
return new llama_sampler{
/* .iface = */ &llama_sampler_llg_i,
/* .ctx = */ ctx,
};
}
#else
llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
return nullptr;
}
#endif // LLAMA_USE_LLGUIDANCE

View file

@ -14,16 +14,6 @@ void common_log_set_verbosity_thold(int verbosity) {
common_log_verbosity_thold = verbosity;
}
#define LOG_COL_DEFAULT "\033[0m"
#define LOG_COL_BOLD "\033[1m"
#define LOG_COL_RED "\033[31m"
#define LOG_COL_GREEN "\033[32m"
#define LOG_COL_YELLOW "\033[33m"
#define LOG_COL_BLUE "\033[34m"
#define LOG_COL_MAGENTA "\033[35m"
#define LOG_COL_CYAN "\033[36m"
#define LOG_COL_WHITE "\033[37m"
static int64_t t_us() {
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
}

View file

@ -2,6 +2,16 @@
#include "ggml.h" // for ggml_log_level
#define LOG_COL_DEFAULT "\033[0m"
#define LOG_COL_BOLD "\033[1m"
#define LOG_COL_RED "\033[31m"
#define LOG_COL_GREEN "\033[32m"
#define LOG_COL_YELLOW "\033[33m"
#define LOG_COL_BLUE "\033[34m"
#define LOG_COL_MAGENTA "\033[35m"
#define LOG_COL_CYAN "\033[36m"
#define LOG_COL_WHITE "\033[37m"
#ifndef __GNUC__
# define LOG_ATTRIBUTE_FORMAT(...)
#elif defined(__MINGW32__)

View file

@ -693,7 +693,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };
class TemplateToken {
public:
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter };
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
static std::string typeToString(Type t) {
switch (t) {
@ -714,6 +714,8 @@ public:
case Type::EndFilter: return "endfilter";
case Type::Generation: return "generation";
case Type::EndGeneration: return "endgeneration";
case Type::Break: return "break";
case Type::Continue: return "continue";
}
return "Unknown";
}
@ -815,6 +817,22 @@ struct CommentTemplateToken : public TemplateToken {
CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {}
};
enum class LoopControlType { Break, Continue };
class LoopControlException : public std::runtime_error {
public:
LoopControlType control_type;
LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {}
LoopControlException(LoopControlType control_type)
: std::runtime_error((control_type == LoopControlType::Continue ? "continue" : "break") + std::string(" outside of a loop")),
control_type(control_type) {}
};
struct LoopControlTemplateToken : public TemplateToken {
LoopControlType control_type;
LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {}
};
class TemplateNode {
Location location_;
protected:
@ -825,6 +843,12 @@ public:
void render(std::ostringstream & out, const std::shared_ptr<Context> & context) const {
try {
do_render(out, context);
} catch (const LoopControlException & e) {
// TODO: make stack creation lazy. Only needed if it was thrown outside of a loop.
std::ostringstream err;
err << e.what();
if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
throw LoopControlException(err.str(), e.control_type);
} catch (const std::exception & e) {
std::ostringstream err;
err << e.what();
@ -897,6 +921,15 @@ public:
}
};
class LoopControlNode : public TemplateNode {
LoopControlType control_type_;
public:
LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {}
void do_render(std::ostringstream &, const std::shared_ptr<Context> &) const override {
throw LoopControlException(control_type_);
}
};
class ForNode : public TemplateNode {
std::vector<std::string> var_names;
std::shared_ptr<Expression> iterable;
@ -961,7 +994,12 @@ public:
loop.set("last", i == (n - 1));
loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value());
loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value());
body->render(out, loop_context);
try {
body->render(out, loop_context);
} catch (const LoopControlException & e) {
if (e.control_type == LoopControlType::Break) break;
if (e.control_type == LoopControlType::Continue) continue;
}
}
}
};
@ -2159,7 +2197,7 @@ private:
static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})");
static std::regex expr_open_regex(R"(\{\{([-~])?)");
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
@ -2291,6 +2329,9 @@ private:
} else if (keyword == "endfilter") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<EndFilterTemplateToken>(location, pre_space, post_space));
} else if (keyword == "break" || keyword == "continue") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<LoopControlTemplateToken>(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue));
} else {
throw std::runtime_error("Unexpected block: " + keyword);
}
@ -2414,6 +2455,8 @@ private:
children.emplace_back(std::make_shared<FilterNode>(token->location, std::move(filter_token->filter), std::move(body)));
} else if (dynamic_cast<CommentTemplateToken*>(token.get())) {
// Ignore comments
} else if (auto ctrl_token = dynamic_cast<LoopControlTemplateToken*>(token.get())) {
children.emplace_back(std::make_shared<LoopControlNode>(token->location, ctrl_token->control_type));
} else if (dynamic_cast<EndForTemplateToken*>(token.get())
|| dynamic_cast<EndSetTemplateToken*>(token.get())
|| dynamic_cast<EndMacroTemplateToken*>(token.get())

View file

@ -156,13 +156,25 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
for (const auto & str : params.grammar_trigger_words) {
trigger_words.push_back(str.word.c_str());
}
struct llama_sampler * grmr;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
} else {
grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
trigger_words.data(), trigger_words.size(),
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
}
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ params.grammar_lazy
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
trigger_words.data(), trigger_words.size(),
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"),
/* .grmr = */ grmr,
/* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},

View file

@ -102,3 +102,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data);

View file

@ -648,7 +648,7 @@ class Model:
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
res = "jina-v2-code"
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b" or chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
@ -4513,7 +4513,7 @@ class JaisModel(Model):
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration")
@Model.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(Model):
model_arch = gguf.MODEL_ARCH.CHATGLM
@ -4619,47 +4619,15 @@ class ChatGLMModel(Model):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = hparams["padded_vocab_size"]
vocab_size = hparams.get("padded_vocab_size",hparams["vocab_size"])
assert max(tokenizer.get_vocab().values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
merges = []
vocab = {}
mergeable_ranks = tokenizer.mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
assert len(merged) >= 2 and len(merged) <= 7
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
added_vocab = tokenizer.get_added_vocab()
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
special_vocab.merges = merges
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
# only add special tokens when they were not already loaded from config.json
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
@ -4670,16 +4638,20 @@ class ChatGLMModel(Model):
def set_gguf_parameters(self):
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
n_head_kv = self.hparams.get("multi_query_group_num", self.hparams.get("num_key_value_heads", n_head))
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
self.gguf_writer.add_embedding_length(n_embed)
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed))
self.gguf_writer.add_block_count(self.hparams["num_layers"])
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed)))
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"]))
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5))
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_dimension_count(64)
if "attention_dim" in self.hparams:
rope_dim = self.hparams["attention_dim"]
else:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
self.gguf_writer.add_add_bos_token(False)
rope_freq = 10000
if "rope_ratio" in self.hparams:
@ -4689,7 +4661,7 @@ class ChatGLMModel(Model):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.endswith(".rotary_pos_emb.inv_freq"):
if name.endswith(".rotary_pos_emb.inv_freq") or name.startswith("model.vision."):
return []
name = name.removeprefix("transformer.")

51
docs/llguidance.md Normal file
View file

@ -0,0 +1,51 @@
# LLGuidance Support in llama.cpp
[LLGuidance](https://github.com/guidance-ai/llguidance) is a library for constrained decoding (also called constrained sampling or structured outputs) for Large Language Models (LLMs). Initially developed as the backend for the [Guidance](https://github.com/guidance-ai/guidance) library, it can also be used independently.
LLGuidance supports JSON Schemas and arbitrary context-free grammars (CFGs) written in a [variant](https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md) of Lark syntax. It is [very fast](https://github.com/guidance-ai/jsonschemabench/tree/main/maskbench) and has [excellent](https://github.com/guidance-ai/llguidance/blob/main/docs/json_schema.md) JSON Schema coverage but requires the Rust compiler, which complicates the llama.cpp build process.
## Building
To enable LLGuidance support, build llama.cpp with the `LLAMA_LLGUIDANCE` option:
```sh
cmake -B build -DLLAMA_LLGUIDANCE=ON
make -C build -j
```
This requires the Rust compiler and the `cargo` tool to be [installed](https://www.rust-lang.org/tools/install).
## Interface
There are no new command-line arguments or modifications to `common_params`. When enabled, grammars starting with `%llguidance` are passed to LLGuidance instead of the [current](../grammars/README.md) llama.cpp grammars. Additionally, JSON Schema requests (e.g., using the `-j` argument in `llama-cli`) are also passed to LLGuidance.
For your existing GBNF grammars, you can use [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/scripts/gbnf_to_lark.py) to convert them to LLGuidance Lark-like format.
## Performance
Computing a "token mask" (i.e., the set of allowed tokens) for a llama3 tokenizer with 128k tokens takes, on average, 50μs of single-core CPU time for the [JSON Schema Bench](https://github.com/guidance-ai/jsonschemabench). The p99 time is 0.5ms, and the p100 time is 20ms. These results are due to the lexer/parser split and several [optimizations](https://github.com/guidance-ai/llguidance/blob/main/docs/optimizations.md).
## JSON Schema
LLGuidance adheres closely to the JSON Schema specification. For example:
- `additionalProperties` defaults to `true`, unlike current grammars, though you can set `"additionalProperties": false` if needed.
- any whitespace is allowed.
- The definition order in the `"properties": {}` object is maintained, regardless of whether properties are required (current grammars always puts required properties first).
Unsupported schemas result in an error message—no keywords are silently ignored.
## Why Not Reuse GBNF Format?
GBNF lacks the concept of a lexer.
Most programming languages, including JSON, use a two-step process: a lexer (built with regular expressions) converts a byte stream into lexemes, which are then processed by a CFG parser. This approach is faster because lexers are cheaper to evaluate, and there is ~10x fewer lexemes than bytes.
LLM tokens often align with lexemes, so the parser is engaged in under 0.5% of tokens, with the lexer handling the rest.
However, the user has to provide the distinction between lexemes and CFG symbols. In [Lark](https://github.com/lark-parser/lark), lexeme names are uppercase, while CFG symbols are lowercase.
The [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/scripts/gbnf_to_lark.py) can often take care of this automatically.
See [LLGuidance syntax docs](https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#terminals-vs-rules) for more details.
## Error Handling
Errors are currently printed to `stderr`, and generation continues. Improved error handling may be added in the future.

View file

@ -31,6 +31,11 @@ defer {
llama_model_free(model)
}
guard let vocab = llama_model_get_vocab(model) else {
print("Failed to get vocab")
exit(1)
}
var tokens = tokenize(text: prompt, add_bos: true)
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
@ -41,7 +46,7 @@ context_params.n_batch = UInt32(max(n_len, n_parallel))
context_params.n_threads = 8
context_params.n_threads_batch = 8
let context = llama_new_context_with_model(model, context_params)
let context = llama_init_from_model(model, context_params)
guard context != nil else {
print("Failed to initialize context")
exit(1)
@ -141,7 +146,7 @@ while n_cur <= n_len {
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
// is it an end of stream? -> mark the stream as finished
if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
i_batch[i] = -1
// print("")
if n_parallel > 1 {
@ -207,7 +212,7 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
let utf8Count = text.utf8.count
let n_tokens = utf8Count + (add_bos ? 1 : 0)
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
var swiftTokens: [llama_token] = []
for i in 0 ..< tokenCount {
swiftTokens.append(tokens[Int(i)])
@ -218,12 +223,12 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
var result = [CChar](repeating: 0, count: 8)
let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count), 0, false)
let nTokens = llama_token_to_piece(vocab, token, &result, Int32(result.count), 0, false)
if nTokens < 0 {
let actualTokensCount = -Int(nTokens)
result = .init(repeating: 0, count: actualTokensCount)
let check = llama_token_to_piece(
model,
vocab,
token,
&result,
Int32(result.count),

View file

@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
actor LlamaContext {
private var model: OpaquePointer
private var context: OpaquePointer
private var vocab: OpaquePointer
private var sampling: UnsafeMutablePointer<llama_sampler>
private var batch: llama_batch
private var tokens_list: [llama_token]
@ -47,6 +48,7 @@ actor LlamaContext {
self.sampling = llama_sampler_chain_init(sparams)
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
vocab = llama_model_get_vocab(model)
}
deinit {
@ -79,7 +81,7 @@ actor LlamaContext {
ctx_params.n_threads = Int32(n_threads)
ctx_params.n_threads_batch = Int32(n_threads)
let context = llama_new_context_with_model(model, ctx_params)
let context = llama_init_from_model(model, ctx_params)
guard let context else {
print("Could not load context!")
throw LlamaError.couldNotInitializeContext
@ -151,7 +153,7 @@ actor LlamaContext {
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
print("\n")
is_done = true
let new_token_str = String(cString: temporary_invalid_cchars + [0])
@ -297,7 +299,7 @@ actor LlamaContext {
let utf8Count = text.utf8.count
let n_tokens = utf8Count + (add_bos ? 1 : 0) + 1
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
var swiftTokens: [llama_token] = []
for i in 0..<tokenCount {
@ -316,7 +318,7 @@ actor LlamaContext {
defer {
result.deallocate()
}
let nTokens = llama_token_to_piece(model, token, result, 8, 0, false)
let nTokens = llama_token_to_piece(vocab, token, result, 8, 0, false)
if nTokens < 0 {
let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens))
@ -324,7 +326,7 @@ actor LlamaContext {
defer {
newResult.deallocate()
}
let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, 0, false)
let nNewTokens = llama_token_to_piece(vocab, token, newResult, -nTokens, 0, false)
let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
return Array(bufferPointer)
} else {

View file

@ -0,0 +1,43 @@
# GLMV-EDGE
Currently this implementation supports [glm-edge-v-2b](https://huggingface.co/THUDM/glm-edge-v-2b) and [glm-edge-v-5b](https://huggingface.co/THUDM/glm-edge-v-5b).
## Usage
Build with cmake or run `make llama-llava-cli` to build it.
After building, run: `./llama-llava-cli` to see the usage. For example:
```sh
./llama-llava-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf --image img_path/image.jpg -p "<|system|>\n system prompt <image><|user|>\n prompt <|assistant|>\n"
```
**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
**note**: For GPU offloading ensure to use the `-ngl` flag just like usual
## GGUF conversion
1. Clone a GLMV-EDGE model ([2B](https://huggingface.co/THUDM/glm-edge-v-2b) or [5B](https://huggingface.co/THUDM/glm-edge-v-5b)). For example:
```sh
git clone https://huggingface.co/THUDM/glm-edge-v-5b or https://huggingface.co/THUDM/glm-edge-v-2b
```
2. Use `glmedge-surgery.py` to split the GLMV-EDGE model to LLM and multimodel projector constituents:
```sh
python ./examples/llava/glmedge-surgery.py -m ../model_path
```
4. Use `glmedge-convert-image-encoder-to-gguf.py` to convert the GLMV-EDGE image encoder to GGUF:
```sh
python ./examples/llava/glmedge-convert-image-encoder-to-gguf.py -m ../model_path --llava-projector ../model_path/glm.projector --output-dir ../model_path
```
5. Use `examples/convert_hf_to_gguf.py` to convert the LLM part of GLMV-EDGE to GGUF:
```sh
python convert_hf_to_gguf.py ../model_path
```
Now both the LLM part and the image encoder are in the `model_path` directory.

View file

@ -102,6 +102,7 @@ static std::string format(const char * fmt, ...) {
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
#define KEY_HAS_GLM_PROJ "clip.has_glm_projector"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
#define KEY_USE_GELU "clip.use_gelu"
@ -160,6 +161,15 @@ static std::string format(const char * fmt, ...) {
#define TN_MINICPMV_ATTN "resampler.attn.%s.%s"
#define TN_MINICPMV_LN "resampler.ln_%s.%s"
#define TN_GLM_ADAPER_CONV "adapter.conv.%s"
#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s"
#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s"
#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s"
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
#define TN_GLM_BOI_W "adapter.boi"
#define TN_GLM_EOI_W "adapter.eoi"
enum projector_type {
PROJECTOR_TYPE_MLP,
@ -167,6 +177,7 @@ enum projector_type {
PROJECTOR_TYPE_LDP,
PROJECTOR_TYPE_LDPV2,
PROJECTOR_TYPE_RESAMPLER,
PROJECTOR_TYPE_GLM_EDGE,
PROJECTOR_TYPE_MERGER,
PROJECTOR_TYPE_UNKNOWN,
};
@ -176,6 +187,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LDP, "ldp" },
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
};
@ -500,6 +512,12 @@ struct clip_vision_model {
struct ggml_tensor * mm_4_w = NULL;
struct ggml_tensor * mm_4_b = NULL;
//GLMV-Edge projection
struct ggml_tensor * mm_model_adapter_conv_w;
struct ggml_tensor * mm_model_adapter_conv_b;
struct ggml_tensor * boi_w;
struct ggml_tensor * eoi_w;
// MobileVLM projection
struct ggml_tensor * mm_model_mlp_1_w;
struct ggml_tensor * mm_model_mlp_1_b;
@ -560,6 +578,7 @@ struct clip_ctx {
bool has_vision_encoder = false;
bool has_llava_projector = false;
bool has_minicpmv_projector = false;
bool has_glm_projector = false;
bool has_qwen2vl_merger = false;
int minicpmv_version = 2;
@ -638,7 +657,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
const int batch_size = imgs->size;
if (ctx->has_llava_projector || ctx->has_minicpmv_projector) {
if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) {
GGML_ASSERT(batch_size == 1);
}
@ -734,8 +753,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
}
// loop over layers
if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) {
// TODO: figure out why we doing thing in this way ???
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
n_layer += 1;
}
for (int il = 0; il < n_layer - 1; il++) {
@ -1095,7 +1113,33 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
GGML_ASSERT(false);
}
}
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
// glm projector
else if (ctx->has_glm_projector) {
if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
//GLU
{
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
embeddings = ggml_norm(ctx0, embeddings, eps);
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
embeddings = ggml_gelu_inplace(ctx0, embeddings);
struct ggml_tensor * x = embeddings;
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
embeddings = ggml_silu_inplace(ctx0, embeddings);
embeddings = ggml_mul(ctx0, embeddings,x);
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
}
} else {
GGML_ABORT("fatel error");
}
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
@ -1284,6 +1328,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
}
idx = gguf_find_key(ctx, KEY_HAS_GLM_PROJ);
if (idx != -1) {
new_clip->has_glm_projector = gguf_get_val_bool(ctx, idx);
}
idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER);
if (idx != -1) {
new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx);
@ -1308,6 +1357,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector);
LOG_INF("%s: glm_projector: %d\n", __func__, new_clip->has_glm_projector);
LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0);
LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
}
@ -1575,6 +1625,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight"));
vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias"));
}
else if (new_clip->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
vision_model.mm_model_adapter_conv_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "weight"));
vision_model.mm_model_adapter_conv_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "bias"));
vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_LINEAR,"weight"));
vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"weight"));
vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"bias"));
vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_H_2_4H,"weight"));
vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_GATE,"weight"));
vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
vision_model.boi_w = get_tensor(new_clip->ctx_data, TN_GLM_BOI_W);
vision_model.eoi_w = get_tensor(new_clip->ctx_data, TN_GLM_EOI_W);
}
else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) {
vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
@ -2115,6 +2177,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
return true;
}
if (ctx->has_glm_projector) {
res_imgs->size = 1;
res_imgs->data = new clip_image_f32[res_imgs->size];
clip_image_u8 resized_image;
int32_t sz=ctx->vision_model.hparams.image_size;
bicubic_resize(*img, resized_image,sz,sz);
clip_image_f32 * res = clip_image_f32_init();
//clip_image_save_to_bmp(resized_image, "resized.bmp");
normalize_image_u8_to_f32(&resized_image, res, ctx->image_mean, ctx->image_std);
res_imgs->data[0] = *res;
clip_image_f32_free(res);
return true;
}
bool pad_to_square = true;
if (!ctx->has_vision_encoder) {
LOG_ERR("This gguf file seems to have no vision encoder\n");
@ -2300,7 +2376,8 @@ void clip_free(clip_ctx * ctx) {
}
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
int extra_tokens = ctx->has_glm_projector ? 2 : 0;
return (clip_n_patches(ctx) + extra_tokens) * clip_n_mmproj_embd(ctx) * sizeof(float);
}
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
@ -2342,7 +2419,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
n_patches /= 4;
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
if (ctx->minicpmv_version == 2) {
@ -2475,6 +2552,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
if (ctx->has_minicpmv_projector) {
GGML_ASSERT(batch_size == 1);
}
if (ctx->has_glm_projector) {
GGML_ASSERT(batch_size == 1);
ggml_tensor * boi = ctx->vision_model.boi_w;
ggml_backend_tensor_get(boi,vec,0,ggml_nbytes(boi));
vec = (float*)(vec+ggml_nelements(boi)); //offset for boi
}
// build the inference graph
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true);
@ -2627,7 +2710,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
{
if (!ctx->has_glm_projector) {
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
int* patches_data = (int*)malloc(ggml_nbytes(patches));
for (int i = 0; i < num_patches; i++) {
@ -2651,6 +2734,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// copy the embeddings to the location passed by the user
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
if (ctx->has_glm_projector) {
//eoi
ggml_tensor * eoi = ctx->vision_model.eoi_w;
int offset = ggml_nelements(embeddings);
ggml_backend_tensor_get(eoi, vec+offset, 0, ggml_nbytes(eoi));
}
return true;
}
@ -2812,6 +2902,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return 3584;
}
}
if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE){
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
}
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
return ctx->vision_model.mm_1_b->ne[0];
}
@ -2827,6 +2920,9 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) {
return 0;
}
bool clip_is_glm(const struct clip_ctx * ctx) {
return ctx->has_glm_projector;
}
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
return ctx->has_qwen2vl_merger;
}

View file

@ -93,6 +93,8 @@ CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
#ifdef __cplusplus
}
#endif

View file

@ -0,0 +1,280 @@
import argparse
import os
import json
import re
import torch
import numpy as np
from gguf import *
TEXT = "clip.text"
VISION = "clip.vision"
from transformers import SiglipVisionModel, SiglipVisionConfig
def k(raw_key: str, arch: str) -> str:
return raw_key.format(arch=arch)
def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool:
if name in (
"logit_scale",
"text_model.embeddings.position_ids",
"vision_model.embeddings.position_ids",
):
return True
if name in (
"vision_model.head.probe",
"vision_model.head.attention.in_proj_weight",
"vision_model.head.attention.in_proj_bias",
"vision_model.head.attention.out_proj.weight",
"vision_model.head.attention.out_proj.bias",
"vision_model.head.layernorm.weight",
"vision_model.head.layernorm.bias",
"vision_model.head.mlp.fc1.weight",
"vision_model.head.mlp.fc1.bias",
"vision_model.head.mlp.fc2.weight",
"vision_model.head.mlp.fc2.bias"
):
return True
if name.startswith("v") and not has_vision:
return True
if name.startswith("t") and not has_text:
return True
return False
def get_tensor_name(name: str) -> str:
if "projection" in name:
return name
if "mm_projector" in name:
name = name.replace("model.mm_projector", "mm")
name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1)
name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1)
return name
return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
ap.add_argument("--text-only", action="store_true", required=False,
help="Save a text-only model. It can't be used to encode images")
ap.add_argument("--vision-only", action="store_true", required=False,
help="Save a vision-only model. It can't be used to encode texts")
ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
help="The clip model is from openclip (for ViT-SO400M type))")
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2","adapter"], default="adapter")
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
default_image_mean = [0.5, 0.5, 0.5]
default_image_std = [0.5, 0.5, 0.5]
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
# with proper
args = ap.parse_args()
if args.text_only and args.vision_only:
print("--text-only and --image-only arguments cannot be specified at the same time.")
exit(1)
if args.use_f32:
print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.")
# output in the same directory as the model if output_dir is None
dir_model = args.model_dir
if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
vocab = None
tokens = None
else:
with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
vocab = json.load(f)
tokens = [key for key in vocab]
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
config = json.load(f)
if args.clip_model_is_vision:
v_hparams = config
t_hparams = None
else:
v_hparams = config["vision_config"]
t_hparams = None
# possible data types
# ftype == 0 -> float32
# ftype == 1 -> float16
#
# map from ftype to string
ftype_str = ["f32", "f16"]
ftype = 1
if args.use_f32:
ftype = 0
vision_config = SiglipVisionConfig(**v_hparams)
model = SiglipVisionModel(vision_config)
model.load_state_dict(torch.load(os.path.join(dir_model, "glm.clip")))
fname_middle = None
has_text_encoder = False
has_vision_encoder = True
has_glm_projector = True
if args.text_only:
fname_middle = "text-"
has_vision_encoder = False
elif args.llava_projector is not None:
fname_middle = "mmproj-"
has_text_encoder = False
has_glm_projector = True
elif args.vision_only:
fname_middle = "vision-"
has_text_encoder = False
else:
fname_middle = ""
output_dir = args.output_dir if args.output_dir is not None else dir_model
os.makedirs(output_dir, exist_ok=True)
output_prefix = os.path.basename(output_dir).replace("ggml_", "")
fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf")
fout = GGUFWriter(path=fname_out, arch="clip")
fout.add_bool("clip.has_text_encoder", has_text_encoder)
fout.add_bool("clip.has_vision_encoder", has_vision_encoder)
fout.add_bool("clip.has_glm_projector", has_glm_projector)
fout.add_file_type(ftype)
model_name = config["_name_or_path"] if "_name_or_path" in config else os.path.basename(dir_model)
fout.add_name(model_name)
if has_glm_projector:
fout.add_description("image encoder for glm4v")
fout.add_string("clip.projector_type", "adapter")
else:
fout.add_description("two-tower CLIP model")
if has_text_encoder:
assert t_hparams is not None
assert tokens is not None
# text_model hparams
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"])
fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"]))
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"])
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"])
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
fout.add_token_list(tokens)
if has_vision_encoder:
# vision_model hparams
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"])
fout.add_uint32("clip.vision.projection_dim", 0)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), v_hparams["num_hidden_layers"])
image_mean = args.image_mean if args.image_mean is not None else default_image_mean
image_std = args.image_std if args.image_std is not None else default_image_std
fout.add_array("clip.vision.image_mean", image_mean)
fout.add_array("clip.vision.image_std", image_std)
fout.add_bool("clip.use_gelu", True)
if has_glm_projector:
# model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue]
projector = torch.load(args.llava_projector)
for name, data in projector.items():
name = get_tensor_name(name)
# pw and dw conv ndim==4
if data.ndim == 2 or data.ndim == 4:
data = data.squeeze().numpy().astype(np.float16)
else:
data = data.squeeze().numpy().astype(np.float32)
if name.startswith("vision."):
name=name.replace("vision.","")
fout.add_tensor(name, data)
print(f"Projector {name} - {data.dtype} - shape = {data.shape}")
# print(f"Projector {name} tensors added\n")
state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue]
for name, data in state_dict.items():
if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_glm_projector):
# we don't need this
print(f"skipping parameter: {name}")
continue
name = get_tensor_name(name)
data = data.squeeze().numpy()
n_dims = len(data.shape)
# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0
if n_dims == 4:
print(f"tensor {name} is always saved in f16")
data = data.astype(np.float16)
ftype_cur = 1
elif ftype == 1:
if name[-7:] == ".weight" and n_dims == 2:
# print(" Converting to float16")
data = data.astype(np.float16)
ftype_cur = 1
else:
# print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
else:
if data.dtype != np.float32:
# print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
print(f"siglip {name} - {data.dtype} - shape = {data.shape}")
# print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
fout.add_tensor(name, data)
fout.write_header_to_file()
fout.write_kv_data_to_file()
fout.write_tensors_to_file()
fout.close()
print("Done. Output file: " + fname_out)

View file

@ -0,0 +1,33 @@
import argparse
import os
import torch
from transformers import AutoModel
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", help="Path to GLM model")
args = ap.parse_args()
# find the model part that includes the the multimodal projector weights
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True)
checkpoint = model.state_dict()
# get a list of mm tensor names
mm_tensors = [k for k, v in checkpoint.items() if k.startswith("vision.adapter.")]
# store these tensors in a new dictionary and torch.save them
projector = {name: checkpoint[name].float() for name in mm_tensors}
torch.save(projector, f"{args.model}/glm.projector")
clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vision.vit.model.vision_model.")]
if len(clip_tensors) > 0:
clip = {name.replace("vision.vit.model.", ""): checkpoint[name].float() for name in clip_tensors}
torch.save(clip, f"{args.model}/glm.clip")
# added tokens should be removed to be able to convert Mistral models
if os.path.exists(f"{args.model}/added_tokens.json"):
with open(f"{args.model}/added_tokens.json", "w") as f:
f.write("{}\n")
print("Done!")
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
print(f"Also, use {args.model}glm.projector to prepare a glm-encoder.gguf file.")

View file

@ -311,6 +311,20 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
img_res_v.size = 0;
img_res_v.data = nullptr;
}
else if (clip_is_glm(ctx_clip)){
struct clip_image_size * load_image_size = clip_image_size_init();
load_image_size->width = img_res_v.data[0].nx;
load_image_size->height = img_res_v.data[0].ny;
clip_add_load_image_size(ctx_clip, load_image_size);
bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd);
int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2);
*n_img_pos = (pos * pos + 2);
if (!encoded){
LOG_ERR("Unable to encode image \n");
return false;
}
}
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
// flat / default llava-1.5 type embedding
*n_img_pos = clip_n_patches(ctx_clip);
@ -395,6 +409,9 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
if (clip_is_minicpmv(ctx_clip)) {
num_max_patches = 10;
}
if (clip_is_glm(ctx_clip)) {
num_max_patches = 1;
}
float * image_embd;
if (clip_is_qwen2vl(ctx_clip)) {
// qwen2vl don't split image into chunks, so `num_max_patches` is not needed.

View file

@ -24,15 +24,16 @@
#include <string>
#include <vector>
#include "chat-template.hpp"
#include "common.h"
#include "json.hpp"
#include "linenoise.cpp/linenoise.h"
#include "llama-cpp.h"
#include "chat-template.hpp"
#include "log.h"
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
[[noreturn]] static void sigint_handler(int) {
printf("\n\033[0m");
printf("\n" LOG_COL_DEFAULT);
exit(0); // not ideal, but it's the only way to guarantee exit in all cases
}
#endif
@ -65,6 +66,13 @@ static int printe(const char * fmt, ...) {
return ret;
}
static std::string strftime_fmt(const char * fmt, const std::tm & tm) {
std::ostringstream oss;
oss << std::put_time(&tm, fmt);
return oss.str();
}
class Opt {
public:
int init(int argc, const char ** argv) {
@ -698,6 +706,39 @@ class LlamaData {
return download(url, bn, true);
}
int s3_dl(const std::string & model, const std::string & bn) {
const size_t slash_pos = model.find('/');
if (slash_pos == std::string::npos) {
return 1;
}
const std::string bucket = model.substr(0, slash_pos);
const std::string key = model.substr(slash_pos + 1);
const char * access_key = std::getenv("AWS_ACCESS_KEY_ID");
const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY");
if (!access_key || !secret_key) {
printe("AWS credentials not found in environment\n");
return 1;
}
// Generate AWS Signature Version 4 headers
// (Implementation requires HMAC-SHA256 and date handling)
// Get current timestamp
const time_t now = time(nullptr);
const tm tm = *gmtime(&now);
const std::string date = strftime_fmt("%Y%m%d", tm);
const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm);
const std::vector<std::string> headers = {
"Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date +
"/us-east-1/s3/aws4_request",
"x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime
};
const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key;
return download(url, bn, true, headers);
}
std::string basename(const std::string & path) {
const size_t pos = path.find_last_of("/\\");
if (pos == std::string::npos) {
@ -738,6 +779,9 @@ class LlamaData {
rm_until_substring(model_, "github:");
rm_until_substring(model_, "://");
ret = github_dl(model_, bn);
} else if (string_starts_with(model_, "s3://")) {
rm_until_substring(model_, "://");
ret = s3_dl(model_, bn);
} else { // ollama:// or nothing
rm_until_substring(model_, "ollama.com/library/");
rm_until_substring(model_, "://");
@ -847,7 +891,7 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch &
const int n_ctx = llama_n_ctx(ctx.get());
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get());
if (n_ctx_used + batch.n_tokens > n_ctx) {
printf("\033[0m\n");
printf(LOG_COL_DEFAULT "\n");
printe("context size exceeded\n");
return 1;
}
@ -910,7 +954,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
batch = llama_batch_get_one(&new_token_id, 1);
}
printf("\033[0m");
printf(LOG_COL_DEFAULT);
return 0;
}
@ -919,7 +963,7 @@ static int read_user_input(std::string & user_input) {
#ifdef WIN32
printf(
"\r%*s"
"\r\033[0m%s",
"\r" LOG_COL_DEFAULT "%s",
get_terminal_width(), " ", prompt_prefix);
std::getline(std::cin, user_input);
@ -956,7 +1000,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
const bool stdout_a_terminal) {
// Set response color
if (stdout_a_terminal) {
printf("\033[33m");
printf(LOG_COL_YELLOW);
}
if (generate(llama_data, prompt, response)) {
@ -965,7 +1009,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
}
// End response with color reset and newline
printf("\n%s", stdout_a_terminal ? "\033[0m" : "");
printf("\n%s", stdout_a_terminal ? LOG_COL_DEFAULT : "");
return 0;
}

View file

@ -1128,6 +1128,7 @@ curl http://localhost:8080/v1/chat/completions \
- Hermes 2/3, Qwen 2.5
- Mistral Nemo
- Firefunction v2
- Command R7B
- DeepSeek R1 (WIP / seems reluctant to call any tools?)
<details>
@ -1202,21 +1203,28 @@ curl http://localhost:8080/v1/chat/completions \
```shell
# Native support:
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q6_K
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B )
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
# Native support requires the right template for these GGUFs:
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/firellama-3-firefunction-v2 )
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \
--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
# Generic format support
llama-server --jinja -fa -hf bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0
llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K
```
- Test in CLI:

Binary file not shown.

View file

@ -131,6 +131,11 @@ struct slot_params {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
}
std::vector<std::string> grammar_trigger_words;
for (const auto & trigger : sampling.grammar_trigger_words) {
grammar_trigger_words.push_back(trigger.word);
}
return json {
{"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
@ -165,8 +170,9 @@ struct slot_params {
{"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep},
{"grammar", sampling.grammar},
// {"grammar_trigger_words", sampling.grammar_trigger_words},
{"grammar_trigger_words", grammar_trigger_words},
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
{"preserved_tokens", sampling.preserved_tokens},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
@ -363,12 +369,26 @@ struct server_task {
if (ids.size() == 1) {
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
continue;
}
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
for (const auto & t : *preserved_tokens) {
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_DBG("Preserved token: %d\n", ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
}
}
}
if (params.sampling.grammar_lazy) {
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
}
@ -695,19 +715,19 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat() {
std::string finish_reason = "length";
common_chat_msg message;
common_chat_msg msg;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
LOG_DBG("Parsing chat message: %s\n", content.c_str());
message = common_chat_parse(content, oaicompat_chat_format);
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
msg = common_chat_parse(content, oaicompat_chat_format);
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
} else {
message.content = content;
msg.content = content;
}
json tool_calls;
if (!message.tool_calls.empty()) {
if (!msg.tool_calls.empty()) {
tool_calls = json::array();
for (const auto & tc : message.tool_calls) {
for (const auto & tc : msg.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
@ -719,14 +739,19 @@ struct server_task_result_cmpl_final : server_task_result {
}
}
json message {
{"content", msg.content},
{"tool_calls", tool_calls},
{"role", "assistant"},
};
if (!msg.tool_plan.empty()) {
message["tool_plan"] = msg.tool_plan;
}
json choice {
{"finish_reason", finish_reason},
{"index", 0},
{"message", json {
{"content", message.content},
{"tool_calls", tool_calls},
{"role", "assistant"},
}},
{"message", message},
};
if (!stream && probs_output.size() > 0) {
@ -2833,8 +2858,7 @@ struct server_context {
server_slot * slot_batched = nullptr;
auto accept_special_token = [&](server_slot & slot, llama_token token) {
const auto & trigger_tokens = slot.params.sampling.grammar_trigger_tokens;
return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end();
return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
};
// frist, add sampled tokens from any ongoing sequences

View file

@ -13,9 +13,12 @@ def create_server():
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
[
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
]

View file

@ -67,8 +67,8 @@ WEATHER_TOOL = {
def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
n_predict = 512
global server
n_predict = 512
# server = ServerPreset.stories15m_moe()
server.jinja = True
server.n_predict = n_predict
@ -139,29 +139,49 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
# TODO: fix these
# (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None):
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
n_predict = 512
server.n_slots = 1
server.jinja = True
@ -169,10 +189,12 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if template_override:
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predict,
@ -251,33 +273,55 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
@pytest.mark.slow
@pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/c4ai-command-r7b-12-2024-GGUF:Q4_K_M", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_weather_tool_call(hf_repo: str, template_override: Tuple[str, str | None] | None):
def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None):
global server
n_predict = 512
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = 512
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if template_override:
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 256,
"max_tokens": n_predict,
"messages": [
{"role": "user", "content": "What is the weather in Istanbul?"},
],
@ -298,19 +342,39 @@ def test_weather_tool_call(hf_repo: str, template_override: Tuple[str, str | Non
@pytest.mark.slow
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
(None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
(None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
('{"code":"print("}', "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
(None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None):
def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
server.n_slots = 1
server.jinja = True
@ -318,10 +382,12 @@ def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo:
server.n_predict = 128
server.model_hf_repo = hf_repo
server.model_hf_file = None
if template_override:
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 256,

View file

@ -5,10 +5,6 @@
#include "llama.h"
#include "common/base64.hpp"
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
#define CPPHTTPLIB_NO_EXCEPTIONS 1
#endif
// increase max payload length to allow use of larger context size
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
#include "httplib.h"
@ -662,6 +658,7 @@ static json oaicompat_completion_params_parse(
});
}
llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}

View file

@ -154,8 +154,6 @@
placeholder="Type a message (Shift+Enter to add a new line)"
v-model="inputMsg"
@keydown.enter.exact.prevent="sendMessage"
@keydown.enter.shift.exact.prevent="inputMsg += '\n'"
:disabled="isGenerating"
id="msg-input"
dir="auto"
></textarea>

View file

@ -468,7 +468,10 @@ const mainApp = createApp({
URL.revokeObjectURL(url);
},
async sendMessage() {
if (!this.inputMsg) return;
// prevent sending empty message
// also allow typing the message while generating, but does not allow sending it (to match UX/UI behavior of other chat apps)
if (!this.inputMsg || this.isGenerating) return;
const currConvId = this.viewingConvId;
StorageUtils.appendMsg(currConvId, {

View file

@ -274,22 +274,25 @@ endif()
# Generate version info based on git commit.
find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH)
execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE GGML_BUILD_NUMBER
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(NOT DEFINED GGML_BUILD_NUMBER)
find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH)
execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE GGML_BUILD_NUMBER
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(GGML_BUILD_NUMBER EQUAL 1)
message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.")
if(GGML_BUILD_NUMBER EQUAL 1)
message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.")
endif()
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE GGML_BUILD_COMMIT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
endif()
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE GGML_BUILD_COMMIT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# Capture variables prefixed with GGML_.

View file

@ -1775,7 +1775,7 @@ extern "C" {
struct ggml_tensor * a,
int k);
#define GGML_KQ_MASK_PAD 32
#define GGML_KQ_MASK_PAD 64
// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]

View file

@ -93,12 +93,18 @@ endif()
if (GGML_CCACHE)
find_program(GGML_CCACHE_FOUND ccache)
find_program(GGML_SCCACHE_FOUND sccache)
if (GGML_CCACHE_FOUND)
if (GGML_CCACHE_FOUND OR GGML_SCCACHE_FOUND)
if(GGML_CCACHE_FOUND)
set(GGML_CCACHE_VARIANT ccache)
else()
set(GGML_CCACHE_VARIANT sccache)
endif()
# TODO: should not be set globally
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
set(ENV{CCACHE_SLOPPINESS} time_macros)
message(STATUS "ccache found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
else()
message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF")
endif ()

View file

@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_CUDA "*.cu")
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
file(GLOB SRCS "template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})

View file

@ -61,6 +61,13 @@
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_QY1 210
#define GGML_CUDA_CC_QY2 220
@ -148,7 +155,7 @@ typedef float2 dfloat2;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define INT8_MMA_AVAILABLE
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
@ -159,14 +166,24 @@ static constexpr bool fast_fp16_available(const int cc) {
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
}
// Any FP16 tensor cores are available.
static constexpr bool fp16_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
}
static constexpr bool int8_mma_available(const int cc) {
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static constexpr bool new_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
}
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
return __AMDGCN_WAVEFRONT_SIZE;
#else
return 32;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
}
[[noreturn]]
static __device__ void no_device_code(
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {

View file

@ -516,6 +516,114 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
nullptr;
}
// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template<int D, int ncols, int KQ_stride> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
const int iter_k = ne11 / KQ_stride;
const int iter_j = (ne01 + (ncols - 1)) / ncols;
const int bidx0 = blockIdx.x;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
return;
}
const int channel = kbc0 / (iter_k*iter_j);
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
dst += jt*ncols*ne02*D + channel*D;
// Load the partial result that needs a fixup:
float dst_val[ncols] = {0.0f};
float max_val[ncols] = {0.0f};
float rowsum[ncols] = {0.0f};
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
break;
}
dst_val[j] = dst[j*ne02*D + threadIdx.x];
const float2 tmp = dst_fixup[bidx0*ncols + j];
max_val[j] = tmp.x;
rowsum[j] = tmp.y;
}
// Iterate over previous blocks and compute the combined results.
// All CUDA blocks that get here must have a previous block that needs a fixup.
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
continue;
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
break;
}
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
// Scale the current and new value accumulators depending on the max. values.
const float max_val_new = fmaxf(max_val[j], tmp.x);
const float diff_val = max_val[j] - max_val_new;
const float diff_add = tmp.x - max_val_new;
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
max_val[j] = max_val_new;
}
// If this block started in a previous tile we are done and don't need to combine additional partial results.
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
break;
}
bidx--;
kbc_stop = kbc;
}
// Write back final result:
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
return;
}
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
}
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
@ -581,10 +689,11 @@ static void on_no_fattn_vec_case(const int D) {
}
}
template <int D, int parallel_blocks>
// parallel_blocks == 0 is stream-k decomposition
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
@ -603,20 +712,23 @@ void launch_fattn(
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
GGML_ASSERT(Q->ne[3] == 1);
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
char * K_data = (char *) K->data;
const char * K_data = (const char *) K->data;
size_t nb11 = K->nb[1];
size_t nb12 = K->nb[2];
size_t nb13 = K->nb[3];
char * V_data = (char *) V->data;
const char * V_data = (const char *) V->data;
size_t nb21 = V->nb[1];
size_t nb22 = V->nb[2];
size_t nb23 = V->nb[3];
@ -649,39 +761,60 @@ void launch_fattn(
nb23 = nb23*bs*sizeof(half)/ts;
}
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;
dim3 blocks_num;
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
const bool short_context = K->ne[1] < 4096;
const int nblocks_stream_k = 2*nsm;
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
blocks_num.y = 1;
blocks_num.z = 1;
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
} else {
blocks_num.x = parallel_blocks*ntiles_x;
blocks_num.y = Q->ne[2];
blocks_num.z = Q->ne[3];
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
}
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
(const char *) Q->data,
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@ -693,16 +826,22 @@ void launch_fattn(
);
CUDA_CHECK(cudaGetLastError());
if ((parallel_blocks) == 1) {
return;
if constexpr (parallel_blocks == 0) {
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine = blocks_num;
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
}
} else if constexpr (parallel_blocks > 1) {
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
}
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
}

View file

@ -0,0 +1,637 @@
#include "common.cuh"
#include "mma.cuh"
#include "fattn-common.cuh"
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
const half * const __restrict__ maskh,
float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup,
const float scale,
const float slope,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3,
const int jt,
const int kb0_start,
const int kb0_stop) {
#ifdef NEW_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
typedef mma_A_I16K8<half2> mma_A;
typedef mma_B_J8K8<half2> mma_B;
typedef mma_C_I16J8<float> mma_C_KQ;
typedef mma_C_I16J8<half2> mma_C_VKQ;
static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
static_assert(D % nwarps == 0, "bad D");
static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
const int stride_Q = nb01 / sizeof(float2);
const int stride_KV = nb11 / sizeof(half2);
const int stride_mask = nb31 / sizeof(half);
mma_B Q_B[D/(2*mma_B::K)];
mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
float2 KQ_rowsum = {0.0f, 0.0f};
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
float2 KQ_max_scale = {0.0f, 0.0f};
// Temporarily load Q data into tile_KV, will be loaded into registers afterwards.
// The loading is done with decreasing granularity for D for better memory bandwidth.
const half2 scale_h2 = make_half2(scale, scale);
#pragma unroll
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_j = WARP_SIZE / stride_k;
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
break;
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
if (jt*ncols + j < ne01) {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
}
} else {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f);
}
}
}
}
__syncthreads();
{
const int j0 = (threadIdx.y / np) * mma_B::J;
#pragma unroll
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
}
}
__syncthreads();
// Iterate over ne11 == previous tokens:
for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
const int k_VKQ_0 = kb0*KQ_stride;
mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
// Load K data into tile with decreasing granularity for D for better memory bandwidth:
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
#pragma unroll
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_i = WARP_SIZE / stride_k;
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
}
}
}
__syncthreads();
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) {
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) {
mma_A K_A;
K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]);
}
}
__syncthreads();
if (use_logit_softcap) {
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) {
#pragma unroll
for (int l = 0; l < mma_C_KQ::ne; ++l) {
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
}
}
}
if (maskh) {
static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size");
static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size");
#pragma unroll
for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) {
const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I;
#pragma unroll
for (int l = 0; l < mma_C_KQ::ne; ++l) {
const int i = i0 + mma_C_KQ::get_i(l);
const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l);
KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
}
}
}
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
float2 KQ_max_new = KQ_max;
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
#pragma unroll
for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) {
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
}
}
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
#pragma unroll
for (int offset = 16; offset > 2; offset >>= 1) {
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
}
{
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale.x = 0.0f;
}
if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale.y = 0.0f;
}
KQ_max = KQ_max_new;
}
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
#pragma unroll
for (int l = 0; l < mma_C_KQ::ne; ++l) {
const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y;
const float diff = KQ_C[k].x[l] - KQ_max_l;
KQ_C[k].x[l] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_C[k].x[l] = 0.0f;
}
if (l % 2 == 0) {
KQ_rowsum_add.x += KQ_C[k].x[l];
} else {
KQ_rowsum_add.y += KQ_C[k].x[l];
}
}
}
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
#pragma unroll
for (int i = 0; i < D/mma_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < mma_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
// Convert KQ C tiles into B tiles for VKQ calculation:
mma_B B[KQ_stride/(np*2*mma_B::K)];
static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) {
B[k] = KQ_C[k].to_mma_B();
}
// Load V data into tile with decreasing granularity for D for better memory bandwidth:
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
#pragma unroll
for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i);
const int i0_stop = D/2 - (D/2) % (1*stride_i);
const int stride_k = WARP_SIZE / stride_i;
#pragma unroll
for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) {
const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i);
#pragma unroll
for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) {
const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i);
tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V];
}
}
}
__syncthreads();
// Calculate VKQ tile:
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
#pragma unroll
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
mma_A A;
A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]);
}
}
__syncthreads();
}
// Finally, sum up partial KQ rowsums.
// The partial sums are spread across 8 threads each, does not need full reduce.
#pragma unroll
for (int offset = 16; offset > 2; offset >>= 1) {
KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE);
KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE);
}
// Write VKQ accumulators to shared memory in column-major format.
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// Also for np > 1 the combination is done via these values in shared memory.
const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data
#pragma unroll
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format.
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
const int k = k0 + mma_B::get_k(l);
tile_KV[j_cwd*D2_padded + k] = B.x[l];
}
}
const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset
const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) {
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
}
__syncthreads();
static_assert(np == 1 || np == 2 || np == 4, "bad np");
if (np == 1) {
// No combination is needed, the meta data can be directly written from registers to VRAM.
if (needs_fixup && threadIdx.x < mma_B::J) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[j_cwm] = KQ_cmr;
}
if (is_fixup && threadIdx.x < mma_B::J) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[j_cwm] = KQ_cmr;
}
} else if (threadIdx.y % np == 0) {
// Combine the meta data for parallel warps via shared memory.
// Warps with threadIdx.y % np != 0 must NOT return early.
// All threads must return simultaneously to avoid race conditions with work on the next tile.
float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2;
float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
KQ_cm = meta_j[0];
}
float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
#pragma unroll
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
}
const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
KQ_crs = KQ_cms*meta_j[1];
}
#pragma unroll
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
}
// Write back combined meta data:
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
meta_j[0] = KQ_cmn; // Combined max. KQ values.
meta_j[1] = KQ_crs; // Combined KQ rowsums.
meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
}
if (needs_fixup && threadIdx.x < mma_B::J) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
if (is_fixup && threadIdx.x < mma_B::J) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
}
if (np > 1) {
__syncthreads();
}
if (np == 1 || threadIdx.y % np == 0) {
// The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
// The values after that are for the partial results of the individual blocks.
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2));
#pragma unroll
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_j = WARP_SIZE / stride_k;
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
break;
}
#pragma unroll
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J;
if (!is_fixup && jt*ncols + j_dst >= ne01) {
continue;
}
const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2;
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
float2 dstk_val = make_float2(0.0f, 0.0f);
#pragma unroll
for (int ip = 0; ip < np; ++ip) {
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2];
const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]);
dstk_val.x += dstk_val_add.x*KQ_crs;
dstk_val.y += dstk_val_add.y*KQ_crs;
}
if (!needs_fixup && !is_fixup) {
const float KQ_rowsum_j = meta_j[1];
dstk_val.x /= KQ_rowsum_j;
dstk_val.y /= KQ_rowsum_j;
}
if (is_fixup) {
dstk_fixup_data[j_dst*(D/2) + k] = dstk_val;
} else {
dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val;
}
}
}
}
}
if (np > 1) {
__syncthreads();
}
#else
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap>
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 2)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride");
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const int iter_k = ne11 / KQ_stride;
const int iter_j = (ne01 + (ncols - 1)) / ncols;
// kbc == k block continuous, current index in continuous ijk space.
int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x;
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x;
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
// In the most general case >2 seams can fall into the same tile.
// kb0 == k start index when in the output tile.
int kb0_start = kbc % iter_k;
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == iter_k) {
const int channel = kbc / (iter_k*iter_j);
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(D/2);
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
jt, kb0_start, kb0_stop);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
jt, kb0_start, kb0_stop);
}
kbc += iter_k;
kbc -= kbc % iter_k;
kb0_start = 0;
kb0_stop = min(iter_k, kbc_stop - kbc);
}
if (kbc >= kbc_stop) {
return;
}
const int channel = kbc / (iter_k*iter_j);
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(D/2);
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
jt, kb0_start, kb0_stop);
}
template <int D, int cols_per_block>
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
typedef mma_A_I16K8<half2> mma_A;
typedef mma_B_J8K8<half2> mma_B;
static_assert(D % mma_B::K == 0, "bad D");
static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block");
const ggml_tensor * KQV = dst;
constexpr int KQ_stride = D <= 128 ? 64 : 32;
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8);
constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 0, KQ_stride>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
}
#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \
template void ggml_cuda_flash_attn_ext_mma_f16_case \
<D, cols_per_block>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
extern DECL_FATTN_MMA_F16_CASE( 64, 8);
extern DECL_FATTN_MMA_F16_CASE( 80, 8);
extern DECL_FATTN_MMA_F16_CASE( 96, 8);
extern DECL_FATTN_MMA_F16_CASE(112, 8);
extern DECL_FATTN_MMA_F16_CASE(128, 8);
extern DECL_FATTN_MMA_F16_CASE(256, 8);
extern DECL_FATTN_MMA_F16_CASE( 64, 16);
extern DECL_FATTN_MMA_F16_CASE( 80, 16);
extern DECL_FATTN_MMA_F16_CASE( 96, 16);
extern DECL_FATTN_MMA_F16_CASE(112, 16);
extern DECL_FATTN_MMA_F16_CASE(128, 16);
extern DECL_FATTN_MMA_F16_CASE(256, 16);
extern DECL_FATTN_MMA_F16_CASE( 64, 32);
extern DECL_FATTN_MMA_F16_CASE( 80, 32);
extern DECL_FATTN_MMA_F16_CASE( 96, 32);
extern DECL_FATTN_MMA_F16_CASE(112, 32);
extern DECL_FATTN_MMA_F16_CASE(128, 32);
extern DECL_FATTN_MMA_F16_CASE(256, 32);
extern DECL_FATTN_MMA_F16_CASE( 64, 64);
extern DECL_FATTN_MMA_F16_CASE( 80, 64);
extern DECL_FATTN_MMA_F16_CASE( 96, 64);
extern DECL_FATTN_MMA_F16_CASE(112, 64);
extern DECL_FATTN_MMA_F16_CASE(128, 64);
extern DECL_FATTN_MMA_F16_CASE(256, 64);

View file

@ -45,7 +45,17 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
@ -288,16 +298,18 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");

View file

@ -48,7 +48,12 @@ static __global__ void flash_attn_tile_ext_f32(
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
@ -287,16 +292,18 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");

View file

@ -42,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
@ -303,7 +309,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>

View file

@ -41,6 +41,11 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
@ -284,7 +289,8 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>

View file

@ -0,0 +1,648 @@
// Old and deprecated WMMA FlashAttention implementation.
// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
// Long-term the WMMA code should be replaced with a dedicated Volta implementation.
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-wmma-f16.cuh"
#ifdef FP16_MMA_AVAILABLE
#include <mma.h>
#endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
constexpr int D_padded = D + 8;
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
frag_b Q_b[D/16][ncols/frag_n];
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
constexpr int mem_KQ = ncols*kqs_padded*kqar;
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
float * KQ_f = (float *) KQ;
half2 * KQ2 = (half2 *) KQ;
float KQ_rowsum_f[ncols/nwarps] = {0.0f};
float KQ_max_f[ncols/nwarps];
float KQ_max_scale_f[ncols/nwarps] = {0.0f};
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
KQ_max_f[j] = -FLT_MAX/2.0f;
}
half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
half2 KQ_max_h2[ncols/nwarps];
half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
}
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
break;
}
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
}
}
// Convert Q to half and apply scale, temporarily store in KQ:
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
}
}
__syncthreads();
// Load Q into tensor core fragments/registers since it will be used frequently:
#pragma unroll
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
}
}
__syncthreads();
// Iterate over ne11 == previous tokens:
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
}
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
}
}
__syncthreads();
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (std::is_same<KQ_acc_t, float>::value) {
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
if (use_logit_softcap) {
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
}
}
float KQ_max_new = KQ_max_f[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
}
KQ_max_new = warp_reduce_max(KQ_max_new);
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
KQ_max_scale_f[j0/nwarps] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale_f[j0/nwarps] = 0.0f;
}
KQ_max_f[j0/nwarps] = KQ_max_new;
float KQ_rowsum_add = 0.0f;
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
}
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
}
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
} else {
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
if (use_logit_softcap) {
// There is no dedicated tangens hyperbolicus function for half2.
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
}
}
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
}
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
KQ_max_h2[j0/nwarps] = KQ_max_new;
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
}
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
}
}
__syncthreads();
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
nvcuda::wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
}
}
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
}
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
}
}
}
__syncthreads();
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, nvcuda::wmma::mem_col_major);
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
half2 VKQ_scale;
if (std::is_same<KQ_acc_t, float>::value) {
VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
} else {
VKQ_scale = KQ_max_scale_h2[j0/nwarps];
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
break;
}
half2 VKQ_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int l = 0; l < VKQ_ratio; ++l) {
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
}
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
}
}
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j_VKQ = j0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
float KQ_rowsum_j;
if (std::is_same<KQ_acc_t, float>::value) {
KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
} else {
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
}
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
float dst_val = VKQ[j_VKQ*D_padded + i];
if (parallel_blocks == 1) {
dst_val /= KQ_rowsum_j;
}
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
}
if (parallel_blocks == 1 || threadIdx.x != 0) {
continue;
}
float2 dst_meta_val;
if (std::is_same<KQ_acc_t, float>::value) {
dst_meta_val.x = KQ_max_f[j0/nwarps];
} else {
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
}
dst_meta_val.y = KQ_rowsum_j;
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
}
#else
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
}
constexpr int get_max_power_of_2(int x) {
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
}
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
// Number of VKQ rows calculated in parallel:
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
}
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
template <int D, int cols_per_block, typename KQ_acc_t>
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
constexpr int nwarps = 4;
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (4*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 4;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 2;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
return;
}
constexpr int parallel_blocks = 1;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
}
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
if (prec != GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
constexpr int cols_per_block = 16;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
} else {
constexpr int cols_per_block = 32;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
// case 256:
// ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
// break;
default:
GGML_ABORT("fatal error");
break;
}
}
return;
}
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
constexpr int cols_per_block = 8;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 16;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
return;
}
constexpr int cols_per_block = 32;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
}

View file

@ -1,543 +1,3 @@
#include "common.cuh"
#include "fattn-common.cuh"
#ifdef FP16_MMA_AVAILABLE
#include <mma.h>
#endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#ifdef FP16_MMA_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
constexpr int D_padded = D + 8;
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
frag_b Q_b[D/16][ncols/frag_n];
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
constexpr int mem_KQ = ncols*kqs_padded*kqar;
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
float * KQ_f = (float *) KQ;
half2 * KQ2 = (half2 *) KQ;
float KQ_rowsum_f[ncols/nwarps] = {0.0f};
float KQ_max_f[ncols/nwarps];
float KQ_max_scale_f[ncols/nwarps] = {0.0f};
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
KQ_max_f[j] = -FLT_MAX/2.0f;
}
half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
half2 KQ_max_h2[ncols/nwarps];
half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
}
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
break;
}
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
}
}
// Convert Q to half and apply scale, temporarily store in KQ:
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
}
}
__syncthreads();
// Load Q into tensor core fragments/registers since it will be used frequently:
#pragma unroll
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
}
}
__syncthreads();
// Iterate over ne11 == previous tokens:
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
}
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
}
}
__syncthreads();
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (std::is_same<KQ_acc_t, float>::value) {
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
if (use_logit_softcap) {
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
}
}
float KQ_max_new = KQ_max_f[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
}
KQ_max_new = warp_reduce_max(KQ_max_new);
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
KQ_max_scale_f[j0/nwarps] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale_f[j0/nwarps] = 0.0f;
}
KQ_max_f[j0/nwarps] = KQ_max_new;
float KQ_rowsum_add = 0.0f;
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
}
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
}
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
} else {
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
if (use_logit_softcap) {
// There is no dedicated tangens hyperbolicus function for half2.
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
}
}
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
}
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
KQ_max_h2[j0/nwarps] = KQ_max_new;
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
}
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
}
}
__syncthreads();
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
nvcuda::wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
}
}
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
}
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
}
}
}
__syncthreads();
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, nvcuda::wmma::mem_col_major);
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
half2 VKQ_scale;
if (std::is_same<KQ_acc_t, float>::value) {
VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
} else {
VKQ_scale = KQ_max_scale_h2[j0/nwarps];
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
break;
}
half2 VKQ_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int l = 0; l < VKQ_ratio; ++l) {
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
}
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
}
}
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j_VKQ = j0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
float KQ_rowsum_j;
if (std::is_same<KQ_acc_t, float>::value) {
KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
} else {
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
}
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
float dst_val = VKQ[j_VKQ*D_padded + i];
if (parallel_blocks == 1) {
dst_val /= KQ_rowsum_j;
}
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
}
if (parallel_blocks == 1 || threadIdx.x != 0) {
continue;
}
float2 dst_meta_val;
if (std::is_same<KQ_acc_t, float>::value) {
dst_meta_val.x = KQ_max_f[j0/nwarps];
} else {
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
}
dst_meta_val.y = KQ_rowsum_j;
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
}
#else
NO_DEVICE_CODE;
#endif // FP16_MMA_AVAILABLE
}
constexpr int get_max_power_of_2(int x) {
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
}
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
// Number of VKQ rows calculated in parallel:
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
}
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
template <int D, int cols_per_block, typename KQ_acc_t>
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
constexpr int nwarps = 4;
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (4*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 4;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 2;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
return;
}
constexpr int parallel_blocks = 1;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
}
#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
template void ggml_cuda_flash_attn_ext_wmma_f16_case \
<D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
extern DECL_FATTN_WMMA_F16_CASE( 64, 8, half);
extern DECL_FATTN_WMMA_F16_CASE( 96, 8, half);
extern DECL_FATTN_WMMA_F16_CASE(128, 8, half);
extern DECL_FATTN_WMMA_F16_CASE(256, 8, half);
extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -1,5 +1,6 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-mma-f16.cuh"
#include "fattn-tile-f16.cuh"
#include "fattn-tile-f32.cuh"
#include "fattn-vec-f16.cuh"
@ -7,144 +8,56 @@
#include "fattn-wmma-f16.cuh"
#include "fattn.cuh"
#include <cstdint>
template <int cols_per_block>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
if (prec != GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
constexpr int cols_per_block = 16;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
} else {
constexpr int cols_per_block = 32;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
// case 256:
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
// break;
default:
GGML_ABORT("fatal error");
break;
}
}
return;
}
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
constexpr int cols_per_block = 8;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 16;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
return;
}
constexpr int cols_per_block = 32;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
return;
}
if (Q->ne[1] <= 16) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
return;
}
if (Q->ne[1] <= 32) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
}
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
@ -323,10 +236,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
if (!fp16_mma_available(cc)) {
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
if (prec == GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
}
} else {
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
}
}
return;
}
@ -341,5 +262,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
}
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
if (cc == GGML_CUDA_CC_VOLTA) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
}

View file

@ -1205,7 +1205,7 @@ static void ggml_cuda_op_mul_mat_cublas(
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
if (compute_capability == GGML_CUDA_CC_CDNA) {
if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
const float alpha = 1.0f;
const float beta = 0.0f;
CUBLAS_CHECK(
@ -1750,7 +1750,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
beta = &beta_f32;
}
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) {
cu_compute_type = CUBLAS_COMPUTE_32F;
alpha = &alpha_f32;
beta = &beta_f32;

View file

@ -1,11 +1,67 @@
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
// The documentation for the PTX instructions can be found under:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
//
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
// A is a row-major matrix with shape I x K.
// B is a column-major matrix with shape K x J.
// C is a column-major matrix with shape I x J.
// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
// All matrix tiles have ne physical 32 bit elements per warp.
//
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
#include "common.cuh"
struct mma_int_A_I16K4 {
#if CUDART_VERSION >= 11800
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
int ret = 0;
#ifdef NEW_MMA_AVAILABLE
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "+r"(ret) : "r"(x));
#else
NO_DEVICE_CODE;
#endif // defined(NEW_MMA_AVAILABLE)
return ret;
}
#else
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
// Imagine transposing row-major matrix to column-major matrix.
const int src_i_low = 2 * (threadIdx.x % 4);
const int src_i_high = src_i_low + 1;
const int src_j = threadIdx.x / 4;
const int src_laneid_low = src_i_low * 4 + src_j / 2;
const int src_laneid_high = src_i_high * 4 + src_j / 2;
const int shift_low = ((src_j + 0) % 2) * 16;
const int shift_high = ((src_j + 1) % 2) * 16;
const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
return ret_low | ret_high;
}
#endif // CUDART_VERSION >= 11800
template <typename T>
struct mma_A_I16K4 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int I = 16;
static constexpr int K = 4;
static constexpr int ne = 2;
int x[ne] = {0};
T x[ne];
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l%2) * (I/2) + threadIdx.x / K;
@ -21,27 +77,35 @@ struct mma_int_A_I16K4 {
return ret;
}
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE)
const int * xs = xs0 + (threadIdx.x%I)*stride;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(x[0]), "+r"(x[1])
: "l"(xs));
#else
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(xi[0]), "+r"(xi[1])
: "l"(xs));
#else
load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
};
struct mma_int_A_I16K8 {
template <typename T>
struct mma_A_I16K8 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int I = 16;
static constexpr int K = 8;
static constexpr int ne = 4;
int x[ne] = {0};
T x[ne];
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
@ -57,31 +121,62 @@ struct mma_int_A_I16K8 {
return ret;
}
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE)
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "l"(xs));
#else
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
__device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int * ) x;
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "l"(xs));
#else
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int * ) x;
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
: "l"(xs));
#else
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void transpose() {
int * xi = (int *) x;
xi[0] = ggml_cuda_movmatrix(xi[0]);
const int tmp = ggml_cuda_movmatrix(xi[1]);
xi[1] = ggml_cuda_movmatrix(xi[2]);
xi[2] = tmp;
xi[3] = ggml_cuda_movmatrix(xi[3]);
}
};
struct mma_int_B_J8K4 {
template <typename T>
struct mma_B_J8K4 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int J = 8;
static constexpr int K = 4;
static constexpr int ne = 1;
int x[ne] = {0};
T x[ne];
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / K;
@ -97,27 +192,34 @@ struct mma_int_B_J8K4 {
return ret;
}
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
const int * xs = xs0 + (threadIdx.x%J)*stride;
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
: "+r"(x[0])
: "l"(xs));
#else
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
: "+r"(xi[0]) : "l"(xs));
#else
load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
};
struct mma_int_B_J8K8 {
template <typename T>
struct mma_B_J8K8 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int J = 8;
static constexpr int K = 8;
static constexpr int ne = 2;
int x[ne] = {0};
T x[ne];
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / (K/2);
@ -133,22 +235,31 @@ struct mma_int_B_J8K8 {
return ret;
}
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(x[0]), "+r"(x[1])
: "l"(xs));
#else
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(xi[0]), "+r"(xi[1])
: "l"(xs));
#else
load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
};
struct mma_int_C_I16J8 {
template <typename T>
struct mma_C_I16J8 {};
template <>
struct mma_C_I16J8<int> {
static constexpr int I = 16;
static constexpr int J = 8;
static constexpr int ne = 4;
@ -169,8 +280,8 @@ struct mma_int_C_I16J8 {
return ret;
}
__device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
#ifdef INT8_MMA_AVAILABLE
__device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
#ifdef NEW_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
@ -188,11 +299,11 @@ struct mma_int_C_I16J8 {
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
#ifdef INT8_MMA_AVAILABLE
__device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
#ifdef NEW_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
@ -216,6 +327,132 @@ struct mma_int_C_I16J8 {
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
#endif // NEW_MMA_AVAILABLE
}
};
template <>
struct mma_C_I16J8<half2> {
static constexpr int I = 16;
static constexpr int J = 4;
static constexpr int ne = 2;
half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = l * (I/2) + threadIdx.x / J;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x % J;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
#ifdef NEW_MMA_AVAILABLE
int * Axi = (int *) mma_A.x;
int * Bxi = (int *) mma_B.x;
int * xi = (int *) x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
: "+r"(xi[0]), "+r"(xi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(xi[0]), "+r"(xi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(xi[0]), "+r"(xi[1])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
mma_B_J8K8<half2> mma_B;
int * xi = (int *) x;
int * Bxi = (int *) mma_B.x;
Bxi[0] = ggml_cuda_movmatrix(xi[0]);
Bxi[1] = ggml_cuda_movmatrix(xi[1]);
return mma_B;
}
};
template <>
struct mma_C_I16J8<float> {
static constexpr int I = 16;
static constexpr int J = 8;
static constexpr int ne = 4;
float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int l) {
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
#ifdef NEW_MMA_AVAILABLE
int * Axi = (int *) mma_A.x;
int * Bxi = (int *) mma_B.x;
int * xi = (int *) x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
mma_B_J8K8<half2> mma_B;
mma_B.x[0] = make_half2(x[0], x[1]);
mma_B.x[1] = make_half2(x[2], x[3]);
int * Bxi = (int *) mma_B.x;
Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
return mma_B;
}
__device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_i(l)];
}
}
};

View file

@ -132,7 +132,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}
if (int8_mma_available(cc)) {
if (new_mma_available(cc)) {
return true;
}
@ -148,5 +148,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

File diff suppressed because it is too large Load diff

View file

@ -5,9 +5,10 @@ template <typename T, typename type_acc, int block_size>
static __global__ void mul_mat_vec(
const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
const int64_t row = blockIdx.x;
const int64_t channel = blockIdx.z;
const int tid = threadIdx.x;
const int64_t row = blockIdx.x;
const int64_t channel = blockIdx.z;
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
y += channel *stride_channel_y;
@ -18,8 +19,8 @@ static __global__ void mul_mat_vec(
extern __shared__ char data_mmv[];
float * buf_iw = (float *) data_mmv;
if (block_size > WARP_SIZE) {
if (tid < WARP_SIZE) {
if (block_size > warp_size) {
if (tid < warp_size) {
buf_iw[tid] = 0.0f;
}
__syncthreads();
@ -67,16 +68,16 @@ static __global__ void mul_mat_vec(
static_assert(std::is_same<T, void>::value, "unsupported type");
}
sumf = warp_reduce_sum(sumf);
sumf = warp_reduce_sum<warp_size>(sumf);
if (block_size > WARP_SIZE) {
buf_iw[tid/WARP_SIZE] = sumf;
if (block_size > warp_size) {
buf_iw[tid/warp_size] = sumf;
__syncthreads();
if (tid >= WARP_SIZE) {
if (tid >= warp_size) {
return;
}
sumf = buf_iw[tid];
sumf = warp_reduce_sum(sumf);
sumf = warp_reduce_sum<warp_size>(sumf);
}
if (tid != 0) {
@ -96,10 +97,19 @@ static void launch_mul_mat_vec_cuda(
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(nchannels_y % nchannels_x == 0);
const int64_t channel_ratio = nchannels_y / nchannels_x;
int device;
int warp_size;
int64_t block_size_best = WARP_SIZE;
int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
CUDA_CHECK(cudaGetDevice(&device));
warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t block_size_best = warp_size;
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
int64_t max_block_size = 256;
if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
max_block_size = 128;
}
for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
if (niter < niter_best) {
niter_best = niter;
@ -107,7 +117,7 @@ static void launch_mul_mat_vec_cuda(
}
}
const int smem = WARP_SIZE*sizeof(float);
const int smem = warp_size*sizeof(float);
const dim3 block_nums(nrows, 1, nchannels_y);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {

View file

@ -18,7 +18,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif
#endif // __clang__
template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
@ -126,7 +126,7 @@ static __global__ void soft_max_f32(
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#endif // __clang__
static __global__ void soft_max_back_f32(
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {

View file

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16);
DECL_FATTN_MMA_F16_CASE(80, 16);
DECL_FATTN_MMA_F16_CASE(96, 16);
DECL_FATTN_MMA_F16_CASE(112, 16);
DECL_FATTN_MMA_F16_CASE(128, 16);
DECL_FATTN_MMA_F16_CASE(256, 16);

View file

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 32);
DECL_FATTN_MMA_F16_CASE(80, 32);
DECL_FATTN_MMA_F16_CASE(96, 32);
DECL_FATTN_MMA_F16_CASE(112, 32);
DECL_FATTN_MMA_F16_CASE(128, 32);
DECL_FATTN_MMA_F16_CASE(256, 32);

View file

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 64);
DECL_FATTN_MMA_F16_CASE(80, 64);
DECL_FATTN_MMA_F16_CASE(96, 64);
DECL_FATTN_MMA_F16_CASE(112, 64);
DECL_FATTN_MMA_F16_CASE(128, 64);
DECL_FATTN_MMA_F16_CASE(256, 64);

View file

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8);
DECL_FATTN_MMA_F16_CASE(80, 8);
DECL_FATTN_MMA_F16_CASE(96, 8);
DECL_FATTN_MMA_F16_CASE(112, 8);
DECL_FATTN_MMA_F16_CASE(128, 8);
DECL_FATTN_MMA_F16_CASE(256, 8);

View file

@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh"
DECL_FATTN_WMMA_F16_CASE(64, 16, float);
DECL_FATTN_WMMA_F16_CASE(80, 16, float);
DECL_FATTN_WMMA_F16_CASE(96, 16, float);
DECL_FATTN_WMMA_F16_CASE(112, 16, float);
DECL_FATTN_WMMA_F16_CASE(128, 16, float);
DECL_FATTN_WMMA_F16_CASE(256, 16, float);

View file

@ -1,9 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh"
DECL_FATTN_WMMA_F16_CASE(64, 32, float);
DECL_FATTN_WMMA_F16_CASE(80, 32, float);
DECL_FATTN_WMMA_F16_CASE(96, 32, float);
DECL_FATTN_WMMA_F16_CASE(112, 32, float);
DECL_FATTN_WMMA_F16_CASE(128, 32, float);

View file

@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh"
DECL_FATTN_WMMA_F16_CASE(64, 16, half);
DECL_FATTN_WMMA_F16_CASE(80, 16, half);
DECL_FATTN_WMMA_F16_CASE(96, 16, half);
DECL_FATTN_WMMA_F16_CASE(112, 16, half);
DECL_FATTN_WMMA_F16_CASE(128, 16, half);
DECL_FATTN_WMMA_F16_CASE(256, 16, half);

View file

@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh"
DECL_FATTN_WMMA_F16_CASE(64, 32, half);
DECL_FATTN_WMMA_F16_CASE(80, 32, half);
DECL_FATTN_WMMA_F16_CASE(96, 32, half);
DECL_FATTN_WMMA_F16_CASE(112, 32, half);
DECL_FATTN_WMMA_F16_CASE(128, 32, half);
DECL_FATTN_WMMA_F16_CASE(256, 32, half);

View file

@ -1,8 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh"
DECL_FATTN_WMMA_F16_CASE(64, 8, half);
DECL_FATTN_WMMA_F16_CASE(96, 8, half);
DECL_FATTN_WMMA_F16_CASE(128, 8, half);
DECL_FATTN_WMMA_F16_CASE(256, 8, half);

View file

@ -12,13 +12,13 @@ SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.p
DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v});
"""
SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh"
#include "../fattn-mma-f16.cuh"
"""
SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@ -57,20 +57,12 @@ for vkq_size in [16, 32]:
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
for kq_acc_t in ["half", "float"]:
for cols_per_block in [8, 16, 32]:
if kq_acc_t == "float" and cols_per_block == 8:
continue
for cols_per_block in [8, 16, 32, 64]:
with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
f.write(SOURCE_FATTN_MMA_START)
with open(f"fattn-wmma-f16-instance-kq{kq_acc_t}-cpb{cols_per_block}.cu", "w") as f:
f.write(SOURCE_FATTN_WMMA_START)
for head_size in [64, 80, 96, 112, 128, 256]:
if cols_per_block == 8 and head_size % 32 != 0: # wmma fragment is 8x32
continue
if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance
continue
f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size))
for head_size in [64, 80, 96, 112, 128, 256]:
f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
for type in TYPES_MMQ:
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:

View file

@ -1,5 +1,6 @@
#pragma once
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
@ -8,6 +9,7 @@
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
@ -25,6 +27,7 @@
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate

View file

@ -50,7 +50,7 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})

View file

@ -20,7 +20,10 @@
#define GGML_METAL_MAX_COMMAND_BUFFERS 8
// create residency sets only on macOS >= 15.0
#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000
#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \
TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \
TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000
#define GGML_METAL_HAS_RESIDENCY_SETS 1
#endif
@ -1071,7 +1074,7 @@ static bool ggml_backend_metal_buffer_rset_init(
}
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
if (@available(macOS 15.0, *)) {
if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];
desc.label = @"ggml_backend_metal";
desc.initialCapacity = ctx->n_buffers;
@ -1106,7 +1109,7 @@ static bool ggml_backend_metal_buffer_rset_init(
// rset free
static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) {
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
if (@available(macOS 15.0, *)) {
if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
if (ctx->rset) {
[ctx->rset endResidency];
[ctx->rset removeAllAllocations];

View file

@ -29,7 +29,7 @@ if (MUSAToolkit_FOUND)
list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})

View file

@ -1357,6 +1357,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,

View file

@ -0,0 +1,156 @@
{{ bos_token }}{%- macro document_turn(documents) -%}
{# format documents into chat turn #}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[
{"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}
]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
{
"tool_call_id": "0",
"results": {
{% for doc in documents %}
"{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},
{% endif %}
{% endfor %}
},
"is_error": null
}
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}
{%- macro tool_call_id_to_int(messages, tool_call_id) %}
{%- set counter = namespace(value=0) %}
{%- set tool_call_id_seen = namespace(value=false) %}
{%- for msg in messages %}
{%- if msg.tool_calls %}
{%- for tool_call in msg.tool_calls %}
{%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}
{{ counter.value }}
{%- set tool_call_id_seen.value = true %}
{%- endif %}
{%- set counter.value = counter.value + 1 %}
{%- endfor %}
{%- endif %}
{%- endfor %}
{%- endmacro %}
{%- macro format_tool_message(messages, tool_msg) -%}
{# format tool message #}
{
"tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",
"results": {
"0": {{ tool_msg.content|tojson }}
},
"is_error": null
}
{%- endmacro -%}
{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}
{%- set tool_idx = namespace(value=0) %}
{%- set tool_ids_seen = namespace(value=[]) %}
{%- set sent_documents = namespace(value=false) %}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble
You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.
Your information cutoff date is June 2024.
You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.
{% if tools or documents %}
You have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.
## Tool Use
Think about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.
0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.
You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.
NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.
Then carry out your plan by repeatedly executing the following steps.
1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.
When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.
2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.
Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".
3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.
You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.
NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.
You can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.
4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.
{% if enable_citations %}
## Grounding
Importantly, note that "Reflection" and "Response" above can be grounded.
Grounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "<co>" and "</co>" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "<co>span</co: 0:[1,2],1:[0]>" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".
{% endif %}
## Available Tools
Here is the list of tools that you have available to you.
You can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.
Each tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).
```json
[
{% if documents %}
{"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}
{% endif %}
{% for tool in tools %}
{"name": "{{ tool['function']['name'] }}", "description": "{{tool['function']['description']}}", "parameters": {{ tool['function']['parameters']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}
{% endfor %}
]
```
{% endif %}
# Default Preamble
The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.
- Your name is Command.
- You are a large language model built by Cohere.
- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.
- If the input is ambiguous, ask clarifying follow-up questions.
- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).
- Use LaTeX to generate mathematical notation for complex equations.
- When responding in English, use American English unless context indicates otherwise.
- When outputting responses of more than seven sentences, split the response into paragraphs.
- Prefer the active voice.
- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.
- Use gender-neutral pronouns for unspecified persons.
- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.
- Use the third person when asked to write a summary.
- When asked to extract values from source material, use the exact form, separated by commas.
- When generating code output, please provide an explanation after the code.
- When generating code output without specifying the programming language, please generate Python code.
- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.
{%- if developer_preamble %}
# Developer Preamble
The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.
{{ developer_preamble }}
{%- endif -%}
<|END_OF_TURN_TOKEN|>
{%- for message in messages %}
{%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>
{%- elif message.role|lower == 'user' %}
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}
{%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[
{% for tc in message.tool_calls %}
{"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}
{% set tool_idx.value = tool_idx.value + 1 %}
{% endfor %}
]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}
{% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
{{ format_tool_message(messages, message) }}
{%- for msg in messages[loop.index0 + 1:] %}
{%- if msg.role|lower == 'tool' %},
{{ format_tool_message(messages, msg) }}
{%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}
{%- else %}
{%- break %}
{%- endif %}
{%- endfor %}
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>
{%- endif %}
{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View file

@ -1 +1 @@
32f0b85987396945afea2291d5f4c5862434292b
694244a6e40dc255f6bb4376fb17431c06633e6c

View file

@ -1024,6 +1024,9 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },

View file

@ -51,6 +51,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
@ -115,7 +116,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
return LLM_CHAT_TEMPLATE_PHI_3;
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
return LLM_CHAT_TEMPLATE_FALCON_3;
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
return LLM_CHAT_TEMPLATE_ZEPHYR;
} else if (tmpl_contains("bos_token + message['role']")) {
@ -440,6 +441,14 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
for (auto message : chat) {

View file

@ -31,6 +31,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_LLAMA_3,
LLM_CHAT_TEMPLATE_CHATGML_3,
LLM_CHAT_TEMPLATE_CHATGML_4,
LLM_CHAT_TEMPLATE_GLMEDGE,
LLM_CHAT_TEMPLATE_MINICPM,
LLM_CHAT_TEMPLATE_EXAONE_3,
LLM_CHAT_TEMPLATE_RWKV_WORLD,

View file

@ -1213,5 +1213,7 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
}
grammar.partial_utf8 = decoded.second;
GGML_ASSERT(!grammar.stacks.empty());
if (grammar.stacks.empty()) {
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
}
}

View file

@ -1093,8 +1093,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 28: type = LLM_TYPE_6B; break;
case 40: type = LLM_TYPE_9B; break;
case 28: {
if (hparams.n_head(0) == 16) {
type = LLM_TYPE_1_5B;
} else {
type = LLM_TYPE_6B;
}
} break;
case 40: {
if (hparams.n_head(0) == 24) {
type = LLM_TYPE_4B;
} else {
type = LLM_TYPE_9B;
}
} break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
@ -3068,9 +3080,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
if (layer.wqkv == nullptr) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
}
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);

View file

@ -7215,17 +7215,30 @@ struct llm_build_context {
struct ggml_tensor * Qcur = nullptr;
struct ggml_tensor * Kcur = nullptr;
struct ggml_tensor * Vcur = nullptr;
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
if (model.type == LLM_TYPE_1_5B || model.type == LLM_TYPE_4B || model.type == LLM_TYPE_9B) {
Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
}
Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
}
Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
}
} else {
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
if (model.layers[il].bqkv) {
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
}
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);

View file

@ -86,6 +86,9 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${CMAKE
llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
if (LLAMA_LLGUIDANCE)
llama_target_and_test(test-grammar-llguidance.cpp ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
endif ()
if (NOT WIN32)
# these tests are disabled on Windows because they use internal functions not exported with LLAMA_API

View file

@ -175,6 +175,14 @@ int main(void) {
/* .bos_token= */ "",
/* .eos_token= */ "",
},
{
/* .name= */ "GLMEdge",
/* .template_str= */ "{% for item in messages %}{% if item['role'] == 'system' %}<|system|>\n{{ item['content'] }}{% elif item['role'] == 'user' %}<|user|>\n{{ item['content'] }}{% elif item['role'] == 'assistant' %}<|assistant|>\n{{ item['content'] }}{% endif %}{% endfor %}<|assistant|>",
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
/* .expected_output_jinja= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
/* .bos_token= */ "",
/* .eos_token= */ "",
},
{
/* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF",
/* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",

View file

@ -22,9 +22,13 @@ static common_chat_msg msg_from_json(const json & message) {
"assistant",
"",
{},
/* .tool_plan = */ "",
};
if (message.contains("content") && !message.at("content").is_null()) {
ret.content = message.at("content").get<std::string>();
ret.content = message.at("content");
}
if (message.contains("tool_plan")) {
ret.tool_plan = message.at("tool_plan");
}
auto has_tool_calls = message.contains("tool_calls");
if (has_tool_calls) {
@ -171,8 +175,7 @@ const json llama_3_1_tools = { special_function_tool, code_interpreter_too
struct delta_data {
std::string delta;
std::string grammar;
common_chat_format format;
common_chat_params params;
};
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
@ -214,7 +217,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
break;
}
}
return { delta, params_full.grammar, params_full.format };
return { delta, params_full };
}
/*
@ -224,7 +227,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
*/
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
bool skip_grammar_test = false, bool skip_parser_test = false) {
bool expect_grammar_triggered = true) {
common_chat_msg expected_msg = msg_from_json(test_message);
auto user_message = json{
@ -238,45 +241,110 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
assert_equals(expected_delta, data.delta);
}
if (!skip_parser_test) {
const auto msg = common_chat_parse(data.delta, data.format);
if (expect_grammar_triggered) {
const auto msg = common_chat_parse(data.delta, data.params.format);
assert_msg_equals(expected_msg, msg);
}
if (!expected_msg.tool_calls.empty()) {
GGML_ASSERT(!data.grammar.empty());
GGML_ASSERT(!data.params.grammar.empty());
}
if (!data.grammar.empty()) {
auto grammar = build_grammar(data.grammar);
if (!data.params.grammar.empty()) {
auto grammar = build_grammar(data.params.grammar);
if (!grammar) {
throw std::runtime_error("Failed to build grammar");
}
// TODO: exercice lazy grammars + triggers here, instead of skipping the test
if (!skip_grammar_test) {
if (!match_string(data.delta, grammar.get())) {
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
"\n\nGrammar: " + data.grammar);
auto earliest_trigger_pos = std::string::npos;
auto constrained = data.delta;
for (const auto & trigger : data.params.grammar_triggers) {
auto pos = constrained.find(trigger.word);
if (pos == std::string::npos) {
continue;
}
if (pos > 0 && trigger.at_start) {
fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
continue;
}
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
earliest_trigger_pos = pos;
}
}
auto grammar_triggered = false;
if (earliest_trigger_pos != std::string::npos) {
constrained = constrained.substr(earliest_trigger_pos);
grammar_triggered = true;
}
if (data.params.grammar_lazy) {
assert_equals(expect_grammar_triggered, grammar_triggered);
}
if (grammar_triggered && !match_string(constrained, grammar.get())) {
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
"\n\nGrammar: " + data.params.grammar);
}
}
}
}
static void test_template_output_parsers() {
auto text_message = json{
json text_message {
{ "role", "assistant" },
{ "content", "Hello, world!" },
{ "content", "Hello, world!\nWhat's up?" },
};
auto tool_call_message = json{
json tool_calls = json::array({{
{ "type", "function" },
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
}});
json tool_call_message {
{ "role", "assistant"},
{ "content", {}},
{ "tool_calls", {
{
{ "type", "function" },
{ "function", {
{ "name", "special_function" },
{ "arguments", "{\"arg1\": 1}" },
}},
},
}},
};
json tool_call_message_with_id {
{ "role", "assistant"},
{ "content", {}},
{ "tool_calls", {
{
{ "type", "function" },
{ "function", {
{ "name", "special_function" },
{ "arguments", "{\"arg1\": 1}" },
}},
{"id", "123456789"},
},
}},
{ "role", "assistant" },
{ "content", {} },
{ "tool_calls", json{ {
{ "type", "function" },
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
} } }
{ "tool_calls", tool_calls }
};
json tool_call_plan_message_with_idx {
{ "role", "assistant"},
{ "content", {}},
{ "tool_plan", "I'm not so sure"},
{ "tool_calls", {
{
{ "type", "function" },
{ "function", {
{ "name", "special_function" },
{ "arguments", "{\"arg1\": 1}" },
}},
// Index of the tool call in the tool_calls array
{"id", "0"},
},
}},
{ "role", "assistant" },
{ "content", {} },
{ "tool_calls", tool_calls }
};
auto tool_call_message_with_id = json::parse(tool_call_message.dump());
tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
auto python_tool_call_message = json{
{ "role", "assistant" },
@ -311,7 +379,7 @@ static void test_template_output_parsers() {
common_chat_inputs inputs_no_tools;
inputs_no_tools.messages = {
{ { "role", "user" }, { "content", "Hey" } }
{ { "role", "user" }, { "content", "Hey\nThere" } }
};
common_chat_inputs inputs_tools = inputs_no_tools;
@ -322,6 +390,28 @@ static void test_template_output_parsers() {
inputs_tools_builtin.tools = json::array();
inputs_tools_builtin.tools.push_back(python_tool);
{
// Not supported yet
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
}
{
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools,
"<|START_THINKING|>I'm not so sure<|END_THINKING|>"
"<|START_ACTION|>[\n"
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
"]<|END_ACTION|>");
test_template(tmpl, end_tokens, text_message, tools,
"<|START_RESPONSE|>Hello, world!\n"
"What's up?<|END_RESPONSE|>",
/* expect_grammar_triggered= */ false);
}
{
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens{ "<end_of_turn>" };
@ -339,7 +429,7 @@ static void test_template_output_parsers() {
assert_msg_equals(msg_from_json(text_message),
common_chat_parse("{\n"
" \"response\": \"Hello, world!\"\n"
" \"response\": \"Hello, world!\\nWhat's up?\"\n"
"}",
common_chat_params_init(tmpl, inputs_tools).format));
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
@ -362,11 +452,10 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(
tmpl, end_tokens, tool_call_message_with_id, tools,
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
/* skip_grammar_test= */ true);
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
}
{
const common_chat_template tmpl(
@ -388,7 +477,7 @@ static void test_template_output_parsers() {
inputs_tools)
.format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
@ -413,7 +502,7 @@ static void test_template_output_parsers() {
inputs_tools_builtin)
.format);
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
test_template(tmpl, end_tokens, python_tool_call_message, tools,
@ -428,7 +517,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
@ -440,7 +529,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<function=special_function>{\"arg1\": 1}</function>");
}
@ -454,8 +543,9 @@ static void test_template_output_parsers() {
test_template(tmpl, end_tokens, text_message, {},
"all\n"
"Hello, world!",
/* skip_grammar_test= */ true);
"Hello, world!\n"
"What's up?",
/* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"special_function\n"
"{\"arg1\": 1}");
@ -467,7 +557,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
}
@ -478,7 +568,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
"```json\n"

View file

@ -129,7 +129,7 @@ static void test_grammar(const std::string & test_desc, const std::string & gram
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
}
static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings);
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str), true), passing_strings, failing_strings);
}
static void test_simple_grammar() {

File diff suppressed because it is too large Load diff

View file

@ -1246,7 +1246,7 @@ int main() {
test_all("C++", [](const TestCase & tc) {
try {
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema)));
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
tc.verify_status(SUCCESS);
} catch (const std::runtime_error & ex) {
fprintf(stderr, "Error: %s\n", ex.what());