diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e19d8032a..efd977fcd 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -59,16 +59,14 @@ jobs: id: cmake_build run: | sysctl -a - mkdir build - cd build - cmake .. \ + cmake -B build \ -DCMAKE_BUILD_RPATH="@loader_path" \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_CURL=ON \ -DGGML_METAL_USE_BF16=ON \ -DGGML_METAL_EMBED_LIBRARY=ON \ -DGGML_RPC=ON - cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) - name: Test id: cmake_test @@ -199,13 +197,11 @@ jobs: - name: Build id: cmake_build run: | - mkdir build - cd build - cmake .. \ + cmake -B build \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_CURL=ON \ -DGGML_RPC=ON - cmake --build . --config Release -j $(nproc) + cmake --build build --config Release -j $(nproc) - name: Test id: cmake_test @@ -283,26 +279,52 @@ jobs: id: cmake_build if: ${{ matrix.sanitizer != 'THREAD' }} run: | - mkdir build - cd build - cmake .. \ + cmake -B build \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} - cmake --build . --config ${{ matrix.build_type }} -j $(nproc) + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) - name: Build (no OpenMP) id: cmake_build_no_openmp if: ${{ matrix.sanitizer == 'THREAD' }} run: | - mkdir build - cd build - cmake .. \ + cmake -B build \ -DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ -DGGML_OPENMP=OFF - cmake --build . --config ${{ matrix.build_type }} -j $(nproc) + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose --timeout 900 + + ubuntu-latest-llguidance: + runs-on: ubuntu-latest + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential + + - name: Build + id: cmake_build + run: | + mkdir build + cd build + cmake .. \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_LLGUIDANCE=ON + cmake --build . --config Release -j $(nproc) - name: Test id: cmake_test @@ -335,11 +357,9 @@ jobs: - name: Build id: cmake_build run: | - mkdir build - cd build - cmake .. \ + cmake -B build \ -DGGML_RPC=ON - cmake --build . --config Release -j $(nproc) + cmake --build build --config Release -j $(nproc) - name: Test id: cmake_test @@ -372,11 +392,9 @@ jobs: - name: Build id: cmake_build run: | - mkdir build - cd build - cmake .. \ + cmake -B build \ -DGGML_VULKAN=ON - cmake --build . --config Release -j $(nproc) + cmake --build build --config Release -j $(nproc) - name: Test id: cmake_test @@ -493,13 +511,11 @@ jobs: id: cmake_build run: | source /opt/intel/oneapi/setvars.sh - mkdir build - cd build - cmake .. \ + cmake -B build \ -DGGML_SYCL=ON \ -DCMAKE_C_COMPILER=icx \ -DCMAKE_CXX_COMPILER=icpx - cmake --build . --config Release -j $(nproc) + cmake --build build --config Release -j $(nproc) ubuntu-22-cmake-sycl-fp16: runs-on: ubuntu-22.04 @@ -543,14 +559,12 @@ jobs: id: cmake_build run: | source /opt/intel/oneapi/setvars.sh - mkdir build - cd build - cmake .. \ + cmake -B build \ -DGGML_SYCL=ON \ -DCMAKE_C_COMPILER=icx \ -DCMAKE_CXX_COMPILER=icpx \ -DGGML_SYCL_F16=ON - cmake --build . --config Release -j $(nproc) + cmake --build build --config Release -j $(nproc) macOS-latest-cmake-ios: runs-on: macos-latest @@ -576,9 +590,7 @@ jobs: id: cmake_build run: | sysctl -a - mkdir build - cd build - cmake -G Xcode .. \ + cmake -B build -G Xcode \ -DGGML_METAL_USE_BF16=ON \ -DGGML_METAL_EMBED_LIBRARY=ON \ -DLLAMA_BUILD_EXAMPLES=OFF \ @@ -587,7 +599,7 @@ jobs: -DCMAKE_SYSTEM_NAME=iOS \ -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml - cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO macOS-latest-cmake-tvos: runs-on: macos-latest @@ -613,9 +625,7 @@ jobs: id: cmake_build run: | sysctl -a - mkdir build - cd build - cmake -G Xcode .. \ + cmake -B build -G Xcode \ -DGGML_METAL_USE_BF16=ON \ -DGGML_METAL_EMBED_LIBRARY=ON \ -DLLAMA_BUILD_EXAMPLES=OFF \ @@ -624,7 +634,7 @@ jobs: -DCMAKE_SYSTEM_NAME=tvOS \ -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml - cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO macOS-latest-swift: runs-on: macos-latest @@ -654,17 +664,15 @@ jobs: id: cmake_build run: | sysctl -a - mkdir build - cd build - cmake -G Xcode .. \ + cmake -B build -G Xcode \ -DGGML_METAL_USE_BF16=ON \ -DGGML_METAL_EMBED_LIBRARY=ON \ -DLLAMA_BUILD_EXAMPLES=OFF \ -DLLAMA_BUILD_TESTS=OFF \ -DLLAMA_BUILD_SERVER=OFF \ -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" - cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) - sudo cmake --install . --config Release + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) + sudo cmake --install build --config Release - name: xcodebuild for swift package id: xcodebuild @@ -689,6 +697,7 @@ jobs: uses: hendrikmuhs/ccache-action@v1.2.16 with: key: windows-msys2 + variant: sccache evict-old-files: 1d - name: Setup ${{ matrix.sys }} @@ -763,6 +772,7 @@ jobs: uses: hendrikmuhs/ccache-action@v1.2.16 with: key: windows-latest-cmake-${{ matrix.build }} + variant: sccache evict-old-files: 1d - name: Clone Kompute submodule @@ -804,21 +814,19 @@ jobs: run: | git clone https://github.com/KhronosGroup/OpenCL-Headers cd OpenCL-Headers - mkdir build && cd build - cmake .. ` + cmake -B build ` -DBUILD_TESTING=OFF ` -DOPENCL_HEADERS_BUILD_TESTING=OFF ` -DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF ` -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" - cmake --build . --target install + cmake --build build --target install git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader cd OpenCL-ICD-Loader - mkdir build-arm64-release && cd build-arm64-release - cmake .. ` + cmake -B build-arm64-release ` -A arm64 ` -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" ` -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" - cmake --build . --target install --config release + cmake --build build-arm64-release --target install --config release - name: Build id: cmake_build @@ -1026,6 +1034,7 @@ jobs: uses: hendrikmuhs/ccache-action@v1.2.16 with: key: ${{ github.job }}-${{ matrix.cuda }}-${{ matrix.build }} + variant: sccache evict-old-files: 1d - name: Install Cuda Toolkit 11.7 @@ -1167,6 +1176,7 @@ jobs: uses: hendrikmuhs/ccache-action@v1.2.16 with: key: windows-latest-cmake-sycl + variant: sccache evict-old-files: 1d - name: Install @@ -1355,9 +1365,7 @@ jobs: id: cmake_build run: | sysctl -a - mkdir build - cd build - cmake -G Xcode .. \ + cmake -B build -G Xcode \ -DGGML_METAL_USE_BF16=ON \ -DGGML_METAL_EMBED_LIBRARY=ON \ -DLLAMA_BUILD_EXAMPLES=OFF \ @@ -1366,8 +1374,8 @@ jobs: -DCMAKE_SYSTEM_NAME=iOS \ -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml - cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO - sudo cmake --install . --config Release + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + sudo cmake --install build --config Release - name: xcodebuild for swift package id: xcodebuild diff --git a/.github/workflows/close-issue.yml b/.github/workflows/close-issue.yml index f63860d14..276a217d4 100644 --- a/.github/workflows/close-issue.yml +++ b/.github/workflows/close-issue.yml @@ -17,7 +17,7 @@ jobs: steps: - uses: actions/stale@v5 with: - exempt-issue-labels: "refactor,help wanted,good first issue,research,bug" + exempt-issue-labels: "refactor,help wanted,good first issue,research,bug,roadmap" days-before-issue-stale: 30 days-before-issue-close: 14 stale-issue-label: "stale" diff --git a/AUTHORS b/AUTHORS index 2eb60806a..6796b2941 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,4 +1,4 @@ -# date: Thu Nov 28 20:46:15 EET 2024 +# date: Tue Feb 4 13:04:05 EET 2025 # this file is auto-generated by scripts/gen-authors.sh 0cc4m @@ -20,6 +20,8 @@ Adithya Balaji AdithyanI Adrian Adrian Hesketh +Adrien Gallouët +Adrien Gallouët Ahmad Tameem <113388789+Tameem-10xE@users.noreply.github.com> Ahmet Zeer AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com> @@ -55,6 +57,7 @@ Ananta Bastola Anas Ahouzi <112881240+aahouzi@users.noreply.github.com> András Salamon Andreas (Andi) Kunar +Andreas Kieslinger <47689530+aendk@users.noreply.github.com> Andrei Andrew Canis Andrew Downing @@ -91,13 +94,17 @@ Ben Siraphob Ben Williams Benjamin Findley <39356821+Kartoffelsaft@users.noreply.github.com> Benjamin Lecaillon <84293038+blecaillon@users.noreply.github.com> +Benson Wong Bernat Vadell +Bernhard M. Wiedemann Bert Wagner +Billel Mokeddem Bingan <70050083+binganao@users.noreply.github.com> Bjarke Viksøe <164612031+bviksoe@users.noreply.github.com> Bodo Graumann Bono Lv Borislav Stanimirov +Borislav Stanimirov Branden Butler Brandon Squizzato <35474886+bsquizz@users.noreply.github.com> Brian @@ -117,6 +124,7 @@ Casey Primozic Casey Primozic CausalLM <148736309+CausalLM@users.noreply.github.com> Cebtenzzre +CentricStorm Chad Brewbaker Changyeon Kim Chao Jiang @@ -131,12 +139,15 @@ Chris Kuehl Christian Demsar Christian Demsar Christian Falch <875252+chrfalch@users.noreply.github.com> +Christian Kastner Christian Kögler Christian Köhnenkamp Christian Zhou-Zheng <59622928+christianazinn@users.noreply.github.com> +Christopher Nielsen <62156882+mascguy@users.noreply.github.com> Clark Saben <76020733+csaben@users.noreply.github.com> Clint Herron Conrad Kramer +Corentin REGAL CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> Csaba Kecskemeti Cuong Trinh Manh @@ -176,6 +187,7 @@ Dibakar Gope Didzis Gosko Diego Devesa Diogo Teles Sant'Anna +Djip007 <3705339+Djip007@users.noreply.github.com> Djip007 Don Mahurin DooWoong Lee (David) @@ -193,6 +205,7 @@ Edward Taylor Elaine Elbios <141279586+Elbios@users.noreply.github.com> Elton Kola +Emreerdog <34742675+Emreerdog@users.noreply.github.com> Engininja2 <139037756+Engininja2@users.noreply.github.com> Equim Eric Curtin @@ -233,6 +246,7 @@ Fred Douglas <43351173+fredlas@users.noreply.github.com> Frederik Vogel Gabe Goodhart Gabe Goodhart +Gaetan Bisson GainLee Galunid Gary Linscott @@ -249,6 +263,7 @@ Guillaume "Vermeille" Sanchez Guillaume Wenzek Guoliang Hua <32868157+nbcsm@users.noreply.github.com> Guoteng <32697156+SolenoidWGT@users.noreply.github.com> +Guspan Tanadi <36249910+guspan-tanadi@users.noreply.github.com> Gustavo Rocha Dias <91472747+gustrd@users.noreply.github.com> Haggai Nuchi Halalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com> @@ -259,11 +274,13 @@ Haoxiang Fei Harald Fernengel Hatsune Miku <129688334+at8u@users.noreply.github.com> HatsuneMikuUwU33 <173229399+HatsuneMikuUwU33@users.noreply.github.com> +Haus1 Henk Poley Henri Vasserman Henrik Forstén Herman Semenov Hesen Peng +HimariO Hoang Nguyen Hong Bo PENG Hongyu Ouyang <96765450+casavaca@users.noreply.github.com> @@ -280,6 +297,7 @@ Icecream95 Ido S IgnacioFDM Igor Okulist +Ihar Hrachyshka Ikko Eltociear Ashimine Ilya Kurdyukov <59548320+ilyakurdyukov@users.noreply.github.com> Ionoclast Laboratories @@ -289,12 +307,14 @@ Ivan Ivan Filipov <159561759+vanaka11@users.noreply.github.com> Ivan Komarov Ivan Stepanov +JFLFY2255 JH23X <165871467+JH23X@users.noreply.github.com> Jack Mousseau Jack Mousseau JackJollimore <130917767+JackJollimore@users.noreply.github.com> Jaeden Amero Jaemin Son +Jafar Uruç Jag Chadha Jakub N James A Capozzoli <157492257+jac-jim@users.noreply.github.com> @@ -315,6 +335,7 @@ Jeffrey Morgan Jeffrey Quesnelle Jeroen Mostert Jesse Jojo Johnson +Jett Janiak Jeximo Jhen-Jie Hong Jiahao Li @@ -343,6 +364,7 @@ Josh Ramer Joyce Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Judd +Juk Armstrong <69222624+jukofyork@users.noreply.github.com> Julius Arkenberg Jun Hee Yoo Jun Jie <71215065+junnjiee16@users.noreply.github.com> @@ -357,6 +379,7 @@ Justine Tunney Juuso Alasuutari KASR Kamil Tomšík +Karol Kontny <82021046+kkontny@users.noreply.github.com> Karsten Weiss Karthick Karthik Kumar Viswanathan <195178+guilt@users.noreply.github.com> @@ -376,6 +399,7 @@ Kolen Cheung Konstantin Herud Konstantin Zhuravlyov Kunshang Ji +Kyle Bruene Kyle Liang Kyle Mistele Kylin <56434533+KyL0N@users.noreply.github.com> @@ -394,6 +418,7 @@ Liu Jia LoganDark Loïc Carrère LostRuins <39025047+LostRuins@users.noreply.github.com> +LostRuins Concedo <39025047+LostRuins@users.noreply.github.com> Luciano Luo Tian Lyle Dean @@ -423,6 +448,7 @@ MasterYi1024 <39848311+MasterYi1024@users.noreply.github.com> Mateusz Charytoniuk Matheus C. França Matheus Gabriel Alves Silva +Mathieu Baudier Mathieu Geli Mathieu Nayrolles Mathijs Henquet @@ -444,6 +470,7 @@ Meng, Hengyu Mengqing Cao Merrick Christensen Michael Coppola +Michael Engel Michael Francis Michael Hueschen Michael Kesper @@ -452,7 +479,9 @@ Michael Podvitskiy Michael Potter Michael de Gans Michaël de Vries +Michał Moskal Michał Tuszyński +Michelle Tan <41475767+MichelleTanPY@users.noreply.github.com> Mihai Mike Mikko Juola @@ -477,6 +506,7 @@ Neo Zhang <14088817+arthw@users.noreply.github.com> Neo Zhang Neo Zhang Jianyu Neuman Vong +NeverLucky <92274250+nvrxq@users.noreply.github.com> Nexes the Old <124105151+Nexesenex@users.noreply.github.com> Nexesenex <124105151+Nexesenex@users.noreply.github.com> Niall Coates <1349685+Niall-@users.noreply.github.com> @@ -484,11 +514,15 @@ Nicholai Tukanov Nico Bosshard Nicolai Weitkemper Nicolás Pérez +Nicolò Scipione Nigel Bosch +Nikita Sarychev <42014488+sARY77@users.noreply.github.com> Niklas Korz NikolaiLyssogor <59844691+NikolaiLyssogor@users.noreply.github.com> +Nikolaos Pothitos Nikolas <127742645+nneubacher@users.noreply.github.com> Nindaleth +Nuno OSecret <135510162+OLSecret@users.noreply.github.com> Oleksandr Nikitin Oleksii Maryshchenko @@ -504,6 +538,7 @@ Pavel Zloi Pavol Rusnak Paweł Wodnicki <151604+32bitmicro@users.noreply.github.com> Pedro Cuenca +Peter Peter Sugihara Phil H <5756783+phiharri@users.noreply.github.com> Philip Taron @@ -529,9 +564,12 @@ Rand Xie Randall Fitzgerald Random Fly Reinforce-II +Rémy Oudompheng Ren Xuancheng Rene Leonhardt <65483435+reneleonhardt@users.noreply.github.com> +Reza Kakhki RhinoDevel +Riccardo Orlando Riceball LEE Rich Dougherty Richard Kiss @@ -544,6 +582,8 @@ Riley Stewart Rinne Rinne Robert Brisita <986796+rbrisita@users.noreply.github.com> +Robert Collins +Robert Ormandi <52251610+ormandi@users.noreply.github.com> Robert Sung-wook Shin Robey Holderith Robyn @@ -559,7 +599,9 @@ Roni Ronny Brendel Ronsor Rowan Hart +Ruan <47767371+ruanych@users.noreply.github.com> Ruchira Hasaranga +Rudi Servo Ruixin Huang <18860020911@163.com> Rune <43761327+Rune-AI@users.noreply.github.com> RunningLeon @@ -623,12 +665,14 @@ Steven Roussey Steward Garcia <57494570+FSSRepo@users.noreply.github.com> StrangeBytesDev <141275258+StrangeBytesDev@users.noreply.github.com> Suaj Carrot <72162667+SuajCarrot@users.noreply.github.com> +Sukriti Sharma SuperUserNameMan Sutou Kouhei Tai Duc Nguyen Taikono-Himazin Tameem <113388789+AhmadTameem@users.noreply.github.com> Tamotsu Takahashi +Tei Home Thái Hoàng Tâm <75922889+RoyalHeart@users.noreply.github.com> Thatcher Chamberlin Theia Vogel @@ -640,6 +684,7 @@ Tim Miller Tim Wang Timmy Knight Timothy Cronin <40186632+4imothy@users.noreply.github.com> +Ting Lou Ting Lou Ting Sun Tobias Lütke @@ -661,6 +706,7 @@ Uzo Nweke Vaibhav Srivastav Val Kharitonov Valentin Konovalov +Valentin Mamedov <45292985+Inf1delis@users.noreply.github.com> Valentyn Bezshapkin <61702053+valentynbez@users.noreply.github.com> Vali Malinoiu <0x4139@gmail.com> Victor Nogueira @@ -673,13 +719,17 @@ Vladimir Malyutin Vladimir Zorin VoidIsVoid <343750470@qq.com> Volodymyr Vitvitskyi <72226+signalpillar@users.noreply.github.com> +Wang Qin <37098874+wangqin0@users.noreply.github.com> +Wang Ran (汪然) WangHaoranRobin <56047610+WangHaoranRobin@users.noreply.github.com> Weird Constructor Welby Seely Wentai Zhang WillCorticesAI <150854901+WillCorticesAI@users.noreply.github.com> William Tambellini +William Tambellini Willy Tarreau +Woof Dog <197125663+woof-dog@users.noreply.github.com> Wouter <9594229+DifferentialityDevelopment@users.noreply.github.com> Wu Jian Ping Wu Jian Ping @@ -692,6 +742,7 @@ Xie Yanbo Xingchen Song(宋星辰) Xinpeng Dou <81913537+Dou-Git@users.noreply.github.com> Xuan Son Nguyen +Xuan-Son Nguyen Yaiko Yann Follet <131855179+YannFollet@users.noreply.github.com> Yaroslav @@ -702,7 +753,9 @@ Yoshi Suhara Yoshi Suhara Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Yueh-Po Peng <94939112+y10ab1@users.noreply.github.com> +Yüg Yui +Yun Dou Yuri Khrustalev Yusuf Kağan Hanoğlu Yuval Peled <31162840+Yuval-Peled@users.noreply.github.com> @@ -714,18 +767,23 @@ Zhang Peiyuan Zheng.Deng <32841220+dengzheng-cloud@users.noreply.github.com> Zhenwei Jin <109658203+kylo5aby@users.noreply.github.com> Zhiyuan Li +Zhiyuan Li ZhouYuChen Ziad Ben Hadj-Alouane Ziang Wu <97337387+ZiangWu-77@users.noreply.github.com> Zsapi a-n-n-a-l-e-e <150648636+a-n-n-a-l-e-e@users.noreply.github.com> +a3sh <38979186+A3shTnT@users.noreply.github.com> adel boussaken afrideva <95653597+afrideva@users.noreply.github.com> +ag2s20150909 <19373730+ag2s20150909@users.noreply.github.com> agray3 akawrykow <142945436+akawrykow@users.noreply.github.com> +alek3y <44779186+alek3y@users.noreply.github.com> alexpinel <93524949+alexpinel@users.noreply.github.com> alonfaraj alwqx +amd-dwang amd-lalithnc amritahs-ibm andrijdavid @@ -737,6 +795,7 @@ arch-btw <57669023+arch-btw@users.noreply.github.com> arcrank ardfork <134447697+ardfork@users.noreply.github.com> arlo-phoenix <140345165+arlo-phoenix@users.noreply.github.com> +aryantandon01 <80969509+aryantandon01@users.noreply.github.com> at8u <129688334+at8u@users.noreply.github.com> automaticcat awatuna <23447591+awatuna@users.noreply.github.com> @@ -751,12 +810,14 @@ bryanSwk <93190252+bryanSwk@users.noreply.github.com> bsilvereagle bssrdf byte-6174 <88070277+byte-6174@users.noreply.github.com> +cduk <19917266+cduk@users.noreply.github.com> cebtenzzre chaihahaha chiranko <96988916+chiranko@users.noreply.github.com> clibdev <52199778+clibdev@users.noreply.github.com> clyang cocktailpeanut <121128867+cocktailpeanut@users.noreply.github.com> +codezjx coezbek comex compilade <113953597+compilade@users.noreply.github.com> @@ -780,14 +841,17 @@ drbh ds5t5 <145942675+ds5t5@users.noreply.github.com> dylan eastriver +ebraminio ebraminio eiery <19350831+eiery@users.noreply.github.com> eric8607242 fairydreaming <166155368+fairydreaming@users.noreply.github.com> fengerhu1 <2748250768@qq.com> +fj-y-saito <85871716+fj-y-saito@users.noreply.github.com> fraxy-v <65565042+fraxy-v@users.noreply.github.com> github-actions[bot] gliptic +gn64 goerch grahameth <96447521+grahameth@users.noreply.github.com> gtygo @@ -812,10 +876,12 @@ icppWorld <124377669+icppWorld@users.noreply.github.com> igarnier intelmatt <61025942+intelmatt@users.noreply.github.com> iohub +issixx <46835150+issixx@users.noreply.github.com> jacobi petrucciani <8117202+jpetrucciani@users.noreply.github.com> jaime-m-p <167997752+jaime-m-p@users.noreply.github.com> jameswu2014 <545426914@qq.com> jdomke <28772296+jdomke@users.noreply.github.com> +jiahao su jiez <373447296@qq.com> jneem joecryptotoo <80373433+joecryptotoo@users.noreply.github.com> @@ -828,6 +894,7 @@ junchao-loongson <68935141+junchao-loongson@users.noreply.github.com> jwj7140 <32943891+jwj7140@users.noreply.github.com> k.h.lai kaizau +kallewoof kalomaze <66376113+kalomaze@users.noreply.github.com> kang katsu560 <118887472+katsu560@users.noreply.github.com> @@ -835,6 +902,7 @@ kchro3 <62481661+kchro3@users.noreply.github.com> khimaros kiltyj klosax <131523366+klosax@users.noreply.github.com> +krystiancha kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> kunnis kuronekosaiko @@ -847,6 +915,8 @@ ldwang le.chang leejet leo-pony +lexasub +lhez limitedAtonement liuwei-git <14815172+liuwei-git@users.noreply.github.com> lon <114724657+longregen@users.noreply.github.com> @@ -855,10 +925,13 @@ ltoniazzi <61414566+ltoniazzi@users.noreply.github.com> luoyu-intel m3ndax maddes8cht <55592906+maddes8cht@users.noreply.github.com> +mahorozte <41834471+mahorozte@users.noreply.github.com> makomk manikbhandari maor-ps <154728172+maor-ps@users.noreply.github.com> +mashdragon <122402293+mashdragon@users.noreply.github.com> matiaslin <45382001+matiaslin@users.noreply.github.com> +matt23654 matteo mdrokz mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com> @@ -868,6 +941,7 @@ mmyjona momonga <115213907+mmnga@users.noreply.github.com> momonga <146910567+mmngays@users.noreply.github.com> moritzbrantner <31051084+moritzbrantner@users.noreply.github.com> +musoles <135031143+musoles@users.noreply.github.com> mzcu nanahi <130121847+na-na-hi@users.noreply.github.com> ngc92 <7938269+ngc92@users.noreply.github.com> @@ -885,6 +959,7 @@ oobabooga <112222186+oobabooga@users.noreply.github.com> opparco ostix360 <55257054+ostix360@users.noreply.github.com> pculliton +peidaqi pengxin99 perserk piDack <104877312+piDack@users.noreply.github.com> @@ -892,10 +967,12 @@ pmysl postmasters pudepiedj qingfengfenga <41416092+qingfengfenga@users.noreply.github.com> +qingy1337 qouoq qunash rabidcopy rankaiyx +redbeard rhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com> rhuddleston rimoliga <53384203+rimoliga@users.noreply.github.com> @@ -912,6 +989,7 @@ sjxx <63994076+ylsdamxssjxxdd@users.noreply.github.com> slaren <2141330+slaren@users.noreply.github.com> slaren snadampal <87143774+snadampal@users.noreply.github.com> +someone13574 <81528246+someone13574@users.noreply.github.com> standby24x7 staviq stduhpf @@ -931,6 +1009,7 @@ uint256_t uint256_t unbounded uvos +uvos valiray <133289098+valiray@users.noreply.github.com> vb vik @@ -951,6 +1030,7 @@ xaedes xctan xloem <0xloem@gmail.com> yangli2 +ymcki <84055651+ymcki@users.noreply.github.com> yuiseki yuri@FreeBSD zakkor @@ -963,4 +1043,5 @@ zrm 杨朱 · Kiki 源文雨 <41315874+fumiama@users.noreply.github.com> 蕭澧邦 <45505768+shou692199@users.noreply.github.com> +谢乃闻 Нияз Гарифзянов <112617865+garrnizon@users.noreply.github.com> diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c62d1788..74b48d24d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -80,6 +80,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE}) # 3rd party libs option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) +option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF) # Required for relocatable CMake package include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) diff --git a/Makefile b/Makefile index ef152d246..dc3de3cb1 100644 --- a/Makefile +++ b/Makefile @@ -596,7 +596,7 @@ ifdef GGML_RPC OBJ_GGML_EXT += ggml/src/ggml-rpc.o endif # GGML_RPC -OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-wmma*.cu)) +OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-mma*.cu)) OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu)) ifdef GGML_CUDA_FA_ALL_QUANTS diff --git a/README.md b/README.md index d40309875..d68330d2a 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [x] [Bitnet b1.58 models](https://huggingface.co/1bitLLM) - [x] [Flan T5](https://huggingface.co/models?search=flan-t5) - [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca) -- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) +- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) + [GLMEdge-1.5b](https://huggingface.co/THUDM/glm-edge-1.5b-chat) + [GLMEdge-4b](https://huggingface.co/THUDM/glm-edge-4b-chat) - [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966) - [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct) - [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a) @@ -117,6 +117,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [x] [Mini CPM](https://huggingface.co/models?search=MiniCPM) - [x] [Moondream](https://huggingface.co/vikhyatk/moondream2) - [x] [Bunny](https://github.com/BAAI-DCAI/Bunny) +- [x] [GLM-EDGE](https://huggingface.co/models?search=glm-edge) - [x] [Qwen2-VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d) @@ -135,6 +136,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - Rust (more features): [edgenai/llama_cpp-rs](https://github.com/edgenai/llama_cpp-rs) - Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp) - Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs) +- Rust (automated build from crates.io): [ShelbyJenkins/llm_client](https://github.com/ShelbyJenkins/llm_client) - C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp) - C#/VB.NET (more features - community license): [LM-Kit.NET](https://docs.lm-kit.com/lm-kit-net/index.html) - Scala 3: [donderom/llm4s](https://github.com/donderom/llm4s) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 72f0915c1..e61015d2a 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -65,6 +65,7 @@ add_library(${TARGET} STATIC console.h json-schema-to-grammar.cpp json.hpp + llguidance.cpp log.cpp log.h minja.hpp @@ -91,6 +92,33 @@ if (LLAMA_CURL) set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY}) endif () +if (LLAMA_LLGUIDANCE) + include(ExternalProject) + set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source) + set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release) + ExternalProject_Add(llguidance_ext + GIT_REPOSITORY https://github.com/guidance-ai/llguidance + # v0.6.12: + GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09 + PREFIX ${CMAKE_BINARY_DIR}/llguidance + SOURCE_DIR ${LLGUIDANCE_SRC} + BUILD_IN_SOURCE TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND cargo build --release + INSTALL_COMMAND "" + BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/libllguidance.a ${LLGUIDANCE_PATH}/llguidance.h + UPDATE_COMMAND "" + ) + target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE) + + add_library(llguidance STATIC IMPORTED) + set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/libllguidance.a) + add_dependencies(llguidance llguidance_ext) + + target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH}) + set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance) +endif () + target_include_directories(${TARGET} PUBLIC .) target_compile_features (${TARGET} PUBLIC cxx_std_17) target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads) diff --git a/common/chat.cpp b/common/chat.cpp index 58db12af9..4a113c0ca 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -16,6 +16,7 @@ std::string common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; + case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; default: throw std::runtime_error("Unknown chat format"); } @@ -317,6 +318,79 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); } +static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + schemas.push_back({ + {"type", "object"}, + {"properties", { + {"tool_call_id", { + {"type", "string"}, + // Command-R's template expects an integer string. + {"pattern", "^[0-9]{1,10}$"}, + }}, + {"tool_name", { + {"type", "string"}, + {"const", function["name"]}, + }}, + {"parameters", function["parameters"]}, + }}, + {"required", json::array({"tool_call_id", "tool_name", "parameters"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\""); + }, grammar_options); + data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false}); + data.preserved_tokens = { + "<|START_RESPONSE|>", + "<|END_RESPONSE|>", + "<|START_THINKING|>", + "<|END_THINKING|>", + "<|END_ACTION|>", + }; + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; + return data; +} +static common_chat_msg common_chat_parse_command_r7b(const std::string & input) { + static std::regex response_regex("<\\|START_RESPONSE\\|>([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>"); + static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>"); + std::smatch match; + + common_chat_msg result; + result.role = "assistant"; + if (std::regex_match(input, match, response_regex)) { + result.content = match[1].str(); + } else if (std::regex_match(input, match, thought_action_regex)) { + result.tool_plan = match[1].str(); + auto actions_str = match[2].str(); + auto actions = json::parse(actions_str); + for (const auto & action : actions) { + result.tool_calls.push_back({ + /* .name = */ action["tool_name"], + /* .arguments = */ action["parameters"].dump(), + /* .id = */ action["tool_call_id"], + }); + } + } else { + LOG_ERR("Failed to parse command_r output"); + result.content = input; + } + return result; +} + static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) { throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); @@ -462,6 +536,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); }); data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); + data.preserved_tokens = { + "<|tool▁sep|>", + "<|tool▁call▁end|>", + }; builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space"); }, grammar_options); data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); @@ -704,8 +782,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat auto tool_call = "\"\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"\" space"; builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); data.grammar_triggers.push_back({"", /* .at_start = */ false}); - // Not really a trigger but need to print this special token to get a successful parse. - data.grammar_triggers.push_back({"", /* .at_start = */ false}); + data.preserved_tokens = { "" }; }, grammar_options); data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); @@ -822,6 +899,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co if (src.find("[TOOL_CALLS]") != std::string::npos) { return common_chat_params_init_mistral_nemo(tmpl, inputs); } + if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) { + return common_chat_params_init_command_r7b(tmpl, inputs); + } return common_chat_params_init_generic(tmpl, inputs); } @@ -855,6 +935,8 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format return common_chat_parse_hermes_2_pro(input); case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return common_chat_parse_firefunction_v2(input); + case COMMON_CHAT_FORMAT_COMMAND_R7B: + return common_chat_parse_command_r7b(input); default: throw std::runtime_error("Unsupported format: " + common_chat_format_name(format)); } diff --git a/common/chat.hpp b/common/chat.hpp index ca165aa13..33e64a430 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -32,6 +32,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_HERMES_2_PRO, + COMMON_CHAT_FORMAT_COMMAND_R7B, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; @@ -42,6 +43,7 @@ struct common_chat_params { std::string grammar; bool grammar_lazy = false; std::vector grammar_triggers; + std::vector preserved_tokens; std::vector additional_stops; }; diff --git a/common/common.cpp b/common/common.cpp index 6c81d18f9..edba6fb4b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1869,11 +1869,19 @@ std::string common_chat_format_example(const common_chat_template & tmpl, bool u return common_chat_apply_template(tmpl, msgs, true, use_jinja); } +#define CHATML_TEMPLATE_SRC \ + "{%- for message in messages -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%}" + common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) { - auto vocab = llama_model_get_vocab(model); - std::string default_template_src = chat_template_override; - std::string template_tool_use_src = chat_template_override; + std::string default_template_src; + std::string template_tool_use_src; + bool has_explicit_template = !chat_template_override.empty(); if (chat_template_override.empty()) { auto str = llama_model_chat_template(model, /* name */ nullptr); @@ -1886,21 +1894,21 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model template_tool_use_src = str; has_explicit_template = true; } + } else { + default_template_src = chat_template_override; } if (default_template_src.empty() || default_template_src == "chatml") { if (!template_tool_use_src.empty()) { default_template_src = template_tool_use_src; } else { - default_template_src = R"( - {%- for message in messages -%} - {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} - {%- endfor -%} - {%- if add_generation_prompt -%} - {{- "<|im_start|>assistant\n" -}} - {%- endif -%} - )"; + default_template_src = CHATML_TEMPLATE_SRC; } } + std::string token_bos; + std::string token_eos; + // TODO: update logic that adds BOS and EOS tokens to the tokenized prompt, in favour of the template. +#if 0 + auto vocab = llama_model_get_vocab(model); const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { if (token == LLAMA_TOKEN_NULL) { if (default_template_src.find(jinja_variable_name) != std::string::npos @@ -1912,15 +1920,25 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model return common_token_to_piece(vocab, token, true); } }; - auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); - auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); - return { - has_explicit_template, - std::make_unique(default_template_src, token_bos, token_eos), - template_tool_use_src.empty() - ? nullptr - : std::make_unique(template_tool_use_src, token_bos, token_eos) - }; + token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); + token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); +#endif + try { + return { + has_explicit_template, + std::make_unique(default_template_src, token_bos, token_eos), + template_tool_use_src.empty() + ? nullptr + : std::make_unique(template_tool_use_src, token_bos, token_eos), + }; + } catch (const std::exception & e) { + LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what()); + return { + has_explicit_template, + std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos), + nullptr, + }; + } } // diff --git a/common/common.h b/common/common.h index 6c1809277..b208d0c7e 100644 --- a/common/common.h +++ b/common/common.h @@ -4,6 +4,7 @@ #include "llama-cpp.h" +#include #include #include #include @@ -163,6 +164,7 @@ struct common_params_sampling { bool grammar_lazy = false; std::vector grammar_trigger_words; // optional trigger words to trigger lazy grammar std::vector grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens. + std::set preserved_tokens; std::vector logit_bias; // logit biases to apply @@ -621,6 +623,7 @@ struct common_chat_msg { std::string role; std::string content; std::vector tool_calls; + std::string tool_plan = ""; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 1f47e313e..3ebcc3d9f 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -991,7 +991,14 @@ public: } }; -std::string json_schema_to_grammar(const json & schema) { +std::string json_schema_to_grammar(const json & schema, bool force_gbnf) { +#ifdef LLAMA_USE_LLGUIDANCE + if (!force_gbnf) { + return "%llguidance {}\nstart: %json " + schema.dump(); + } +#else + (void)force_gbnf; +#endif // LLAMA_USE_LLGUIDANCE return build_grammar([&](const common_grammar_builder & callbacks) { auto copy = schema; callbacks.resolve_refs(copy); diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index ba4112cb9..62a3b0a44 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -5,7 +5,8 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -std::string json_schema_to_grammar(const nlohmann::ordered_json & schema); +std::string json_schema_to_grammar(const nlohmann::ordered_json & schema, + bool force_gbnf = false); struct common_grammar_builder { std::function add_rule; diff --git a/common/llguidance.cpp b/common/llguidance.cpp new file mode 100644 index 000000000..7aa8ddd80 --- /dev/null +++ b/common/llguidance.cpp @@ -0,0 +1,270 @@ +#include "sampling.h" +#include "log.h" + +#ifdef LLAMA_USE_LLGUIDANCE + +# include "llguidance.h" +# include + +struct llama_sampler_llg { + const llama_vocab * vocab; + std::string grammar_kind; + std::string grammar_data; + LlgTokenizer * tokenizer; + LlgConstraint * grammar; + LlgMaskResult llg_res; + bool has_llg_res; +}; + +static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind, + const char * grammar_data) { + LlgConstraintInit cinit; + llg_constraint_init_set_defaults(&cinit, tokenizer); + const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL"); + if (log_level && *log_level) { + cinit.log_stderr_level = atoi(log_level); + } + auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data); + if (llg_get_error(c)) { + LOG_ERR("llg error: %s\n", llg_get_error(c)); + llg_free_constraint(c); + return nullptr; + } + return c; +} + +static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) { + return "llguidance"; +} + +static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_llg *) smpl->ctx; + if (ctx->grammar) { + LlgCommitResult res; + llg_commit_token(ctx->grammar, token, &res); + ctx->has_llg_res = false; + } +} + +static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_llg *) smpl->ctx; + if (ctx->grammar) { + if (!ctx->has_llg_res) { + if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) { + ctx->has_llg_res = true; + } else { + LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar)); + llg_free_constraint(ctx->grammar); + ctx->grammar = nullptr; + } + } + if (ctx->has_llg_res) { + if (ctx->llg_res.is_stop) { + for (size_t i = 0; i < cur_p->size; ++i) { + if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) { + cur_p->data[i].logit = -INFINITY; + } + } + } else { + const uint32_t * mask = ctx->llg_res.sample_mask; + for (size_t i = 0; i < cur_p->size; ++i) { + auto token = cur_p->data[i].id; + if ((mask[token / 32] & (1 << (token % 32))) == 0) { + cur_p->data[i].logit = -INFINITY; + } + } + } + } + } +} + +static void llama_sampler_llg_reset(llama_sampler * smpl) { + auto * ctx = (llama_sampler_llg *) smpl->ctx; + if (!ctx->grammar) { + return; + } + + auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str()); + llg_free_constraint(ctx->grammar); + ctx->grammar = grammar_new; + ctx->has_llg_res = false; +} + +static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_llg *) smpl->ctx; + + auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr); + + // copy the state + { + auto * result_ctx = (llama_sampler_llg *) result->ctx; + + if (ctx->grammar) { + result_ctx->grammar_kind = ctx->grammar_kind; + result_ctx->grammar_data = ctx->grammar_data; + result_ctx->grammar = llg_clone_constraint(ctx->grammar); + result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer); + } + } + + return result; +} + +static void llama_sampler_llg_free(llama_sampler * smpl) { + const auto * ctx = (llama_sampler_llg *) smpl->ctx; + + if (ctx->grammar) { + llg_free_constraint(ctx->grammar); + llg_free_tokenizer(ctx->tokenizer); + } + + delete ctx; +} + +static llama_sampler_i llama_sampler_llg_i = { + /* .name = */ llama_sampler_llg_name, + /* .accept = */ llama_sampler_llg_accept_impl, + /* .apply = */ llama_sampler_llg_apply, + /* .reset = */ llama_sampler_llg_reset, + /* .clone = */ llama_sampler_llg_clone, + /* .free = */ llama_sampler_llg_free, +}; + +static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len, + uint32_t * output_tokens, size_t output_tokens_len) { + const llama_vocab * vocab = (const llama_vocab *) user_data; + int r = 0; + try { + r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false, + true); + } catch (const std::exception & e) { + GGML_ABORT("llama_tokenize failed: %s\n", e.what()); + } + if (r < 0) { + return -r; + } + return r; +} + +static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) { + // TODO store the tokenizer in the vocab somehow + static const llama_vocab * vocab_cache; + static LlgTokenizer * tokenizer_cache; + + if (vocab_cache == vocab) { + return llg_clone_tokenizer(tokenizer_cache); + } + + auto tok_eos = llama_vocab_eot(vocab); + if (tok_eos == LLAMA_TOKEN_NULL) { + tok_eos = llama_vocab_eos(vocab); + } + + size_t vocab_size = llama_vocab_n_tokens(vocab); + + auto token_lens = new uint32_t[vocab_size]; + // we typically have ~7 bytes per token; let's go on the safe side here + auto token_bytes_size = vocab_size * 16 + 1024 * 1024; + auto token_bytes = new uint8_t[token_bytes_size]; + + size_t offset = 0; + for (size_t i = 0; i < vocab_size; i++) { + size_t max_token = 1024; + if (token_bytes_size - offset < max_token) { + GGML_ABORT("token_bytes buffer too small\n"); + } + + llama_token token = i; + auto dp = (char *) token_bytes + offset; + auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false); + if (size < 0) { + GGML_ABORT("llama_detokenize failed\n"); + } + if (size == 0) { + size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true); + if (size < 0) { + GGML_ABORT("llama_detokenize failed\n"); + } + if (size != 0) { + *dp = '\xff'; // special token prefix marker + size += 1; + } + } + + token_lens[i] = size; + offset += size; + } + + LlgTokenizerInit tinit = { + /* .vocab_size = */ (uint32_t) vocab_size, + /* .tok_eos = */ (uint32_t) tok_eos, + /* .token_lens = */ token_lens, + /* .token_bytes = */ token_bytes, + /* .tokenizer_json = */ nullptr, + /* .tokenize_assumes_string = */ true, + /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn, + /* .use_approximate_greedy_tokenize_fn = */ false, + /* .tokenize_user_data = */ vocab, + }; + + char error_buffer[1024]; + LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer)); + + delete[] token_bytes; + delete[] token_lens; + + if (tokenizer == nullptr) { + LOG_ERR("llg tokenizer error: %s\n", error_buffer); + return tokenizer; + } + + if (tokenizer_cache) { + llg_free_tokenizer(tokenizer_cache); + } + vocab_cache = vocab; + tokenizer_cache = tokenizer; + + return llg_clone_tokenizer(tokenizer_cache); +} + +llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, + const char * grammar_data) { + auto * ctx = new llama_sampler_llg; + + if (grammar_kind != nullptr && grammar_kind[0] != '\0') { + auto tokenizer = llama_sampler_llg_new_tokenizer(vocab); + *ctx = { + /* .vocab = */ vocab, + /* .grammar_kind = */ grammar_kind, + /* .grammar_data = */ grammar_data, + /* .tokenizer = */ tokenizer, + /* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data), + /* .llg_res = */ {}, + /* .has_llg_res = */ false, + }; + } else { + *ctx = { + /* .vocab = */ vocab, + /* .grammar_kind = */ {}, + /* .grammar_data = */ {}, + /* .tokenizer = */ nullptr, + /* .grammar = */ nullptr, + /* .llg_res = */ {}, + /* .has_llg_res = */ false, + }; + } + + return new llama_sampler{ + /* .iface = */ &llama_sampler_llg_i, + /* .ctx = */ ctx, + }; +} + +#else + +llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) { + LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); + return nullptr; +} + +#endif // LLAMA_USE_LLGUIDANCE diff --git a/common/log.cpp b/common/log.cpp index 0b8994ae1..4bfbecf15 100644 --- a/common/log.cpp +++ b/common/log.cpp @@ -14,16 +14,6 @@ void common_log_set_verbosity_thold(int verbosity) { common_log_verbosity_thold = verbosity; } -#define LOG_COL_DEFAULT "\033[0m" -#define LOG_COL_BOLD "\033[1m" -#define LOG_COL_RED "\033[31m" -#define LOG_COL_GREEN "\033[32m" -#define LOG_COL_YELLOW "\033[33m" -#define LOG_COL_BLUE "\033[34m" -#define LOG_COL_MAGENTA "\033[35m" -#define LOG_COL_CYAN "\033[36m" -#define LOG_COL_WHITE "\033[37m" - static int64_t t_us() { return std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); } diff --git a/common/log.h b/common/log.h index 66605cc69..85dd4393b 100644 --- a/common/log.h +++ b/common/log.h @@ -2,6 +2,16 @@ #include "ggml.h" // for ggml_log_level +#define LOG_COL_DEFAULT "\033[0m" +#define LOG_COL_BOLD "\033[1m" +#define LOG_COL_RED "\033[31m" +#define LOG_COL_GREEN "\033[32m" +#define LOG_COL_YELLOW "\033[33m" +#define LOG_COL_BLUE "\033[34m" +#define LOG_COL_MAGENTA "\033[35m" +#define LOG_COL_CYAN "\033[36m" +#define LOG_COL_WHITE "\033[37m" + #ifndef __GNUC__ # define LOG_ATTRIBUTE_FORMAT(...) #elif defined(__MINGW32__) diff --git a/common/minja.hpp b/common/minja.hpp index f0e80fd7c..e77eb69d5 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -693,7 +693,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; class TemplateToken { public: - enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter }; + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue }; static std::string typeToString(Type t) { switch (t) { @@ -714,6 +714,8 @@ public: case Type::EndFilter: return "endfilter"; case Type::Generation: return "generation"; case Type::EndGeneration: return "endgeneration"; + case Type::Break: return "break"; + case Type::Continue: return "continue"; } return "Unknown"; } @@ -815,6 +817,22 @@ struct CommentTemplateToken : public TemplateToken { CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {} }; +enum class LoopControlType { Break, Continue }; + +class LoopControlException : public std::runtime_error { +public: + LoopControlType control_type; + LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {} + LoopControlException(LoopControlType control_type) + : std::runtime_error((control_type == LoopControlType::Continue ? "continue" : "break") + std::string(" outside of a loop")), + control_type(control_type) {} +}; + +struct LoopControlTemplateToken : public TemplateToken { + LoopControlType control_type; + LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {} +}; + class TemplateNode { Location location_; protected: @@ -825,6 +843,12 @@ public: void render(std::ostringstream & out, const std::shared_ptr & context) const { try { do_render(out, context); + } catch (const LoopControlException & e) { + // TODO: make stack creation lazy. Only needed if it was thrown outside of a loop. + std::ostringstream err; + err << e.what(); + if (location_.source) err << error_location_suffix(*location_.source, location_.pos); + throw LoopControlException(err.str(), e.control_type); } catch (const std::exception & e) { std::ostringstream err; err << e.what(); @@ -897,6 +921,15 @@ public: } }; +class LoopControlNode : public TemplateNode { + LoopControlType control_type_; + public: + LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {} + void do_render(std::ostringstream &, const std::shared_ptr &) const override { + throw LoopControlException(control_type_); + } +}; + class ForNode : public TemplateNode { std::vector var_names; std::shared_ptr iterable; @@ -961,7 +994,12 @@ public: loop.set("last", i == (n - 1)); loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value()); loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value()); - body->render(out, loop_context); + try { + body->render(out, loop_context); + } catch (const LoopControlException & e) { + if (e.control_type == LoopControlType::Break) break; + if (e.control_type == LoopControlType::Continue) continue; + } } } }; @@ -2159,7 +2197,7 @@ private: static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); - static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})"); static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})"); @@ -2291,6 +2329,9 @@ private: } else if (keyword == "endfilter") { auto post_space = parseBlockClose(); tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "break" || keyword == "continue") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue)); } else { throw std::runtime_error("Unexpected block: " + keyword); } @@ -2414,6 +2455,8 @@ private: children.emplace_back(std::make_shared(token->location, std::move(filter_token->filter), std::move(body))); } else if (dynamic_cast(token.get())) { // Ignore comments + } else if (auto ctrl_token = dynamic_cast(token.get())) { + children.emplace_back(std::make_shared(token->location, ctrl_token->control_type)); } else if (dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) diff --git a/common/sampling.cpp b/common/sampling.cpp index bc7e49fdb..e4b21ca10 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -156,13 +156,25 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co for (const auto & str : params.grammar_trigger_words) { trigger_words.push_back(str.word.c_str()); } + + struct llama_sampler * grmr; + if (params.grammar.compare(0, 11, "%llguidance") == 0) { +#ifdef LLAMA_USE_LLGUIDANCE + grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); +#else + GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); +#endif // LLAMA_USE_LLGUIDANCE + } else { + grmr = params.grammar_lazy + ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", + trigger_words.data(), trigger_words.size(), + params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()) + : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); + } + auto * result = new common_sampler { /* .params = */ params, - /* .grmr = */ params.grammar_lazy - ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", - trigger_words.data(), trigger_words.size(), - params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()) - : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"), + /* .grmr = */ grmr, /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, diff --git a/common/sampling.h b/common/sampling.h index 348911b18..2064421db 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -102,3 +102,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr); std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names); std::vector common_sampler_types_from_chars(const std::string & chars); + +llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, + const char * grammar_kind, const char * grammar_data); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 63b54a9cf..018a2a588 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -648,7 +648,7 @@ class Model: if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a": # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code res = "jina-v2-code" - if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": + if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b" or chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516": # ref: https://huggingface.co/THUDM/glm-4-9b-chat res = "chatglm-bpe" if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee": @@ -4513,7 +4513,7 @@ class JaisModel(Model): self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias) -@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration") +@Model.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(Model): model_arch = gguf.MODEL_ARCH.CHATGLM @@ -4619,47 +4619,15 @@ class ChatGLMModel(Model): from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) - vocab_size = hparams["padded_vocab_size"] + vocab_size = hparams.get("padded_vocab_size",hparams["vocab_size"]) assert max(tokenizer.get_vocab().values()) < vocab_size - tokpre = self.get_vocab_base_pre(tokenizer) - - merges = [] - vocab = {} - mergeable_ranks = tokenizer.mergeable_ranks - for token, rank in mergeable_ranks.items(): - vocab[ChatGLMModel.token_bytes_to_string(token)] = rank - if len(token) == 1: - continue - merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank) - assert len(merged) >= 2 and len(merged) <= 7 - merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged))) - - # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined - added_vocab = tokenizer.get_added_vocab() - reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()} - - for i in range(vocab_size): - if i not in reverse_vocab: - tokens.append(f"[PAD{i}]") - toktypes.append(gguf.TokenType.UNUSED) - elif reverse_vocab[i] in added_vocab: - tokens.append(reverse_vocab[i]) - if tokenizer.added_tokens_decoder[i].special: - toktypes.append(gguf.TokenType.CONTROL) - else: - toktypes.append(gguf.TokenType.USER_DEFINED) - else: - tokens.append(reverse_vocab[i]) - toktypes.append(gguf.TokenType.NORMAL) - + tokens, toktypes, tokpre = self.get_vocab_base() self.gguf_writer.add_tokenizer_model("gpt2") self.gguf_writer.add_tokenizer_pre(tokpre) self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) - - special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) - special_vocab.merges = merges + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) # only add special tokens when they were not already loaded from config.json special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) @@ -4670,16 +4638,20 @@ class ChatGLMModel(Model): def set_gguf_parameters(self): n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) - n_head_kv = self.hparams.get("multi_query_group_num", n_head) + n_head_kv = self.hparams.get("multi_query_group_num", self.hparams.get("num_key_value_heads", n_head)) self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) self.gguf_writer.add_embedding_length(n_embed) - self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed)) - self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed))) + self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"])) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head_kv) - self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5)) self.gguf_writer.add_file_type(self.ftype) - self.gguf_writer.add_rope_dimension_count(64) + if "attention_dim" in self.hparams: + rope_dim = self.hparams["attention_dim"] + else: + rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) self.gguf_writer.add_add_bos_token(False) rope_freq = 10000 if "rope_ratio" in self.hparams: @@ -4689,7 +4661,7 @@ class ChatGLMModel(Model): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused - if name.endswith(".rotary_pos_emb.inv_freq"): + if name.endswith(".rotary_pos_emb.inv_freq") or name.startswith("model.vision."): return [] name = name.removeprefix("transformer.") diff --git a/docs/llguidance.md b/docs/llguidance.md new file mode 100644 index 000000000..792d20704 --- /dev/null +++ b/docs/llguidance.md @@ -0,0 +1,51 @@ +# LLGuidance Support in llama.cpp + +[LLGuidance](https://github.com/guidance-ai/llguidance) is a library for constrained decoding (also called constrained sampling or structured outputs) for Large Language Models (LLMs). Initially developed as the backend for the [Guidance](https://github.com/guidance-ai/guidance) library, it can also be used independently. + +LLGuidance supports JSON Schemas and arbitrary context-free grammars (CFGs) written in a [variant](https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md) of Lark syntax. It is [very fast](https://github.com/guidance-ai/jsonschemabench/tree/main/maskbench) and has [excellent](https://github.com/guidance-ai/llguidance/blob/main/docs/json_schema.md) JSON Schema coverage but requires the Rust compiler, which complicates the llama.cpp build process. + +## Building + +To enable LLGuidance support, build llama.cpp with the `LLAMA_LLGUIDANCE` option: + +```sh +cmake -B build -DLLAMA_LLGUIDANCE=ON +make -C build -j +``` + +This requires the Rust compiler and the `cargo` tool to be [installed](https://www.rust-lang.org/tools/install). + +## Interface + +There are no new command-line arguments or modifications to `common_params`. When enabled, grammars starting with `%llguidance` are passed to LLGuidance instead of the [current](../grammars/README.md) llama.cpp grammars. Additionally, JSON Schema requests (e.g., using the `-j` argument in `llama-cli`) are also passed to LLGuidance. + +For your existing GBNF grammars, you can use [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/scripts/gbnf_to_lark.py) to convert them to LLGuidance Lark-like format. + +## Performance + +Computing a "token mask" (i.e., the set of allowed tokens) for a llama3 tokenizer with 128k tokens takes, on average, 50μs of single-core CPU time for the [JSON Schema Bench](https://github.com/guidance-ai/jsonschemabench). The p99 time is 0.5ms, and the p100 time is 20ms. These results are due to the lexer/parser split and several [optimizations](https://github.com/guidance-ai/llguidance/blob/main/docs/optimizations.md). + +## JSON Schema + +LLGuidance adheres closely to the JSON Schema specification. For example: + +- `additionalProperties` defaults to `true`, unlike current grammars, though you can set `"additionalProperties": false` if needed. +- any whitespace is allowed. +- The definition order in the `"properties": {}` object is maintained, regardless of whether properties are required (current grammars always puts required properties first). + +Unsupported schemas result in an error message—no keywords are silently ignored. + +## Why Not Reuse GBNF Format? + +GBNF lacks the concept of a lexer. + +Most programming languages, including JSON, use a two-step process: a lexer (built with regular expressions) converts a byte stream into lexemes, which are then processed by a CFG parser. This approach is faster because lexers are cheaper to evaluate, and there is ~10x fewer lexemes than bytes. +LLM tokens often align with lexemes, so the parser is engaged in under 0.5% of tokens, with the lexer handling the rest. + +However, the user has to provide the distinction between lexemes and CFG symbols. In [Lark](https://github.com/lark-parser/lark), lexeme names are uppercase, while CFG symbols are lowercase. +The [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/scripts/gbnf_to_lark.py) can often take care of this automatically. +See [LLGuidance syntax docs](https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#terminals-vs-rules) for more details. + +## Error Handling + +Errors are currently printed to `stderr`, and generation continues. Improved error handling may be added in the future. diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 371917b2e..55c31166c 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -31,6 +31,11 @@ defer { llama_model_free(model) } +guard let vocab = llama_model_get_vocab(model) else { + print("Failed to get vocab") + exit(1) +} + var tokens = tokenize(text: prompt, add_bos: true) let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel) @@ -41,7 +46,7 @@ context_params.n_batch = UInt32(max(n_len, n_parallel)) context_params.n_threads = 8 context_params.n_threads_batch = 8 -let context = llama_new_context_with_model(model, context_params) +let context = llama_init_from_model(model, context_params) guard context != nil else { print("Failed to initialize context") exit(1) @@ -141,7 +146,7 @@ while n_cur <= n_len { let new_token_id = llama_sampler_sample(smpl, context, i_batch[i]) // is it an end of stream? -> mark the stream as finished - if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len { + if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len { i_batch[i] = -1 // print("") if n_parallel > 1 { @@ -207,7 +212,7 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] { let utf8Count = text.utf8.count let n_tokens = utf8Count + (add_bos ? 1 : 0) let tokens = UnsafeMutablePointer.allocate(capacity: n_tokens) - let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false) + let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false) var swiftTokens: [llama_token] = [] for i in 0 ..< tokenCount { swiftTokens.append(tokens[Int(i)]) @@ -218,12 +223,12 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] { private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? { var result = [CChar](repeating: 0, count: 8) - let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count), 0, false) + let nTokens = llama_token_to_piece(vocab, token, &result, Int32(result.count), 0, false) if nTokens < 0 { let actualTokensCount = -Int(nTokens) result = .init(repeating: 0, count: actualTokensCount) let check = llama_token_to_piece( - model, + vocab, token, &result, Int32(result.count), diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 477c3e6f2..ee7141a66 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama actor LlamaContext { private var model: OpaquePointer private var context: OpaquePointer + private var vocab: OpaquePointer private var sampling: UnsafeMutablePointer private var batch: llama_batch private var tokens_list: [llama_token] @@ -47,6 +48,7 @@ actor LlamaContext { self.sampling = llama_sampler_chain_init(sparams) llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4)) llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234)) + vocab = llama_model_get_vocab(model) } deinit { @@ -79,7 +81,7 @@ actor LlamaContext { ctx_params.n_threads = Int32(n_threads) ctx_params.n_threads_batch = Int32(n_threads) - let context = llama_new_context_with_model(model, ctx_params) + let context = llama_init_from_model(model, ctx_params) guard let context else { print("Could not load context!") throw LlamaError.couldNotInitializeContext @@ -151,7 +153,7 @@ actor LlamaContext { new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) - if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len { + if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len { print("\n") is_done = true let new_token_str = String(cString: temporary_invalid_cchars + [0]) @@ -297,7 +299,7 @@ actor LlamaContext { let utf8Count = text.utf8.count let n_tokens = utf8Count + (add_bos ? 1 : 0) + 1 let tokens = UnsafeMutablePointer.allocate(capacity: n_tokens) - let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false) + let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false) var swiftTokens: [llama_token] = [] for i in 0...allocate(capacity: Int(-nTokens)) @@ -324,7 +326,7 @@ actor LlamaContext { defer { newResult.deallocate() } - let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, 0, false) + let nNewTokens = llama_token_to_piece(vocab, token, newResult, -nTokens, 0, false) let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens)) return Array(bufferPointer) } else { diff --git a/examples/llava/README-glmedge.md b/examples/llava/README-glmedge.md new file mode 100644 index 000000000..603d01474 --- /dev/null +++ b/examples/llava/README-glmedge.md @@ -0,0 +1,43 @@ +# GLMV-EDGE + +Currently this implementation supports [glm-edge-v-2b](https://huggingface.co/THUDM/glm-edge-v-2b) and [glm-edge-v-5b](https://huggingface.co/THUDM/glm-edge-v-5b). + +## Usage +Build with cmake or run `make llama-llava-cli` to build it. + +After building, run: `./llama-llava-cli` to see the usage. For example: + +```sh +./llama-llava-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf --image img_path/image.jpg -p "<|system|>\n system prompt <|user|>\n prompt <|assistant|>\n" +``` + +**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so. +**note**: For GPU offloading ensure to use the `-ngl` flag just like usual + +## GGUF conversion + +1. Clone a GLMV-EDGE model ([2B](https://huggingface.co/THUDM/glm-edge-v-2b) or [5B](https://huggingface.co/THUDM/glm-edge-v-5b)). For example: + +```sh +git clone https://huggingface.co/THUDM/glm-edge-v-5b or https://huggingface.co/THUDM/glm-edge-v-2b +``` + +2. Use `glmedge-surgery.py` to split the GLMV-EDGE model to LLM and multimodel projector constituents: + +```sh +python ./examples/llava/glmedge-surgery.py -m ../model_path +``` + +4. Use `glmedge-convert-image-encoder-to-gguf.py` to convert the GLMV-EDGE image encoder to GGUF: + +```sh +python ./examples/llava/glmedge-convert-image-encoder-to-gguf.py -m ../model_path --llava-projector ../model_path/glm.projector --output-dir ../model_path +``` + +5. Use `examples/convert_hf_to_gguf.py` to convert the LLM part of GLMV-EDGE to GGUF: + +```sh +python convert_hf_to_gguf.py ../model_path +``` + +Now both the LLM part and the image encoder are in the `model_path` directory. diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 24073c5a9..7367d44cb 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -102,6 +102,7 @@ static std::string format(const char * fmt, ...) { #define KEY_HAS_VIS_ENC "clip.has_vision_encoder" #define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" #define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" +#define KEY_HAS_GLM_PROJ "clip.has_glm_projector" #define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" #define KEY_USE_GELU "clip.use_gelu" @@ -160,6 +161,15 @@ static std::string format(const char * fmt, ...) { #define TN_MINICPMV_ATTN "resampler.attn.%s.%s" #define TN_MINICPMV_LN "resampler.ln_%s.%s" +#define TN_GLM_ADAPER_CONV "adapter.conv.%s" +#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s" +#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s" +#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s" +#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s" +#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s" +#define TN_GLM_BOI_W "adapter.boi" +#define TN_GLM_EOI_W "adapter.eoi" + enum projector_type { PROJECTOR_TYPE_MLP, @@ -167,6 +177,7 @@ enum projector_type { PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDPV2, PROJECTOR_TYPE_RESAMPLER, + PROJECTOR_TYPE_GLM_EDGE, PROJECTOR_TYPE_MERGER, PROJECTOR_TYPE_UNKNOWN, }; @@ -176,6 +187,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDPV2, "ldpv2"}, { PROJECTOR_TYPE_RESAMPLER, "resampler"}, + { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, }; @@ -500,6 +512,12 @@ struct clip_vision_model { struct ggml_tensor * mm_4_w = NULL; struct ggml_tensor * mm_4_b = NULL; + //GLMV-Edge projection + struct ggml_tensor * mm_model_adapter_conv_w; + struct ggml_tensor * mm_model_adapter_conv_b; + struct ggml_tensor * boi_w; + struct ggml_tensor * eoi_w; + // MobileVLM projection struct ggml_tensor * mm_model_mlp_1_w; struct ggml_tensor * mm_model_mlp_1_b; @@ -560,6 +578,7 @@ struct clip_ctx { bool has_vision_encoder = false; bool has_llava_projector = false; bool has_minicpmv_projector = false; + bool has_glm_projector = false; bool has_qwen2vl_merger = false; int minicpmv_version = 2; @@ -638,7 +657,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 const int batch_size = imgs->size; - if (ctx->has_llava_projector || ctx->has_minicpmv_projector) { + if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) { GGML_ASSERT(batch_size == 1); } @@ -734,8 +753,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } // loop over layers - if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) { - // TODO: figure out why we doing thing in this way ??? + if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) { n_layer += 1; } for (int il = 0; il < n_layer - 1; il++) { @@ -1095,7 +1113,33 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 GGML_ASSERT(false); } } - else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + // glm projector + else if (ctx->has_glm_projector) { + if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + size_t gridsz = (size_t)sqrt(embeddings->ne[1]); + embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); + embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); + embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); + embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); + embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3)); + embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b); + //GLU + { + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); + embeddings = ggml_norm(ctx0, embeddings, eps); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b); + embeddings = ggml_gelu_inplace(ctx0, embeddings); + struct ggml_tensor * x = embeddings; + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); + x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); + embeddings = ggml_silu_inplace(ctx0, embeddings); + embeddings = ggml_mul(ctx0, embeddings,x); + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); + } + } else { + GGML_ABORT("fatel error"); + } + } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); @@ -1284,6 +1328,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); } + idx = gguf_find_key(ctx, KEY_HAS_GLM_PROJ); + if (idx != -1) { + new_clip->has_glm_projector = gguf_get_val_bool(ctx, idx); + } + idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER); if (idx != -1) { new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx); @@ -1308,6 +1357,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector); + LOG_INF("%s: glm_projector: %d\n", __func__, new_clip->has_glm_projector); LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); } @@ -1575,6 +1625,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight")); vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias")); } + else if (new_clip->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + vision_model.mm_model_adapter_conv_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "weight")); + vision_model.mm_model_adapter_conv_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "bias")); + vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_LINEAR,"weight")); + vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"weight")); + vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"bias")); + vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_H_2_4H,"weight")); + vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_GATE,"weight")); + vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); + vision_model.boi_w = get_tensor(new_clip->ctx_data, TN_GLM_BOI_W); + vision_model.eoi_w = get_tensor(new_clip->ctx_data, TN_GLM_EOI_W); + } else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) { vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); @@ -2115,6 +2177,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli return true; } + if (ctx->has_glm_projector) { + res_imgs->size = 1; + res_imgs->data = new clip_image_f32[res_imgs->size]; + clip_image_u8 resized_image; + int32_t sz=ctx->vision_model.hparams.image_size; + bicubic_resize(*img, resized_image,sz,sz); + clip_image_f32 * res = clip_image_f32_init(); + //clip_image_save_to_bmp(resized_image, "resized.bmp"); + normalize_image_u8_to_f32(&resized_image, res, ctx->image_mean, ctx->image_std); + res_imgs->data[0] = *res; + clip_image_f32_free(res); + return true; + } + bool pad_to_square = true; if (!ctx->has_vision_encoder) { LOG_ERR("This gguf file seems to have no vision encoder\n"); @@ -2300,7 +2376,8 @@ void clip_free(clip_ctx * ctx) { } size_t clip_embd_nbytes(const struct clip_ctx * ctx) { - return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float); + int extra_tokens = ctx->has_glm_projector ? 2 : 0; + return (clip_n_patches(ctx) + extra_tokens) * clip_n_mmproj_embd(ctx) * sizeof(float); } size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) { @@ -2342,7 +2419,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); - if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) { + if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { n_patches /= 4; } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { if (ctx->minicpmv_version == 2) { @@ -2475,6 +2552,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima if (ctx->has_minicpmv_projector) { GGML_ASSERT(batch_size == 1); } + if (ctx->has_glm_projector) { + GGML_ASSERT(batch_size == 1); + ggml_tensor * boi = ctx->vision_model.boi_w; + ggml_backend_tensor_get(boi,vec,0,ggml_nbytes(boi)); + vec = (float*)(vec+ggml_nelements(boi)); //offset for boi + } // build the inference graph ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true); @@ -2627,7 +2710,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); - { + if (!ctx->has_glm_projector) { struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); int* patches_data = (int*)malloc(ggml_nbytes(patches)); for (int i = 0; i < num_patches; i++) { @@ -2651,6 +2734,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // copy the embeddings to the location passed by the user ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + if (ctx->has_glm_projector) { + //eoi + ggml_tensor * eoi = ctx->vision_model.eoi_w; + int offset = ggml_nelements(embeddings); + ggml_backend_tensor_get(eoi, vec+offset, 0, ggml_nbytes(eoi)); + } + return true; } @@ -2812,6 +2902,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return 3584; } } + if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE){ + return ctx->vision_model.mm_model_mlp_3_w->ne[1]; + } if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { return ctx->vision_model.mm_1_b->ne[0]; } @@ -2827,6 +2920,9 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) { return 0; } +bool clip_is_glm(const struct clip_ctx * ctx) { + return ctx->has_glm_projector; +} bool clip_is_qwen2vl(const struct clip_ctx * ctx) { return ctx->has_qwen2vl_merger; } diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 1603edd26..841b4f6f9 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -93,6 +93,8 @@ CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); +CLIP_API bool clip_is_glm(const struct clip_ctx * ctx); + #ifdef __cplusplus } #endif diff --git a/examples/llava/glmedge-convert-image-encoder-to-gguf.py b/examples/llava/glmedge-convert-image-encoder-to-gguf.py new file mode 100644 index 000000000..848ef1cf3 --- /dev/null +++ b/examples/llava/glmedge-convert-image-encoder-to-gguf.py @@ -0,0 +1,280 @@ +import argparse +import os +import json +import re + +import torch +import numpy as np +from gguf import * + +TEXT = "clip.text" +VISION = "clip.vision" +from transformers import SiglipVisionModel, SiglipVisionConfig + +def k(raw_key: str, arch: str) -> str: + return raw_key.format(arch=arch) + + +def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool: + if name in ( + "logit_scale", + "text_model.embeddings.position_ids", + "vision_model.embeddings.position_ids", + ): + return True + + if name in ( + "vision_model.head.probe", + "vision_model.head.attention.in_proj_weight", + "vision_model.head.attention.in_proj_bias", + "vision_model.head.attention.out_proj.weight", + "vision_model.head.attention.out_proj.bias", + "vision_model.head.layernorm.weight", + "vision_model.head.layernorm.bias", + "vision_model.head.mlp.fc1.weight", + "vision_model.head.mlp.fc1.bias", + "vision_model.head.mlp.fc2.weight", + "vision_model.head.mlp.fc2.bias" + ): + return True + + if name.startswith("v") and not has_vision: + return True + + if name.startswith("t") and not has_text: + return True + + return False + + +def get_tensor_name(name: str) -> str: + if "projection" in name: + return name + if "mm_projector" in name: + name = name.replace("model.mm_projector", "mm") + name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1) + name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1) + return name + + return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") + + +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +ap = argparse.ArgumentParser() +ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) +ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") +ap.add_argument("--text-only", action="store_true", required=False, + help="Save a text-only model. It can't be used to encode images") +ap.add_argument("--vision-only", action="store_true", required=False, + help="Save a vision-only model. It can't be used to encode texts") +ap.add_argument("--clip-model-is-vision", action="store_true", required=False, + help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") +ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, + help="The clip model is from openclip (for ViT-SO400M type))") +ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") +ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2","adapter"], default="adapter") +ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) +# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 +# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 +default_image_mean = [0.5, 0.5, 0.5] +default_image_std = [0.5, 0.5, 0.5] +ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) +ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) + +# with proper +args = ap.parse_args() + + +if args.text_only and args.vision_only: + print("--text-only and --image-only arguments cannot be specified at the same time.") + exit(1) + +if args.use_f32: + print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.") + +# output in the same directory as the model if output_dir is None +dir_model = args.model_dir + +if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: + vocab = None + tokens = None +else: + with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f: + vocab = json.load(f) + tokens = [key for key in vocab] + +with open(dir_model + "/config.json", "r", encoding="utf-8") as f: + config = json.load(f) + if args.clip_model_is_vision: + v_hparams = config + t_hparams = None + else: + v_hparams = config["vision_config"] + t_hparams = None + +# possible data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 +# +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 +if args.use_f32: + ftype = 0 + +vision_config = SiglipVisionConfig(**v_hparams) +model = SiglipVisionModel(vision_config) +model.load_state_dict(torch.load(os.path.join(dir_model, "glm.clip"))) + +fname_middle = None +has_text_encoder = False +has_vision_encoder = True +has_glm_projector = True +if args.text_only: + fname_middle = "text-" + has_vision_encoder = False +elif args.llava_projector is not None: + fname_middle = "mmproj-" + has_text_encoder = False + has_glm_projector = True +elif args.vision_only: + fname_middle = "vision-" + has_text_encoder = False +else: + fname_middle = "" + +output_dir = args.output_dir if args.output_dir is not None else dir_model +os.makedirs(output_dir, exist_ok=True) +output_prefix = os.path.basename(output_dir).replace("ggml_", "") +fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf") +fout = GGUFWriter(path=fname_out, arch="clip") + +fout.add_bool("clip.has_text_encoder", has_text_encoder) +fout.add_bool("clip.has_vision_encoder", has_vision_encoder) +fout.add_bool("clip.has_glm_projector", has_glm_projector) +fout.add_file_type(ftype) +model_name = config["_name_or_path"] if "_name_or_path" in config else os.path.basename(dir_model) +fout.add_name(model_name) +if has_glm_projector: + fout.add_description("image encoder for glm4v") + fout.add_string("clip.projector_type", "adapter") +else: + fout.add_description("two-tower CLIP model") + +if has_text_encoder: + assert t_hparams is not None + assert tokens is not None + # text_model hparams + fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"]) + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"]) + fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"]) + fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"])) + fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"]) + fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"]) + fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"]) + fout.add_token_list(tokens) + +if has_vision_encoder: + # vision_model hparams + fout.add_uint32("clip.vision.image_size", v_hparams["image_size"]) + fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"]) + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"]) + fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"]) + fout.add_uint32("clip.vision.projection_dim", 0) + fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"]) + fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) + fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), v_hparams["num_hidden_layers"]) + + image_mean = args.image_mean if args.image_mean is not None else default_image_mean + image_std = args.image_std if args.image_std is not None else default_image_std + fout.add_array("clip.vision.image_mean", image_mean) + fout.add_array("clip.vision.image_std", image_std) + +fout.add_bool("clip.use_gelu", True) + + +if has_glm_projector: + # model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue] + projector = torch.load(args.llava_projector) + for name, data in projector.items(): + name = get_tensor_name(name) + # pw and dw conv ndim==4 + if data.ndim == 2 or data.ndim == 4: + data = data.squeeze().numpy().astype(np.float16) + else: + data = data.squeeze().numpy().astype(np.float32) + if name.startswith("vision."): + name=name.replace("vision.","") + fout.add_tensor(name, data) + print(f"Projector {name} - {data.dtype} - shape = {data.shape}") + # print(f"Projector {name} tensors added\n") + +state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue] +for name, data in state_dict.items(): + if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_glm_projector): + # we don't need this + print(f"skipping parameter: {name}") + continue + + name = get_tensor_name(name) + data = data.squeeze().numpy() + + n_dims = len(data.shape) + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if n_dims == 4: + print(f"tensor {name} is always saved in f16") + data = data.astype(np.float16) + ftype_cur = 1 + elif ftype == 1: + if name[-7:] == ".weight" and n_dims == 2: + # print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + # print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + # print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + print(f"siglip {name} - {data.dtype} - shape = {data.shape}") + # print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") + fout.add_tensor(name, data) + + +fout.write_header_to_file() +fout.write_kv_data_to_file() +fout.write_tensors_to_file() +fout.close() + +print("Done. Output file: " + fname_out) diff --git a/examples/llava/glmedge-surgery.py b/examples/llava/glmedge-surgery.py new file mode 100644 index 000000000..16bb915d0 --- /dev/null +++ b/examples/llava/glmedge-surgery.py @@ -0,0 +1,33 @@ +import argparse +import os +import torch +from transformers import AutoModel + +ap = argparse.ArgumentParser() +ap.add_argument("-m", "--model", help="Path to GLM model") +args = ap.parse_args() + +# find the model part that includes the the multimodal projector weights +model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True) +checkpoint = model.state_dict() + +# get a list of mm tensor names +mm_tensors = [k for k, v in checkpoint.items() if k.startswith("vision.adapter.")] + +# store these tensors in a new dictionary and torch.save them +projector = {name: checkpoint[name].float() for name in mm_tensors} +torch.save(projector, f"{args.model}/glm.projector") + +clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vision.vit.model.vision_model.")] +if len(clip_tensors) > 0: + clip = {name.replace("vision.vit.model.", ""): checkpoint[name].float() for name in clip_tensors} + torch.save(clip, f"{args.model}/glm.clip") + + # added tokens should be removed to be able to convert Mistral models + if os.path.exists(f"{args.model}/added_tokens.json"): + with open(f"{args.model}/added_tokens.json", "w") as f: + f.write("{}\n") + +print("Done!") +print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") +print(f"Also, use {args.model}glm.projector to prepare a glm-encoder.gguf file.") diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 2cac7933d..300714045 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -311,6 +311,20 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli img_res_v.size = 0; img_res_v.data = nullptr; } + else if (clip_is_glm(ctx_clip)){ + struct clip_image_size * load_image_size = clip_image_size_init(); + load_image_size->width = img_res_v.data[0].nx; + load_image_size->height = img_res_v.data[0].ny; + clip_add_load_image_size(ctx_clip, load_image_size); + + bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); + int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2); + *n_img_pos = (pos * pos + 2); + if (!encoded){ + LOG_ERR("Unable to encode image \n"); + return false; + } + } else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding *n_img_pos = clip_n_patches(ctx_clip); @@ -395,6 +409,9 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co if (clip_is_minicpmv(ctx_clip)) { num_max_patches = 10; } + if (clip_is_glm(ctx_clip)) { + num_max_patches = 1; + } float * image_embd; if (clip_is_qwen2vl(ctx_clip)) { // qwen2vl don't split image into chunks, so `num_max_patches` is not needed. diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 9cecae48c..ca9273155 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -24,15 +24,16 @@ #include #include +#include "chat-template.hpp" #include "common.h" #include "json.hpp" #include "linenoise.cpp/linenoise.h" #include "llama-cpp.h" -#include "chat-template.hpp" +#include "log.h" #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) [[noreturn]] static void sigint_handler(int) { - printf("\n\033[0m"); + printf("\n" LOG_COL_DEFAULT); exit(0); // not ideal, but it's the only way to guarantee exit in all cases } #endif @@ -65,6 +66,13 @@ static int printe(const char * fmt, ...) { return ret; } +static std::string strftime_fmt(const char * fmt, const std::tm & tm) { + std::ostringstream oss; + oss << std::put_time(&tm, fmt); + + return oss.str(); +} + class Opt { public: int init(int argc, const char ** argv) { @@ -698,6 +706,39 @@ class LlamaData { return download(url, bn, true); } + int s3_dl(const std::string & model, const std::string & bn) { + const size_t slash_pos = model.find('/'); + if (slash_pos == std::string::npos) { + return 1; + } + + const std::string bucket = model.substr(0, slash_pos); + const std::string key = model.substr(slash_pos + 1); + const char * access_key = std::getenv("AWS_ACCESS_KEY_ID"); + const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY"); + if (!access_key || !secret_key) { + printe("AWS credentials not found in environment\n"); + return 1; + } + + // Generate AWS Signature Version 4 headers + // (Implementation requires HMAC-SHA256 and date handling) + // Get current timestamp + const time_t now = time(nullptr); + const tm tm = *gmtime(&now); + const std::string date = strftime_fmt("%Y%m%d", tm); + const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm); + const std::vector headers = { + "Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date + + "/us-east-1/s3/aws4_request", + "x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime + }; + + const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key; + + return download(url, bn, true, headers); + } + std::string basename(const std::string & path) { const size_t pos = path.find_last_of("/\\"); if (pos == std::string::npos) { @@ -738,6 +779,9 @@ class LlamaData { rm_until_substring(model_, "github:"); rm_until_substring(model_, "://"); ret = github_dl(model_, bn); + } else if (string_starts_with(model_, "s3://")) { + rm_until_substring(model_, "://"); + ret = s3_dl(model_, bn); } else { // ollama:// or nothing rm_until_substring(model_, "ollama.com/library/"); rm_until_substring(model_, "://"); @@ -847,7 +891,7 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch & const int n_ctx = llama_n_ctx(ctx.get()); const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get()); if (n_ctx_used + batch.n_tokens > n_ctx) { - printf("\033[0m\n"); + printf(LOG_COL_DEFAULT "\n"); printe("context size exceeded\n"); return 1; } @@ -910,7 +954,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str batch = llama_batch_get_one(&new_token_id, 1); } - printf("\033[0m"); + printf(LOG_COL_DEFAULT); return 0; } @@ -919,7 +963,7 @@ static int read_user_input(std::string & user_input) { #ifdef WIN32 printf( "\r%*s" - "\r\033[0m%s", + "\r" LOG_COL_DEFAULT "%s", get_terminal_width(), " ", prompt_prefix); std::getline(std::cin, user_input); @@ -956,7 +1000,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, const bool stdout_a_terminal) { // Set response color if (stdout_a_terminal) { - printf("\033[33m"); + printf(LOG_COL_YELLOW); } if (generate(llama_data, prompt, response)) { @@ -965,7 +1009,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, } // End response with color reset and newline - printf("\n%s", stdout_a_terminal ? "\033[0m" : ""); + printf("\n%s", stdout_a_terminal ? LOG_COL_DEFAULT : ""); return 0; } diff --git a/examples/server/README.md b/examples/server/README.md index 276b43013..e9d0374ad 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -1128,6 +1128,7 @@ curl http://localhost:8080/v1/chat/completions \ - Hermes 2/3, Qwen 2.5 - Mistral Nemo - Firefunction v2 + - Command R7B - DeepSeek R1 (WIP / seems reluctant to call any tools?)
@@ -1202,21 +1203,28 @@ curl http://localhost:8080/v1/chat/completions \ ```shell # Native support: llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M - llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M - llama-server --jinja -fa -hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q6_K + llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M - llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \ - --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B ) + llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M # Native support requires the right template for these GGUFs: + + llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \ + --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use ) + llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \ --chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use ) + llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \ - --chat-template-file <( python scripts/get_chat_template.py fireworks-ai/firellama-3-firefunction-v2 ) + --chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use ) + + llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \ + --chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use ) # Generic format support - llama-server --jinja -fa -hf bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M - llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q4_K_M + llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0 + llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0 + llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K ``` - Test in CLI: diff --git a/examples/server/public/index.html.gz b/examples/server/public/index.html.gz index 582ccc0d3..e876776e8 100644 Binary files a/examples/server/public/index.html.gz and b/examples/server/public/index.html.gz differ diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3451e96a2..e0acc4705 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -131,6 +131,11 @@ struct slot_params { lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); } + std::vector grammar_trigger_words; + for (const auto & trigger : sampling.grammar_trigger_words) { + grammar_trigger_words.push_back(trigger.word); + } + return json { {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, @@ -165,8 +170,9 @@ struct slot_params { {"n_probs", sampling.n_probs}, {"min_keep", sampling.min_keep}, {"grammar", sampling.grammar}, - // {"grammar_trigger_words", sampling.grammar_trigger_words}, + {"grammar_trigger_words", grammar_trigger_words}, {"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, + {"preserved_tokens", sampling.preserved_tokens}, {"samplers", samplers}, {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, @@ -363,12 +369,26 @@ struct server_task { if (ids.size() == 1) { LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); params.sampling.grammar_trigger_tokens.push_back(ids[0]); + params.sampling.preserved_tokens.insert(ids[0]); continue; } LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); params.sampling.grammar_trigger_words.push_back(trigger); } } + const auto preserved_tokens = data.find("preserved_tokens"); + if (preserved_tokens != data.end()) { + for (const auto & t : *preserved_tokens) { + auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + LOG_DBG("Preserved token: %d\n", ids[0]); + params.sampling.preserved_tokens.insert(ids[0]); + } else { + // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. + LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get().c_str()); + } + } + } if (params.sampling.grammar_lazy) { GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0); } @@ -695,19 +715,19 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_oaicompat_chat() { std::string finish_reason = "length"; - common_chat_msg message; + common_chat_msg msg; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { LOG_DBG("Parsing chat message: %s\n", content.c_str()); - message = common_chat_parse(content, oaicompat_chat_format); - finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls"; + msg = common_chat_parse(content, oaicompat_chat_format); + finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; } else { - message.content = content; + msg.content = content; } json tool_calls; - if (!message.tool_calls.empty()) { + if (!msg.tool_calls.empty()) { tool_calls = json::array(); - for (const auto & tc : message.tool_calls) { + for (const auto & tc : msg.tool_calls) { tool_calls.push_back({ {"type", "function"}, {"function", { @@ -719,14 +739,19 @@ struct server_task_result_cmpl_final : server_task_result { } } + json message { + {"content", msg.content}, + {"tool_calls", tool_calls}, + {"role", "assistant"}, + }; + if (!msg.tool_plan.empty()) { + message["tool_plan"] = msg.tool_plan; + } + json choice { {"finish_reason", finish_reason}, {"index", 0}, - {"message", json { - {"content", message.content}, - {"tool_calls", tool_calls}, - {"role", "assistant"}, - }}, + {"message", message}, }; if (!stream && probs_output.size() > 0) { @@ -2833,8 +2858,7 @@ struct server_context { server_slot * slot_batched = nullptr; auto accept_special_token = [&](server_slot & slot, llama_token token) { - const auto & trigger_tokens = slot.params.sampling.grammar_trigger_tokens; - return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end(); + return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end(); }; // frist, add sampled tokens from any ongoing sequences diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index f5d8b0572..f23d5cff4 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -13,9 +13,12 @@ def create_server(): @pytest.mark.parametrize( "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template", [ + (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None), + (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None), (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None), - (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None), - (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), + (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None), + (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'), + (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), ] diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index e6ed9c9be..4a551404f 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -67,8 +67,8 @@ WEATHER_TOOL = { def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None): - n_predict = 512 global server + n_predict = 512 # server = ServerPreset.stories15m_moe() server.jinja = True server.n_predict = n_predict @@ -139,29 +139,49 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, @pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [ (TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + + # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), - (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)), + (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)), + (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), # TODO: fix these # (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), # (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ]) -def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None): +def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): + global server n_predict = 512 server.n_slots = 1 server.jinja = True @@ -169,10 +189,12 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str server.n_predict = n_predict server.model_hf_repo = hf_repo server.model_hf_file = None - if template_override: + if isinstance(template_override, tuple): (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) res = server.make_request("POST", "/chat/completions", data={ "max_tokens": n_predict, @@ -251,33 +273,55 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t @pytest.mark.slow @pytest.mark.parametrize("hf_repo,template_override", [ + ("bartowski/c4ai-command-r7b-12-2024-GGUF:Q4_K_M", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")), ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), - ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + + # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), # ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ]) -def test_weather_tool_call(hf_repo: str, template_override: Tuple[str, str | None] | None): +def test_weather(hf_repo: str, template_override: Tuple[str, str | None] | None): global server + n_predict = 512 server.n_slots = 1 server.jinja = True server.n_ctx = 8192 - server.n_predict = 512 + server.n_predict = n_predict server.model_hf_repo = hf_repo server.model_hf_file = None - if template_override: + if isinstance(template_override, tuple): (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": 256, + "max_tokens": n_predict, "messages": [ {"role": "user", "content": "What is the weather in Istanbul?"}, ], @@ -298,19 +342,39 @@ def test_weather_tool_call(hf_repo: str, template_override: Tuple[str, str | Non @pytest.mark.slow @pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ - (None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), - ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), + + (None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + + ('{"code":"print("}', "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), + ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), - (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + + (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + + (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + + # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. + (None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ]) -def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None): +def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server server.n_slots = 1 server.jinja = True @@ -318,10 +382,12 @@ def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: server.n_predict = 128 server.model_hf_repo = hf_repo server.model_hf_file = None - if template_override: + if isinstance(template_override, tuple): (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) res = server.make_request("POST", "/chat/completions", data={ "max_tokens": 256, diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index bfe623c4c..5f97df5fd 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -5,10 +5,6 @@ #include "llama.h" #include "common/base64.hpp" -#ifndef NDEBUG -// crash the server in debug mode, otherwise send an http 500 error -#define CPPHTTPLIB_NO_EXCEPTIONS 1 -#endif // increase max payload length to allow use of larger context size #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 #include "httplib.h" @@ -662,6 +658,7 @@ static json oaicompat_completion_params_parse( }); } llama_params["grammar_triggers"] = grammar_triggers; + llama_params["preserved_tokens"] = chat_params.preserved_tokens; for (const auto & stop : chat_params.additional_stops) { llama_params["stop"].push_back(stop); } diff --git a/examples/server/webui/index.html b/examples/server/webui/index.html index d3893ea4e..1ad85e739 100644 --- a/examples/server/webui/index.html +++ b/examples/server/webui/index.html @@ -154,8 +154,6 @@ placeholder="Type a message (Shift+Enter to add a new line)" v-model="inputMsg" @keydown.enter.exact.prevent="sendMessage" - @keydown.enter.shift.exact.prevent="inputMsg += '\n'" - :disabled="isGenerating" id="msg-input" dir="auto" > diff --git a/examples/server/webui/src/main.js b/examples/server/webui/src/main.js index 90f4ca368..50197a355 100644 --- a/examples/server/webui/src/main.js +++ b/examples/server/webui/src/main.js @@ -468,7 +468,10 @@ const mainApp = createApp({ URL.revokeObjectURL(url); }, async sendMessage() { - if (!this.inputMsg) return; + // prevent sending empty message + // also allow typing the message while generating, but does not allow sending it (to match UX/UI behavior of other chat apps) + if (!this.inputMsg || this.isGenerating) return; + const currConvId = this.viewingConvId; StorageUtils.appendMsg(currConvId, { diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 7c069e420..75b5ea3b4 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -274,22 +274,25 @@ endif() # Generate version info based on git commit. -find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH) -execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - OUTPUT_VARIABLE GGML_BUILD_NUMBER - OUTPUT_STRIP_TRAILING_WHITESPACE -) +if(NOT DEFINED GGML_BUILD_NUMBER) + find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH) + execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_NUMBER + OUTPUT_STRIP_TRAILING_WHITESPACE + ) -if(GGML_BUILD_NUMBER EQUAL 1) - message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.") + if(GGML_BUILD_NUMBER EQUAL 1) + message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.") + endif() + + execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_COMMIT + OUTPUT_STRIP_TRAILING_WHITESPACE + ) endif() -execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - OUTPUT_VARIABLE GGML_BUILD_COMMIT - OUTPUT_STRIP_TRAILING_WHITESPACE -) # Capture variables prefixed with GGML_. diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1198dc1fd..5bd8d9c8b 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1775,7 +1775,7 @@ extern "C" { struct ggml_tensor * a, int k); -#define GGML_KQ_MASK_PAD 32 +#define GGML_KQ_MASK_PAD 64 // q: [n_embd, n_batch, n_head, 1] // k: [n_embd, n_kv, n_head_kv, 1] diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 566709135..0002ac18a 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -93,12 +93,18 @@ endif() if (GGML_CCACHE) find_program(GGML_CCACHE_FOUND ccache) + find_program(GGML_SCCACHE_FOUND sccache) - if (GGML_CCACHE_FOUND) + if (GGML_CCACHE_FOUND OR GGML_SCCACHE_FOUND) + if(GGML_CCACHE_FOUND) + set(GGML_CCACHE_VARIANT ccache) + else() + set(GGML_CCACHE_VARIANT sccache) + endif() # TODO: should not be set globally - set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache) + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}") set(ENV{CCACHE_SLOPPINESS} time_macros) - message(STATUS "ccache found, compilation results will be cached. Disable with GGML_CCACHE=OFF.") + message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.") else() message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF") endif () diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 14761650f..119fd39b8 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_CUDA "*.cu") - file(GLOB SRCS "template-instances/fattn-wmma*.cu") + file(GLOB SRCS "template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8d8d3932e..174916bc9 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -61,6 +61,13 @@ #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) +#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) +#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3) +#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA) +#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1) + #define GGML_CUDA_CC_QY1 210 #define GGML_CUDA_CC_QY2 220 @@ -148,7 +155,7 @@ typedef float2 dfloat2; #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING -#define INT8_MMA_AVAILABLE +#define NEW_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) @@ -159,14 +166,24 @@ static constexpr bool fast_fp16_available(const int cc) { return cc >= GGML_CUDA_CC_PASCAL && cc != 610; } +// Any FP16 tensor cores are available. static constexpr bool fp16_mma_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA; } -static constexpr bool int8_mma_available(const int cc) { +// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. +static constexpr bool new_mma_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING; } +static constexpr __device__ int ggml_cuda_get_physical_warp_size() { +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + return __AMDGCN_WAVEFRONT_SIZE; +#else + return 32; +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +} + [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index ee9752da6..d40ee2da4 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -516,6 +516,114 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } +// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ + +template // D == head size +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_stream_k_fixup( + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); + + const int iter_k = ne11 / KQ_stride; + const int iter_j = (ne01 + (ncols - 1)) / ncols; + + const int bidx0 = blockIdx.x; + + const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; + const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + const int channel = kbc0 / (iter_k*iter_j); + const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + + dst += jt*ncols*ne02*D + channel*D; + + // Load the partial result that needs a fixup: + float dst_val[ncols] = {0.0f}; + float max_val[ncols] = {0.0f}; + float rowsum[ncols] = {0.0f}; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (jt*ncols + j >= ne01) { + break; + } + dst_val[j] = dst[j*ne02*D + threadIdx.x]; + + const float2 tmp = dst_fixup[bidx0*ncols + j]; + max_val[j] = tmp.x; + rowsum[j] = tmp.y; + } + + // Iterate over previous blocks and compute the combined results. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int bidx = bidx0 - 1; + int kbc_stop = kbc0; + while(true) { + const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x; + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (jt*ncols + j >= ne01) { + break; + } + const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x]; + + const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j]; + + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = fmaxf(max_val[j], tmp.x); + + const float diff_val = max_val[j] - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add; + rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y; + + max_val[j] = max_val_new; + } + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + break; + } + bidx--; + kbc_stop = kbc; + } + + // Write back final result: +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (jt*ncols + j >= ne01) { + return; + } + dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j]; + } +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) @@ -581,10 +689,11 @@ static void on_no_fattn_vec_case(const int D) { } } -template +// parallel_blocks == 0 is stream-k decomposition +template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, - const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V + const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V ) { const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -603,20 +712,23 @@ void launch_fattn( GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + GGML_ASSERT(Q->ne[3] == 1); + ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); - char * K_data = (char *) K->data; + const char * K_data = (const char *) K->data; size_t nb11 = K->nb[1]; size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - char * V_data = (char *) V->data; + const char * V_data = (const char *) V->data; size_t nb21 = V->nb[1]; size_t nb22 = V->nb[2]; size_t nb23 = V->nb[3]; @@ -649,39 +761,60 @@ void launch_fattn( nb23 = nb23*bs*sizeof(half)/ts; } - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); - } + const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block); + const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3]; const dim3 block_dim(WARP_SIZE, nwarps, 1); - const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); - const int shmem = 0; + dim3 blocks_num; + if (parallel_blocks == 0) { + // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. + const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm; + const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total; + const bool short_context = K->ne[1] < 4096; + + const int nblocks_stream_k = 2*nsm; + + blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k; + blocks_num.y = 1; + blocks_num.z = 1; + + dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float)); + } else { + blocks_num.x = parallel_blocks*ntiles_x; + blocks_num.y = Q->ne[2]; + blocks_num.z = Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + } + float scale = 1.0f; float max_bias = 0.0f; float logit_softcap = 0.0f; - memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float)); + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); if (logit_softcap != 0.0f) { scale /= logit_softcap; } const uint32_t n_head = Q->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - fattn_kernel<<>>( + fattn_kernel<<>>( (const char *) Q->data, K_data, V_data, mask ? ((const char *) mask->data) : nullptr, - (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + (parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -693,16 +826,22 @@ void launch_fattn( ); CUDA_CHECK(cudaGetLastError()); - if ((parallel_blocks) == 1) { - return; + if constexpr (parallel_blocks == 0) { + if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles. + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine = blocks_num; + + flash_attn_stream_k_fixup + <<>> + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + } + } else if constexpr (parallel_blocks > 1) { + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); } - - const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; - - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh new file mode 100644 index 000000000..05bc91a3b --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -0,0 +1,637 @@ +#include "common.cuh" +#include "mma.cuh" +#include "fattn-common.cuh" + +template +static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ maskh, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3, + const int jt, + const int kb0_start, + const int kb0_stop) { +#ifdef NEW_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + typedef mma_A_I16K8 mma_A; + typedef mma_B_J8K8 mma_B; + typedef mma_C_I16J8 mma_C_KQ; + typedef mma_C_I16J8 mma_C_VKQ; + + static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps"); + constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column. + + static_assert(D % nwarps == 0, "bad D"); + static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); + + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements. + + const int stride_Q = nb01 / sizeof(float2); + const int stride_KV = nb11 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + + mma_B Q_B[D/(2*mma_B::K)]; + mma_C_VKQ VKQ_C[D/mma_C_VKQ::I]; + + float2 KQ_rowsum = {0.0f, 0.0f}; + float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; + float2 KQ_max_scale = {0.0f, 0.0f}; + + // Temporarily load Q data into tile_KV, will be loaded into registers afterwards. + // The loading is done with decreasing granularity for D for better memory bandwidth. + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_j = WARP_SIZE / stride_k; + + if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { + break; + } + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) { + const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jt*ncols + j < ne01) { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; + tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + } + } else { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f); + } + } + } + } + + __syncthreads(); + + { + const int j0 = (threadIdx.y / np) * mma_B::J; + +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { + Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded); + } + } + + __syncthreads(); + + // Iterate over ne11 == previous tokens: + for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) { + const int k_VKQ_0 = kb0*KQ_stride; + mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)]; + + // Load K data into tile with decreasing granularity for D for better memory bandwidth: + static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) { + const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) { + const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ]; + } + } + } + + __syncthreads(); + + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) { + mma_A K_A; + K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded); + KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]); + } + } + + __syncthreads(); + + if (use_logit_softcap) { + static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) { +#pragma unroll + for (int l = 0; l < mma_C_KQ::ne; ++l) { + KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); + } + } + } + + if (maskh) { + static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size"); + static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size"); +#pragma unroll + for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I; +#pragma unroll + for (int l = 0; l < mma_C_KQ::ne; ++l) { + const int i = i0 + mma_C_KQ::get_i(l); + const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l); + + KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + float2 KQ_max_new = KQ_max; + static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { +#pragma unroll + for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) { + KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); + KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); + } + } + + // Values per KQ column are spread across 8 threads, does not need full warp reduce: +#pragma unroll + for (int offset = 16; offset > 2; offset >>= 1) { + KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); + KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); + } + + { + const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); + KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); + if (diff.x <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale.x = 0.0f; + } + if (diff.y <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale.y = 0.0f; + } + KQ_max = KQ_max_new; + } + + float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); + static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < mma_C_KQ::ne; ++l) { + const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y; + const float diff = KQ_C[k].x[l] - KQ_max_l; + KQ_C[k].x[l] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_C[k].x[l] = 0.0f; + } + + if (l % 2 == 0) { + KQ_rowsum_add.x += KQ_C[k].x[l]; + } else { + KQ_rowsum_add.y += KQ_C[k].x[l]; + } + } + } + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x; + KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y; + + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y); +#pragma unroll + for (int i = 0; i < D/mma_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < mma_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + + // Convert KQ C tiles into B tiles for VKQ calculation: + mma_B B[KQ_stride/(np*2*mma_B::K)]; + static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) { + B[k] = KQ_C[k].to_mma_B(); + } + + // Load V data into tile with decreasing granularity for D for better memory bandwidth: + static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i); + const int i0_stop = D/2 - (D/2) % (1*stride_i); + const int stride_k = WARP_SIZE / stride_i; + +#pragma unroll + for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) { + const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i); + +#pragma unroll + for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) { + const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i); + + tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V]; + } + } + } + + __syncthreads(); + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) { + static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size"); +#pragma unroll + for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) { + const int k0 = k00 + (threadIdx.y % np)*mma_A::K; + + mma_A A; + A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); + VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]); + } + } + + __syncthreads(); + } + + // Finally, sum up partial KQ rowsums. + // The partial sums are spread across 8 threads each, does not need full reduce. +#pragma unroll + for (int offset = 16; offset > 2; offset >>= 1) { + KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE); + KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE); + } + + // Write VKQ accumulators to shared memory in column-major format. + // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // Also for np > 1 the combination is done via these values in shared memory. + const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { + const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format. + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int k = k0 + mma_B::get_k(l); + + tile_KV[j_cwd*D2_padded + k] = B.x[l]; + } + } + + const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset + const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta + const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } + + __syncthreads(); + + static_assert(np == 1 || np == 2 || np == 4, "bad np"); + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && threadIdx.x < mma_B::J) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[j_cwm] = KQ_cmr; + } + if (is_fixup && threadIdx.x < mma_B::J) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[j_cwm] = KQ_cmr; + } + } else if (threadIdx.y % np == 0) { + // Combine the meta data for parallel warps via shared memory. + // Warps with threadIdx.y % np != 0 must NOT return early. + // All threads must return simultaneously to avoid race conditions with work on the next tile. + + float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2; + + float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp. + if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + KQ_cm = meta_j[0]; + } + + float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps. +#pragma unroll + for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + } + + const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp. + float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps. + if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + KQ_crs = KQ_cms*meta_j[1]; + } +#pragma unroll + for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + } + + // Write back combined meta data: + if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + meta_j[0] = KQ_cmn; // Combined max. KQ values. + meta_j[1] = KQ_crs; // Combined KQ rowsums. + meta_j[2] = KQ_cms; // KQ max scales per parallel warp. + } + if (needs_fixup && threadIdx.x < mma_B::J) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + if (is_fixup && threadIdx.x < mma_B::J) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + } + + if (np > 1) { + __syncthreads(); + } + + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); + +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_j = WARP_SIZE / stride_k; + + if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { + break; + } + +#pragma unroll + for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) { + const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J; + + if (!is_fixup && jt*ncols + j_dst >= ne01) { + continue; + } + const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2]; + const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[j_dst*(D/2) + k] = dstk_val; + } else { + dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val; + } + } + } + } + } + + if (np > 1) { + __syncthreads(); + } +#else + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE +} + +template +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(nwarps*WARP_SIZE, 2) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + + static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride"); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + + const int iter_k = ne11 / KQ_stride; + const int iter_j = (ne01 + (ncols - 1)) / ncols; + + // kbc == k block continuous, current index in continuous ijk space. + int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x; + + // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. + // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). + // In the most general case >2 seams can fall into the same tile. + + // kb0 == k start index when in the output tile. + int kb0_start = kbc % iter_k; + int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == iter_k) { + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape + const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(D/2); + + const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); + + constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + if (kb0_start == 0) { + constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, + jt, kb0_start, kb0_stop); + } else { + constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, + jt, kb0_start, kb0_stop); + } + + kbc += iter_k; + kbc -= kbc % iter_k; + + kb0_start = 0; + kb0_stop = min(iter_k, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape + const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(D/2); + + const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); + + constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + constexpr bool needs_fixup = false; + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, + jt, kb0_start, kb0_stop); +} + +template +void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + typedef mma_A_I16K8 mma_A; + typedef mma_B_J8K8 mma_B; + + static_assert(D % mma_B::K == 0, "bad D"); + static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block"); + + const ggml_tensor * KQV = dst; + + constexpr int KQ_stride = D <= 128 ? 64 : 32; + constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? + cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8); + constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half); + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16; + } + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); +} + +#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \ + template void ggml_cuda_flash_attn_ext_mma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_MMA_F16_CASE( 64, 8); +extern DECL_FATTN_MMA_F16_CASE( 80, 8); +extern DECL_FATTN_MMA_F16_CASE( 96, 8); +extern DECL_FATTN_MMA_F16_CASE(112, 8); +extern DECL_FATTN_MMA_F16_CASE(128, 8); +extern DECL_FATTN_MMA_F16_CASE(256, 8); + +extern DECL_FATTN_MMA_F16_CASE( 64, 16); +extern DECL_FATTN_MMA_F16_CASE( 80, 16); +extern DECL_FATTN_MMA_F16_CASE( 96, 16); +extern DECL_FATTN_MMA_F16_CASE(112, 16); +extern DECL_FATTN_MMA_F16_CASE(128, 16); +extern DECL_FATTN_MMA_F16_CASE(256, 16); + +extern DECL_FATTN_MMA_F16_CASE( 64, 32); +extern DECL_FATTN_MMA_F16_CASE( 80, 32); +extern DECL_FATTN_MMA_F16_CASE( 96, 32); +extern DECL_FATTN_MMA_F16_CASE(112, 32); +extern DECL_FATTN_MMA_F16_CASE(128, 32); +extern DECL_FATTN_MMA_F16_CASE(256, 32); + +extern DECL_FATTN_MMA_F16_CASE( 64, 64); +extern DECL_FATTN_MMA_F16_CASE( 80, 64); +extern DECL_FATTN_MMA_F16_CASE( 96, 64); +extern DECL_FATTN_MMA_F16_CASE(112, 64); +extern DECL_FATTN_MMA_F16_CASE(128, 64); +extern DECL_FATTN_MMA_F16_CASE(256, 64); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 4d314dacb..d4edbad07 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -45,7 +45,17 @@ static __global__ void flash_attn_tile_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + +#ifndef FLASH_ATTN_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FLASH_ATTN_AVAILABLE + // Skip unused kernel variants for faster compilation: +#ifdef FP16_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FP16_MMA_AVAILABLE if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; return; @@ -288,16 +298,18 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; + constexpr int D = 64; + constexpr int nwarps = 8; + constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; + constexpr int D = 128; + constexpr int nwarps = 8; + constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index bb3360447..0d274f332 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -48,7 +48,12 @@ static __global__ void flash_attn_tile_ext_f32( NO_DEVICE_CODE; return; #endif // FLASH_ATTN_AVAILABLE + // Skip unused kernel variants for faster compilation: +#ifdef FP16_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FP16_MMA_AVAILABLE if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; return; @@ -287,16 +292,18 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; + constexpr int D = 64; + constexpr int nwarps = 8; + constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; + constexpr int D = 128; + constexpr int nwarps = 8; + constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 34a2992c7..d9ac44246 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -42,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + +#ifndef FLASH_ATTN_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FLASH_ATTN_AVAILABLE + // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -303,7 +309,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr size_t nbytes_shared = 0; + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index a28fc8b7f..6ef8f9dcc 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -41,6 +41,11 @@ static __global__ void flash_attn_vec_ext_f32( const int ne1, const int ne2, const int ne3) { +#ifndef FLASH_ATTN_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FLASH_ATTN_AVAILABLE + // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -284,7 +289,8 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr size_t nbytes_shared = 0; + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu new file mode 100644 index 000000000..45702ad65 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -0,0 +1,648 @@ +// Old and deprecated WMMA FlashAttention implementation. +// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing. +// Long-term the WMMA code should be replaced with a dedicated Volta implementation. + +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-wmma-f16.cuh" + +#ifdef FP16_MMA_AVAILABLE +#include +#endif // FP16_MMA_AVAILABLE + +// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +template +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(nwarps*WARP_SIZE, 1) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); + constexpr int frag_m = ncols == 8 ? 32 : 16; + constexpr int frag_n = ncols == 8 ? 8 : 16; + static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + typedef nvcuda::wmma::fragment frag_a_K; + typedef nvcuda::wmma::fragment frag_a_V; + typedef nvcuda::wmma::fragment frag_b; + typedef nvcuda::wmma::fragment frag_c_KQ; + typedef nvcuda::wmma::fragment frag_c_VKQ; + + constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. + constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. + static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); + + // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: + constexpr int D_padded = D + 8; + constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); + const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; + const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + + const int stride_Q = nb01 / sizeof(float); + const int stride_KV = nb11 / sizeof(half); + + const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const half slopeh = __float2half(slopef); + const half2 slope2 = make_half2(slopef, slopef); + + const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); + + frag_b Q_b[D/16][ncols/frag_n]; + + // A single buffer for temporarily holding tiles of KQ and VKQ parts: + constexpr int mem_KQ = ncols*kqs_padded*kqar; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + float * KQ_f = (float *) KQ; + half2 * KQ2 = (half2 *) KQ; + + float KQ_rowsum_f[ncols/nwarps] = {0.0f}; + float KQ_max_f[ncols/nwarps]; + float KQ_max_scale_f[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_f[j] = -FLT_MAX/2.0f; + } + + half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max_h2[ncols/nwarps]; + half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); + } + + __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + half2 * VKQ2 = (half2 *) VKQ; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + } + } + + // Convert Q to half and apply scale, temporarily store in KQ: +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } + } + + __syncthreads(); + + // Load Q into tensor core fragments/registers since it will be used frequently: +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 16) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + } + } + + __syncthreads(); + + // Iterate over ne11 == previous tokens: + for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { + frag_c_KQ KQ_c[ncols/frag_n]; +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); + } +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); + } + } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (std::is_same::value) { + float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + + if (use_logit_softcap) { + KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); + } + } + + float KQ_max_new = KQ_max_f[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + } + KQ_max_new = warp_reduce_max(KQ_max_new); + + const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; + KQ_max_scale_f[j0/nwarps] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale_f[j0/nwarps] = 0.0f; + } + KQ_max_f[j0/nwarps] = KQ_max_new; + + float KQ_rowsum_add = 0.0f; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + } + KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; + } else { + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + + if (use_logit_softcap) { + // There is no dedicated tangens hyperbolicus function for half2. + KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); + + KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; + } + } + + half2 KQ_max_new = KQ_max_h2[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + } + KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; + KQ_max_scale_h2[j0/nwarps] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + KQ_max_h2[j0/nwarps] = KQ_max_new; + + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; + } + } + + __syncthreads(); + + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + nvcuda::wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); + } + } + + frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + } + +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + + frag_a_V v_a; + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + } + } + } + + __syncthreads(); + + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync( + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], + D_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + half2 VKQ_scale; + if (std::is_same::value) { + VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); + } else { + VKQ_scale = KQ_max_scale_h2[j0/nwarps]; + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + + half2 VKQ_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int l = 0; l < VKQ_ratio; ++l) { + VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + } + VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + + float KQ_rowsum_j; + if (std::is_same::value) { + KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; + } else { + KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); + } + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + float dst_val = VKQ[j_VKQ*D_padded + i]; + if (parallel_blocks == 1) { + dst_val /= KQ_rowsum_j; + } + dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; + } + + if (parallel_blocks == 1 || threadIdx.x != 0) { + continue; + } + + float2 dst_meta_val; + if (std::is_same::value) { + dst_meta_val.x = KQ_max_f[j0/nwarps]; + } else { + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); + } + dst_meta_val.y = KQ_rowsum_j; + dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; + } +#else + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +} + +constexpr int get_max_power_of_2(int x) { + return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; +} + +static_assert(get_max_power_of_2(1) == 1, "Test failed."); +static_assert(get_max_power_of_2(2) == 2, "Test failed."); +static_assert(get_max_power_of_2(4) == 4, "Test failed."); +static_assert(get_max_power_of_2(6) == 2, "Test failed."); + +// Number of VKQ rows calculated in parallel: +constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { + return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; +} + +static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); +static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + +template +void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + + constexpr int nwarps = 4; + + constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; + const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (4*blocks_num_pb1 < 2*nsm) { + constexpr int parallel_blocks = 4; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); + return; + } + if (2*blocks_num_pb1 < 2*nsm) { + constexpr int parallel_blocks = 2; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); + return; + } + constexpr int parallel_blocks = 1; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); +} + +void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + + const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); + + if (prec != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + } else { + constexpr int cols_per_block = 32; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + break; + // case 256: + // ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); + // break; + default: + GGML_ABORT("fatal error"); + break; + } + } + return; + } + + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + return; + } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + return; + } + + constexpr int cols_per_block = 32; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index 860d0e6dc..beeea95eb 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -1,543 +1,3 @@ #include "common.cuh" -#include "fattn-common.cuh" -#ifdef FP16_MMA_AVAILABLE -#include -#endif // FP16_MMA_AVAILABLE - -// D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { -#ifdef FP16_MMA_AVAILABLE - // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { - NO_DEVICE_CODE; - return; - } - - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. - - static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); - static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); - constexpr int frag_m = ncols == 8 ? 32 : 16; - constexpr int frag_n = ncols == 8 ? 8 : 16; - static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); - typedef nvcuda::wmma::fragment frag_a_K; - typedef nvcuda::wmma::fragment frag_a_V; - typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c_KQ; - typedef nvcuda::wmma::fragment frag_c_VKQ; - - constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. - constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. - static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); - - // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: - constexpr int D_padded = D + 8; - constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; - constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); - - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; - const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); - - const int stride_Q = nb01 / sizeof(float); - const int stride_KV = nb11 / sizeof(half); - - const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - const half2 slope2 = make_half2(slopef, slopef); - - const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); - - frag_b Q_b[D/16][ncols/frag_n]; - - // A single buffer for temporarily holding tiles of KQ and VKQ parts: - constexpr int mem_KQ = ncols*kqs_padded*kqar; - constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; - __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; - float * KQ_f = (float *) KQ; - half2 * KQ2 = (half2 *) KQ; - - float KQ_rowsum_f[ncols/nwarps] = {0.0f}; - float KQ_max_f[ncols/nwarps]; - float KQ_max_scale_f[ncols/nwarps] = {0.0f}; - -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - KQ_max_f[j] = -FLT_MAX/2.0f; - } - - half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; - half2 KQ_max_h2[ncols/nwarps]; - half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; - -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); - } - - __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. - half2 * VKQ2 = (half2 *) VKQ; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { - break; - } - VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); - } - } - - // Convert Q to half and apply scale, temporarily store in KQ: -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } - } - - __syncthreads(); - - // Load Q into tensor core fragments/registers since it will be used frequently: -#pragma unroll - for (int i0 = 0; i0 < D; i0 += 16) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); - } - } - - __syncthreads(); - - // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { - // Calculate tile of KQ: -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { - frag_c_KQ KQ_c[ncols/frag_n]; -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); - } -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { - frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); - } - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); - } - } - - __syncthreads(); - - // Calculate softmax for each KQ column using the current max. value. - // The divisor is stored in KQ_rowsum and will be applied at the end. -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (std::is_same::value) { - float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; - - if (use_logit_softcap) { - KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); - } - } - - float KQ_max_new = KQ_max_f[j0/nwarps]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; - KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); - } - KQ_max_new = warp_reduce_max(KQ_max_new); - - const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; - KQ_max_scale_f[j0/nwarps] = expf(diff); - if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_max_scale_f[j0/nwarps] = 0.0f; - } - KQ_max_f[j0/nwarps] = KQ_max_new; - - float KQ_rowsum_add = 0.0f; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; - KQ_f_tmp[k0/WARP_SIZE] = expf(diff); - if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_f_tmp[k0/WARP_SIZE] = 0.0f; - } - KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; - KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; - } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); - - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; - } else { - half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; - - if (use_logit_softcap) { - // There is no dedicated tangens hyperbolicus function for half2. - KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); - KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) - /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); - - KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; - } - } - - half2 KQ_max_new = KQ_max_h2[j0/nwarps]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); - } - KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; - KQ_max_scale_h2[j0/nwarps] = h2exp(diff); - const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; - KQ_max_h2[j0/nwarps] = KQ_max_new; - - half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; - KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); - const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; - KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; - KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; - } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); - - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; - } - } - - __syncthreads(); - - frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { - const int k = k0 + (threadIdx.y % VKQ_ratio)*16; - nvcuda::wmma::load_matrix_sync( - KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*(kqar*kqs_padded) + k, - kqar*kqs_padded); - } - } - - frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; -#pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); - } - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { - const int k = k0 + (threadIdx.y % VKQ_ratio)*16; - - frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); - } - } - } - - __syncthreads(); - - const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), - VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - half2 VKQ_scale; - if (std::is_same::value) { - VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); - } else { - VKQ_scale = KQ_max_scale_h2[j0/nwarps]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { - break; - } - - half2 VKQ_add = make_half2(0.0f, 0.0f); -#pragma unroll - for (int l = 0; l < VKQ_ratio; ++l) { - VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; - } - VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; - } - } - - __syncthreads(); - } - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j_VKQ = j0 + threadIdx.y; - if (ic0 + j_VKQ >= ne01) { - return; - } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - - float KQ_rowsum_j; - if (std::is_same::value) { - KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; - } else { - KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); - } - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - float dst_val = VKQ[j_VKQ*D_padded + i]; - if (parallel_blocks == 1) { - dst_val /= KQ_rowsum_j; - } - dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; - } - - if (parallel_blocks == 1 || threadIdx.x != 0) { - continue; - } - - float2 dst_meta_val; - if (std::is_same::value) { - dst_meta_val.x = KQ_max_f[j0/nwarps]; - } else { - dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); - } - dst_meta_val.y = KQ_rowsum_j; - dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; - } -#else - NO_DEVICE_CODE; -#endif // FP16_MMA_AVAILABLE -} - -constexpr int get_max_power_of_2(int x) { - return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; -} - -static_assert(get_max_power_of_2(1) == 1, "Test failed."); -static_assert(get_max_power_of_2(2) == 2, "Test failed."); -static_assert(get_max_power_of_2(4) == 4, "Test failed."); -static_assert(get_max_power_of_2(6) == 2, "Test failed."); - -// Number of VKQ rows calculated in parallel: -constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { - return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; -} - -static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); -static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); -static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); -static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); -static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); -static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); -static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); -static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); -static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); - -template -void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - constexpr int nwarps = 4; - - constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; - const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (4*blocks_num_pb1 < 2*nsm) { - constexpr int parallel_blocks = 4; - fattn_kernel_t fattn_kernel; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } else { - constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); - return; - } - if (2*blocks_num_pb1 < 2*nsm) { - constexpr int parallel_blocks = 2; - fattn_kernel_t fattn_kernel; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } else { - constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); - return; - } - constexpr int parallel_blocks = 1; - fattn_kernel_t fattn_kernel; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } else { - constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); -} - -#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \ - template void ggml_cuda_flash_attn_ext_wmma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ - -extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float); -extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float); -extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float); -extern DECL_FATTN_WMMA_F16_CASE(112, 16, float); -extern DECL_FATTN_WMMA_F16_CASE(128, 16, float); -extern DECL_FATTN_WMMA_F16_CASE(256, 16, float); - -extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float); -extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float); -extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float); -extern DECL_FATTN_WMMA_F16_CASE(112, 32, float); -extern DECL_FATTN_WMMA_F16_CASE(128, 32, float); -// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float); - -extern DECL_FATTN_WMMA_F16_CASE( 64, 8, half); -extern DECL_FATTN_WMMA_F16_CASE( 96, 8, half); -extern DECL_FATTN_WMMA_F16_CASE(128, 8, half); -extern DECL_FATTN_WMMA_F16_CASE(256, 8, half); - -extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half); -extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half); -extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half); -extern DECL_FATTN_WMMA_F16_CASE(112, 16, half); -extern DECL_FATTN_WMMA_F16_CASE(128, 16, half); -extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); - -extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half); -extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half); -extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half); -extern DECL_FATTN_WMMA_F16_CASE(112, 32, half); -extern DECL_FATTN_WMMA_F16_CASE(128, 32, half); -extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); +void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 0b26b0f8e..b0cf152f5 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -1,5 +1,6 @@ #include "common.cuh" #include "fattn-common.cuh" +#include "fattn-mma-f16.cuh" #include "fattn-tile-f16.cuh" #include "fattn-tile-f32.cuh" #include "fattn-vec-f16.cuh" @@ -7,144 +8,56 @@ #include "fattn-wmma-f16.cuh" #include "fattn.cuh" -#include +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; -static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); - - if (prec != GGML_PREC_DEFAULT) { - if (Q->ne[1] <= 32 || Q->ne[0] > 128) { - constexpr int cols_per_block = 16; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } - } else { - constexpr int cols_per_block = 32; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); - break; - // case 256: - // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); - // break; - default: - GGML_ABORT("fatal error"); - break; - } - } - return; - } - - if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { - constexpr int cols_per_block = 8; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } - return; - } - - if (Q->ne[1] <= 32) { - constexpr int cols_per_block = 16; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } - return; - } - - constexpr int cols_per_block = 32; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst); break; default: GGML_ABORT("fatal error"); break; } } + +static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); + return; + } + + if (Q->ne[1] <= 16) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst); + return; + } + + if (Q->ne[1] <= 32) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst); +} + #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ @@ -323,10 +236,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (!fp16_mma_available(cc)) { - if (Q->ne[1] <= 8) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + if (prec == GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + } } else { - ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); + } } return; } @@ -341,5 +262,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } } - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: + if (cc == GGML_CUDA_CC_VOLTA) { + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 383131c77..bda10aec1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1205,7 +1205,7 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - if (compute_capability == GGML_CUDA_CC_CDNA) { + if (GGML_CUDA_CC_IS_CDNA(compute_capability)) { const float alpha = 1.0f; const float beta = 0.0f; CUBLAS_CHECK( @@ -1750,7 +1750,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co beta = &beta_f32; } - if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { + if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) { cu_compute_type = CUBLAS_COMPUTE_32F; alpha = &alpha_f32; beta = &beta_f32; diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 7d11540af..9788a1389 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -1,11 +1,67 @@ +// This file contains primitives that expose the tensor core PTX instructions for CUDA code. +// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. +// The documentation for the PTX instructions can be found under: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction +// +// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C. +// A is a row-major matrix with shape I x K. +// B is a column-major matrix with shape K x J. +// C is a column-major matrix with shape I x J. +// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements. +// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile. +// All matrix tiles have ne physical 32 bit elements per warp. +// +// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. + #include "common.cuh" -struct mma_int_A_I16K4 { + +#if CUDART_VERSION >= 11800 + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + int ret = 0; + +#ifdef NEW_MMA_AVAILABLE + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "+r"(ret) : "r"(x)); +#else + NO_DEVICE_CODE; +#endif // defined(NEW_MMA_AVAILABLE) + return ret; +} + +#else + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + // Imagine transposing row-major matrix to column-major matrix. + const int src_i_low = 2 * (threadIdx.x % 4); + const int src_i_high = src_i_low + 1; + const int src_j = threadIdx.x / 4; + + const int src_laneid_low = src_i_low * 4 + src_j / 2; + const int src_laneid_high = src_i_high * 4 + src_j / 2; + + const int shift_low = ((src_j + 0) % 2) * 16; + const int shift_high = ((src_j + 1) % 2) * 16; + + const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; + const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; + + return ret_low | ret_high; +} + +#endif // CUDART_VERSION >= 11800 + + +template +struct mma_A_I16K4 { + static_assert(sizeof(T) == 4, "bad type size"); + static constexpr int I = 16; static constexpr int K = 4; static constexpr int ne = 2; - int x[ne] = {0}; + T x[ne]; static __device__ __forceinline__ int get_i(const int l) { const int ret = (l%2) * (I/2) + threadIdx.x / K; @@ -21,27 +77,35 @@ struct mma_int_A_I16K4 { return ret; } - __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { -#if defined(INT8_MMA_AVAILABLE) - const int * xs = xs0 + (threadIdx.x%I)*stride; - asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" - : "+r"(x[0]), "+r"(x[1]) - : "l"(xs)); -#else + __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { #pragma unroll for (int l = 0; l < ne; ++l) { x[l] = xs0[get_i(l)*stride + get_k(l)]; } -#endif // defined(INT8_MMA_AVAILABLE) + } + + __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int *) x; + const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride; + asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "+r"(xi[0]), "+r"(xi[1]) + : "l"(xs)); +#else + load_generic(xs0, stride); +#endif // NEW_MMA_AVAILABLE } }; -struct mma_int_A_I16K8 { +template +struct mma_A_I16K8 { + static_assert(sizeof(T) == 4, "bad type size"); + static constexpr int I = 16; static constexpr int K = 8; static constexpr int ne = 4; - int x[ne] = {0}; + T x[ne]; static __device__ __forceinline__ int get_i(const int l) { const int ret = (l%2) * (I/2) + threadIdx.x / (K/2); @@ -57,31 +121,62 @@ struct mma_int_A_I16K8 { return ret; } - __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { -#if defined(INT8_MMA_AVAILABLE) - const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); - asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" - : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) - : "l"(xs)); -#else + __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { #pragma unroll for (int l = 0; l < ne; ++l) { x[l] = xs0[get_i(l)*stride + get_k(l)]; } -#endif // defined(INT8_MMA_AVAILABLE) } - __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) { - ((mma_int_A_I16K4 *) x)[0].load(xs0, stride); + __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int * ) x; + const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); + asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "l"(xs)); +#else + GGML_UNUSED(xs0); + GGML_UNUSED(stride); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + + __device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int * ) x; + const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); + asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" + : "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3]) + : "l"(xs)); +#else + GGML_UNUSED(xs0); + GGML_UNUSED(stride); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + + __device__ __forceinline__ void transpose() { + int * xi = (int *) x; + xi[0] = ggml_cuda_movmatrix(xi[0]); + + const int tmp = ggml_cuda_movmatrix(xi[1]); + xi[1] = ggml_cuda_movmatrix(xi[2]); + xi[2] = tmp; + + xi[3] = ggml_cuda_movmatrix(xi[3]); } }; -struct mma_int_B_J8K4 { +template +struct mma_B_J8K4 { + static_assert(sizeof(T) == 4, "bad type size"); + static constexpr int J = 8; static constexpr int K = 4; static constexpr int ne = 1; - int x[ne] = {0}; + T x[ne]; static __device__ __forceinline__ int get_j(const int /* l */) { const int ret = threadIdx.x / K; @@ -97,27 +192,34 @@ struct mma_int_B_J8K4 { return ret; } - __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { -#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster - const int * xs = xs0 + (threadIdx.x%J)*stride; - asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" - : "+r"(x[0]) - : "l"(xs)); -#else + __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { #pragma unroll for (int l = 0; l < ne; ++l) { x[l] = xs0[get_j(l)*stride + get_k(l)]; } -#endif // defined(INT8_MMA_AVAILABLE) + } + + __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int *) x; + const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride; + asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" + : "+r"(xi[0]) : "l"(xs)); +#else + load_generic(xs0, stride); +#endif // NEW_MMA_AVAILABLE } }; -struct mma_int_B_J8K8 { +template +struct mma_B_J8K8 { + static_assert(sizeof(T) == 4, "bad type size"); + static constexpr int J = 8; static constexpr int K = 8; static constexpr int ne = 2; - int x[ne] = {0}; + T x[ne]; static __device__ __forceinline__ int get_j(const int /* l */) { const int ret = threadIdx.x / (K/2); @@ -133,22 +235,31 @@ struct mma_int_B_J8K8 { return ret; } - __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { -#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster - const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; - asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" - : "+r"(x[0]), "+r"(x[1]) - : "l"(xs)); -#else + __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { #pragma unroll for (int l = 0; l < ne; ++l) { x[l] = xs0[get_j(l)*stride + get_k(l)]; } -#endif // defined(INT8_MMA_AVAILABLE) + } + + __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int *) x; + const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; + asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "+r"(xi[0]), "+r"(xi[1]) + : "l"(xs)); +#else + load_generic(xs0, stride); +#endif // NEW_MMA_AVAILABLE } }; -struct mma_int_C_I16J8 { +template +struct mma_C_I16J8 {}; + +template <> +struct mma_C_I16J8 { static constexpr int I = 16; static constexpr int J = 8; static constexpr int ne = 4; @@ -169,8 +280,8 @@ struct mma_int_C_I16J8 { return ret; } - __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) { -#ifdef INT8_MMA_AVAILABLE + __device__ __forceinline__ void mma(const mma_A_I16K4 & mma_A, const mma_B_J8K4 & mma_B) { +#ifdef NEW_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) @@ -188,11 +299,11 @@ struct mma_int_C_I16J8 { GGML_UNUSED(mma_A); GGML_UNUSED(mma_B); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) { -#ifdef INT8_MMA_AVAILABLE + __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { +#ifdef NEW_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) @@ -216,6 +327,132 @@ struct mma_int_C_I16J8 { GGML_UNUSED(mma_A); GGML_UNUSED(mma_B); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE + } +}; + +template <> +struct mma_C_I16J8 { + static constexpr int I = 16; + static constexpr int J = 4; + static constexpr int ne = 2; + + half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = l * (I/2) + threadIdx.x / J; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_j(const int /* l */) { + const int ret = threadIdx.x % J; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { +#ifdef NEW_MMA_AVAILABLE + int * Axi = (int *) mma_A.x; + int * Bxi = (int *) mma_B.x; + int * xi = (int *) x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(xi[0]), "+r"(xi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(xi[0]), "+r"(xi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(xi[0]), "+r"(xi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED(mma_A); + GGML_UNUSED(mma_B); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + + __device__ __forceinline__ mma_B_J8K8 to_mma_B() { + mma_B_J8K8 mma_B; + + int * xi = (int *) x; + int * Bxi = (int *) mma_B.x; + Bxi[0] = ggml_cuda_movmatrix(xi[0]); + Bxi[1] = ggml_cuda_movmatrix(xi[1]); + + return mma_B; + } +}; + +template <> +struct mma_C_I16J8 { + static constexpr int I = 16; + static constexpr int J = 8; + static constexpr int ne = 4; + + float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_j(const int l) { + const int ret = 2 * (threadIdx.x % (J/2)) + l%2; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { +#ifdef NEW_MMA_AVAILABLE + int * Axi = (int *) mma_A.x; + int * Bxi = (int *) mma_B.x; + int * xi = (int *) x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED(mma_A); + GGML_UNUSED(mma_B); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + + __device__ __forceinline__ mma_B_J8K8 to_mma_B() { + mma_B_J8K8 mma_B; + mma_B.x[0] = make_half2(x[0], x[1]); + mma_B.x[1] = make_half2(x[2], x[3]); + + int * Bxi = (int *) mma_B.x; + Bxi[0] = ggml_cuda_movmatrix(Bxi[0]); + Bxi[1] = ggml_cuda_movmatrix(Bxi[1]); + + return mma_B; + } + + __device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) { +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_j(l)*stride + get_i(l)]; + } } }; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 270251df4..45212f66c 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -132,7 +132,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - if (int8_mma_available(cc)) { + if (new_mma_available(cc)) { return true; } @@ -148,5 +148,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } - return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 3cd508a1d..7a2c4d85b 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -87,7 +87,7 @@ struct tile_x_sizes { }; static constexpr int get_mmq_x_max_host(const int cc) { - return int8_mma_available(cc) ? 128 : + return new_mma_available(cc) ? 128 : #ifdef GGML_CUDA_FORCE_MMQ cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64; #else @@ -96,9 +96,9 @@ static constexpr int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE return 128; -#else // INT8_MMA_AVAILABLE +#else // NEW_MMA_AVAILABLE #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) return 128; @@ -116,11 +116,11 @@ static constexpr __device__ int get_mmq_x_max_device() { #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } static constexpr int get_mmq_y_host(const int cc) { - return cc >= GGML_CUDA_CC_OFFSET_AMD ? (cc == GGML_CUDA_CC_RDNA1 ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64); + return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64); } static constexpr __device__ int get_mmq_y_device() { @@ -209,10 +209,10 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) static int mmq_get_granularity_host(const int mmq_x, const int cc) { - return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8; + return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8; } -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 48 ? 16 : 8; } @@ -220,21 +220,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) { return 8; } -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE // ------------------------------------------------------------ template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_0; const int kqsx = threadIdx.x % QI4_0; @@ -250,12 +250,12 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; @@ -271,11 +271,11 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -322,14 +322,14 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( template static __device__ __forceinline__ void load_tiles_q4_1( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_1; const int kqsx = threadIdx.x % QI4_1; @@ -345,12 +345,12 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; @@ -366,11 +366,11 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -417,14 +417,14 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( template static __device__ __forceinline__ void load_tiles_q5_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI5_0; const int kqsx = threadIdx.x % QI5_0; @@ -456,13 +456,13 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; @@ -478,25 +478,25 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_q5_1( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI5_1; const int kqsx = threadIdx.x % QI5_1; @@ -526,13 +526,13 @@ template static __device__ __forceinlin qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; @@ -548,25 +548,25 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_q8_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_tile + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI8_0; const int kqsx = threadIdx.x % QI8_0; @@ -581,13 +581,13 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); #else x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0; @@ -603,11 +603,11 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -645,9 +645,9 @@ template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K8 mma_A; + typedef mma_B_J8K8 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -672,7 +672,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { const int k0 = k00 + k01; - A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll @@ -695,7 +695,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( mma_B B; float dB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -711,7 +711,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C; - C.mma_K8(A[n][k01/QI8_0], B); + C.mma(A[n][k01/QI8_0], B); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -756,9 +756,9 @@ template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K8 mma_A; + typedef mma_B_J8K8 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -782,7 +782,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll @@ -805,7 +805,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( mma_B B; float2 dsB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -817,7 +817,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C; - C.mma_K8(A[n][k01/QI8_1], B); + C.mma(A[n][k01/QI8_1], B); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -864,12 +864,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE - typedef mma_int_A_I16K4 mma_A; - typedef mma_int_A_I16K8 mma_A_K8; - typedef mma_int_B_J8K4 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K4 mma_A; + typedef mma_A_I16K8 mma_A_K8; + typedef mma_B_J8K4 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -893,7 +893,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { const int k0 = k00 + k01; - ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + ((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll @@ -916,8 +916,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( mma_B B[2]; float dB[mma_C::ne/2]; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + // Here load_generic is faster than load_ldmatrix. + B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -929,8 +930,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C[2]; - C[0].mma_K4(A[n][k01/4 + 0], B[0]); - C[1].mma_K4(A[n][k01/4 + 1], B[1]); + C[0].mma(A[n][k01/4 + 0], B[0]); + C[1].mma(A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -942,20 +943,20 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % QI2_K; @@ -977,11 +978,11 @@ template static __device__ __forceinlin const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int sc_m = bxi->scales[kqsx]; @@ -992,11 +993,11 @@ template static __device__ __forceinlin const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -1051,12 +1052,12 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE - typedef mma_int_A_I16K4 mma_A; - typedef mma_int_A_I16K8 mma_A_K8; - typedef mma_int_B_J8K4 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K4 mma_A; + typedef mma_A_I16K8 mma_A_K8; + typedef mma_B_J8K4 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -1081,7 +1082,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + ((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } } @@ -1118,24 +1119,25 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { mma_B B[2]; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + // Here load_generic is faster than load_ldmatrix. + B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); mma_C Cm[2]; if (k01 >= WARP_SIZE * 3/4) { mma_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; - Cm[0].mma_K4(A1, B[0]); - Cm[1].mma_K4(A1, B[1]); + Cm[0].mma(A1, B[0]); + Cm[1].mma(A1, B[1]); } #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C Cd[2]; - Cd[0].mma_K4(A[n][k01/4 + 0], B[0]); - Cd[1].mma_K4(A[n][k01/4 + 1], B[1]); + Cd[0].mma(A[n][k01/4 + 0], B[0]); + Cd[1].mma(A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -1172,13 +1174,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q3_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else @@ -1186,7 +1188,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % QI3_K; @@ -1212,11 +1214,11 @@ template static __device__ __forceinlin const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -1242,7 +1244,7 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; @@ -1252,10 +1254,10 @@ template static __device__ __forceinlin } #else x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } -#ifndef INT8_MMA_AVAILABLE +#ifndef NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; @@ -1268,7 +1270,7 @@ template static __device__ __forceinlin x_df[i] = bxi->d; } -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template @@ -1317,7 +1319,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co template static __device__ __forceinline__ void load_tiles_q4_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else @@ -1325,7 +1327,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -1338,15 +1340,15 @@ template static __device__ __forceinlin const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; const int qs0 = get_int_b4(bxi->qs, threadIdx.x); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { @@ -1407,7 +1409,7 @@ template static __device__ __forceinlin x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; } -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template @@ -1446,7 +1448,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( template static __device__ __forceinline__ void load_tiles_q5_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); #else @@ -1454,7 +1456,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -1478,16 +1480,16 @@ template static __device__ __forceinlin const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0; x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { @@ -1548,7 +1550,7 @@ template static __device__ __forceinlin x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; } -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template @@ -1587,7 +1589,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( template static __device__ __forceinline__ void load_tiles_q6_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K); @@ -1596,7 +1598,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -1619,13 +1621,13 @@ template static __device__ __forceinlin const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0; const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 @@ -1641,11 +1643,11 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } #pragma unroll @@ -1658,11 +1660,11 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); #else x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -1702,11 +1704,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE - typedef mma_int_A_I16K4 mma_A; - typedef mma_int_B_J8K4 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K4 mma_A; + typedef mma_B_J8K4 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -1732,8 +1734,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { const int k0 = k00 + k01; - A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); - A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); + A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll @@ -1771,8 +1773,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( mma_B B[2]; float dB[mma_C::ne/2]; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K); + // Here load_generic is faster than load_ldmatrix. + B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -1784,8 +1787,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C[2]; - C[0].mma_K4(A[n][k01/4 + 0], B[0]); - C[1].mma_K4(A[n][k01/4 + 1], B[1]); + C[0].mma(A[n][k01/4 + 0], B[0]); + C[1].mma(A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -1805,20 +1808,20 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_iq4_nl( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_NL; const int kqsx = threadIdx.x % QI4_NL; @@ -1836,13 +1839,13 @@ template static __device__ __forceinlin const int aux_q4 = get_int_b2(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL; @@ -1858,25 +1861,25 @@ template static __device__ __forceinlin const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq2_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI2_XXS/2); @@ -1905,36 +1908,36 @@ template static __device__ __forceinlin const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = aux32 >> 28; const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq2_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI2_XS/2); @@ -1959,38 +1962,38 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq2_s( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI2_S/2); @@ -2022,38 +2025,38 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq3_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI3_XXS/2); @@ -2080,36 +2083,36 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = aux32 >> 28; const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq3_s( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI3_S/2); @@ -2143,36 +2146,36 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq1_s( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % QI1_S; @@ -2198,37 +2201,37 @@ template static __device__ __forceinlin const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq4_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = 0; // threadIdx.x / QI4_XS const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS @@ -2246,13 +2249,13 @@ template static __device__ __forceinlin const int aux_q4 = get_int_b4(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } #pragma unroll @@ -2270,11 +2273,11 @@ template static __device__ __forceinlin const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -2307,16 +2310,16 @@ template static __device__ __forceinline__ void mmq_write_back_mma( const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { - typedef mma_int_C_I16J8 mma_C; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { @@ -2505,13 +2508,13 @@ static __device__ void mul_mat_q_process_tile( int * tile_y = (int *) data_mul_mat_q; int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -2643,7 +2646,7 @@ static __global__ void mul_mat_q( const int jt = kbc / (blocks_per_ne00*nty); const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; - constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks. + constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. mul_mat_q_process_tile (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, it, jt, kb0_start, kb0_stop); @@ -2749,7 +2752,7 @@ template static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) { const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); - const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const int shmem_y = mmq_x*sizeof(block_q8_1_mmq); return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); } diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu index ac45f2d17..5a9ddd958 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmv.cu @@ -5,9 +5,10 @@ template static __global__ void mul_mat_vec( const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { - const int64_t row = blockIdx.x; - const int64_t channel = blockIdx.z; - const int tid = threadIdx.x; + const int64_t row = blockIdx.x; + const int64_t channel = blockIdx.z; + const int tid = threadIdx.x; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); x += (channel/channel_ratio)*stride_channel_x + row*stride_row; y += channel *stride_channel_y; @@ -18,8 +19,8 @@ static __global__ void mul_mat_vec( extern __shared__ char data_mmv[]; float * buf_iw = (float *) data_mmv; - if (block_size > WARP_SIZE) { - if (tid < WARP_SIZE) { + if (block_size > warp_size) { + if (tid < warp_size) { buf_iw[tid] = 0.0f; } __syncthreads(); @@ -67,16 +68,16 @@ static __global__ void mul_mat_vec( static_assert(std::is_same::value, "unsupported type"); } - sumf = warp_reduce_sum(sumf); + sumf = warp_reduce_sum(sumf); - if (block_size > WARP_SIZE) { - buf_iw[tid/WARP_SIZE] = sumf; + if (block_size > warp_size) { + buf_iw[tid/warp_size] = sumf; __syncthreads(); - if (tid >= WARP_SIZE) { + if (tid >= warp_size) { return; } sumf = buf_iw[tid]; - sumf = warp_reduce_sum(sumf); + sumf = warp_reduce_sum(sumf); } if (tid != 0) { @@ -96,10 +97,19 @@ static void launch_mul_mat_vec_cuda( GGML_ASSERT(stride_row % 2 == 0); GGML_ASSERT(nchannels_y % nchannels_x == 0); const int64_t channel_ratio = nchannels_y / nchannels_x; + int device; + int warp_size; - int64_t block_size_best = WARP_SIZE; - int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE); - for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) { + CUDA_CHECK(cudaGetDevice(&device)); + warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t block_size_best = warp_size; + int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); + int64_t max_block_size = 256; + if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) { + max_block_size = 128; + } + for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) { const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); if (niter < niter_best) { niter_best = niter; @@ -107,7 +117,7 @@ static void launch_mul_mat_vec_cuda( } } - const int smem = WARP_SIZE*sizeof(float); + const int smem = warp_size*sizeof(float); const dim3 block_nums(nrows, 1, nchannels_y); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index da377200e..aac6e0999 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -18,7 +18,7 @@ __device__ float __forceinline__ t2f32(half val) { #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wpass-failed" -#endif +#endif // __clang__ template static __global__ void soft_max_f32( const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, @@ -126,7 +126,7 @@ static __global__ void soft_max_f32( } #ifdef __clang__ #pragma clang diagnostic pop -#endif +#endif // __clang__ static __global__ void soft_max_back_f32( const float * grad, const float * dstf, float * dst, const int ncols, const float scale) { diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu new file mode 100644 index 000000000..f09bdeff7 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16); +DECL_FATTN_MMA_F16_CASE(80, 16); +DECL_FATTN_MMA_F16_CASE(96, 16); +DECL_FATTN_MMA_F16_CASE(112, 16); +DECL_FATTN_MMA_F16_CASE(128, 16); +DECL_FATTN_MMA_F16_CASE(256, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu new file mode 100644 index 000000000..221108873 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32); +DECL_FATTN_MMA_F16_CASE(80, 32); +DECL_FATTN_MMA_F16_CASE(96, 32); +DECL_FATTN_MMA_F16_CASE(112, 32); +DECL_FATTN_MMA_F16_CASE(128, 32); +DECL_FATTN_MMA_F16_CASE(256, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu new file mode 100644 index 000000000..d24b08575 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64); +DECL_FATTN_MMA_F16_CASE(80, 64); +DECL_FATTN_MMA_F16_CASE(96, 64); +DECL_FATTN_MMA_F16_CASE(112, 64); +DECL_FATTN_MMA_F16_CASE(128, 64); +DECL_FATTN_MMA_F16_CASE(256, 64); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu new file mode 100644 index 000000000..bdf86c0ea --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8); +DECL_FATTN_MMA_F16_CASE(80, 8); +DECL_FATTN_MMA_F16_CASE(96, 8); +DECL_FATTN_MMA_F16_CASE(112, 8); +DECL_FATTN_MMA_F16_CASE(128, 8); +DECL_FATTN_MMA_F16_CASE(256, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu deleted file mode 100644 index 2d94e65c2..000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 16, float); -DECL_FATTN_WMMA_F16_CASE(80, 16, float); -DECL_FATTN_WMMA_F16_CASE(96, 16, float); -DECL_FATTN_WMMA_F16_CASE(112, 16, float); -DECL_FATTN_WMMA_F16_CASE(128, 16, float); -DECL_FATTN_WMMA_F16_CASE(256, 16, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu deleted file mode 100644 index c3d9df3c4..000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +++ /dev/null @@ -1,9 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 32, float); -DECL_FATTN_WMMA_F16_CASE(80, 32, float); -DECL_FATTN_WMMA_F16_CASE(96, 32, float); -DECL_FATTN_WMMA_F16_CASE(112, 32, float); -DECL_FATTN_WMMA_F16_CASE(128, 32, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu deleted file mode 100644 index bb680e401..000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 16, half); -DECL_FATTN_WMMA_F16_CASE(80, 16, half); -DECL_FATTN_WMMA_F16_CASE(96, 16, half); -DECL_FATTN_WMMA_F16_CASE(112, 16, half); -DECL_FATTN_WMMA_F16_CASE(128, 16, half); -DECL_FATTN_WMMA_F16_CASE(256, 16, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu deleted file mode 100644 index 073f71b1f..000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 32, half); -DECL_FATTN_WMMA_F16_CASE(80, 32, half); -DECL_FATTN_WMMA_F16_CASE(96, 32, half); -DECL_FATTN_WMMA_F16_CASE(112, 32, half); -DECL_FATTN_WMMA_F16_CASE(128, 32, half); -DECL_FATTN_WMMA_F16_CASE(256, 32, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu deleted file mode 100644 index d30710c5f..000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +++ /dev/null @@ -1,8 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 8, half); -DECL_FATTN_WMMA_F16_CASE(96, 8, half); -DECL_FATTN_WMMA_F16_CASE(128, 8, half); -DECL_FATTN_WMMA_F16_CASE(256, 8, half); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index d7874e6ea..a2628f16e 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -12,13 +12,13 @@ SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.p DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v}); """ -SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. +SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. -#include "../fattn-wmma-f16.cuh" +#include "../fattn-mma-f16.cuh" """ -SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n" +SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n" TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", @@ -57,20 +57,12 @@ for vkq_size in [16, 32]: with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) -for kq_acc_t in ["half", "float"]: - for cols_per_block in [8, 16, 32]: - if kq_acc_t == "float" and cols_per_block == 8: - continue +for cols_per_block in [8, 16, 32, 64]: + with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f: + f.write(SOURCE_FATTN_MMA_START) - with open(f"fattn-wmma-f16-instance-kq{kq_acc_t}-cpb{cols_per_block}.cu", "w") as f: - f.write(SOURCE_FATTN_WMMA_START) - - for head_size in [64, 80, 96, 112, 128, 256]: - if cols_per_block == 8 and head_size % 32 != 0: # wmma fragment is 8x32 - continue - if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance - continue - f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size)) + for head_size in [64, 80, 96, 112, 128, 256]: + f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size)) for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 8594093f0..81964611c 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -1,5 +1,6 @@ #pragma once +#define HIP_ENABLE_WARP_SYNC_BUILTINS 1 #include #include #include @@ -8,6 +9,7 @@ // for rocblas_initialize() #include "rocblas/rocblas.h" #endif // __HIP_PLATFORM_AMD__ + #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F @@ -25,6 +27,7 @@ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} +#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 #define cublasCreate hipblasCreate diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 7a877bdc1..eb03e10fa 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -50,7 +50,7 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu") -file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu") +file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 76f8e4291..9605914ff 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -20,7 +20,10 @@ #define GGML_METAL_MAX_COMMAND_BUFFERS 8 // create residency sets only on macOS >= 15.0 -#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 +#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ + TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \ + TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \ + TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000 #define GGML_METAL_HAS_RESIDENCY_SETS 1 #endif @@ -1071,7 +1074,7 @@ static bool ggml_backend_metal_buffer_rset_init( } #if defined(GGML_METAL_HAS_RESIDENCY_SETS) - if (@available(macOS 15.0, *)) { + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init]; desc.label = @"ggml_backend_metal"; desc.initialCapacity = ctx->n_buffers; @@ -1106,7 +1109,7 @@ static bool ggml_backend_metal_buffer_rset_init( // rset free static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) { #if defined(GGML_METAL_HAS_RESIDENCY_SETS) - if (@available(macOS 15.0, *)) { + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { if (ctx->rset) { [ctx->rset endResidency]; [ctx->rset removeAllAllocations]; diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 415b2b2e0..2f555416e 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -29,7 +29,7 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu") - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu") + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8fe84df21..ecac5b4bb 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1357,6 +1357,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.OUTPUT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_DOWN, diff --git a/models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja b/models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja new file mode 100644 index 000000000..078e9f545 --- /dev/null +++ b/models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja @@ -0,0 +1,156 @@ +{{ bos_token }}{%- macro document_turn(documents) -%} +{# format documents into chat turn #} +<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[ + {"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}} +]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[ + { + "tool_call_id": "0", + "results": { +{% for doc in documents %} + "{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %}, + {% endif %} +{% endfor %} + + }, + "is_error": null + } +]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %} +{%- macro tool_call_id_to_int(messages, tool_call_id) %} +{%- set counter = namespace(value=0) %} +{%- set tool_call_id_seen = namespace(value=false) %} +{%- for msg in messages %} + {%- if msg.tool_calls %} + {%- for tool_call in msg.tool_calls %} + {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%} + {{ counter.value }} + {%- set tool_call_id_seen.value = true %} + {%- endif %} + {%- set counter.value = counter.value + 1 %} + {%- endfor %} + {%- endif %} +{%- endfor %} +{%- endmacro %} +{%- macro format_tool_message(messages, tool_msg) -%} +{# format tool message #} + { + "tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}", + "results": { + "0": {{ tool_msg.content|tojson }} + }, + "is_error": null + } +{%- endmacro -%} +{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %} +{%- set tool_idx = namespace(value=0) %} +{%- set tool_ids_seen = namespace(value=[]) %} +{%- set sent_documents = namespace(value=false) %} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble +You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes. + +Your information cutoff date is June 2024. + +You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages. +{% if tools or documents %} + +You have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests. + +## Tool Use +Think about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first. + +0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>. + You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed. + NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools. + +Then carry out your plan by repeatedly executing the following steps. +1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields. + When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>. +2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results. + Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id". +3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>. + You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded. + NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user. + +You can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user. + +4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>. +{% if enable_citations %} + +## Grounding +Importantly, note that "Reflection" and "Response" above can be grounded. +Grounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "" and "" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "span" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1". +{% endif %} + +## Available Tools +Here is the list of tools that you have available to you. +You can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it. +Each tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema). + +```json +[ +{% if documents %} + {"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %} + +{% endif %} +{% for tool in tools %} + {"name": "{{ tool['function']['name'] }}", "description": "{{tool['function']['description']}}", "parameters": {{ tool['function']['parameters']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %} + +{% endfor %} +] +``` + +{% endif %} +# Default Preamble +The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt. +- Your name is Command. +- You are a large language model built by Cohere. +- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions. +- If the input is ambiguous, ask clarifying follow-up questions. +- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks). +- Use LaTeX to generate mathematical notation for complex equations. +- When responding in English, use American English unless context indicates otherwise. +- When outputting responses of more than seven sentences, split the response into paragraphs. +- Prefer the active voice. +- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references. +- Use gender-neutral pronouns for unspecified persons. +- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list. +- Use the third person when asked to write a summary. +- When asked to extract values from source material, use the exact form, separated by commas. +- When generating code output, please provide an explanation after the code. +- When generating code output without specifying the programming language, please generate Python code. +- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer. +{%- if developer_preamble %} + + +# Developer Preamble +The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions. +{{ developer_preamble }} +{%- endif -%} +<|END_OF_TURN_TOKEN|> +{%- for message in messages %} + {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|> + {%- elif message.role|lower == 'user' %} +<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %} + {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %} +<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[ + {% for tc in message.tool_calls %} + {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %} + + {% set tool_idx.value = tool_idx.value + 1 %} + {% endfor %} +]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %} + {% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[ +{{ format_tool_message(messages, message) }} + {%- for msg in messages[loop.index0 + 1:] %} + {%- if msg.role|lower == 'tool' %}, +{{ format_tool_message(messages, msg) }} + {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %} + {%- else %} + {%- break %} + {%- endif %} + {%- endfor %} + +]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|> + {%- endif %} +{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> \ No newline at end of file diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index ddb9d817e..26a105f64 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -32f0b85987396945afea2291d5f4c5862434292b +694244a6e40dc255f6bb4376fb17431c06633e6c diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a7260f495..97a1e7e5e 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1024,6 +1024,9 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 5c19bab24..028a64794 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -51,6 +51,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 }, + { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE }, { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, @@ -115,7 +116,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { return LLM_CHAT_TEMPLATE_PHI_3; } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { - return LLM_CHAT_TEMPLATE_FALCON_3; + return tmpl_contains("") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { return LLM_CHAT_TEMPLATE_ZEPHYR; } else if (tmpl_contains("bos_token + message['role']")) { @@ -440,6 +441,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|assistant|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) { + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) { // MiniCPM-3B-OpenHermes-2.5-v2-GGUF for (auto message : chat) { diff --git a/src/llama-chat.h b/src/llama-chat.h index 3a4d07ce3..2f6a0e3e2 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -31,6 +31,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_CHATGML_3, LLM_CHAT_TEMPLATE_CHATGML_4, + LLM_CHAT_TEMPLATE_GLMEDGE, LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_EXAONE_3, LLM_CHAT_TEMPLATE_RWKV_WORLD, diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 6be5cbe0e..9b518d1ac 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1213,5 +1213,7 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string } grammar.partial_utf8 = decoded.second; - GGML_ASSERT(!grammar.stacks.empty()); + if (grammar.stacks.empty()) { + throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); + } } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 18bd0b071..0487c978b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1093,8 +1093,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 28: type = LLM_TYPE_6B; break; - case 40: type = LLM_TYPE_9B; break; + case 28: { + if (hparams.n_head(0) == 16) { + type = LLM_TYPE_1_5B; + } else { + type = LLM_TYPE_6B; + } + } break; + case 40: { + if (hparams.n_head(0) == 24) { + type = LLM_TYPE_4B; + } else { + type = LLM_TYPE_9B; + } + } break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -3068,9 +3080,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); diff --git a/src/llama.cpp b/src/llama.cpp index 192b20a27..5760017e0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7215,17 +7215,30 @@ struct llm_build_context { struct ggml_tensor * Qcur = nullptr; struct ggml_tensor * Kcur = nullptr; struct ggml_tensor * Vcur = nullptr; - - cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - + if (model.type == LLM_TYPE_1_5B || model.type == LLM_TYPE_4B || model.type == LLM_TYPE_9B) { + Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + } else { + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 40f83ff0d..7a158d602 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -86,6 +86,9 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${CMAKE llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf) llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf) +if (LLAMA_LLGUIDANCE) + llama_target_and_test(test-grammar-llguidance.cpp ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf) +endif () if (NOT WIN32) # these tests are disabled on Windows because they use internal functions not exported with LLAMA_API diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 4563f9dcb..e0314ae1d 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -175,6 +175,14 @@ int main(void) { /* .bos_token= */ "", /* .eos_token= */ "", }, + { + /* .name= */ "GLMEdge", + /* .template_str= */ "{% for item in messages %}{% if item['role'] == 'system' %}<|system|>\n{{ item['content'] }}{% elif item['role'] == 'user' %}<|user|>\n{{ item['content'] }}{% elif item['role'] == 'assistant' %}<|assistant|>\n{{ item['content'] }}{% endif %}{% endfor %}<|assistant|>", + /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", + /* .expected_output_jinja= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", + /* .bos_token= */ "", + /* .eos_token= */ "", + }, { /* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF", /* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}", diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index ccc65d87a..50bd40738 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -22,9 +22,13 @@ static common_chat_msg msg_from_json(const json & message) { "assistant", "", {}, + /* .tool_plan = */ "", }; if (message.contains("content") && !message.at("content").is_null()) { - ret.content = message.at("content").get(); + ret.content = message.at("content"); + } + if (message.contains("tool_plan")) { + ret.tool_plan = message.at("tool_plan"); } auto has_tool_calls = message.contains("tool_calls"); if (has_tool_calls) { @@ -171,8 +175,7 @@ const json llama_3_1_tools = { special_function_tool, code_interpreter_too struct delta_data { std::string delta; - std::string grammar; - common_chat_format format; + common_chat_params params; }; static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, @@ -214,7 +217,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto break; } } - return { delta, params_full.grammar, params_full.format }; + return { delta, params_full }; } /* @@ -224,7 +227,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto */ static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, const json & test_message, const json & tools = {}, const std::string & expected_delta = "", - bool skip_grammar_test = false, bool skip_parser_test = false) { + bool expect_grammar_triggered = true) { common_chat_msg expected_msg = msg_from_json(test_message); auto user_message = json{ @@ -238,45 +241,110 @@ static void test_template(const common_chat_template & tmpl, const std::vector 0 && trigger.at_start) { + fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str()); + continue; + } + if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) { + earliest_trigger_pos = pos; + } + } + auto grammar_triggered = false; + if (earliest_trigger_pos != std::string::npos) { + constrained = constrained.substr(earliest_trigger_pos); + grammar_triggered = true; + } + if (data.params.grammar_lazy) { + assert_equals(expect_grammar_triggered, grammar_triggered); + } + + if (grammar_triggered && !match_string(constrained, grammar.get())) { + throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta + + "\n\nGrammar: " + data.params.grammar); } } } } static void test_template_output_parsers() { - auto text_message = json{ + json text_message { { "role", "assistant" }, - { "content", "Hello, world!" }, + { "content", "Hello, world!\nWhat's up?" }, }; - auto tool_call_message = json{ + json tool_calls = json::array({{ + { "type", "function" }, + { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } }, + }}); + + json tool_call_message { + { "role", "assistant"}, + { "content", {}}, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + }, + }}, + }; + json tool_call_message_with_id { + { "role", "assistant"}, + { "content", {}}, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + {"id", "123456789"}, + }, + }}, { "role", "assistant" }, { "content", {} }, - { "tool_calls", json{ { - { "type", "function" }, - { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } }, - } } } + { "tool_calls", tool_calls } + }; + json tool_call_plan_message_with_idx { + { "role", "assistant"}, + { "content", {}}, + { "tool_plan", "I'm not so sure"}, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + // Index of the tool call in the tool_calls array + {"id", "0"}, + }, + }}, + { "role", "assistant" }, + { "content", {} }, + { "tool_calls", tool_calls } }; - auto tool_call_message_with_id = json::parse(tool_call_message.dump()); - tool_call_message_with_id["tool_calls"][0]["id"] = "123456789"; auto python_tool_call_message = json{ { "role", "assistant" }, @@ -311,7 +379,7 @@ static void test_template_output_parsers() { common_chat_inputs inputs_no_tools; inputs_no_tools.messages = { - { { "role", "user" }, { "content", "Hey" } } + { { "role", "user" }, { "content", "Hey\nThere" } } }; common_chat_inputs inputs_tools = inputs_no_tools; @@ -322,6 +390,28 @@ static void test_template_output_parsers() { inputs_tools_builtin.tools = json::array(); inputs_tools_builtin.tools.push_back(python_tool); + { + // Not supported yet + const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "", ""); + assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format); + } + { + const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "", ""); + std::vector end_tokens{ "<|END_OF_TURN_TOKEN|>" }; + + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format); + + test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools, + "<|START_THINKING|>I'm not so sure<|END_THINKING|>" + "<|START_ACTION|>[\n" + " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" + "]<|END_ACTION|>"); + test_template(tmpl, end_tokens, text_message, tools, + "<|START_RESPONSE|>Hello, world!\n" + "What's up?<|END_RESPONSE|>", + /* expect_grammar_triggered= */ false); + } { const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "", ""); std::vector end_tokens{ "" }; @@ -339,7 +429,7 @@ static void test_template_output_parsers() { assert_msg_equals(msg_from_json(text_message), common_chat_parse("{\n" - " \"response\": \"Hello, world!\"\n" + " \"response\": \"Hello, world!\\nWhat's up?\"\n" "}", common_chat_params_init(tmpl, inputs_tools).format)); test_template(tmpl, end_tokens, tool_call_message_with_id, tools, @@ -362,11 +452,10 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_template( tmpl, end_tokens, tool_call_message_with_id, tools, - "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]", - /* skip_grammar_test= */ true); + "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]"); } { const common_chat_template tmpl( @@ -388,7 +477,7 @@ static void test_template_output_parsers() { inputs_tools) .format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_template(tmpl, end_tokens, tool_call_message, tools, "\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" @@ -413,7 +502,7 @@ static void test_template_output_parsers() { inputs_tools_builtin) .format); - // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true); + // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false); test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools, "<|python_tag|>code_interpreter.call(code=\"print('hey')\")"); test_template(tmpl, end_tokens, python_tool_call_message, tools, @@ -428,7 +517,7 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_template(tmpl, end_tokens, tool_call_message, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } @@ -440,7 +529,7 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_template(tmpl, end_tokens, tool_call_message, tools, "{\"arg1\": 1}"); } @@ -454,8 +543,9 @@ static void test_template_output_parsers() { test_template(tmpl, end_tokens, text_message, {}, "all\n" - "Hello, world!", - /* skip_grammar_test= */ true); + "Hello, world!\n" + "What's up?", + /* expect_grammar_triggered= */ false); test_template(tmpl, end_tokens, tool_call_message, tools, "special_function\n" "{\"arg1\": 1}"); @@ -467,7 +557,7 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_template(tmpl, end_tokens, tool_call_message, tools, " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]"); } @@ -478,7 +568,7 @@ static void test_template_output_parsers() { assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); - test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); + test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_template(tmpl, end_tokens, tool_call_message, tools, "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" "```json\n" diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 288e08f51..890608648 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -129,7 +129,7 @@ static void test_grammar(const std::string & test_desc, const std::string & gram test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings); } static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector & passing_strings, const std::vector & failing_strings) { - test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings); + test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str), true), passing_strings, failing_strings); } static void test_simple_grammar() { diff --git a/tests/test-grammar-llguidance.cpp b/tests/test-grammar-llguidance.cpp new file mode 100644 index 000000000..8b696006b --- /dev/null +++ b/tests/test-grammar-llguidance.cpp @@ -0,0 +1,1140 @@ +#ifdef NDEBUG +# undef NDEBUG +#endif + +#include "unicode.h" +#include "sampling.h" + +#include +#include +#include + +static const llama_vocab * vocab; + +static bool match_string(const std::string & input, llama_sampler * grammar) { + llama_sampler_reset(grammar); + auto tokens = common_tokenize(vocab, input, false, false); + + auto n_vocab = llama_vocab_n_tokens(vocab); + + std::vector cur; + cur.reserve(n_vocab); + for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) { + cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f }); + } + auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false }; + + for (const auto token : tokens) { + for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) { + cur[token_id].logit = 0.0f; + } + llama_sampler_apply(grammar, &tok_arr); + if (cur[token].logit < 0.0f) { + return false; + } + llama_sampler_accept(grammar, token); + } + + // do we allow EOS at the end? if so the grammar is accepting + + auto tok_eos = llama_vocab_eot(vocab); + if (tok_eos == LLAMA_TOKEN_NULL) { + tok_eos = llama_vocab_eos(vocab); + } + + cur[tok_eos].logit = 0.0f; + llama_sampler_apply(grammar, &tok_arr); + + return cur[tok_eos].logit >= 0.0f; +} + +static void test(const std::string & test_desc, const std::string & grammar_str, + const std::vector & passing_strings, const std::vector & failing_strings) { + fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str()); + fflush(stderr); + + auto * grammar = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str()); + + fprintf(stderr, " 🔵 Valid strings:\n"); + + // Passing strings + for (const auto & test_string : passing_strings) { + fprintf(stderr, " \"%s\" ", test_string.c_str()); + fflush(stderr); + + bool matched = match_string(test_string, grammar); + + if (!matched) { + fprintf(stderr, "❌ (failed to match)\n"); + + // DEBUG: Write strings to files so that we can analyze more easily with gbnf-validator program to see exactly where things failed. + // DEBUG: Write the grammar_str to test-grammar-integration.grammar.gbnf + FILE * grammar_file = fopen("test-grammar-integration.grammar.gbnf", "w"); + if (grammar_file) { + fprintf(grammar_file, "%s", grammar_str.c_str()); + fclose(grammar_file); + } + + // DEBUG: Write the test string to test-grammar-integration.string.txt + FILE * string_file = fopen("test-grammar-integration.string.txt", "w"); + if (string_file) { + fprintf(string_file, "%s", test_string.c_str()); + fclose(string_file); + } + + fprintf(stderr, + "\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following " + "command: ./llama-gbnf-validator test-grammar-integration.grammar.gbnf " + "test-grammar-integration.string.txt\n\n"); + } else { + fprintf(stdout, "✅︎\n"); + } + + assert(matched); + } + + fprintf(stderr, " 🟠 Invalid strings:\n"); + + // Failing strings + for (const auto & test_string : failing_strings) { + fprintf(stderr, " \"%s\" ", test_string.c_str()); + fflush(stderr); + + bool matched = match_string(test_string, grammar); + + if (matched) { + fprintf(stderr, "❌ (incorrectly matched)\n"); + } else { + fprintf(stdout, "✅︎\n"); + } + assert(!matched); + } + + llama_sampler_free(grammar); +} + +static void test_grammar(const std::string & test_desc, const std::string & grammar_str, + const std::vector & passing_strings, + const std::vector & failing_strings) { + test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings); +} + +static void test_schema(const std::string & test_desc, const std::string & schema_str, + const std::vector & passing_strings, + const std::vector & failing_strings) { + test(test_desc + ". Schema: " + schema_str, "%llguidance {}\nstart: %json " + schema_str, passing_strings, + failing_strings); +} + +static void test_simple_grammar() { + test_schema("min 0", + R"""({ + "type": "integer", + "minimum": 0 + })""", + // Passing strings + { + "0", + "10", + "12", + "10000", + }, + // Failing strings + { + "-1", + "-10", + "-10000", + "-100000000000000000000000000000000", + // "100000000000000000000000000000000", + "00", + "01", + "-0", + }); + test_schema("min 2", + // Schema + R"""({ + "type": "integer", + "minimum": 2 + })""", + // Passing strings + { + "2", + "3", + "4", + "10", + "20", + "1234567890000000", + }, + // Failing strings + { + "0", "1", "-1", "-100", "0", "1", "01", "02", + // "12345678900000000", + }); + test_schema("min 456", + R"""({ + "type": "integer", + "minimum": 456 + })""", + // Passing strings + { + "456", + "4560", + "457", + "460", + "500", + }, + // Failing strings + { + "455", + "356", + "50", + "050", + "-1", + "-456", + }); + test_schema("min -123", + R"""({ + "type": "integer", + "minimum": -123 + })""", + // Passing strings + { + "-123", + "-122", + "-11", + "-1", + "0", + "1", + "123", + "1234", + "2345", + }, + // Failing strings + { + "-1234", + "-124", + }); + + test_schema("max 9999", + // Schema + R"""({ + "type": "integer", + "maximum": 9999 + })""", + // Passing strings + { + "-99999", + "0", + "9999", + }, + // Failing strings + { + "10000", + "99991", + }); + test_schema("max -9999", + // Schema + R"""({ + "type": "integer", + "maximum": -9999 + })""", + // Passing strings + { + "-10000", + "-9999", + }, + // Failing strings + { + "-9998", + "0", + "9999", + }); + test_schema("min 5 max 30", + // Schema + R"""({ + "type": "integer", + "minimum": 5, + "maximum": 30 + })""", + // Passing strings + { + "5", + "10", + "30", + }, + // Failing strings + { + "05", + "4", + "-1", + "31", + "123", + "0123", + }); + test_schema("min -1 max 1", + R"""({ + "type": "integer", + "minimum": -1, + "maximum": 1 + })""", + // Passing strings + { + "-1", + "0", + "1", + }, + // Failing strings + { + "-11", + "-10", + "-2", + "2", + "10", + "11", + }); + test_schema("min -123 max 42", + R"""({ + "type": "integer", + "minimum": -123, + "maximum": 42 + })""", + // Passing strings + { + "-123", + "-122", + "-13", + "-11", + "-2", + "-1", + "0", + "1", + "5", + "10", + "39", + "40", + "42", + }, + // Failing strings + { + "-0123", + "-124", + "-1123", + "-200", + "43", + "123", + "0123", + }); + test_schema("exclusive min / max", + // Schema + R"""({ + "type": "integer", + "exclusiveMinimum": 0, + "exclusiveMaximum": 10000 + })""", + // Passing strings + { + "1", + "9999", + }, + // Failing strings + { + "0", + "01", + "10000", + "99999", + }); + + // Test case for a simple grammar + test_grammar("simple grammar", + R"""( + start: expr + expr: term ("+" term)* + term: number + number: /[0-9]+/ )""", + // Passing strings + { + "42", + "1+2+3+4+5", + "123+456", + }, + // Failing strings + { + "+", + "/ 3", + "1+2+3+4+5+", + "12a45", + }); +} + +static void test_complex_grammar() { + // Test case for a more complex grammar, with both failure strings and success strings + test_grammar("medium complexity grammar", + // Grammar + R"""( + start: expression + expression: term ws (("+"|"-") ws term)* + term: factor ws (("*"|"/") ws factor)* + factor: number | variable | "(" expression ")" | function-call + number: /[0-9]+/ + variable: /[a-zA-Z_][a-zA-Z0-9_]*/ + function-call: variable ws "(" (expression ("," ws expression)*)? ")" + ws: /[ \t\n\r]?/ )""", + // Passing strings + { "42", + "1*2*3*4*5", + "x", + "x+10", + "x1+y2", + "(a+b)*(c-d)", + "func()", + "func(x,y+2)", + "a*(b+c)-d/e", + "f(g(x),h(y,z))", + "x + 10", + "x1 + y2", + "(a + b) * (c - d)", + "func()", + "func(x, y + 2)", + "a * (b + c) - d / e", + "f(g(x), h(y, z))", + "123+456", + "123*456*789-123/456+789*123", + "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456" }, + // Failing strings + { + "+", + "/ 3x", + "x + + y", + "a * / b", + "func(,)", + "func(x y)", + "(a + b", + "x + y)", + "a + b * (c - d", + "42 +", + "x +", + "x + 10 +", + "(a + b) * (c - d", + "func(", + "func(x, y + 2", + "a * (b + c) - d /", + "f(g(x), h(y, z)", + "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/", + }); +} + +static void test_special_chars() { + // A collection of tests to exercise special characters such as "." + test_grammar("special characters", + // Grammar + R"""( + start: /.../ "abc" /.../ + )""", + // Passing strings + { "abcabcabc", "aaaabcccc", + // NOTE: Also ensures that multi-byte characters still count as a single character + "🔵🟠✅abc❌🟠🔵" }, + // Failing strings + { "aaabcccc", "aaaaabcccc", "aaaabccc", "aaaabccccc", "🔵🟠✅❌abc❌✅🟠🔵", "🔵🟠abc🟠🔵" }); +} + +static void test_quantifiers() { + // A collection of tests to exercise * + and ? quantifiers + + test_grammar( + "* quantifier", + // Grammar + R"""(start: "a"*)""", + // Passing strings + { "", "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" }, + // Failing strings + { "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" }); + test_grammar( + "+ quantifier", + // Grammar + R"""(start: "a"+)""", + // Passing strings + { "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" }, + // Failing strings + { "", "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" }); + test_grammar("? quantifier", + // Grammar + R"""(start: "a"?)""", + // Passing strings + { "", "a" }, + // Failing strings + { + "b", + "ab", + "aa", + "ba", + }); + test_grammar("mixed quantifiers", + // Grammar + R"""( + start: cons+ vowel* cons? (vowel cons)* + vowel: /[aeiouy]/ + cons: /[bcdfghjklmnpqrstvwxyz]/ + )""", + // Passing strings + { + "yes", + "no", + "noyes", + "crwth", + "four", + "bryyyy", + }, + // Failing strings + { + "yess", + "yesno", + "forty", + "catyyy", + }); + test_grammar("simple exact repetition", + // Grammar + R"""( + start: /[ab]{4}/ + )""", + // Passing strings + { + "aaaa", + "bbbb", + "abab", + }, + // Failing strings + { + "a", + "b", + "aaaaa", + }); + test_grammar("simple min repetition", + // Grammar + R"""( + start: /[ab]{4,}/ + )""", + // Passing strings + { + "aaaa", + "aaaaab", + "bbbb", + "ababab", + }, + // Failing strings + { + "", + "aba", + }); + test_grammar("simple max repetition", + // Grammar + R"""( + start: /[ab]{0,4}/ + )""", + // Passing strings + { + "", + "a", + "aa", + "aaa", + "aaab", + }, + // Failing strings + { + "aaaaa", + }); + // test_grammar("min / max repetition", + // // Grammar + // R"""( + // start: ("0x" /[A-F0-9]{2}/ " "?){3,5} + // )""", + // // Passing strings + // { + // "0xFF 0x12 0xAB", + // "0xFF 0x12 0xAB 0x00 0x00", + // }, + // // Failing strings + // { + // "", + // "0xFF", + // "0xFF 0x12", + // "0xFF 0x12 0xAB 0x00 0x00 0x00", + // }); +} + +static void test_json_schema() { + // Note that this is similar to the regular grammar tests, + // but we convert each json schema to a grammar before parsing. + // Otherwise, this test structure is the same. + + test_schema("empty schema (object)", + // Schema + R"""( + {"type":"object"} + )""", + // Passing strings + { + R"""({})""", + R"""({"foo": "bar"})""", + }, + // Failing strings + { + "", + "[]", + "null", + R"""("")""", + "true", + }); + + test_schema( + "exotic formats (list)", + // Schema + R"""({ + "items": [ + { "format": "date" }, + { "format": "uuid" }, + { "format": "time" }, + { "format": "date-time" } + ] + })""", + // Passing strings + { + // "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it? + // "[]", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it? + R"""(["2012-04-23", "12345678-1234-1234-1234-1234567890ab", "18:25:43.511Z", "2012-04-23T18:25:43.511Z"])""", + //R"""(["2012-04-23","12345678-1234-1234-1234-1234567890ab"])""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it? + //R"""({"foo": "bar"})""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it? + }, + // Failing strings + { + R"""(["foo", "bar"])""", + R"""(["12345678-1234-1234-1234-1234567890ab"])""", + }); + + test_schema("string", + // Schema + R"""({ + "type": "string" + })""", + // Passing strings + { + R"""("foo")""", + R"""("bar")""", + R"""("")""", + }, + // Failing strings + { + R"""({})""", + R"""("foo": "bar")""", + }); + + test_schema("string w/ min length 1", + // Schema + R"""({ + "type": "string", + "minLength": 1 + })""", + // Passing strings + { + R"""("foo")""", + R"""("bar")""", + }, + // Failing strings + { + R"""("")""", + R"""({})""", + R"""("foo": "bar")""", + }); + + test_schema("string w/ min length 3", + // Schema + R"""({ + "type": "string", + "minLength": 3 + })""", + // Passing strings + { + R"""("foo")""", + R"""("bar")""", + R"""("foobar")""", + }, + // Failing strings + { + R"""("")""", + R"""("f")""", + R"""("fo")""", + }); + + test_schema("string w/ max length", + // Schema + R"""({ + "type": "string", + "maxLength": 3 + })""", + // Passing strings + { + R"""("foo")""", + R"""("bar")""", + R"""("")""", + R"""("f")""", + R"""("fo")""", + }, + // Failing strings + { + R"""("foobar")""", + }); + + test_schema("string w/ min & max length", + // Schema + R"""({ + "type": "string", + "minLength": 1, + "maxLength": 4 + })""", + // Passing strings + { + R"""("foo")""", + R"""("bar")""", + R"""("f")""", + R"""("barf")""", + }, + // Failing strings + { + R"""("")""", + R"""("barfo")""", + R"""("foobar")""", + }); + + test_schema("boolean", + // Schema + R"""({ + "type": "boolean" + })""", + // Passing strings + { + "true", + "false", + }, + // Failing strings + { + R"""("")""", + R"""("true")""", + R"""(True)""", + R"""(FALSE)""", + }); + + test_schema("integer", + // Schema + R"""({ + "type": "integer" + })""", + // Passing strings + { + R"""(0)""", + R"""(12345)""", + R"""(1234567890123456)""", + }, + // Failing strings + { + R"""()""", + R"""(01)""", + R"""(007)""", + R"""(12345678901234567 )""", + }); + + test_schema("string const", + // Schema + R"""({ + "const": "foo" + })""", + // Passing strings + { + R"""("foo")""", + }, + // Failing strings + { + R"""(foo)""", + R"""("bar")""", + }); + + test_schema("non-string const", + // Schema + R"""({ + "const": true + })""", + // Passing strings + { + R"""(true)""", + }, + // Failing strings + { + R"""()""", + R"""(foo)""", + R"""("true")""", + }); + + test_schema("non-string const", + // Schema + R"""({ + "enum": ["red", "amber", "green", null, 42, ["foo"]] + })""", + // Passing strings + { + R"""("red")""", + R"""(null)""", + R"""(42)""", + R"""(["foo"])""", + }, + // Failing strings + { + R"""()""", + R"""(420)""", + R"""(true)""", + R"""(foo)""", + }); + + test_schema("simple pattern", + // Schema + R"""({ + "pattern": "^[a-zA-Z0-9_-]*$" + })""", + // Passing strings + { + R"""("")""", + R"""("He_llo-12")""", + }, + // Failing strings + { + R"""("!")""", + R"""("Hello World")""", + }); + + test_schema("pattern with escapes", + // Schema + R"""({ + "pattern": "^a\\^\\$\\.\\[\\]\\(\\)\\|\\{\\}\\*\\+\\?b$" + })""", + // Passing strings + { + R"""("a^$.[]()|{}*+?b")""", + }, + // Failing strings + { + R"""("ab")""", + }); + + test_schema("", + // Schema + R"""( + { + "type": ["array", "null"], + "items": { "type": "string" } + } + )""", + // Passing strings + { + "null", + "[]", + "[\"123\"]", + "[\"foo\", \"bar\"]", + }, + // Failing strings + { + "", + "[123]", + "\"foo\"", + "[\"foo\", 42]", + }); + + test_schema("min+max items", + // Schema + R"""({ + "items": { + "type": ["number", "integer"] + }, + "minItems": 3, + "maxItems": 5 + })""", + // Passing strings + { + R"""([1, 2, 3])""", + R"""([1, 2, 3, 4])""", + R"""([1, 2, 3, 4, 5])""", + // this is in fact correct; keyword do not apply if the type is wrong + R"""(1)""", + }, + // Failing strings + { + R"""([1, 2])""", + R"""([1, 2, 3, 4, 5, 6])""", + }); + + // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties) + test_schema("object properties", + // Schema + R"""({ + "type": "object", + "properties": { + "number": { "type": "number" }, + "street_name": { "type": "string" }, + "street_type": { "enum": ["Street", "Avenue", "Boulevard"] } + }, + "additionalProperties": false + })""", + // Passing strings + { + R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""", + // "By default, leaving out properties is valid" + R"""({ "street_name": "Pennsylvania" })""", + R"""({ "number": 1600, "street_name": "Pennsylvania" })""", + // "By extension, even an empty object is valid" + R"""({})""", + R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""", + }, + // Failing strings + { + // Change datatype from number to string + R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""", + // Reorder properties + R"""({ "street_name": "Pennsylvania", "number": 1600 })""", + // Reorder properties + R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""", + // Additional properties set to false + R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""", + + }); + + test_schema("additional properties can't override other properties", + R"""({ + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"} + }, + "additionalProperties": true + })""", + // Passing strings + { + R"""({"a": 42})""", + R"""({"c": ""})""", + R"""({"a": 42, "c": ""})""", + R"""({"a_": ""})""", + }, + // Failing strings + { + R"""()""", + R"""({"a": ""})""", + R"""({"a": "", "b": ""})""", + }); + + // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties) + test_schema("object properties, additionalProperties: true", + // Schema + R"""({ + "type": "object", + "properties": { + "number": { "type": "number" }, + "street_name": { "type": "string" }, + "street_type": { "enum": ["Street", "Avenue", "Boulevard"] } + }, + "additionalProperties": true + })""", + // Passing strings + { + // "By extension, even an empty object is valid" + R"""({})""", + R"""({"number":1600,"street_name":"Pennsylvania","street_type":"Avenue"})""", + // "By default, leaving out properties is valid" + R"""({ "street_name": "Pennsylvania" })""", + R"""({ "number": 1600, "street_name": "Pennsylvania" })""", + // "By default, providing additional properties is valid" + R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""", + R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""", + }, + // Failing strings + { + // Change datatype from number to string + R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""", + // Reorder properties + R"""({ "street_name": "Pennsylvania", "number": 1600, "street_type":"Avenue"})""", + }); + + // Additional properties: false + test_schema( + "required + optional props each in original order", + // Schema + R"""({ + "type": "object", + "properties": { + "number": { "type": "number" }, + "street_name": { "type": "string" }, + "street_type": { "enum": ["Street", "Avenue", "Boulevard"] } + }, + "additionalProperties": false + })""", + // Passing strings + { + R"""({ "street_name": "Pennsylvania" })""", + R"""({ "number": 1600, "street_type":"Avenue"})""", + R"""({ "number": 1600, "street_name": "Pennsylvania" })""", + R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""", + // Spaces are permitted around enum values + R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""", + }, + // Failing strings + { + // Reorder properties + R"""({ "street_type": "Avenue", "number": 1600 })""", + // Add "direction" + R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue", "direction": "NW" })""", + }); + + test_schema("required + optional props each in original order", + // Schema + R"""({ + "properties": { + "b": {"type": "string"}, + "a": {"type": "string"}, + "d": {"type": "string"}, + "c": {"type": "string"} + }, + "required": ["a", "b"], + "additionalProperties": false + })""", + // Passing strings + { + R"""({"b": "foo", "a": "bar"})""", + R"""({"b":"foo","a":"bar","d":"qux"})""", + R"""({"b":"foo", "a":"bar", "d":"qux", "c":"baz"})""", + }, + // Failing strings + { + R"""({"a": "foo", "b": "bar"})""", + R"""({"b": "bar"})""", + R"""({"a": "foo", "c": "baz"})""", + R"""({"a":"foo", "b":"bar", "c":"baz", "d":"qux"})""", + }); + + // NOTE: Example from https://json-schema.org/learn/getting-started-step-by-step#define-required-properties + test_schema( + "required props", + // Schema + R"""({ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://example.com/product.schema.json", + "title": "Product", + "description": "A product from Acme's catalog", + "type": "object", + "properties": { + "productId": { + "description": "The unique identifier for a product", + "type": "integer" + }, + "productName": { + "description": "Name of the product", + "type": "string" + }, + "price": { + "description": "The price of the product", + "type": "number", + "exclusiveMinimum": 0 + }, + "tags": { + "description": "Tags for the product", + "type": "array", + "items": { + "type": "string" + }, + "minItems": 1, + "DISABLED_uniqueItems": true + }, + "dimensions": { + "type": "object", + "properties": { + "length": { + "type": "number" + }, + "width": { + "type": "number" + }, + "height": { + "type": "number" + } + }, + "required": [ "length", "width", "height" ] + } + }, + "required": [ "productId", "productName", "price" ] + })""", + // Passing strings + { + R"""({"productId": 1, "productName": "A green door", "price": 12.50})""", + R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"]})""", + R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"], "dimensions": {"length": 785, "width": 250.5, "height": -0.359}})""", + }, + // Failing strings + { + R"""({})""", // Missing all required properties + R"""({"productName": "A green door", "price": 12.50, "productId": 1})""", // Out of order properties + // `exclusiveMinimum` is OK for llg + R"""({"productId": 1, "productName": "A green door", "price": -12.50})""", + R"""({"productId": 1, "productName": "A green door"})""", // Missing required property (price) + R"""({"productName": "A green door", "price": 12.50})""", // Missing required property (productId) + R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": []})""", // tags is empty, but minItems is 1 + R"""({"productId": 1, "productName": "A green door", "price": 12.50, "dimensions": {"length": 785, "width": 250.5, "height": -0.359}, "tags": ["home", "green"]})""", // Tags and dimensions are out of order + // TODO: The following line should fail, but currently it passes. `uniqueItems` is not supported, as it would likely be too difficult to implement. + // R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""", + }); +} + +int main(int argc, const char ** argv) { + fprintf(stdout, "Running llguidance integration tests...\n"); + + if (argc != 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + const char * vocab_file = argv[1]; + + fprintf(stderr, "reading vocab from: '%s'\n", vocab_file); + + llama_model * model; + llama_context * ctx; + + llama_backend_init(); + + // load the vocab + { + auto mparams = llama_model_default_params(); + + mparams.vocab_only = true; + + model = llama_model_load_from_file(vocab_file, mparams); + + if (model == NULL) { + fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file); + return 1; + } + + // needed? + auto cparams = llama_context_default_params(); + + ctx = llama_init_from_model(model, cparams); + + if (ctx == NULL) { + fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file); + llama_model_free(model); + return 1; + } + } + + vocab = llama_model_get_vocab(model); + + test_simple_grammar(); + test_complex_grammar(); + test_special_chars(); + test_quantifiers(); + test_json_schema(); + fprintf(stdout, "All tests passed.\n"); + return 0; +} diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 9d2db91f5..f38994c92 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1246,7 +1246,7 @@ int main() { test_all("C++", [](const TestCase & tc) { try { - tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema))); + tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true)); tc.verify_status(SUCCESS); } catch (const std::runtime_error & ex) { fprintf(stderr, "Error: %s\n", ex.what());