diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index fb719a550..c6db1666e 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -565,6 +565,31 @@ jobs:
path: |
cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip
+ windows-latest-cmake-sycl:
+ runs-on: windows-latest
+ defaults:
+ run:
+ shell: bash
+
+ env:
+ WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/62641e01-1e8d-4ace-91d6-ae03f7f8a71f/w_BaseKit_p_2024.0.0.49563_offline.exe
+ WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel
+
+
+ steps:
+ - name: Clone
+ id: checkout
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+
+ - name: Install
+ run: scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL
+
+ - name: Build
+ id: cmake_build
+ run: examples/sycl/win-build-sycl.bat
+
ios-xcode-build:
runs-on: macos-latest
diff --git a/.github/workflows/editorconfig.yml b/.github/workflows/editorconfig.yml
index b4e535acf..0e0993cd4 100644
--- a/.github/workflows/editorconfig.yml
+++ b/.github/workflows/editorconfig.yml
@@ -1,6 +1,12 @@
name: EditorConfig Checker
on:
+ workflow_dispatch: # allows manual triggering
+ inputs:
+ create_release:
+ description: 'Create new release'
+ required: true
+ type: boolean
push:
branches:
- master
diff --git a/.gitignore b/.gitignore
index cb0069bfb..b84459b92 100644
--- a/.gitignore
+++ b/.gitignore
@@ -89,3 +89,4 @@ examples/jeopardy/results.txt
poetry.lock
poetry.toml
+nppBackup
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 65a6f3971..15a1101aa 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -507,7 +507,11 @@ if (LLAMA_SYCL)
set(GGML_HEADERS_SYCL ggml.h ggml-sycl.h)
set(GGML_SOURCES_SYCL ggml-sycl.cpp)
- set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
+ if (WIN32)
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -fsycl sycl7 OpenCL mkl_sycl_blas_dll.lib mkl_intel_ilp64_dll.lib mkl_sequential_dll.lib mkl_core_dll.lib)
+ else()
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -fsycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
+ endif()
endif()
if (LLAMA_KOMPUTE)
diff --git a/README_sycl.md b/README-sycl.md
similarity index 62%
rename from README_sycl.md
rename to README-sycl.md
index 94722489c..de3641293 100644
--- a/README_sycl.md
+++ b/README-sycl.md
@@ -8,10 +8,14 @@
[Linux](#linux)
+[Windows](#windows)
+
[Environment Variable](#environment-variable)
[Known Issue](#known-issue)
+[Q&A](#q&a)
+
[Todo](#todo)
## Background
@@ -33,7 +37,7 @@ For Intel CPU, recommend to use llama.cpp for X86 (Intel MKL building).
|OS|Status|Verified|
|-|-|-|
|Linux|Support|Ubuntu 22.04|
-|Windows|Ongoing| |
+|Windows|Support|Windows 11|
## Intel GPU
@@ -42,7 +46,7 @@ For Intel CPU, recommend to use llama.cpp for X86 (Intel MKL building).
|-|-|-|
|Intel Data Center Max Series| Support| Max 1550|
|Intel Data Center Flex Series| Support| Flex 170|
-|Intel Arc Series| Support| Arc 770|
+|Intel Arc Series| Support| Arc 770, 730M|
|Intel built-in Arc GPU| Support| built-in Arc GPU in Meteor Lake|
|Intel iGPU| Support| iGPU in i5-1250P, i7-1260P, i7-1165G7|
@@ -149,6 +153,8 @@ cmake .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
# Or, build all binary
cmake --build . --config Release -v
+
+cd ..
```
or
@@ -233,11 +239,175 @@ Note:
5. Check the device ID in output
-Like:
+Like:
```
Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device
```
+## Windows
+
+### Setup Environment
+
+1. Install Intel GPU driver.
+
+Please install Intel GPU driver by official guide: [Install GPU Drivers](https://www.intel.com/content/www/us/en/products/docs/discrete-gpus/arc/software/drivers.html).
+
+2. Install Intel® oneAPI Base toolkit.
+
+a. Please follow the procedure in [Get the Intel® oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html).
+
+Recommend to install to default folder: **/opt/intel/oneapi**.
+
+Following guide uses the default folder as example. If you use other folder, please modify the following guide info with your folder.
+
+b. Enable oneAPI running environment:
+
+- In Search, input 'oneAPI'.
+
+Search & open "Intel oneAPI command prompt for Intel 64 for Visual Studio 2022"
+
+- In Run:
+
+In CMD:
+```
+"C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64
+```
+
+c. Check GPU
+
+In oneAPI command line:
+
+```
+sycl-ls
+```
+
+There should be one or more level-zero devices. Like **[ext_oneapi_level_zero:gpu:0]**.
+
+Output (example):
+```
+[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000]
+[opencl:cpu:1] Intel(R) OpenCL, 11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000]
+[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Iris(R) Xe Graphics OpenCL 3.0 NEO [31.0.101.5186]
+[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Iris(R) Xe Graphics 1.3 [1.3.28044]
+
+```
+
+3. Install cmake & make
+
+a. Download & install cmake for windows: https://cmake.org/download/
+
+b. Download & install make for windows provided by mingw-w64: https://www.mingw-w64.org/downloads/
+
+
+### Build locally:
+
+In oneAPI command line window:
+
+```
+mkdir -p build
+cd build
+@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
+
+:: for FP16
+:: faster for long-prompt inference
+:: cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON
+
+:: for FP32
+cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release
+
+
+:: build example/main only
+:: make main
+
+:: build all binary
+make -j
+cd ..
+```
+
+or
+
+```
+.\examples\sycl\win-build-sycl.bat
+```
+
+Note:
+
+- By default, it will build for all binary files. It will take more time. To reduce the time, we recommend to build for **example/main** only.
+
+### Run
+
+1. Put model file to folder **models**
+
+2. Enable oneAPI running environment
+
+- In Search, input 'oneAPI'.
+
+Search & open "Intel oneAPI command prompt for Intel 64 for Visual Studio 2022"
+
+- In Run:
+
+In CMD:
+```
+"C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64
+```
+
+3. List device ID
+
+Run without parameter:
+
+```
+build\bin\ls-sycl-device.exe
+
+or
+
+build\bin\main.exe
+```
+
+Check the ID in startup log, like:
+
+```
+found 4 SYCL devices:
+ Device 0: Intel(R) Arc(TM) A770 Graphics, compute capability 1.3,
+ max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136
+ Device 1: Intel(R) FPGA Emulation Device, compute capability 1.2,
+ max compute_units 24, max work group size 67108864, max sub group size 64, global mem size 67065057280
+ Device 2: 13th Gen Intel(R) Core(TM) i7-13700K, compute capability 3.0,
+ max compute_units 24, max work group size 8192, max sub group size 64, global mem size 67065057280
+ Device 3: Intel(R) Arc(TM) A770 Graphics, compute capability 3.0,
+ max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136
+
+```
+
+|Attribute|Note|
+|-|-|
+|compute capability 1.3|Level-zero running time, recommended |
+|compute capability 3.0|OpenCL running time, slower than level-zero in most cases|
+
+4. Set device ID and execute llama.cpp
+
+Set device ID = 0 by **set GGML_SYCL_DEVICE=0**
+
+```
+set GGML_SYCL_DEVICE=0
+build\bin\main.exe -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0
+```
+or run by script:
+
+```
+.\examples\sycl\win-run-llama2.bat
+```
+
+Note:
+
+- By default, mmap is used to read model file. In some cases, it leads to the hang issue. Recommend to use parameter **--no-mmap** to disable mmap() to skip this issue.
+
+
+5. Check the device ID in output
+
+Like:
+```
+Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device
+```
## Environment Variable
@@ -248,7 +418,7 @@ Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device
|LLAMA_SYCL|ON (mandatory)|Enable build with SYCL code path.
For FP32/FP16, LLAMA_SYCL=ON is mandatory.|
|LLAMA_SYCL_F16|ON (optional)|Enable FP16 build with SYCL code path. Faster for long-prompt inference.
For FP32, not set it.|
|CMAKE_C_COMPILER|icx|Use icx compiler for SYCL code path|
-|CMAKE_CXX_COMPILER|icpx|use icpx for SYCL code path|
+|CMAKE_CXX_COMPILER|icpx (Linux), icx (Windows)|use icpx/icx for SYCL code path|
#### Running
@@ -260,19 +430,24 @@ Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device
## Known Issue
-- Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`.
-
- Miss to enable oneAPI running environment.
-
- Install oneAPI base toolkit and enable it by: `source /opt/intel/oneapi/setvars.sh`.
-
-
- Hang during startup
llama.cpp use mmap as default way to read model file and copy to GPU. In some system, memcpy will be abnormal and block.
Solution: add **--no-mmap**.
+## Q&A
+
+- Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`.
+
+ Miss to enable oneAPI running environment.
+
+ Install oneAPI base toolkit and enable it by: `source /opt/intel/oneapi/setvars.sh`.
+
+- In Windows, no result, not error.
+
+ Miss to enable oneAPI running environment.
+
## Todo
- Support to build in Windows.
diff --git a/README.md b/README.md
index 96d4aaa26..f26f8e91a 100644
--- a/README.md
+++ b/README.md
@@ -10,6 +10,8 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
### Hot topics
+- ⚠️ Incoming backends: https://github.com/ggerganov/llama.cpp/discussions/5138
+ - [SYCL backend](README-sycl.md) is ready (1/28/2024), support Linux/Windows in Intel GPUs (iGPU, Arc/Flex/Max series)
- New SOTA quantized models, including pure 2-bits: https://huggingface.co/ikawrakow
- Collecting Apple Silicon performance stats:
- M-series: https://github.com/ggerganov/llama.cpp/discussions/4167
diff --git a/common/common.cpp b/common/common.cpp
index 288013676..9d976c7c8 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1520,7 +1520,9 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false");
fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false");
fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false");
+ fprintf(stream, "cpu_has_vulkan: %s\n", ggml_cpu_has_vulkan() ? "true" : "false");
fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false");
+ fprintf(stream, "cpu_has_kompute: %s\n", ggml_cpu_has_kompute() ? "true" : "false");
fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index f239415d3..542cc7bb8 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -563,6 +563,7 @@ struct test {
static const bool cuda;
static const bool opencl;
static const bool vulkan;
+ static const bool kompute;
static const bool metal;
static const bool gpu_blas;
static const bool blas;
@@ -647,6 +648,9 @@ struct test {
if (vulkan) {
return "Vulkan";
}
+ if (kompute) {
+ return "Kompute";
+ }
if (metal) {
return "Metal";
}
@@ -662,7 +666,7 @@ struct test {
static const std::vector & get_fields() {
static const std::vector fields = {
"build_commit", "build_number",
- "cuda", "opencl", "vulkan", "metal", "gpu_blas", "blas",
+ "cuda", "opencl", "vulkan", "kompute", "metal", "gpu_blas", "blas",
"cpu_info", "gpu_info",
"model_filename", "model_type", "model_size", "model_n_params",
"n_batch", "n_threads", "type_k", "type_v",
@@ -686,8 +690,9 @@ struct test {
field == "avg_ns" || field == "stddev_ns") {
return INT;
}
- if (field == "cuda" || field == "opencl" || field == "vulkan"|| field == "metal" || field == "gpu_blas" || field == "blas" ||
- field == "f16_kv" || field == "no_kv_offload" || field == "mul_mat_q") {
+ if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" ||
+ field == "gpu_blas" || field == "blas" || field == "f16_kv" || field == "no_kv_offload" ||
+ field == "mul_mat_q") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
@@ -714,7 +719,8 @@ struct test {
}
std::vector values = {
build_commit, std::to_string(build_number),
- std::to_string(cuda), std::to_string(opencl), std::to_string(vulkan), std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas),
+ std::to_string(cuda), std::to_string(opencl), std::to_string(vulkan), std::to_string(vulkan),
+ std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas),
cpu_info, gpu_info,
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
@@ -743,6 +749,7 @@ const int test::build_number = LLAMA_BUILD_NUMBER;
const bool test::cuda = !!ggml_cpu_has_cublas();
const bool test::opencl = !!ggml_cpu_has_clblast();
const bool test::vulkan = !!ggml_cpu_has_vulkan();
+const bool test::kompute = !!ggml_cpu_has_kompute();
const bool test::metal = !!ggml_cpu_has_metal();
const bool test::gpu_blas = !!ggml_cpu_has_gpublas();
const bool test::blas = !!ggml_cpu_has_blas();
diff --git a/examples/llava/MobileVLM-README.md b/examples/llava/MobileVLM-README.md
index c6258eba6..9eba791da 100644
--- a/examples/llava/MobileVLM-README.md
+++ b/examples/llava/MobileVLM-README.md
@@ -111,17 +111,71 @@ llama_print_timings: eval time = 1279.03 ms / 18 runs ( 71.06 m
llama_print_timings: total time = 34570.79 ms
```
+## Orin compile and run
+### compile
+```sh
+make LLAMA_CUBLAS=1 CUDA_DOCKER_ARCH=sm_87 LLAMA_CUDA_F16=1 -j 32
+```
+
+### run on Orin
+### case 1
+**input**
+```sh
+./llava-cli \
+ -m /data/local/tmp/ggml-model-q4_k.gguf \
+ --mmproj /data/local/tmp/mmproj-model-f16.gguf \
+ --image /data/local/tmp/demo.jpeg \
+ -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWho is the author of this book? \nAnswer the question using a single word or phrase. ASSISTANT:" \
+ --n-gpu-layers 999
+```
+**output**
+```sh
+
+encode_image_with_clip: image encoded in 296.62 ms by CLIP ( 2.06 ms per image patch)
+
+ Susan Wise Bauer
+
+llama_print_timings: load time = 1067.64 ms
+llama_print_timings: sample time = 1.53 ms / 6 runs ( 0.25 ms per token, 3934.43 tokens per second)
+llama_print_timings: prompt eval time = 306.84 ms / 246 tokens ( 1.25 ms per token, 801.72 tokens per second)
+llama_print_timings: eval time = 91.50 ms / 6 runs ( 15.25 ms per token, 65.58 tokens per second)
+llama_print_timings: total time = 1352.63 ms / 252 tokens
+```
+
+### case 2
+**input**
+```sh
+./llava-cli \
+ -m /data/local/tmp/ggml-model-q4_k.gguf \
+ --mmproj /data/local/tmp/mmproj-model-f16.gguf \
+ -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat is in the image? ASSISTANT:" \
+ --n-gpu-layers 999
+
+```
+**output**
+```sh
+encode_image_with_clip: image encoded in 302.15 ms by CLIP ( 2.10 ms per image patch)
+
+ The image features a cat lying in the grass.
+
+llama_print_timings: load time = 1057.07 ms
+llama_print_timings: sample time = 3.27 ms / 11 runs ( 0.30 ms per token, 3360.83 tokens per second)
+llama_print_timings: prompt eval time = 213.60 ms / 232 tokens ( 0.92 ms per token, 1086.14 tokens per second)
+llama_print_timings: eval time = 166.65 ms / 11 runs ( 15.15 ms per token, 66.01 tokens per second)
+llama_print_timings: total time = 1365.47 ms / 243 tokens
+```
+
## Minor shortcomings
The `n_patch` of output in `ldp` is 1/4 of the input. In order to implement quickly, we uniformly modified `clip_n_patches` function to a quarter. when counting the time consumption, the calculated time will be 4 times bigger than the real cost.
## TODO
-- [ ] Support non-CPU backend for the new operators, such as `depthwise`, `hardswish`, `hardsigmoid`
+- [x] Support non-CPU backend for the new operators, such as `depthwise`, `hardswish`, `hardsigmoid`
- [ ] Optimize LDP projector performance
- Optimize the structure definition to avoid unnecessary memory rearrangements, to reduce the use of `ggml_permute_cpy`;
- Optimize operator implementation (ARM CPU/NVIDIA GPU): such as depthwise conv, hardswish, hardsigmoid, etc.
-- [ ] run MobileVLM on `Jetson Orin`
+- [x] run MobileVLM on `Jetson Orin`
- [ ] Support more model variants, such as `MobileVLM-3B`.
diff --git a/examples/sycl/ls-sycl-device.cpp b/examples/sycl/ls-sycl-device.cpp
index 42847154a..52442e4ca 100644
--- a/examples/sycl/ls-sycl-device.cpp
+++ b/examples/sycl/ls-sycl-device.cpp
@@ -1,7 +1,9 @@
-/*MIT license
- Copyright (C) 2024 Intel Corporation
- SPDX-License-Identifier: MIT
-*/
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
#include "ggml-sycl.h"
diff --git a/examples/sycl/win-build-sycl.bat b/examples/sycl/win-build-sycl.bat
new file mode 100644
index 000000000..f9d43f8ed
--- /dev/null
+++ b/examples/sycl/win-build-sycl.bat
@@ -0,0 +1,23 @@
+
+:: MIT license
+:: Copyright (C) 2024 Intel Corporation
+:: SPDX-License-Identifier: MIT
+
+mkdir -p build
+cd build
+@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
+
+:: for FP16
+:: faster for long-prompt inference
+:: cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON
+
+:: for FP32
+cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release
+
+
+:: build example/main only
+:: make main
+
+:: build all binary
+make -j
+cd ..
diff --git a/examples/sycl/win-run-llama2.bat b/examples/sycl/win-run-llama2.bat
new file mode 100644
index 000000000..28d935541
--- /dev/null
+++ b/examples/sycl/win-run-llama2.bat
@@ -0,0 +1,13 @@
+:: MIT license
+:: Copyright (C) 2024 Intel Corporation
+:: SPDX-License-Identifier: MIT
+
+INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
+@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
+
+
+set GGML_SYCL_DEVICE=0
+rem set GGML_SYCL_DEBUG=1
+.\build\bin\main.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 33 -s 0
+
+
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 949bc8a1c..e56595742 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -524,6 +524,8 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_TANH_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
+#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
+#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_SCALE_BLOCK_SIZE 256
@@ -540,6 +542,7 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
#define CUDA_PAD_BLOCK_SIZE 256
#define CUDA_ACC_BLOCK_SIZE 256
#define CUDA_IM2COL_BLOCK_SIZE 256
+#define CUDA_POOL2D_BLOCK_SIZE 256
#define CUDA_Q8_0_NE_ALIGN 2048
@@ -823,6 +826,24 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
dst[i] = fmaxf(x[i], 0);
}
+static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
+static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@@ -5823,7 +5844,7 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
}
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
- const int row = blockIdx.y;
+ const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
@@ -6145,9 +6166,10 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
-static __global__ void im2col_f32_f16(
- const float * x, half * dst,
- int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
+template
+static __global__ void im2col_kernel(
+ const float * x, T * dst, int batch_offset,
+ int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= pelements) {
@@ -6160,21 +6182,73 @@ static __global__ void im2col_f32_f16(
const int ky = (i - kd) / OW;
const int ix = i % OW;
+ const int oh = blockIdx.y;
+ const int batch = blockIdx.z / IC;
+ const int ic = blockIdx.z % IC;
+
const int64_t iiw = ix * s0 + kx * d0 - p0;
- const int64_t iih = blockIdx.y * s1 + ky * d1 - p1;
+ const int64_t iih = oh * s1 + ky * d1 - p1;
const int64_t offset_dst =
- (blockIdx.y * OW + ix) * CHW +
- (blockIdx.z * (KW * KH) + ky * KW + kx);
+ ((batch * OH + oh) * OW + ix) * CHW +
+ (ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
- dst[offset_dst] = __float2half(0.0f);
+ dst[offset_dst] = 0.0f;
} else {
- const int64_t offset_src = blockIdx.z * offset_delta;
- dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
+ const int64_t offset_src = ic * offset_delta + batch * batch_offset;
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
+template
+static __global__ void pool2d_nchw_kernel(
+ const int ih, const int iw, const int oh, const int ow,
+ const int kh, const int kw, const int sh, const int sw,
+ const int ph, const int pw, const int parallel_elements,
+ const Ti* src, To* dst, const enum ggml_op_pool op) {
+ int idx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (idx >= parallel_elements) {
+ return;
+ }
+
+ const int I_HW = ih * iw;
+ const int O_HW = oh * ow;
+ const int nc = idx / O_HW;
+ const int cur_oh = idx % O_HW / ow;
+ const int cur_ow = idx % O_HW % ow;
+ const Ti* i_ptr = src + nc * I_HW;
+ To* o_ptr = dst + nc * O_HW;
+ const int start_h = cur_oh * sh - ph;
+ const int bh = max(0, start_h);
+ const int eh = min(ih, start_h + kh);
+ const int start_w = cur_ow * sw - pw;
+ const int bw = max(0, start_w);
+ const int ew = min(iw, start_w + kw);
+ const To scale = 1. / (kh * kw);
+ To res = 0;
+
+ switch (op) {
+ case GGML_OP_POOL_AVG: res = 0; break;
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
+ }
+
+ for (int i = bh; i < eh; i += 1) {
+ for (int j = bw; j < ew; j += 1) {
+ #if __CUDA_ARCH__ >= 350
+ Ti cur = __ldg(i_ptr + i * iw + j);
+ #else
+ Ti cur = i_ptr[i * iw + j];
+ #endif
+ switch (op) {
+ case GGML_OP_POOL_AVG: res += cur * scale; break;
+ case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
+ }
+ }
+ }
+ o_ptr[cur_oh * ow + cur_ow] = res;
+}
+
template
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
@@ -6388,6 +6462,16 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
relu_f32<<>>(x, dst, k);
}
+static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
+ hardsigmoid_f32<<>>(x, dst, k);
+}
+
+static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
+ hardswish_f32<<>>(x, dst, k);
+}
+
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
leaky_relu_f32<<>>(x, dst, k, negative_slope);
@@ -7475,7 +7559,7 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
- const dim3 block_nums(1, nrows, 1);
+ const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<>>(x, dst, ncols);
}
@@ -7587,14 +7671,15 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
}
}
-static void im2col_f32_f16_cuda(const float* x, half* dst,
+template
+static void im2col_cuda(const float* x, T* dst,
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
- int offset_delta,
+ int batch, int batch_offset, int offset_delta,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
- dim3 block_nums(num_blocks, OH, IC);
- im2col_f32_f16<<>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+ dim3 block_nums(num_blocks, OH, batch * IC);
+ im2col_kernel<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
// buffer pool for cuda
@@ -8179,6 +8264,34 @@ static void ggml_cuda_op_relu(
(void) src1_dd;
}
+static void ggml_cuda_op_hardsigmoid(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ hardsigmoid_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+static void ggml_cuda_op_hardswish(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ hardswish_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
static void ggml_cuda_op_leaky_relu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
@@ -8810,13 +8923,46 @@ static void ggml_cuda_op_alibi(
(void) src1_dd;
}
+static void ggml_cuda_op_pool2d(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int32_t * opts = (const int32_t *)dst->op_params;
+ enum ggml_op_pool op = static_cast(opts[0]);
+ const int k0 = opts[1];
+ const int k1 = opts[2];
+ const int s0 = opts[3];
+ const int s1 = opts[4];
+ const int p0 = opts[5];
+ const int p1 = opts[6];
+
+ const int64_t IH = src0->ne[1];
+ const int64_t IW = src0->ne[0];
+
+ const int64_t N = dst->ne[3];
+ const int64_t OC = dst->ne[2];
+ const int64_t OH = dst->ne[1];
+ const int64_t OW = dst->ne[0];
+
+ const int parallel_elements = N * OC * OH * OW;
+ const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
+ dim3 block_nums(num_blocks);
+ pool2d_nchw_kernel<<>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op);
+
+ (void) src1;
+ (void) src1_dd;
+}
+
static void ggml_cuda_op_im2col(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
@@ -8838,8 +8984,14 @@ static void ggml_cuda_op_im2col(
const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+ const int64_t batch = src1->ne[3];
+ const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
- im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ if(dst->type == GGML_TYPE_F16) {
+ im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ } else {
+ im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ }
(void) src0;
(void) src0_dd;
@@ -9435,6 +9587,13 @@ static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, g
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
}
+static void ggml_cuda_hardsigmoid(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardsigmoid);
+}
+
+static void ggml_cuda_hardswish(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardswish);
+}
static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
}
@@ -10220,6 +10379,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
}
+static void ggml_cuda_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pool2d);
+}
+
static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
}
@@ -10321,6 +10484,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
case GGML_UNARY_OP_RELU:
func = ggml_cuda_relu;
break;
+ case GGML_UNARY_OP_HARDSIGMOID:
+ func = ggml_cuda_hardsigmoid;
+ break;
+ case GGML_UNARY_OP_HARDSWISH:
+ func = ggml_cuda_hardswish;
+ break;
default:
return false;
}
@@ -10395,6 +10564,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
case GGML_OP_IM2COL:
func = ggml_cuda_im2col;
break;
+ case GGML_OP_POOL_2D:
+ func = ggml_cuda_pool2d;
+ break;
case GGML_OP_SUM_ROWS:
func = ggml_cuda_sum_rows;
break;
@@ -11123,6 +11295,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
return true;
@@ -11221,6 +11395,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ROPE:
case GGML_OP_ALIBI:
case GGML_OP_IM2COL:
+ case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
diff --git a/ggml-metal.m b/ggml-metal.m
index f87859552..5260ed827 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -135,6 +135,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ROPE_F16,
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
+ GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -506,6 +507,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
@@ -630,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_ALIBI:
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
+ return true;
+ case GGML_OP_POOL_1D:
+ case GGML_OP_POOL_2D:
+ return false;
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARGSORT:
@@ -2015,7 +2021,7 @@ static bool ggml_metal_graph_compute(
{
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
@@ -2023,6 +2029,7 @@ static bool ggml_metal_graph_compute(
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
const int32_t N = src1->ne[is_2D ? 3 : 2];
@@ -2043,8 +2050,8 @@ static bool ggml_metal_graph_compute(
id pipeline = nil;
- switch (src0->type) {
- case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
+ switch (dst->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
default: GGML_ASSERT(false);
};
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 2614d82e8..efed6ad46 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -1775,9 +1775,29 @@ kernel void kernel_rope(
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope;
-kernel void kernel_im2col_f16(
+typedef void (im2col_t)(
device const float * x,
- device half * dst,
+ device char * dst,
+ constant int32_t & ofs0,
+ constant int32_t & ofs1,
+ constant int32_t & IW,
+ constant int32_t & IH,
+ constant int32_t & CHW,
+ constant int32_t & s0,
+ constant int32_t & s1,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ constant int32_t & d0,
+ constant int32_t & d1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]);
+
+template
+kernel void kernel_im2col(
+ device const float * x,
+ device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
@@ -1800,14 +1820,19 @@ kernel void kernel_im2col_f16(
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
+ device T * pdst = (device T *) (dst);
+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
- dst[offset_dst] = 0.0f;
+ pdst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
- dst[offset_dst] = x[offset_src + iih * IW + iiw];
+ pdst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
+template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col;
+template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col;
+
kernel void kernel_upscale_f32(
device const char * src0,
device char * dst,
diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp
index 3fc346975..1cc55ef52 100644
--- a/ggml-sycl.cpp
+++ b/ggml-sycl.cpp
@@ -1,7 +1,14 @@
-/*MIT license
- Copyright (C) 2024 Intel Corporation
- SPDX-License-Identifier: MIT
-*/
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
#include
#include
diff --git a/ggml-sycl.h b/ggml-sycl.h
index 0eabb53cc..ba0c61473 100644
--- a/ggml-sycl.h
+++ b/ggml-sycl.h
@@ -1,7 +1,8 @@
-/*MIT license
- Copyright (C) 2024 Intel Corporation
- SPDX-License-Identifier: MIT
-*/
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
#pragma once
diff --git a/ggml-vulkan-shaders.hpp b/ggml-vulkan-shaders.hpp
index 321e36383..e2e9be22c 100644
--- a/ggml-vulkan-shaders.hpp
+++ b/ggml-vulkan-shaders.hpp
@@ -890,7 +890,7 @@ const uint64_t cpy_f32_f32_len = 2472;
unsigned char dequant_f16_data[] = {
0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
-0x87,0x02,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
+0x81,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
0x01,0x00,0x00,0x00,0x11,0x00,0x02,0x00,0x09,0x00,0x00,0x00,
0x11,0x00,0x02,0x00,0x51,0x11,0x00,0x00,0x0b,0x00,0x06,0x00,
0x01,0x00,0x00,0x00,0x47,0x4c,0x53,0x4c,0x2e,0x73,0x74,0x64,
@@ -898,7 +898,7 @@ unsigned char dequant_f16_data[] = {
0x00,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x0f,0x00,0x09,0x00,
0x05,0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x6d,0x61,0x69,0x6e,
0x00,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
-0x51,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,0x10,0x00,0x06,0x00,
+0x4f,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,0x10,0x00,0x06,0x00,
0x04,0x00,0x00,0x00,0x11,0x00,0x00,0x00,0x00,0x01,0x00,0x00,
0x01,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
0x0c,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,
@@ -910,23 +910,23 @@ unsigned char dequant_f16_data[] = {
0x48,0x00,0x05,0x00,0x14,0x00,0x00,0x00,0x03,0x00,0x00,0x00,
0x23,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x47,0x00,0x03,0x00,
0x14,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x06,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x48,0x00,0x04,0x00,0x4f,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x18,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0x4f,0x00,0x00,0x00,
+0x4c,0x00,0x00,0x00,0x06,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
+0x48,0x00,0x04,0x00,0x4d,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x18,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0x4d,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x47,0x00,0x03,0x00,0x4f,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x51,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x51,0x00,0x00,0x00,
+0x47,0x00,0x03,0x00,0x4d,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x4f,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x4f,0x00,0x00,0x00,
0x21,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x5c,0x00,0x00,0x00,0x06,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x48,0x00,0x04,0x00,0x5d,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x19,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0x5d,0x00,0x00,0x00,
+0x5a,0x00,0x00,0x00,0x06,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
+0x48,0x00,0x04,0x00,0x5b,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x19,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0x5b,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x47,0x00,0x03,0x00,0x5d,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x5f,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x5f,0x00,0x00,0x00,
+0x47,0x00,0x03,0x00,0x5b,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x5d,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x5d,0x00,0x00,0x00,
0x21,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x80,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
+0x7e,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
0x13,0x00,0x02,0x00,0x02,0x00,0x00,0x00,0x21,0x00,0x03,0x00,
0x03,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x15,0x00,0x04,0x00,
0x06,0x00,0x00,0x00,0x20,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
@@ -945,330 +945,109 @@ unsigned char dequant_f16_data[] = {
0x16,0x00,0x00,0x00,0x09,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
0x06,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
0x20,0x00,0x04,0x00,0x18,0x00,0x00,0x00,0x09,0x00,0x00,0x00,
-0x06,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x1b,0x00,0x00,0x00,0x20,0x00,0x00,0x00,0x14,0x00,0x02,0x00,
-0x24,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x37,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x48,0x00,0x00,0x00,
-0x10,0x00,0x00,0x00,0x16,0x00,0x03,0x00,0x4a,0x00,0x00,0x00,
-0x10,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,0x4e,0x00,0x00,0x00,
-0x4a,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,0x4f,0x00,0x00,0x00,
-0x4e,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x50,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,
-0x50,0x00,0x00,0x00,0x51,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x54,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x4a,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,0x5c,0x00,0x00,0x00,
-0x4a,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,0x5d,0x00,0x00,0x00,
-0x5c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x5e,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,
-0x5e,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x61,0x00,0x00,0x00,
-0x03,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x09,0x00,0x00,0x00,
-0x79,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x09,0x00,0x00,0x00,0x7f,0x00,0x00,0x00,0x00,0x01,0x00,0x00,
-0x2c,0x00,0x06,0x00,0x0a,0x00,0x00,0x00,0x80,0x00,0x00,0x00,
-0x7f,0x00,0x00,0x00,0x79,0x00,0x00,0x00,0x79,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x6c,0x02,0x00,0x00,
-0x11,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x6d,0x02,0x00,0x00,0x12,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x6e,0x02,0x00,0x00,0x13,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x6f,0x02,0x00,0x00,
-0x04,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x70,0x02,0x00,0x00,0x14,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x71,0x02,0x00,0x00,0x05,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x72,0x02,0x00,0x00,
-0x15,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x73,0x02,0x00,0x00,0x06,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x74,0x02,0x00,0x00,0x16,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x75,0x02,0x00,0x00,
-0x07,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x76,0x02,0x00,0x00,0x17,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x77,0x02,0x00,0x00,0x08,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x78,0x02,0x00,0x00,
-0x18,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x79,0x02,0x00,0x00,0x09,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x7a,0x02,0x00,0x00,0x19,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x7b,0x02,0x00,0x00,
-0x0a,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x7c,0x02,0x00,0x00,0x1a,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x7d,0x02,0x00,0x00,0x0b,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x7e,0x02,0x00,0x00,
-0x1b,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x7f,0x02,0x00,0x00,0x0c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x80,0x02,0x00,0x00,0x1c,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x81,0x02,0x00,0x00,
-0x0d,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x82,0x02,0x00,0x00,0x1d,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x83,0x02,0x00,0x00,0x0e,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x84,0x02,0x00,0x00,
-0x1e,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x85,0x02,0x00,0x00,0x0f,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x86,0x02,0x00,0x00,0x1f,0x00,0x00,0x00,
-0x36,0x00,0x05,0x00,0x02,0x00,0x00,0x00,0x04,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x03,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x05,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,0x81,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0xfb,0x00,0x03,0x00,0x0d,0x00,0x00,0x00,
-0x82,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x82,0x00,0x00,0x00,
-0x41,0x00,0x05,0x00,0x0e,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x0d,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x09,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,
-0x7c,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x11,0x00,0x00,0x00,
-0x10,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x18,0x00,0x00,0x00,
-0x19,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x17,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,
-0x19,0x00,0x00,0x00,0x87,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x1c,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,
-0x8b,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x1d,0x00,0x00,0x00,
-0x11,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,0x87,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x11,0x00,0x00,0x00,
-0x1c,0x00,0x00,0x00,0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x26,0x00,0x00,0x00,0x1d,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,
-0xaf,0x00,0x05,0x00,0x24,0x00,0x00,0x00,0x29,0x00,0x00,0x00,
-0x26,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,0xa8,0x00,0x04,0x00,
-0x24,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,0x29,0x00,0x00,0x00,
-0xf7,0x00,0x03,0x00,0x2c,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0xfa,0x00,0x04,0x00,0x2a,0x00,0x00,0x00,0x2b,0x00,0x00,0x00,
-0x2c,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x2b,0x00,0x00,0x00,
-0x41,0x00,0x05,0x00,0x18,0x00,0x00,0x00,0x2f,0x00,0x00,0x00,
-0x16,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x30,0x00,0x00,0x00,0x2f,0x00,0x00,0x00,
-0xaf,0x00,0x05,0x00,0x24,0x00,0x00,0x00,0x31,0x00,0x00,0x00,
-0x23,0x00,0x00,0x00,0x30,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
-0x2c,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x2c,0x00,0x00,0x00,
-0xf5,0x00,0x07,0x00,0x24,0x00,0x00,0x00,0x32,0x00,0x00,0x00,
-0x29,0x00,0x00,0x00,0x82,0x00,0x00,0x00,0x31,0x00,0x00,0x00,
-0x2b,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,0x34,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0xfa,0x00,0x04,0x00,0x32,0x00,0x00,0x00,
-0x33,0x00,0x00,0x00,0x34,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x33,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x81,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0x34,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
-0x18,0x00,0x00,0x00,0x38,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
-0x37,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x39,0x00,0x00,0x00,0x38,0x00,0x00,0x00,0x87,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,0x39,0x00,0x00,0x00,
-0x1b,0x00,0x00,0x00,0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x3e,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x40,0x00,0x00,0x00,
-0x3e,0x00,0x00,0x00,0x1d,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
-0x54,0x00,0x00,0x00,0x55,0x00,0x00,0x00,0x51,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x40,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0x56,0x00,0x00,0x00,0x55,0x00,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x58,0x00,0x00,0x00,
-0x40,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
-0x54,0x00,0x00,0x00,0x59,0x00,0x00,0x00,0x51,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x58,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,0x59,0x00,0x00,0x00,
-0x41,0x00,0x05,0x00,0x18,0x00,0x00,0x00,0x62,0x00,0x00,0x00,
-0x16,0x00,0x00,0x00,0x61,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x63,0x00,0x00,0x00,0x62,0x00,0x00,0x00,
-0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x64,0x00,0x00,0x00,
-0x23,0x00,0x00,0x00,0x63,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x67,0x00,0x00,0x00,0x64,0x00,0x00,0x00,
-0x26,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0x6e,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x67,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0x6e,0x00,0x00,0x00,
-0x56,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x78,0x00,0x00,0x00,0x67,0x00,0x00,0x00,0x48,0x00,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0x7c,0x00,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x78,0x00,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x7c,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x92,0x00,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x95,0x00,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x9c,0x00,0x00,0x00,0x67,0x00,0x00,0x00,
-0x17,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0x9f,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x9c,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0x9f,0x00,0x00,0x00,
-0x92,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0xa6,0x00,0x00,0x00,0x67,0x00,0x00,0x00,0x6c,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0xa8,0x00,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0xa6,0x00,0x00,0x00,
-0x3e,0x00,0x03,0x00,0xa8,0x00,0x00,0x00,0x95,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0xb2,0x00,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0xb5,0x00,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0xbc,0x00,0x00,0x00,0x67,0x00,0x00,0x00,
-0x37,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0xbf,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0xbc,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0xbf,0x00,0x00,0x00,
-0xb2,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0xc6,0x00,0x00,0x00,0x67,0x00,0x00,0x00,0x6d,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0xc8,0x00,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0xc6,0x00,0x00,0x00,
-0x3e,0x00,0x03,0x00,0xc8,0x00,0x00,0x00,0xb5,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0xd2,0x00,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0xd5,0x00,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0xdc,0x00,0x00,0x00,0x67,0x00,0x00,0x00,
-0x61,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0xdf,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0xdc,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0xdf,0x00,0x00,0x00,
-0xd2,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0xe6,0x00,0x00,0x00,0x67,0x00,0x00,0x00,0x6e,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0xe8,0x00,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0xe6,0x00,0x00,0x00,
-0x3e,0x00,0x03,0x00,0xe8,0x00,0x00,0x00,0xd5,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0xf2,0x00,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0xf5,0x00,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0xfc,0x00,0x00,0x00,0x67,0x00,0x00,0x00,
-0x6f,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0xff,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0xfc,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0xff,0x00,0x00,0x00,
-0xf2,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x06,0x01,0x00,0x00,0x67,0x00,0x00,0x00,0x70,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0x08,0x01,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x06,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x08,0x01,0x00,0x00,0xf5,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x12,0x01,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x15,0x01,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x1c,0x01,0x00,0x00,0x67,0x00,0x00,0x00,
-0x71,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0x1f,0x01,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x1c,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,0x1f,0x01,0x00,0x00,
-0x12,0x01,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x26,0x01,0x00,0x00,0x67,0x00,0x00,0x00,0x72,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0x28,0x01,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x26,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x28,0x01,0x00,0x00,0x15,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x32,0x01,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x35,0x01,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x3c,0x01,0x00,0x00,0x67,0x00,0x00,0x00,
-0x73,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0x3f,0x01,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x3c,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,0x3f,0x01,0x00,0x00,
-0x32,0x01,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x46,0x01,0x00,0x00,0x67,0x00,0x00,0x00,0x74,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0x48,0x01,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x46,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x48,0x01,0x00,0x00,0x35,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x52,0x01,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x55,0x01,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x5c,0x01,0x00,0x00,0x67,0x00,0x00,0x00,
-0x75,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0x5f,0x01,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x5c,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,0x5f,0x01,0x00,0x00,
-0x52,0x01,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x66,0x01,0x00,0x00,0x67,0x00,0x00,0x00,0x76,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0x68,0x01,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x66,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x68,0x01,0x00,0x00,0x55,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x72,0x01,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x75,0x01,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x7c,0x01,0x00,0x00,0x67,0x00,0x00,0x00,
-0x77,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0x7f,0x01,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x7c,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,0x7f,0x01,0x00,0x00,
-0x72,0x01,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x86,0x01,0x00,0x00,0x67,0x00,0x00,0x00,0x78,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0x88,0x01,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x86,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x88,0x01,0x00,0x00,0x75,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x92,0x01,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x95,0x01,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x9c,0x01,0x00,0x00,0x67,0x00,0x00,0x00,
-0x79,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0x9f,0x01,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x9c,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,0x9f,0x01,0x00,0x00,
-0x92,0x01,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0xa6,0x01,0x00,0x00,0x67,0x00,0x00,0x00,0x7a,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0xa8,0x01,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0xa6,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0xa8,0x01,0x00,0x00,0x95,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0xb2,0x01,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0xb5,0x01,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0xbc,0x01,0x00,0x00,0x67,0x00,0x00,0x00,
-0x7b,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0xbf,0x01,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0xbc,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,0xbf,0x01,0x00,0x00,
-0xb2,0x01,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0xc6,0x01,0x00,0x00,0x67,0x00,0x00,0x00,0x7c,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0xc8,0x01,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0xc6,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0xc8,0x01,0x00,0x00,0xb5,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0xd2,0x01,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0xd5,0x01,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0xdc,0x01,0x00,0x00,0x67,0x00,0x00,0x00,
-0x7d,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0xdf,0x01,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0xdc,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,0xdf,0x01,0x00,0x00,
-0xd2,0x01,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0xe6,0x01,0x00,0x00,0x67,0x00,0x00,0x00,0x7e,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0xe8,0x01,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0xe6,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0xe8,0x01,0x00,0x00,0xd5,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0xf2,0x01,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0xf5,0x01,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0xfc,0x01,0x00,0x00,0x67,0x00,0x00,0x00,
-0x7f,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0xff,0x01,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0xfc,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,0xff,0x01,0x00,0x00,
-0xf2,0x01,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x06,0x02,0x00,0x00,0x67,0x00,0x00,0x00,0x80,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0x08,0x02,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x06,0x02,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x08,0x02,0x00,0x00,0xf5,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x12,0x02,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x15,0x02,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x1c,0x02,0x00,0x00,0x67,0x00,0x00,0x00,
-0x81,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0x1f,0x02,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x1c,0x02,0x00,0x00,0x3e,0x00,0x03,0x00,0x1f,0x02,0x00,0x00,
-0x12,0x02,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x26,0x02,0x00,0x00,0x67,0x00,0x00,0x00,0x82,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0x28,0x02,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x26,0x02,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x28,0x02,0x00,0x00,0x15,0x02,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x32,0x02,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x35,0x02,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x3c,0x02,0x00,0x00,0x67,0x00,0x00,0x00,
-0x83,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0x3f,0x02,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x3c,0x02,0x00,0x00,0x3e,0x00,0x03,0x00,0x3f,0x02,0x00,0x00,
-0x32,0x02,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x46,0x02,0x00,0x00,0x67,0x00,0x00,0x00,0x84,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0x48,0x02,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x46,0x02,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x48,0x02,0x00,0x00,0x35,0x02,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x52,0x02,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x55,0x02,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x5c,0x02,0x00,0x00,0x67,0x00,0x00,0x00,
-0x85,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0x5f,0x02,0x00,0x00,0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x5c,0x02,0x00,0x00,0x3e,0x00,0x03,0x00,0x5f,0x02,0x00,0x00,
-0x52,0x02,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x66,0x02,0x00,0x00,0x67,0x00,0x00,0x00,0x86,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0x68,0x02,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x66,0x02,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x68,0x02,0x00,0x00,0x55,0x02,0x00,0x00,
-0xf9,0x00,0x02,0x00,0x81,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x81,0x00,0x00,0x00,0xfd,0x00,0x01,0x00,0x38,0x00,0x01,0x00,
-
+0x06,0x00,0x00,0x00,0x14,0x00,0x02,0x00,0x23,0x00,0x00,0x00,
+0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x36,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x16,0x00,0x03,0x00,
+0x48,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,
+0x4c,0x00,0x00,0x00,0x48,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,
+0x4d,0x00,0x00,0x00,0x4c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
+0x4e,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x4d,0x00,0x00,0x00,
+0x3b,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x52,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x48,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,
+0x5a,0x00,0x00,0x00,0x48,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,
+0x5b,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
+0x5c,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,
+0x3b,0x00,0x04,0x00,0x5c,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x5f,0x00,0x00,0x00,0x03,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
+0x09,0x00,0x00,0x00,0x77,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
+0x2b,0x00,0x04,0x00,0x09,0x00,0x00,0x00,0x7d,0x00,0x00,0x00,
+0x00,0x01,0x00,0x00,0x2c,0x00,0x06,0x00,0x0a,0x00,0x00,0x00,
+0x7e,0x00,0x00,0x00,0x7d,0x00,0x00,0x00,0x77,0x00,0x00,0x00,
+0x77,0x00,0x00,0x00,0x36,0x00,0x05,0x00,0x02,0x00,0x00,0x00,
+0x04,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x03,0x00,0x00,0x00,
+0xf8,0x00,0x02,0x00,0x05,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,
+0x7f,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xfb,0x00,0x03,0x00,
+0x0d,0x00,0x00,0x00,0x80,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
+0x80,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x0e,0x00,0x00,0x00,
+0x0f,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x0d,0x00,0x00,0x00,
+0x3d,0x00,0x04,0x00,0x09,0x00,0x00,0x00,0x10,0x00,0x00,0x00,
+0x0f,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x11,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
+0x18,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x17,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x1a,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x87,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,
+0x17,0x00,0x00,0x00,0x8b,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x1c,0x00,0x00,0x00,0x11,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,
+0x87,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
+0x11,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0xaf,0x00,0x05,0x00,
+0x23,0x00,0x00,0x00,0x28,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,
+0x1a,0x00,0x00,0x00,0xa8,0x00,0x04,0x00,0x23,0x00,0x00,0x00,
+0x29,0x00,0x00,0x00,0x28,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,
+0x2b,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xfa,0x00,0x04,0x00,
+0x29,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,0x2b,0x00,0x00,0x00,
+0xf8,0x00,0x02,0x00,0x2a,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
+0x18,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x2d,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x2f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0xaf,0x00,0x05,0x00,
+0x23,0x00,0x00,0x00,0x30,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
+0x2f,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x2b,0x00,0x00,0x00,
+0xf8,0x00,0x02,0x00,0x2b,0x00,0x00,0x00,0xf5,0x00,0x07,0x00,
+0x23,0x00,0x00,0x00,0x31,0x00,0x00,0x00,0x28,0x00,0x00,0x00,
+0x80,0x00,0x00,0x00,0x30,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,
+0xf7,0x00,0x03,0x00,0x33,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0xfa,0x00,0x04,0x00,0x31,0x00,0x00,0x00,0x32,0x00,0x00,0x00,
+0x33,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x32,0x00,0x00,0x00,
+0xf9,0x00,0x02,0x00,0x7f,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
+0x33,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x18,0x00,0x00,0x00,
+0x37,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x36,0x00,0x00,0x00,
+0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x38,0x00,0x00,0x00,
+0x37,0x00,0x00,0x00,0x87,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x39,0x00,0x00,0x00,0x38,0x00,0x00,0x00,0x17,0x00,0x00,0x00,
+0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x3d,0x00,0x00,0x00,
+0x22,0x00,0x00,0x00,0x39,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,0x3d,0x00,0x00,0x00,
+0x1c,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x52,0x00,0x00,0x00,
+0x53,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,
+0x3f,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x48,0x00,0x00,0x00,
+0x54,0x00,0x00,0x00,0x53,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x56,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,
+0x17,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x52,0x00,0x00,0x00,
+0x57,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,
+0x56,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x48,0x00,0x00,0x00,
+0x58,0x00,0x00,0x00,0x57,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
+0x18,0x00,0x00,0x00,0x60,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x5f,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x61,0x00,0x00,0x00,0x60,0x00,0x00,0x00,0x84,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x62,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
+0x61,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x65,0x00,0x00,0x00,0x62,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,
+0x41,0x00,0x06,0x00,0x52,0x00,0x00,0x00,0x6c,0x00,0x00,0x00,
+0x5d,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
+0x3e,0x00,0x03,0x00,0x6c,0x00,0x00,0x00,0x54,0x00,0x00,0x00,
+0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x76,0x00,0x00,0x00,
+0x65,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x52,0x00,0x00,0x00,0x7a,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,
+0x2d,0x00,0x00,0x00,0x76,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
+0x7a,0x00,0x00,0x00,0x58,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
+0x7f,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x7f,0x00,0x00,0x00,
+0xfd,0x00,0x01,0x00,0x38,0x00,0x01,0x00,
};
-const uint64_t dequant_f16_len = 4392;
+const uint64_t dequant_f16_len = 1748;
unsigned char dequant_f16_fp32_data[] = {
0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
-0xc8,0x02,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
+0x86,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
0x01,0x00,0x00,0x00,0x11,0x00,0x02,0x00,0x51,0x11,0x00,0x00,
0x0b,0x00,0x06,0x00,0x01,0x00,0x00,0x00,0x47,0x4c,0x53,0x4c,
0x2e,0x73,0x74,0x64,0x2e,0x34,0x35,0x30,0x00,0x00,0x00,0x00,
0x0e,0x00,0x03,0x00,0x00,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
0x0f,0x00,0x09,0x00,0x05,0x00,0x00,0x00,0x04,0x00,0x00,0x00,
0x6d,0x61,0x69,0x6e,0x00,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x16,0x00,0x00,0x00,0x52,0x00,0x00,0x00,0x62,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x50,0x00,0x00,0x00,0x60,0x00,0x00,0x00,
0x10,0x00,0x06,0x00,0x04,0x00,0x00,0x00,0x11,0x00,0x00,0x00,
0x00,0x01,0x00,0x00,0x01,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
0x47,0x00,0x04,0x00,0x0c,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
@@ -1280,23 +1059,23 @@ unsigned char dequant_f16_fp32_data[] = {
0x08,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0x14,0x00,0x00,0x00,
0x03,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
0x47,0x00,0x03,0x00,0x14,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x4f,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x50,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x4d,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x18,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
-0x50,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x50,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x52,0x00,0x00,0x00,
+0x4e,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x4e,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x50,0x00,0x00,0x00,
0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x52,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x5f,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x60,0x00,0x00,0x00,
+0x50,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x5d,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x5e,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
-0x60,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x60,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x62,0x00,0x00,0x00,
+0x5e,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x5e,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x60,0x00,0x00,0x00,
0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x62,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x85,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
+0x60,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x83,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
0x19,0x00,0x00,0x00,0x13,0x00,0x02,0x00,0x02,0x00,0x00,0x00,
0x21,0x00,0x03,0x00,0x03,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
0x15,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
@@ -1315,405 +1094,105 @@ unsigned char dequant_f16_fp32_data[] = {
0x15,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x09,0x00,0x00,0x00,
0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x17,0x00,0x00,0x00,
0x01,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x18,0x00,0x00,0x00,
-0x09,0x00,0x00,0x00,0x06,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
-0x14,0x00,0x02,0x00,0x24,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x37,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x48,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x16,0x00,0x03,0x00,
-0x4a,0x00,0x00,0x00,0x20,0x00,0x00,0x00,0x16,0x00,0x03,0x00,
-0x4e,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,
-0x4f,0x00,0x00,0x00,0x4e,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,
-0x50,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
-0x51,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x50,0x00,0x00,0x00,
-0x3b,0x00,0x04,0x00,0x51,0x00,0x00,0x00,0x52,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x55,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x4e,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,
-0x5f,0x00,0x00,0x00,0x4e,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,
-0x60,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
-0x61,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x60,0x00,0x00,0x00,
-0x3b,0x00,0x04,0x00,0x61,0x00,0x00,0x00,0x62,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x64,0x00,0x00,0x00,0x03,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x09,0x00,0x00,0x00,0x7d,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x09,0x00,0x00,0x00,0x84,0x00,0x00,0x00,
-0x00,0x01,0x00,0x00,0x2c,0x00,0x06,0x00,0x0a,0x00,0x00,0x00,
-0x85,0x00,0x00,0x00,0x84,0x00,0x00,0x00,0x7d,0x00,0x00,0x00,
-0x7d,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0xad,0x02,0x00,0x00,0x11,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0xae,0x02,0x00,0x00,0x12,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0xaf,0x02,0x00,0x00,
-0x13,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0xb0,0x02,0x00,0x00,0x04,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0xb1,0x02,0x00,0x00,0x14,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0xb2,0x02,0x00,0x00,
-0x05,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0xb3,0x02,0x00,0x00,0x15,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0xb4,0x02,0x00,0x00,0x06,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0xb5,0x02,0x00,0x00,
-0x16,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0xb6,0x02,0x00,0x00,0x07,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0xb7,0x02,0x00,0x00,0x17,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0xb8,0x02,0x00,0x00,
-0x08,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0xb9,0x02,0x00,0x00,0x18,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0xba,0x02,0x00,0x00,0x09,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0xbb,0x02,0x00,0x00,
-0x19,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0xbc,0x02,0x00,0x00,0x0a,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0xbd,0x02,0x00,0x00,0x1a,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0xbe,0x02,0x00,0x00,
-0x0b,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0xbf,0x02,0x00,0x00,0x1b,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0xc0,0x02,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0xc1,0x02,0x00,0x00,
-0x1c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0xc2,0x02,0x00,0x00,0x0d,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0xc3,0x02,0x00,0x00,0x1d,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0xc4,0x02,0x00,0x00,
-0x0e,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0xc5,0x02,0x00,0x00,0x1e,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0xc6,0x02,0x00,0x00,0x0f,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0xc7,0x02,0x00,0x00,
-0x1f,0x00,0x00,0x00,0x36,0x00,0x05,0x00,0x02,0x00,0x00,0x00,
-0x04,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x03,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0x05,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,
-0x86,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xfb,0x00,0x03,0x00,
-0x0d,0x00,0x00,0x00,0x87,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x87,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x0e,0x00,0x00,0x00,
-0x0f,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x0d,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x09,0x00,0x00,0x00,0x10,0x00,0x00,0x00,
-0x0f,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x11,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
-0x18,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
-0x17,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x1a,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x87,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,
-0x1b,0x00,0x00,0x00,0x8b,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x1d,0x00,0x00,0x00,0x11,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,
-0x87,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0x11,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,0x84,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x26,0x00,0x00,0x00,0x1d,0x00,0x00,0x00,
-0x1b,0x00,0x00,0x00,0xaf,0x00,0x05,0x00,0x24,0x00,0x00,0x00,
-0x29,0x00,0x00,0x00,0x26,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,
-0xa8,0x00,0x04,0x00,0x24,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,
-0x29,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,0x2c,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0xfa,0x00,0x04,0x00,0x2a,0x00,0x00,0x00,
-0x2b,0x00,0x00,0x00,0x2c,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x2b,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x18,0x00,0x00,0x00,
-0x2f,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x30,0x00,0x00,0x00,
-0x2f,0x00,0x00,0x00,0xaf,0x00,0x05,0x00,0x24,0x00,0x00,0x00,
-0x31,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x30,0x00,0x00,0x00,
-0xf9,0x00,0x02,0x00,0x2c,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x2c,0x00,0x00,0x00,0xf5,0x00,0x07,0x00,0x24,0x00,0x00,0x00,
-0x32,0x00,0x00,0x00,0x29,0x00,0x00,0x00,0x87,0x00,0x00,0x00,
-0x31,0x00,0x00,0x00,0x2b,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,
-0x34,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xfa,0x00,0x04,0x00,
-0x32,0x00,0x00,0x00,0x33,0x00,0x00,0x00,0x34,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0x33,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
-0x86,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x34,0x00,0x00,0x00,
-0x41,0x00,0x05,0x00,0x18,0x00,0x00,0x00,0x38,0x00,0x00,0x00,
-0x16,0x00,0x00,0x00,0x37,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
+0x09,0x00,0x00,0x00,0x06,0x00,0x00,0x00,0x14,0x00,0x02,0x00,
+0x23,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x2d,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
+0x06,0x00,0x00,0x00,0x36,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
+0x16,0x00,0x03,0x00,0x48,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
+0x16,0x00,0x03,0x00,0x4c,0x00,0x00,0x00,0x10,0x00,0x00,0x00,
+0x1d,0x00,0x03,0x00,0x4d,0x00,0x00,0x00,0x4c,0x00,0x00,0x00,
+0x1e,0x00,0x03,0x00,0x4e,0x00,0x00,0x00,0x4d,0x00,0x00,0x00,
+0x20,0x00,0x04,0x00,0x4f,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x4e,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,0x4f,0x00,0x00,0x00,
+0x50,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
+0x53,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x4c,0x00,0x00,0x00,
+0x1d,0x00,0x03,0x00,0x5d,0x00,0x00,0x00,0x4c,0x00,0x00,0x00,
+0x1e,0x00,0x03,0x00,0x5e,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,
+0x20,0x00,0x04,0x00,0x5f,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x5e,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,0x5f,0x00,0x00,0x00,
+0x60,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
+0x06,0x00,0x00,0x00,0x62,0x00,0x00,0x00,0x03,0x00,0x00,0x00,
+0x2b,0x00,0x04,0x00,0x09,0x00,0x00,0x00,0x7b,0x00,0x00,0x00,
+0x01,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x09,0x00,0x00,0x00,
+0x82,0x00,0x00,0x00,0x00,0x01,0x00,0x00,0x2c,0x00,0x06,0x00,
+0x0a,0x00,0x00,0x00,0x83,0x00,0x00,0x00,0x82,0x00,0x00,0x00,
+0x7b,0x00,0x00,0x00,0x7b,0x00,0x00,0x00,0x36,0x00,0x05,0x00,
+0x02,0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x03,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x05,0x00,0x00,0x00,
+0xf7,0x00,0x03,0x00,0x84,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0xfb,0x00,0x03,0x00,0x0d,0x00,0x00,0x00,0x85,0x00,0x00,0x00,
+0xf8,0x00,0x02,0x00,0x85,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
+0x0e,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x0d,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x09,0x00,0x00,0x00,
+0x10,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,
+0x06,0x00,0x00,0x00,0x11,0x00,0x00,0x00,0x10,0x00,0x00,0x00,
+0x41,0x00,0x05,0x00,0x18,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
+0x06,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
+0x87,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,
+0x1a,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x8b,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,0x11,0x00,0x00,0x00,
+0x1b,0x00,0x00,0x00,0x87,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x22,0x00,0x00,0x00,0x11,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,
+0xaf,0x00,0x05,0x00,0x23,0x00,0x00,0x00,0x28,0x00,0x00,0x00,
+0x1c,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,0xa8,0x00,0x04,0x00,
+0x23,0x00,0x00,0x00,0x29,0x00,0x00,0x00,0x28,0x00,0x00,0x00,
+0xf7,0x00,0x03,0x00,0x2b,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0xfa,0x00,0x04,0x00,0x29,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,
+0x2b,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x2a,0x00,0x00,0x00,
+0x41,0x00,0x05,0x00,0x18,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
+0x06,0x00,0x00,0x00,0x2f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
+0xaf,0x00,0x05,0x00,0x23,0x00,0x00,0x00,0x30,0x00,0x00,0x00,
+0x22,0x00,0x00,0x00,0x2f,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
+0x2b,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x2b,0x00,0x00,0x00,
+0xf5,0x00,0x07,0x00,0x23,0x00,0x00,0x00,0x31,0x00,0x00,0x00,
+0x28,0x00,0x00,0x00,0x85,0x00,0x00,0x00,0x30,0x00,0x00,0x00,
+0x2a,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,0x33,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0xfa,0x00,0x04,0x00,0x31,0x00,0x00,0x00,
+0x32,0x00,0x00,0x00,0x33,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
+0x32,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x84,0x00,0x00,0x00,
+0xf8,0x00,0x02,0x00,0x33,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
+0x18,0x00,0x00,0x00,0x37,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x36,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x38,0x00,0x00,0x00,0x37,0x00,0x00,0x00,0x87,0x00,0x05,0x00,
0x06,0x00,0x00,0x00,0x39,0x00,0x00,0x00,0x38,0x00,0x00,0x00,
-0x87,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,
-0x39,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0x84,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x3e,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0x3a,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x40,0x00,0x00,0x00,0x3e,0x00,0x00,0x00,0x1d,0x00,0x00,0x00,
-0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,0x56,0x00,0x00,0x00,
-0x52,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x40,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x57,0x00,0x00,0x00,
-0x56,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x58,0x00,0x00,0x00,0x57,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,0x40,0x00,0x00,0x00,
-0x17,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,
-0x5b,0x00,0x00,0x00,0x52,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x5a,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x5c,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,0x5c,0x00,0x00,0x00,
-0x41,0x00,0x05,0x00,0x18,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
-0x16,0x00,0x00,0x00,0x64,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x66,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
-0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x67,0x00,0x00,0x00,
-0x23,0x00,0x00,0x00,0x66,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x6a,0x00,0x00,0x00,0x67,0x00,0x00,0x00,
-0x26,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x71,0x00,0x00,0x00,0x58,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
-0x55,0x00,0x00,0x00,0x72,0x00,0x00,0x00,0x62,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x6a,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
-0x72,0x00,0x00,0x00,0x71,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x7c,0x00,0x00,0x00,0x6a,0x00,0x00,0x00,
-0x48,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x80,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
-0x55,0x00,0x00,0x00,0x81,0x00,0x00,0x00,0x62,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x7c,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
-0x81,0x00,0x00,0x00,0x80,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x97,0x00,0x00,0x00,0x56,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x98,0x00,0x00,0x00,
-0x97,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x9b,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0x9c,0x00,0x00,0x00,0x9b,0x00,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0xa3,0x00,0x00,0x00,
-0x6a,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0xa6,0x00,0x00,0x00,0x98,0x00,0x00,0x00,
-0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,0xa7,0x00,0x00,0x00,
-0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0xa3,0x00,0x00,0x00,
-0x3e,0x00,0x03,0x00,0xa7,0x00,0x00,0x00,0xa6,0x00,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0xae,0x00,0x00,0x00,
-0x6a,0x00,0x00,0x00,0xad,0x02,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0xb0,0x00,0x00,0x00,0x9c,0x00,0x00,0x00,
-0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,0xb1,0x00,0x00,0x00,
-0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0xae,0x00,0x00,0x00,
-0x3e,0x00,0x03,0x00,0xb1,0x00,0x00,0x00,0xb0,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0xbb,0x00,0x00,0x00,
-0x56,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0xbc,0x00,0x00,0x00,0xbb,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0xbf,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0xc0,0x00,0x00,0x00,
-0xbf,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0xc7,0x00,0x00,0x00,0x6a,0x00,0x00,0x00,0x37,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0xca,0x00,0x00,0x00,
-0xbc,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,
-0xcb,0x00,0x00,0x00,0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0xc7,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0xcb,0x00,0x00,0x00,
-0xca,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0xd2,0x00,0x00,0x00,0x6a,0x00,0x00,0x00,0xae,0x02,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0xd4,0x00,0x00,0x00,
-0xc0,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,
-0xd5,0x00,0x00,0x00,0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0xd2,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0xd5,0x00,0x00,0x00,
-0xd4,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0xdf,0x00,0x00,0x00,0x56,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0xe0,0x00,0x00,0x00,0xdf,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0xe3,0x00,0x00,0x00,
-0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0xe4,0x00,0x00,0x00,0xe3,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0xeb,0x00,0x00,0x00,0x6a,0x00,0x00,0x00,
-0x64,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0xee,0x00,0x00,0x00,0xe0,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
-0x55,0x00,0x00,0x00,0xef,0x00,0x00,0x00,0x62,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0xeb,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
-0xef,0x00,0x00,0x00,0xee,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0xf6,0x00,0x00,0x00,0x6a,0x00,0x00,0x00,
-0xaf,0x02,0x00,0x00,0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0xf8,0x00,0x00,0x00,0xe4,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
-0x55,0x00,0x00,0x00,0xf9,0x00,0x00,0x00,0x62,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0xf6,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
-0xf9,0x00,0x00,0x00,0xf8,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x03,0x01,0x00,0x00,0x56,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x04,0x01,0x00,0x00,
-0x03,0x01,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x07,0x01,0x00,0x00,0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0x08,0x01,0x00,0x00,0x07,0x01,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x0f,0x01,0x00,0x00,
-0x6a,0x00,0x00,0x00,0xb0,0x02,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x12,0x01,0x00,0x00,0x04,0x01,0x00,0x00,
-0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,0x13,0x01,0x00,0x00,
-0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x0f,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x13,0x01,0x00,0x00,0x12,0x01,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x1a,0x01,0x00,0x00,
-0x6a,0x00,0x00,0x00,0xb1,0x02,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x1c,0x01,0x00,0x00,0x08,0x01,0x00,0x00,
-0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,0x1d,0x01,0x00,0x00,
-0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x1a,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x1d,0x01,0x00,0x00,0x1c,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x27,0x01,0x00,0x00,
-0x56,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x28,0x01,0x00,0x00,0x27,0x01,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x2b,0x01,0x00,0x00,0x5b,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x2c,0x01,0x00,0x00,
-0x2b,0x01,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x33,0x01,0x00,0x00,0x6a,0x00,0x00,0x00,0xb2,0x02,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x36,0x01,0x00,0x00,
-0x28,0x01,0x00,0x00,0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,
-0x37,0x01,0x00,0x00,0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x33,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,0x37,0x01,0x00,0x00,
-0x36,0x01,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x3e,0x01,0x00,0x00,0x6a,0x00,0x00,0x00,0xb3,0x02,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x40,0x01,0x00,0x00,
-0x2c,0x01,0x00,0x00,0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,
-0x41,0x01,0x00,0x00,0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x3e,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,0x41,0x01,0x00,0x00,
-0x40,0x01,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x4b,0x01,0x00,0x00,0x56,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0x4c,0x01,0x00,0x00,0x4b,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x4f,0x01,0x00,0x00,
-0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x50,0x01,0x00,0x00,0x4f,0x01,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x57,0x01,0x00,0x00,0x6a,0x00,0x00,0x00,
-0xb4,0x02,0x00,0x00,0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x5a,0x01,0x00,0x00,0x4c,0x01,0x00,0x00,0x41,0x00,0x06,0x00,
-0x55,0x00,0x00,0x00,0x5b,0x01,0x00,0x00,0x62,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x57,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,
-0x5b,0x01,0x00,0x00,0x5a,0x01,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x62,0x01,0x00,0x00,0x6a,0x00,0x00,0x00,
-0xb5,0x02,0x00,0x00,0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x64,0x01,0x00,0x00,0x50,0x01,0x00,0x00,0x41,0x00,0x06,0x00,
-0x55,0x00,0x00,0x00,0x65,0x01,0x00,0x00,0x62,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x62,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,
-0x65,0x01,0x00,0x00,0x64,0x01,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x6f,0x01,0x00,0x00,0x56,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x70,0x01,0x00,0x00,
-0x6f,0x01,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x73,0x01,0x00,0x00,0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0x74,0x01,0x00,0x00,0x73,0x01,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x7b,0x01,0x00,0x00,
-0x6a,0x00,0x00,0x00,0xb6,0x02,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x7e,0x01,0x00,0x00,0x70,0x01,0x00,0x00,
-0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,0x7f,0x01,0x00,0x00,
-0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x7b,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x7f,0x01,0x00,0x00,0x7e,0x01,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x86,0x01,0x00,0x00,
-0x6a,0x00,0x00,0x00,0xb7,0x02,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x88,0x01,0x00,0x00,0x74,0x01,0x00,0x00,
-0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,0x89,0x01,0x00,0x00,
-0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x86,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x89,0x01,0x00,0x00,0x88,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x93,0x01,0x00,0x00,
-0x56,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x94,0x01,0x00,0x00,0x93,0x01,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x97,0x01,0x00,0x00,0x5b,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x98,0x01,0x00,0x00,
-0x97,0x01,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x9f,0x01,0x00,0x00,0x6a,0x00,0x00,0x00,0xb8,0x02,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0xa2,0x01,0x00,0x00,
-0x94,0x01,0x00,0x00,0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,
-0xa3,0x01,0x00,0x00,0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x9f,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,0xa3,0x01,0x00,0x00,
-0xa2,0x01,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0xaa,0x01,0x00,0x00,0x6a,0x00,0x00,0x00,0xb9,0x02,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0xac,0x01,0x00,0x00,
-0x98,0x01,0x00,0x00,0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,
-0xad,0x01,0x00,0x00,0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0xaa,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,0xad,0x01,0x00,0x00,
-0xac,0x01,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0xb7,0x01,0x00,0x00,0x56,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0xb8,0x01,0x00,0x00,0xb7,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0xbb,0x01,0x00,0x00,
-0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0xbc,0x01,0x00,0x00,0xbb,0x01,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0xc3,0x01,0x00,0x00,0x6a,0x00,0x00,0x00,
-0xba,0x02,0x00,0x00,0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0xc6,0x01,0x00,0x00,0xb8,0x01,0x00,0x00,0x41,0x00,0x06,0x00,
-0x55,0x00,0x00,0x00,0xc7,0x01,0x00,0x00,0x62,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0xc3,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,
-0xc7,0x01,0x00,0x00,0xc6,0x01,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0xce,0x01,0x00,0x00,0x6a,0x00,0x00,0x00,
-0xbb,0x02,0x00,0x00,0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0xd0,0x01,0x00,0x00,0xbc,0x01,0x00,0x00,0x41,0x00,0x06,0x00,
-0x55,0x00,0x00,0x00,0xd1,0x01,0x00,0x00,0x62,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0xce,0x01,0x00,0x00,0x3e,0x00,0x03,0x00,
-0xd1,0x01,0x00,0x00,0xd0,0x01,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0xdb,0x01,0x00,0x00,0x56,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0xdc,0x01,0x00,0x00,
-0xdb,0x01,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0xdf,0x01,0x00,0x00,0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0xe0,0x01,0x00,0x00,0xdf,0x01,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0xe7,0x01,0x00,0x00,
-0x6a,0x00,0x00,0x00,0xbc,0x02,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0xea,0x01,0x00,0x00,0xdc,0x01,0x00,0x00,
-0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,0xeb,0x01,0x00,0x00,
-0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0xe7,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0xeb,0x01,0x00,0x00,0xea,0x01,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0xf2,0x01,0x00,0x00,
-0x6a,0x00,0x00,0x00,0xbd,0x02,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0xf4,0x01,0x00,0x00,0xe0,0x01,0x00,0x00,
-0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,0xf5,0x01,0x00,0x00,
-0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0xf2,0x01,0x00,0x00,
-0x3e,0x00,0x03,0x00,0xf5,0x01,0x00,0x00,0xf4,0x01,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0xff,0x01,0x00,0x00,
-0x56,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x00,0x02,0x00,0x00,0xff,0x01,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x03,0x02,0x00,0x00,0x5b,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x04,0x02,0x00,0x00,
-0x03,0x02,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x0b,0x02,0x00,0x00,0x6a,0x00,0x00,0x00,0xbe,0x02,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x0e,0x02,0x00,0x00,
-0x00,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,
-0x0f,0x02,0x00,0x00,0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x0b,0x02,0x00,0x00,0x3e,0x00,0x03,0x00,0x0f,0x02,0x00,0x00,
-0x0e,0x02,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x16,0x02,0x00,0x00,0x6a,0x00,0x00,0x00,0xbf,0x02,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x18,0x02,0x00,0x00,
-0x04,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,
-0x19,0x02,0x00,0x00,0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x16,0x02,0x00,0x00,0x3e,0x00,0x03,0x00,0x19,0x02,0x00,0x00,
-0x18,0x02,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x23,0x02,0x00,0x00,0x56,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0x24,0x02,0x00,0x00,0x23,0x02,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x27,0x02,0x00,0x00,
-0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x28,0x02,0x00,0x00,0x27,0x02,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x2f,0x02,0x00,0x00,0x6a,0x00,0x00,0x00,
-0xc0,0x02,0x00,0x00,0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x32,0x02,0x00,0x00,0x24,0x02,0x00,0x00,0x41,0x00,0x06,0x00,
-0x55,0x00,0x00,0x00,0x33,0x02,0x00,0x00,0x62,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x2f,0x02,0x00,0x00,0x3e,0x00,0x03,0x00,
-0x33,0x02,0x00,0x00,0x32,0x02,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x3a,0x02,0x00,0x00,0x6a,0x00,0x00,0x00,
-0xc1,0x02,0x00,0x00,0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x3c,0x02,0x00,0x00,0x28,0x02,0x00,0x00,0x41,0x00,0x06,0x00,
-0x55,0x00,0x00,0x00,0x3d,0x02,0x00,0x00,0x62,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x3a,0x02,0x00,0x00,0x3e,0x00,0x03,0x00,
-0x3d,0x02,0x00,0x00,0x3c,0x02,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x47,0x02,0x00,0x00,0x56,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x48,0x02,0x00,0x00,
-0x47,0x02,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x4b,0x02,0x00,0x00,0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0x4c,0x02,0x00,0x00,0x4b,0x02,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x53,0x02,0x00,0x00,
-0x6a,0x00,0x00,0x00,0xc2,0x02,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x56,0x02,0x00,0x00,0x48,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,0x57,0x02,0x00,0x00,
-0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x53,0x02,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x57,0x02,0x00,0x00,0x56,0x02,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x5e,0x02,0x00,0x00,
-0x6a,0x00,0x00,0x00,0xc3,0x02,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x60,0x02,0x00,0x00,0x4c,0x02,0x00,0x00,
-0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,0x61,0x02,0x00,0x00,
-0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x5e,0x02,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x61,0x02,0x00,0x00,0x60,0x02,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x6b,0x02,0x00,0x00,
-0x56,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x6c,0x02,0x00,0x00,0x6b,0x02,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x4e,0x00,0x00,0x00,0x6f,0x02,0x00,0x00,0x5b,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,0x70,0x02,0x00,0x00,
-0x6f,0x02,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x77,0x02,0x00,0x00,0x6a,0x00,0x00,0x00,0xc4,0x02,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x7a,0x02,0x00,0x00,
-0x6c,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,
-0x7b,0x02,0x00,0x00,0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x77,0x02,0x00,0x00,0x3e,0x00,0x03,0x00,0x7b,0x02,0x00,0x00,
-0x7a,0x02,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x82,0x02,0x00,0x00,0x6a,0x00,0x00,0x00,0xc5,0x02,0x00,0x00,
-0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x84,0x02,0x00,0x00,
-0x70,0x02,0x00,0x00,0x41,0x00,0x06,0x00,0x55,0x00,0x00,0x00,
-0x85,0x02,0x00,0x00,0x62,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x82,0x02,0x00,0x00,0x3e,0x00,0x03,0x00,0x85,0x02,0x00,0x00,
-0x84,0x02,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x8f,0x02,0x00,0x00,0x56,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x4a,0x00,0x00,0x00,0x90,0x02,0x00,0x00,0x8f,0x02,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x93,0x02,0x00,0x00,
-0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x4a,0x00,0x00,0x00,
-0x94,0x02,0x00,0x00,0x93,0x02,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x9b,0x02,0x00,0x00,0x6a,0x00,0x00,0x00,
-0xc6,0x02,0x00,0x00,0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0x9e,0x02,0x00,0x00,0x90,0x02,0x00,0x00,0x41,0x00,0x06,0x00,
-0x55,0x00,0x00,0x00,0x9f,0x02,0x00,0x00,0x62,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x9b,0x02,0x00,0x00,0x3e,0x00,0x03,0x00,
-0x9f,0x02,0x00,0x00,0x9e,0x02,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0xa6,0x02,0x00,0x00,0x6a,0x00,0x00,0x00,
-0xc7,0x02,0x00,0x00,0x73,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
-0xa8,0x02,0x00,0x00,0x94,0x02,0x00,0x00,0x41,0x00,0x06,0x00,
-0x55,0x00,0x00,0x00,0xa9,0x02,0x00,0x00,0x62,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0xa6,0x02,0x00,0x00,0x3e,0x00,0x03,0x00,
-0xa9,0x02,0x00,0x00,0xa8,0x02,0x00,0x00,0xf9,0x00,0x02,0x00,
-0x86,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x86,0x00,0x00,0x00,
-0xfd,0x00,0x01,0x00,0x38,0x00,0x01,0x00,
+0x17,0x00,0x00,0x00,0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x3d,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x39,0x00,0x00,0x00,
+0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,
+0x3d,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x53,0x00,0x00,0x00,0x54,0x00,0x00,0x00,0x50,0x00,0x00,0x00,
+0x2d,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
+0x4c,0x00,0x00,0x00,0x55,0x00,0x00,0x00,0x54,0x00,0x00,0x00,
+0x73,0x00,0x04,0x00,0x48,0x00,0x00,0x00,0x56,0x00,0x00,0x00,
+0x55,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x58,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,0x17,0x00,0x00,0x00,
+0x41,0x00,0x06,0x00,0x53,0x00,0x00,0x00,0x59,0x00,0x00,0x00,
+0x50,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x58,0x00,0x00,0x00,
+0x3d,0x00,0x04,0x00,0x4c,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,
+0x59,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x48,0x00,0x00,0x00,
+0x5b,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
+0x18,0x00,0x00,0x00,0x63,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x62,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x64,0x00,0x00,0x00,0x63,0x00,0x00,0x00,0x84,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x65,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
+0x64,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x68,0x00,0x00,0x00,0x65,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,
+0x73,0x00,0x04,0x00,0x4c,0x00,0x00,0x00,0x6f,0x00,0x00,0x00,
+0x56,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x53,0x00,0x00,0x00,
+0x70,0x00,0x00,0x00,0x60,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,
+0x68,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0x70,0x00,0x00,0x00,
+0x6f,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x7a,0x00,0x00,0x00,0x68,0x00,0x00,0x00,0x17,0x00,0x00,0x00,
+0x73,0x00,0x04,0x00,0x4c,0x00,0x00,0x00,0x7e,0x00,0x00,0x00,
+0x5b,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x53,0x00,0x00,0x00,
+0x7f,0x00,0x00,0x00,0x60,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,
+0x7a,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0x7f,0x00,0x00,0x00,
+0x7e,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x84,0x00,0x00,0x00,
+0xf8,0x00,0x02,0x00,0x84,0x00,0x00,0x00,0xfd,0x00,0x01,0x00,
+0x38,0x00,0x01,0x00,
};
-const uint64_t dequant_f16_fp32_len = 5420;
+const uint64_t dequant_f16_fp32_len = 1816;
unsigned char dequant_q2_K_data[] = {
0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
@@ -15313,7 +14792,7 @@ const uint64_t gelu_f32_len = 1408;
unsigned char get_rows_f16_data[] = {
0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
-0x7a,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
+0x77,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
0x01,0x00,0x00,0x00,0x11,0x00,0x02,0x00,0x09,0x00,0x00,0x00,
0x11,0x00,0x02,0x00,0x51,0x11,0x00,0x00,0x0b,0x00,0x06,0x00,
0x01,0x00,0x00,0x00,0x47,0x4c,0x53,0x4c,0x2e,0x73,0x74,0x64,
@@ -15321,7 +14800,7 @@ unsigned char get_rows_f16_data[] = {
0x00,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x0f,0x00,0x0a,0x00,
0x05,0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x6d,0x61,0x69,0x6e,
0x00,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,0x1f,0x00,0x00,0x00,
-0x2d,0x00,0x00,0x00,0x57,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
+0x2d,0x00,0x00,0x00,0x55,0x00,0x00,0x00,0x63,0x00,0x00,0x00,
0x10,0x00,0x06,0x00,0x04,0x00,0x00,0x00,0x11,0x00,0x00,0x00,
0x00,0x02,0x00,0x00,0x01,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
0x47,0x00,0x04,0x00,0x0b,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
@@ -15341,22 +14820,184 @@ unsigned char get_rows_f16_data[] = {
0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x2d,0x00,0x00,0x00,
0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
0x2d,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x54,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x55,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x52,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x53,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x18,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
-0x55,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x55,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x57,0x00,0x00,0x00,
+0x53,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x53,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x55,0x00,0x00,0x00,
0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x57,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x62,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x63,0x00,0x00,0x00,
+0x55,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x60,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x61,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
-0x63,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x63,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x65,0x00,0x00,0x00,
+0x61,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x61,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x63,0x00,0x00,0x00,
0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x65,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
+0x63,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x74,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
+0x19,0x00,0x00,0x00,0x13,0x00,0x02,0x00,0x02,0x00,0x00,0x00,
+0x21,0x00,0x03,0x00,0x03,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
+0x15,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x17,0x00,0x04,0x00,0x09,0x00,0x00,0x00,
+0x06,0x00,0x00,0x00,0x03,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
+0x0a,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x09,0x00,0x00,0x00,
+0x3b,0x00,0x04,0x00,0x0a,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
+0x01,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
+0x0d,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
+0x15,0x00,0x04,0x00,0x10,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
+0x01,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x10,0x00,0x00,0x00,
+0x12,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
+0x06,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
+0x16,0x00,0x03,0x00,0x1c,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
+0x1e,0x00,0x06,0x00,0x1d,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
+0x06,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,
+0x20,0x00,0x04,0x00,0x1e,0x00,0x00,0x00,0x09,0x00,0x00,0x00,
+0x1d,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,0x1e,0x00,0x00,0x00,
+0x1f,0x00,0x00,0x00,0x09,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
+0x10,0x00,0x00,0x00,0x20,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
+0x20,0x00,0x04,0x00,0x21,0x00,0x00,0x00,0x09,0x00,0x00,0x00,
+0x06,0x00,0x00,0x00,0x14,0x00,0x02,0x00,0x24,0x00,0x00,0x00,
+0x1d,0x00,0x03,0x00,0x2a,0x00,0x00,0x00,0x10,0x00,0x00,0x00,
+0x1e,0x00,0x03,0x00,0x2b,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,
+0x20,0x00,0x04,0x00,0x2c,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x2b,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,0x2c,0x00,0x00,0x00,
+0x2d,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
+0x10,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x20,0x00,0x04,0x00,0x30,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x10,0x00,0x00,0x00,0x16,0x00,0x03,0x00,0x4e,0x00,0x00,0x00,
+0x10,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,0x52,0x00,0x00,0x00,
+0x4e,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,0x53,0x00,0x00,0x00,
+0x52,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x54,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x53,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,
+0x54,0x00,0x00,0x00,0x55,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x20,0x00,0x04,0x00,0x58,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x4e,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,0x60,0x00,0x00,0x00,
+0x4e,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,0x61,0x00,0x00,0x00,
+0x60,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x62,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x61,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,
+0x62,0x00,0x00,0x00,0x63,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x73,0x00,0x00,0x00,
+0x00,0x02,0x00,0x00,0x2c,0x00,0x06,0x00,0x09,0x00,0x00,0x00,
+0x74,0x00,0x00,0x00,0x73,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x36,0x00,0x05,0x00,0x02,0x00,0x00,0x00,
+0x04,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x03,0x00,0x00,0x00,
+0xf8,0x00,0x02,0x00,0x05,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,
+0x75,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xfb,0x00,0x03,0x00,
+0x0c,0x00,0x00,0x00,0x76,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
+0x76,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x0d,0x00,0x00,0x00,
+0x0e,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,
+0x0e,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,0x10,0x00,0x00,0x00,
+0x11,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,0x84,0x00,0x05,0x00,
+0x10,0x00,0x00,0x00,0x13,0x00,0x00,0x00,0x11,0x00,0x00,0x00,
+0x12,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x14,0x00,0x00,0x00,0x13,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
+0x0d,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x18,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,
+0x10,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x18,0x00,0x00,0x00,
+0x7c,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,
+0x19,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x21,0x00,0x00,0x00,
+0x22,0x00,0x00,0x00,0x1f,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
+0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
+0x22,0x00,0x00,0x00,0xae,0x00,0x05,0x00,0x24,0x00,0x00,0x00,
+0x25,0x00,0x00,0x00,0x14,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
+0xf7,0x00,0x03,0x00,0x27,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0xfa,0x00,0x04,0x00,0x25,0x00,0x00,0x00,0x26,0x00,0x00,0x00,
+0x27,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x26,0x00,0x00,0x00,
+0xf9,0x00,0x02,0x00,0x75,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
+0x27,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x30,0x00,0x00,0x00,
+0x31,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
+0x1a,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x10,0x00,0x00,0x00,
+0x32,0x00,0x00,0x00,0x31,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,
+0x06,0x00,0x00,0x00,0x33,0x00,0x00,0x00,0x32,0x00,0x00,0x00,
+0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x38,0x00,0x00,0x00,
+0x33,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,0x38,0x00,0x00,0x00,
+0x14,0x00,0x00,0x00,0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x3f,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
+0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x41,0x00,0x00,0x00,
+0x3f,0x00,0x00,0x00,0x14,0x00,0x00,0x00,0x86,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x89,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x47,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x86,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x48,0x00,0x00,0x00,
+0x47,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x89,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x4c,0x00,0x00,0x00,0x41,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x82,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x4d,0x00,0x00,0x00,0x41,0x00,0x00,0x00,0x4c,0x00,0x00,0x00,
+0x41,0x00,0x06,0x00,0x58,0x00,0x00,0x00,0x59,0x00,0x00,0x00,
+0x55,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x44,0x00,0x00,0x00,
+0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,
+0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x5c,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x41,0x00,0x06,0x00,0x58,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,
+0x55,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x5c,0x00,0x00,0x00,
+0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,0x5e,0x00,0x00,0x00,
+0x5d,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x66,0x00,0x00,0x00,0x4d,0x00,0x00,0x00,0x48,0x00,0x00,0x00,
+0x41,0x00,0x06,0x00,0x58,0x00,0x00,0x00,0x6b,0x00,0x00,0x00,
+0x63,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x66,0x00,0x00,0x00,
+0x3e,0x00,0x03,0x00,0x6b,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,
+0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x6f,0x00,0x00,0x00,
+0x66,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x58,0x00,0x00,0x00,0x72,0x00,0x00,0x00,0x63,0x00,0x00,0x00,
+0x2e,0x00,0x00,0x00,0x6f,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
+0x72,0x00,0x00,0x00,0x5e,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
+0x75,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x75,0x00,0x00,0x00,
+0xfd,0x00,0x01,0x00,0x38,0x00,0x01,0x00,
+};
+const uint64_t get_rows_f16_len = 1892;
+
+unsigned char get_rows_f16_f32_data[] = {
+0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
+0x7a,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
+0x01,0x00,0x00,0x00,0x11,0x00,0x02,0x00,0x09,0x00,0x00,0x00,
+0x11,0x00,0x02,0x00,0x51,0x11,0x00,0x00,0x0b,0x00,0x06,0x00,
+0x01,0x00,0x00,0x00,0x47,0x4c,0x53,0x4c,0x2e,0x73,0x74,0x64,
+0x2e,0x34,0x35,0x30,0x00,0x00,0x00,0x00,0x0e,0x00,0x03,0x00,
+0x00,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x0f,0x00,0x0a,0x00,
+0x05,0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x6d,0x61,0x69,0x6e,
+0x00,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,0x1f,0x00,0x00,0x00,
+0x2d,0x00,0x00,0x00,0x55,0x00,0x00,0x00,0x63,0x00,0x00,0x00,
+0x10,0x00,0x06,0x00,0x04,0x00,0x00,0x00,0x11,0x00,0x00,0x00,
+0x00,0x02,0x00,0x00,0x01,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x0b,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
+0x1c,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0x1d,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x48,0x00,0x05,0x00,0x1d,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
+0x23,0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
+0x1d,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
+0x08,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0x1d,0x00,0x00,0x00,
+0x03,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x47,0x00,0x03,0x00,0x1d,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x2a,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
+0x04,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x2b,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x18,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
+0x2b,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x2b,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x2d,0x00,0x00,0x00,
+0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
+0x2d,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x52,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x53,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x18,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
+0x53,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x53,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x55,0x00,0x00,0x00,
+0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
+0x55,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x60,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
+0x04,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x61,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
+0x61,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x61,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x63,0x00,0x00,0x00,
+0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
+0x63,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
0x47,0x00,0x04,0x00,0x77,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
0x19,0x00,0x00,0x00,0x13,0x00,0x02,0x00,0x02,0x00,0x00,0x00,
0x21,0x00,0x03,0x00,0x03,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
@@ -15388,198 +15029,28 @@ unsigned char get_rows_f16_data[] = {
0x2d,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
0x10,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
0x20,0x00,0x04,0x00,0x30,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x10,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x44,0x00,0x00,0x00,0x20,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x49,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x16,0x00,0x03,0x00,0x50,0x00,0x00,0x00,0x10,0x00,0x00,0x00,
-0x1d,0x00,0x03,0x00,0x54,0x00,0x00,0x00,0x50,0x00,0x00,0x00,
-0x1e,0x00,0x03,0x00,0x55,0x00,0x00,0x00,0x54,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x56,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,0x56,0x00,0x00,0x00,
-0x57,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
-0x5a,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x50,0x00,0x00,0x00,
-0x1d,0x00,0x03,0x00,0x62,0x00,0x00,0x00,0x50,0x00,0x00,0x00,
-0x1e,0x00,0x03,0x00,0x63,0x00,0x00,0x00,0x62,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x64,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x63,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,0x64,0x00,0x00,0x00,
-0x65,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x71,0x00,0x00,0x00,0x10,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x76,0x00,0x00,0x00,
-0x00,0x02,0x00,0x00,0x2c,0x00,0x06,0x00,0x09,0x00,0x00,0x00,
-0x77,0x00,0x00,0x00,0x76,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
-0x16,0x00,0x00,0x00,0x36,0x00,0x05,0x00,0x02,0x00,0x00,0x00,
-0x04,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x03,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0x05,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,
-0x78,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xfb,0x00,0x03,0x00,
-0x0c,0x00,0x00,0x00,0x79,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x79,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x0d,0x00,0x00,0x00,
-0x0e,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,
-0x0e,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,0x10,0x00,0x00,0x00,
-0x11,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,0x84,0x00,0x05,0x00,
-0x10,0x00,0x00,0x00,0x13,0x00,0x00,0x00,0x11,0x00,0x00,0x00,
-0x12,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x14,0x00,0x00,0x00,0x13,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
-0x0d,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
-0x16,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x18,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,
-0x10,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x18,0x00,0x00,0x00,
-0x7c,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,
-0x19,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x21,0x00,0x00,0x00,
-0x22,0x00,0x00,0x00,0x1f,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0x22,0x00,0x00,0x00,0xae,0x00,0x05,0x00,0x24,0x00,0x00,0x00,
-0x25,0x00,0x00,0x00,0x14,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0xf7,0x00,0x03,0x00,0x27,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0xfa,0x00,0x04,0x00,0x25,0x00,0x00,0x00,0x26,0x00,0x00,0x00,
-0x27,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x26,0x00,0x00,0x00,
-0xf9,0x00,0x02,0x00,0x78,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x27,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x30,0x00,0x00,0x00,
-0x31,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x1a,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x10,0x00,0x00,0x00,
-0x32,0x00,0x00,0x00,0x31,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x33,0x00,0x00,0x00,0x32,0x00,0x00,0x00,
-0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x38,0x00,0x00,0x00,
-0x33,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,0x38,0x00,0x00,0x00,
-0x14,0x00,0x00,0x00,0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x3f,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x41,0x00,0x00,0x00,
-0x3f,0x00,0x00,0x00,0x14,0x00,0x00,0x00,0x86,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x45,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,
-0x44,0x00,0x00,0x00,0x89,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x48,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,0x44,0x00,0x00,0x00,
-0x86,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x4a,0x00,0x00,0x00,
-0x48,0x00,0x00,0x00,0x49,0x00,0x00,0x00,0x89,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x4e,0x00,0x00,0x00,0x41,0x00,0x00,0x00,
-0x44,0x00,0x00,0x00,0x82,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x4f,0x00,0x00,0x00,0x41,0x00,0x00,0x00,0x4e,0x00,0x00,0x00,
-0x41,0x00,0x06,0x00,0x5a,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,
-0x57,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x45,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x50,0x00,0x00,0x00,0x5c,0x00,0x00,0x00,
-0x5b,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x5e,0x00,0x00,0x00,0x45,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
-0x41,0x00,0x06,0x00,0x5a,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,
-0x57,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x5e,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x50,0x00,0x00,0x00,0x60,0x00,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x68,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,0x4a,0x00,0x00,0x00,
-0x41,0x00,0x06,0x00,0x5a,0x00,0x00,0x00,0x6d,0x00,0x00,0x00,
-0x65,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x68,0x00,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x6d,0x00,0x00,0x00,0x5c,0x00,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x72,0x00,0x00,0x00,
-0x68,0x00,0x00,0x00,0x71,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
-0x5a,0x00,0x00,0x00,0x75,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x72,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
-0x75,0x00,0x00,0x00,0x60,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
-0x78,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x78,0x00,0x00,0x00,
-0xfd,0x00,0x01,0x00,0x38,0x00,0x01,0x00,
-};
-const uint64_t get_rows_f16_len = 1940;
-
-unsigned char get_rows_f16_f32_data[] = {
-0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
-0x7d,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
-0x01,0x00,0x00,0x00,0x11,0x00,0x02,0x00,0x09,0x00,0x00,0x00,
-0x11,0x00,0x02,0x00,0x51,0x11,0x00,0x00,0x0b,0x00,0x06,0x00,
-0x01,0x00,0x00,0x00,0x47,0x4c,0x53,0x4c,0x2e,0x73,0x74,0x64,
-0x2e,0x34,0x35,0x30,0x00,0x00,0x00,0x00,0x0e,0x00,0x03,0x00,
-0x00,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x0f,0x00,0x0a,0x00,
-0x05,0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x6d,0x61,0x69,0x6e,
-0x00,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,0x1f,0x00,0x00,0x00,
-0x2d,0x00,0x00,0x00,0x57,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
-0x10,0x00,0x06,0x00,0x04,0x00,0x00,0x00,0x11,0x00,0x00,0x00,
-0x00,0x02,0x00,0x00,0x01,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x0b,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
-0x1c,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0x1d,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x48,0x00,0x05,0x00,0x1d,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
-0x23,0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
-0x1d,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0x08,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0x1d,0x00,0x00,0x00,
-0x03,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x47,0x00,0x03,0x00,0x1d,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x2a,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
-0x04,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x2b,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x18,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
-0x2b,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x2b,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x2d,0x00,0x00,0x00,
-0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x2d,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x54,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x55,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x18,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
-0x55,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x55,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x57,0x00,0x00,0x00,
-0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x57,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x62,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
-0x04,0x00,0x00,0x00,0x48,0x00,0x04,0x00,0x63,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
-0x63,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,0x63,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x65,0x00,0x00,0x00,
-0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x65,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x7a,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
-0x19,0x00,0x00,0x00,0x13,0x00,0x02,0x00,0x02,0x00,0x00,0x00,
-0x21,0x00,0x03,0x00,0x03,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x15,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x17,0x00,0x04,0x00,0x09,0x00,0x00,0x00,
-0x06,0x00,0x00,0x00,0x03,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
-0x0a,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x09,0x00,0x00,0x00,
-0x3b,0x00,0x04,0x00,0x0a,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
-0x01,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
-0x0d,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
-0x15,0x00,0x04,0x00,0x10,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
-0x01,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x10,0x00,0x00,0x00,
-0x12,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
-0x16,0x00,0x03,0x00,0x1c,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
-0x1e,0x00,0x06,0x00,0x1d,0x00,0x00,0x00,0x06,0x00,0x00,0x00,
-0x06,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x1e,0x00,0x00,0x00,0x09,0x00,0x00,0x00,
-0x1d,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,0x1e,0x00,0x00,0x00,
-0x1f,0x00,0x00,0x00,0x09,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x10,0x00,0x00,0x00,0x20,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x21,0x00,0x00,0x00,0x09,0x00,0x00,0x00,
-0x06,0x00,0x00,0x00,0x14,0x00,0x02,0x00,0x24,0x00,0x00,0x00,
-0x1d,0x00,0x03,0x00,0x2a,0x00,0x00,0x00,0x10,0x00,0x00,0x00,
-0x1e,0x00,0x03,0x00,0x2b,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x2c,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x2b,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,0x2c,0x00,0x00,0x00,
-0x2d,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x10,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x30,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x10,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x44,0x00,0x00,0x00,0x20,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x49,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x16,0x00,0x03,0x00,0x50,0x00,0x00,0x00,0x10,0x00,0x00,0x00,
-0x1d,0x00,0x03,0x00,0x54,0x00,0x00,0x00,0x50,0x00,0x00,0x00,
-0x1e,0x00,0x03,0x00,0x55,0x00,0x00,0x00,0x54,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x56,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x55,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,0x56,0x00,0x00,0x00,
-0x57,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
-0x5a,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x50,0x00,0x00,0x00,
-0x1d,0x00,0x03,0x00,0x62,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,
-0x1e,0x00,0x03,0x00,0x63,0x00,0x00,0x00,0x62,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x64,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x63,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,0x64,0x00,0x00,0x00,
-0x65,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
-0x6e,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x73,0x00,0x00,0x00,
-0x10,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x79,0x00,0x00,0x00,0x00,0x02,0x00,0x00,0x2c,0x00,0x06,0x00,
-0x09,0x00,0x00,0x00,0x7a,0x00,0x00,0x00,0x79,0x00,0x00,0x00,
+0x10,0x00,0x00,0x00,0x16,0x00,0x03,0x00,0x4e,0x00,0x00,0x00,
+0x10,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,0x52,0x00,0x00,0x00,
+0x4e,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,0x53,0x00,0x00,0x00,
+0x52,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x54,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x53,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,
+0x54,0x00,0x00,0x00,0x55,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x20,0x00,0x04,0x00,0x58,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x4e,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,0x60,0x00,0x00,0x00,
+0x1c,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,0x61,0x00,0x00,0x00,
+0x60,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x62,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x61,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,
+0x62,0x00,0x00,0x00,0x63,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x20,0x00,0x04,0x00,0x6c,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x1c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x76,0x00,0x00,0x00,0x00,0x02,0x00,0x00,0x2c,0x00,0x06,0x00,
+0x09,0x00,0x00,0x00,0x77,0x00,0x00,0x00,0x76,0x00,0x00,0x00,
0x16,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x36,0x00,0x05,0x00,
0x02,0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
0x03,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x05,0x00,0x00,0x00,
-0xf7,0x00,0x03,0x00,0x7b,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0xfb,0x00,0x03,0x00,0x0c,0x00,0x00,0x00,0x7c,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0x7c,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
+0xf7,0x00,0x03,0x00,0x78,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0xfb,0x00,0x03,0x00,0x0c,0x00,0x00,0x00,0x79,0x00,0x00,0x00,
+0xf8,0x00,0x02,0x00,0x79,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
0x0d,0x00,0x00,0x00,0x0e,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
0x0c,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
0x0f,0x00,0x00,0x00,0x0e,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,
@@ -15600,7 +15071,7 @@ unsigned char get_rows_f16_f32_data[] = {
0x23,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,0x27,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0xfa,0x00,0x04,0x00,0x25,0x00,0x00,0x00,
0x26,0x00,0x00,0x00,0x27,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x26,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x7b,0x00,0x00,0x00,
+0x26,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x78,0x00,0x00,0x00,
0xf8,0x00,0x02,0x00,0x27,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
0x30,0x00,0x00,0x00,0x31,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,
0x2e,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
@@ -15613,51 +15084,51 @@ unsigned char get_rows_f16_f32_data[] = {
0x06,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,
0x23,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
0x41,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,0x14,0x00,0x00,0x00,
-0x86,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x45,0x00,0x00,0x00,
-0x3a,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x89,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x48,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,
-0x44,0x00,0x00,0x00,0x86,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x4a,0x00,0x00,0x00,0x48,0x00,0x00,0x00,0x49,0x00,0x00,0x00,
-0x89,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x4e,0x00,0x00,0x00,
-0x41,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x82,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,0x41,0x00,0x00,0x00,
-0x4e,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x5a,0x00,0x00,0x00,
-0x5b,0x00,0x00,0x00,0x57,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x45,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x50,0x00,0x00,0x00,
-0x5c,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x5e,0x00,0x00,0x00,0x45,0x00,0x00,0x00,
-0x16,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x5a,0x00,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x57,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x5e,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x50,0x00,0x00,0x00,
-0x60,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x68,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,
-0x4a,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x1c,0x00,0x00,0x00,
-0x6d,0x00,0x00,0x00,0x5c,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
-0x6e,0x00,0x00,0x00,0x6f,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x68,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
-0x6f,0x00,0x00,0x00,0x6d,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x74,0x00,0x00,0x00,0x68,0x00,0x00,0x00,
-0x73,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x1c,0x00,0x00,0x00,
-0x77,0x00,0x00,0x00,0x60,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
-0x6e,0x00,0x00,0x00,0x78,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x74,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
-0x78,0x00,0x00,0x00,0x77,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
-0x7b,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x7b,0x00,0x00,0x00,
+0x86,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x44,0x00,0x00,0x00,
+0x3a,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x89,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x47,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x86,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x48,0x00,0x00,0x00,0x47,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x89,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x4c,0x00,0x00,0x00,
+0x41,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x82,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x4d,0x00,0x00,0x00,0x41,0x00,0x00,0x00,
+0x4c,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x58,0x00,0x00,0x00,
+0x59,0x00,0x00,0x00,0x55,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
+0x44,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
+0x5a,0x00,0x00,0x00,0x59,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x5c,0x00,0x00,0x00,0x44,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x58,0x00,0x00,0x00,
+0x5d,0x00,0x00,0x00,0x55,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
+0x5c,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4e,0x00,0x00,0x00,
+0x5e,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x66,0x00,0x00,0x00,0x4d,0x00,0x00,0x00,
+0x48,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x1c,0x00,0x00,0x00,
+0x6b,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x6c,0x00,0x00,0x00,0x6d,0x00,0x00,0x00,0x63,0x00,0x00,0x00,
+0x2e,0x00,0x00,0x00,0x66,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
+0x6d,0x00,0x00,0x00,0x6b,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x71,0x00,0x00,0x00,0x66,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x1c,0x00,0x00,0x00,
+0x74,0x00,0x00,0x00,0x5e,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x6c,0x00,0x00,0x00,0x75,0x00,0x00,0x00,0x63,0x00,0x00,0x00,
+0x2e,0x00,0x00,0x00,0x71,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
+0x75,0x00,0x00,0x00,0x74,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
+0x78,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x78,0x00,0x00,0x00,
0xfd,0x00,0x01,0x00,0x38,0x00,0x01,0x00,
};
-const uint64_t get_rows_f16_f32_len = 1988;
+const uint64_t get_rows_f16_f32_len = 1940;
unsigned char get_rows_f16_f32_fp32_data[] = {
0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
-0x7d,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
+0x7a,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
0x01,0x00,0x00,0x00,0x11,0x00,0x02,0x00,0x51,0x11,0x00,0x00,
0x0b,0x00,0x06,0x00,0x01,0x00,0x00,0x00,0x47,0x4c,0x53,0x4c,
0x2e,0x73,0x74,0x64,0x2e,0x34,0x35,0x30,0x00,0x00,0x00,0x00,
0x0e,0x00,0x03,0x00,0x00,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
0x0f,0x00,0x0a,0x00,0x05,0x00,0x00,0x00,0x04,0x00,0x00,0x00,
0x6d,0x61,0x69,0x6e,0x00,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
-0x1f,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x57,0x00,0x00,0x00,
-0x67,0x00,0x00,0x00,0x10,0x00,0x06,0x00,0x04,0x00,0x00,0x00,
+0x1f,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x55,0x00,0x00,0x00,
+0x65,0x00,0x00,0x00,0x10,0x00,0x06,0x00,0x04,0x00,0x00,0x00,
0x11,0x00,0x00,0x00,0x00,0x02,0x00,0x00,0x01,0x00,0x00,0x00,
0x01,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x0b,0x00,0x00,0x00,
0x0b,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
@@ -15676,23 +15147,23 @@ unsigned char get_rows_f16_f32_fp32_data[] = {
0x2b,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
0x2d,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
0x47,0x00,0x04,0x00,0x2d,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0x01,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x54,0x00,0x00,0x00,
+0x01,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x52,0x00,0x00,0x00,
0x06,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,
-0x55,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x18,0x00,0x00,0x00,
-0x48,0x00,0x05,0x00,0x55,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x53,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x18,0x00,0x00,0x00,
+0x48,0x00,0x05,0x00,0x53,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
0x23,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,
-0x55,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x57,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x57,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x64,0x00,0x00,0x00,
+0x53,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
+0x55,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x55,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x62,0x00,0x00,0x00,
0x06,0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x48,0x00,0x04,0x00,
-0x65,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
-0x48,0x00,0x05,0x00,0x65,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x63,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
+0x48,0x00,0x05,0x00,0x63,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
0x23,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,
-0x65,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x67,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x67,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x7a,0x00,0x00,0x00,
+0x63,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
+0x65,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x65,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x77,0x00,0x00,0x00,
0x0b,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x13,0x00,0x02,0x00,
0x02,0x00,0x00,0x00,0x21,0x00,0x03,0x00,0x03,0x00,0x00,0x00,
0x02,0x00,0x00,0x00,0x15,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
@@ -15723,32 +15194,28 @@ unsigned char get_rows_f16_f32_fp32_data[] = {
0x2c,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
0x2b,0x00,0x04,0x00,0x10,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x30,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x49,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x16,0x00,0x03,0x00,0x53,0x00,0x00,0x00,
-0x10,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,0x54,0x00,0x00,0x00,
-0x53,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,0x55,0x00,0x00,0x00,
-0x54,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x56,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x55,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,
-0x56,0x00,0x00,0x00,0x57,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x5a,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x53,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,0x64,0x00,0x00,0x00,
-0x1c,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,0x65,0x00,0x00,0x00,
-0x64,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x66,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x65,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,
-0x66,0x00,0x00,0x00,0x67,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x6f,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x1c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x74,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x79,0x00,0x00,0x00,0x00,0x02,0x00,0x00,
-0x2c,0x00,0x06,0x00,0x09,0x00,0x00,0x00,0x7a,0x00,0x00,0x00,
-0x79,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x16,0x00,0x03,0x00,
+0x51,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,
+0x52,0x00,0x00,0x00,0x51,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,
+0x53,0x00,0x00,0x00,0x52,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
+0x54,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x53,0x00,0x00,0x00,
+0x3b,0x00,0x04,0x00,0x54,0x00,0x00,0x00,0x55,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x58,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x51,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,
+0x62,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,
+0x63,0x00,0x00,0x00,0x62,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
+0x64,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x63,0x00,0x00,0x00,
+0x3b,0x00,0x04,0x00,0x64,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x6d,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
+0x06,0x00,0x00,0x00,0x76,0x00,0x00,0x00,0x00,0x02,0x00,0x00,
+0x2c,0x00,0x06,0x00,0x09,0x00,0x00,0x00,0x77,0x00,0x00,0x00,
+0x76,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
0x36,0x00,0x05,0x00,0x02,0x00,0x00,0x00,0x04,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x03,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x05,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,0x7b,0x00,0x00,0x00,
+0x05,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,0x78,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0xfb,0x00,0x03,0x00,0x0c,0x00,0x00,0x00,
-0x7c,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x7c,0x00,0x00,0x00,
+0x79,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x79,0x00,0x00,0x00,
0x41,0x00,0x05,0x00,0x0d,0x00,0x00,0x00,0x0e,0x00,0x00,0x00,
0x0b,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
0x06,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,0x0e,0x00,0x00,0x00,
@@ -15770,7 +15237,7 @@ unsigned char get_rows_f16_f32_fp32_data[] = {
0x27,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xfa,0x00,0x04,0x00,
0x25,0x00,0x00,0x00,0x26,0x00,0x00,0x00,0x27,0x00,0x00,0x00,
0xf8,0x00,0x02,0x00,0x26,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
-0x7b,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x27,0x00,0x00,0x00,
+0x78,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x27,0x00,0x00,0x00,
0x41,0x00,0x06,0x00,0x30,0x00,0x00,0x00,0x31,0x00,0x00,0x00,
0x2d,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,
0x3d,0x00,0x04,0x00,0x10,0x00,0x00,0x00,0x32,0x00,0x00,0x00,
@@ -15783,51 +15250,51 @@ unsigned char get_rows_f16_f32_fp32_data[] = {
0x1a,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
0x06,0x00,0x00,0x00,0x41,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,
0x14,0x00,0x00,0x00,0x86,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x45,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,0x44,0x00,0x00,0x00,
-0x89,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x48,0x00,0x00,0x00,
-0x3a,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x86,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x4a,0x00,0x00,0x00,0x48,0x00,0x00,0x00,
-0x49,0x00,0x00,0x00,0x89,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x4e,0x00,0x00,0x00,0x41,0x00,0x00,0x00,0x44,0x00,0x00,0x00,
-0x82,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,
-0x41,0x00,0x00,0x00,0x4e,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
-0x5a,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,0x57,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x45,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x53,0x00,0x00,0x00,0x5c,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x1c,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,
-0x5c,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x5f,0x00,0x00,0x00,0x45,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
-0x41,0x00,0x06,0x00,0x5a,0x00,0x00,0x00,0x60,0x00,0x00,0x00,
-0x57,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x53,0x00,0x00,0x00,0x61,0x00,0x00,0x00,
-0x60,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x1c,0x00,0x00,0x00,
-0x62,0x00,0x00,0x00,0x61,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x6a,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,
-0x4a,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x6f,0x00,0x00,0x00,
-0x70,0x00,0x00,0x00,0x67,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x6a,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0x70,0x00,0x00,0x00,
-0x5d,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x75,0x00,0x00,0x00,0x6a,0x00,0x00,0x00,0x74,0x00,0x00,0x00,
-0x41,0x00,0x06,0x00,0x6f,0x00,0x00,0x00,0x78,0x00,0x00,0x00,
-0x67,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x75,0x00,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x78,0x00,0x00,0x00,0x62,0x00,0x00,0x00,
-0xf9,0x00,0x02,0x00,0x7b,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x7b,0x00,0x00,0x00,0xfd,0x00,0x01,0x00,0x38,0x00,0x01,0x00,
+0x44,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x89,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x47,0x00,0x00,0x00,
+0x3a,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x86,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x48,0x00,0x00,0x00,0x47,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x89,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x4c,0x00,0x00,0x00,0x41,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x82,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x4d,0x00,0x00,0x00,
+0x41,0x00,0x00,0x00,0x4c,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x58,0x00,0x00,0x00,0x59,0x00,0x00,0x00,0x55,0x00,0x00,0x00,
+0x2e,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
+0x51,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,0x59,0x00,0x00,0x00,
+0x73,0x00,0x04,0x00,0x1c,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,
+0x5a,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x5d,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x41,0x00,0x06,0x00,0x58,0x00,0x00,0x00,0x5e,0x00,0x00,0x00,
+0x55,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,
+0x3d,0x00,0x04,0x00,0x51,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,
+0x5e,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x1c,0x00,0x00,0x00,
+0x60,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x68,0x00,0x00,0x00,0x4d,0x00,0x00,0x00,
+0x48,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x6d,0x00,0x00,0x00,
+0x6e,0x00,0x00,0x00,0x65,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
+0x68,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0x6e,0x00,0x00,0x00,
+0x5b,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x72,0x00,0x00,0x00,0x68,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x41,0x00,0x06,0x00,0x6d,0x00,0x00,0x00,0x75,0x00,0x00,0x00,
+0x65,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x72,0x00,0x00,0x00,
+0x3e,0x00,0x03,0x00,0x75,0x00,0x00,0x00,0x60,0x00,0x00,0x00,
+0xf9,0x00,0x02,0x00,0x78,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
+0x78,0x00,0x00,0x00,0xfd,0x00,0x01,0x00,0x38,0x00,0x01,0x00,
};
-const uint64_t get_rows_f16_f32_fp32_len = 1980;
+const uint64_t get_rows_f16_f32_fp32_len = 1932;
unsigned char get_rows_f16_fp32_data[] = {
0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
-0x7e,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
+0x7b,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
0x01,0x00,0x00,0x00,0x11,0x00,0x02,0x00,0x51,0x11,0x00,0x00,
0x0b,0x00,0x06,0x00,0x01,0x00,0x00,0x00,0x47,0x4c,0x53,0x4c,
0x2e,0x73,0x74,0x64,0x2e,0x34,0x35,0x30,0x00,0x00,0x00,0x00,
0x0e,0x00,0x03,0x00,0x00,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
0x0f,0x00,0x0a,0x00,0x05,0x00,0x00,0x00,0x04,0x00,0x00,0x00,
0x6d,0x61,0x69,0x6e,0x00,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
-0x1f,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x57,0x00,0x00,0x00,
-0x67,0x00,0x00,0x00,0x10,0x00,0x06,0x00,0x04,0x00,0x00,0x00,
+0x1f,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x55,0x00,0x00,0x00,
+0x65,0x00,0x00,0x00,0x10,0x00,0x06,0x00,0x04,0x00,0x00,0x00,
0x11,0x00,0x00,0x00,0x00,0x02,0x00,0x00,0x01,0x00,0x00,0x00,
0x01,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x0b,0x00,0x00,0x00,
0x0b,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
@@ -15846,23 +15313,23 @@ unsigned char get_rows_f16_fp32_data[] = {
0x2b,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
0x2d,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
0x47,0x00,0x04,0x00,0x2d,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0x01,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x54,0x00,0x00,0x00,
+0x01,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x52,0x00,0x00,0x00,
0x06,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,
-0x55,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x18,0x00,0x00,0x00,
-0x48,0x00,0x05,0x00,0x55,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x53,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x18,0x00,0x00,0x00,
+0x48,0x00,0x05,0x00,0x53,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
0x23,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,
-0x55,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x57,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x57,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x64,0x00,0x00,0x00,
+0x53,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
+0x55,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x55,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x62,0x00,0x00,0x00,
0x06,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x48,0x00,0x04,0x00,
-0x65,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
-0x48,0x00,0x05,0x00,0x65,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x63,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
+0x48,0x00,0x05,0x00,0x63,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
0x23,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,
-0x65,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x67,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x67,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x7b,0x00,0x00,0x00,
+0x63,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
+0x65,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x65,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
+0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x78,0x00,0x00,0x00,
0x0b,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x13,0x00,0x02,0x00,
0x02,0x00,0x00,0x00,0x21,0x00,0x03,0x00,0x03,0x00,0x00,0x00,
0x02,0x00,0x00,0x00,0x15,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
@@ -15893,31 +15360,27 @@ unsigned char get_rows_f16_fp32_data[] = {
0x2c,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
0x2b,0x00,0x04,0x00,0x10,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x30,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x20,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x49,0x00,0x00,0x00,
-0x02,0x00,0x00,0x00,0x16,0x00,0x03,0x00,0x53,0x00,0x00,0x00,
-0x10,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,0x54,0x00,0x00,0x00,
-0x53,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,0x55,0x00,0x00,0x00,
-0x54,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x56,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x55,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,
-0x56,0x00,0x00,0x00,0x57,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x5a,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x53,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,0x64,0x00,0x00,0x00,
-0x53,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,0x65,0x00,0x00,0x00,
-0x64,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x66,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x65,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,
-0x66,0x00,0x00,0x00,0x67,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x74,0x00,0x00,0x00,
-0x10,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x7a,0x00,0x00,0x00,0x00,0x02,0x00,0x00,0x2c,0x00,0x06,0x00,
-0x09,0x00,0x00,0x00,0x7b,0x00,0x00,0x00,0x7a,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x16,0x00,0x03,0x00,
+0x51,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,
+0x52,0x00,0x00,0x00,0x51,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,
+0x53,0x00,0x00,0x00,0x52,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
+0x54,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x53,0x00,0x00,0x00,
+0x3b,0x00,0x04,0x00,0x54,0x00,0x00,0x00,0x55,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x58,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x51,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,
+0x62,0x00,0x00,0x00,0x51,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,
+0x63,0x00,0x00,0x00,0x62,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
+0x64,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x63,0x00,0x00,0x00,
+0x3b,0x00,0x04,0x00,0x64,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x77,0x00,0x00,0x00,0x00,0x02,0x00,0x00,0x2c,0x00,0x06,0x00,
+0x09,0x00,0x00,0x00,0x78,0x00,0x00,0x00,0x77,0x00,0x00,0x00,
0x16,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x36,0x00,0x05,0x00,
0x02,0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
0x03,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x05,0x00,0x00,0x00,
-0xf7,0x00,0x03,0x00,0x7c,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0xfb,0x00,0x03,0x00,0x0c,0x00,0x00,0x00,0x7d,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0x7d,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
+0xf7,0x00,0x03,0x00,0x79,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0xfb,0x00,0x03,0x00,0x0c,0x00,0x00,0x00,0x7a,0x00,0x00,0x00,
+0xf8,0x00,0x02,0x00,0x7a,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
0x0d,0x00,0x00,0x00,0x0e,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
0x0c,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
0x0f,0x00,0x00,0x00,0x0e,0x00,0x00,0x00,0x7c,0x00,0x04,0x00,
@@ -15938,7 +15401,7 @@ unsigned char get_rows_f16_fp32_data[] = {
0x23,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,0x27,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0xfa,0x00,0x04,0x00,0x25,0x00,0x00,0x00,
0x26,0x00,0x00,0x00,0x27,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x26,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x7c,0x00,0x00,0x00,
+0x26,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x79,0x00,0x00,0x00,
0xf8,0x00,0x02,0x00,0x27,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
0x30,0x00,0x00,0x00,0x31,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,
0x2e,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
@@ -15951,42 +15414,42 @@ unsigned char get_rows_f16_fp32_data[] = {
0x06,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,0x1a,0x00,0x00,0x00,
0x23,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
0x41,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,0x14,0x00,0x00,0x00,
-0x86,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x45,0x00,0x00,0x00,
-0x3a,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x89,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x48,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,
-0x44,0x00,0x00,0x00,0x86,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x4a,0x00,0x00,0x00,0x48,0x00,0x00,0x00,0x49,0x00,0x00,0x00,
-0x89,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x4e,0x00,0x00,0x00,
-0x41,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x82,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,0x41,0x00,0x00,0x00,
-0x4e,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x5a,0x00,0x00,0x00,
-0x5b,0x00,0x00,0x00,0x57,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x45,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x53,0x00,0x00,0x00,
-0x5c,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x1c,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,0x5c,0x00,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,
-0x45,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
-0x5a,0x00,0x00,0x00,0x60,0x00,0x00,0x00,0x57,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x53,0x00,0x00,0x00,0x61,0x00,0x00,0x00,0x60,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x1c,0x00,0x00,0x00,0x62,0x00,0x00,0x00,
-0x61,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x6a,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,0x4a,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x53,0x00,0x00,0x00,0x6f,0x00,0x00,0x00,
-0x5d,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x5a,0x00,0x00,0x00,
-0x70,0x00,0x00,0x00,0x67,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x6a,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0x70,0x00,0x00,0x00,
-0x6f,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x75,0x00,0x00,0x00,0x6a,0x00,0x00,0x00,0x74,0x00,0x00,0x00,
-0x73,0x00,0x04,0x00,0x53,0x00,0x00,0x00,0x78,0x00,0x00,0x00,
-0x62,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x5a,0x00,0x00,0x00,
-0x79,0x00,0x00,0x00,0x67,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x75,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0x79,0x00,0x00,0x00,
-0x78,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x7c,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0x7c,0x00,0x00,0x00,0xfd,0x00,0x01,0x00,
+0x86,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x44,0x00,0x00,0x00,
+0x3a,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x89,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x47,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x86,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x48,0x00,0x00,0x00,0x47,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x89,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x4c,0x00,0x00,0x00,
+0x41,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x82,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x4d,0x00,0x00,0x00,0x41,0x00,0x00,0x00,
+0x4c,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x58,0x00,0x00,0x00,
+0x59,0x00,0x00,0x00,0x55,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
+0x44,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x51,0x00,0x00,0x00,
+0x5a,0x00,0x00,0x00,0x59,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
+0x1c,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,
+0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,
+0x44,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x58,0x00,0x00,0x00,0x5e,0x00,0x00,0x00,0x55,0x00,0x00,0x00,
+0x2e,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
+0x51,0x00,0x00,0x00,0x5f,0x00,0x00,0x00,0x5e,0x00,0x00,0x00,
+0x73,0x00,0x04,0x00,0x1c,0x00,0x00,0x00,0x60,0x00,0x00,0x00,
+0x5f,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x68,0x00,0x00,0x00,0x4d,0x00,0x00,0x00,0x48,0x00,0x00,0x00,
+0x73,0x00,0x04,0x00,0x51,0x00,0x00,0x00,0x6d,0x00,0x00,0x00,
+0x5b,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x58,0x00,0x00,0x00,
+0x6e,0x00,0x00,0x00,0x65,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
+0x68,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0x6e,0x00,0x00,0x00,
+0x6d,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x72,0x00,0x00,0x00,0x68,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
+0x73,0x00,0x04,0x00,0x51,0x00,0x00,0x00,0x75,0x00,0x00,0x00,
+0x60,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x58,0x00,0x00,0x00,
+0x76,0x00,0x00,0x00,0x65,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
+0x72,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0x76,0x00,0x00,0x00,
+0x75,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x79,0x00,0x00,0x00,
+0xf8,0x00,0x02,0x00,0x79,0x00,0x00,0x00,0xfd,0x00,0x01,0x00,
0x38,0x00,0x01,0x00,
};
-const uint64_t get_rows_f16_fp32_len = 1996;
+const uint64_t get_rows_f16_fp32_len = 1948;
unsigned char get_rows_q4_0_data[] = {
0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
@@ -52701,7 +52164,7 @@ const uint64_t mul_f32_len = 1456;
unsigned char mul_mat_vec_f16_f32_data[] = {
0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
-0xba,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
+0xb6,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
0x01,0x00,0x00,0x00,0x11,0x00,0x02,0x00,0x51,0x11,0x00,0x00,
0x0b,0x00,0x06,0x00,0x01,0x00,0x00,0x00,0x47,0x4c,0x53,0x4c,
0x2e,0x73,0x74,0x64,0x2e,0x34,0x35,0x30,0x00,0x00,0x00,0x00,
@@ -52709,9 +52172,9 @@ unsigned char mul_mat_vec_f16_f32_data[] = {
0x0f,0x00,0x0c,0x00,0x05,0x00,0x00,0x00,0x04,0x00,0x00,0x00,
0x6d,0x61,0x69,0x6e,0x00,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
0x13,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,
-0x51,0x00,0x00,0x00,0x66,0x00,0x00,0x00,0xad,0x00,0x00,0x00,
+0x51,0x00,0x00,0x00,0x65,0x00,0x00,0x00,0xaa,0x00,0x00,0x00,
0x10,0x00,0x06,0x00,0x04,0x00,0x00,0x00,0x11,0x00,0x00,0x00,
-0x20,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
+0x01,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
0x47,0x00,0x04,0x00,0x0c,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
0x1a,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x13,0x00,0x00,0x00,
0x0b,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
@@ -52729,23 +52192,23 @@ unsigned char mul_mat_vec_f16_f32_data[] = {
0x47,0x00,0x04,0x00,0x51,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x51,0x00,0x00,0x00,
0x21,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x63,0x00,0x00,0x00,0x06,0x00,0x00,0x00,0x04,0x00,0x00,0x00,
-0x48,0x00,0x04,0x00,0x64,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x18,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0x64,0x00,0x00,0x00,
+0x62,0x00,0x00,0x00,0x06,0x00,0x00,0x00,0x04,0x00,0x00,0x00,
+0x48,0x00,0x04,0x00,0x63,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x18,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0x63,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x47,0x00,0x03,0x00,0x64,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x66,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x66,0x00,0x00,0x00,
+0x47,0x00,0x03,0x00,0x63,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x65,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x65,0x00,0x00,0x00,
0x21,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0xaa,0x00,0x00,0x00,0x06,0x00,0x00,0x00,0x04,0x00,0x00,0x00,
-0x48,0x00,0x04,0x00,0xab,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x19,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0xab,0x00,0x00,0x00,
+0xa7,0x00,0x00,0x00,0x06,0x00,0x00,0x00,0x04,0x00,0x00,0x00,
+0x48,0x00,0x04,0x00,0xa8,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x19,0x00,0x00,0x00,0x48,0x00,0x05,0x00,0xa8,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x47,0x00,0x03,0x00,0xab,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0xad,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0xad,0x00,0x00,0x00,
+0x47,0x00,0x03,0x00,0xa8,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0xaa,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0xaa,0x00,0x00,0x00,
0x21,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0xb5,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
+0xb2,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
0x13,0x00,0x02,0x00,0x02,0x00,0x00,0x00,0x21,0x00,0x03,0x00,
0x03,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x15,0x00,0x04,0x00,
0x06,0x00,0x00,0x00,0x20,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
@@ -52760,7 +52223,7 @@ unsigned char mul_mat_vec_f16_f32_data[] = {
0x3b,0x00,0x04,0x00,0x0b,0x00,0x00,0x00,0x13,0x00,0x00,0x00,
0x01,0x00,0x00,0x00,0x16,0x00,0x03,0x00,0x17,0x00,0x00,0x00,
0x20,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x09,0x00,0x00,0x00,
-0x18,0x00,0x00,0x00,0x20,0x00,0x00,0x00,0x1c,0x00,0x04,0x00,
+0x18,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x1c,0x00,0x04,0x00,
0x19,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x18,0x00,0x00,0x00,
0x20,0x00,0x04,0x00,0x1a,0x00,0x00,0x00,0x04,0x00,0x00,0x00,
0x19,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,0x1a,0x00,0x00,0x00,
@@ -52775,7 +52238,7 @@ unsigned char mul_mat_vec_f16_f32_data[] = {
0x29,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,0x09,0x00,0x00,0x00,
0x20,0x00,0x04,0x00,0x2b,0x00,0x00,0x00,0x09,0x00,0x00,0x00,
0x06,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x20,0x00,0x00,0x00,0x14,0x00,0x02,0x00,
+0x2e,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x14,0x00,0x02,0x00,
0x30,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
0x35,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x16,0x00,0x03,0x00,
0x4d,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,
@@ -52784,26 +52247,22 @@ unsigned char mul_mat_vec_f16_f32_data[] = {
0x50,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x4f,0x00,0x00,0x00,
0x3b,0x00,0x04,0x00,0x50,0x00,0x00,0x00,0x51,0x00,0x00,0x00,
0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x54,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x4d,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x59,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
-0x1d,0x00,0x03,0x00,0x63,0x00,0x00,0x00,0x17,0x00,0x00,0x00,
-0x1e,0x00,0x03,0x00,0x64,0x00,0x00,0x00,0x63,0x00,0x00,0x00,
-0x20,0x00,0x04,0x00,0x65,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x64,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,0x65,0x00,0x00,0x00,
-0x66,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
-0x6e,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x17,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x09,0x00,0x00,0x00,0x77,0x00,0x00,0x00,
-0x01,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0x80,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
-0x09,0x00,0x00,0x00,0x8b,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x09,0x00,0x00,0x00,0x8c,0x00,0x00,0x00,
-0x08,0x01,0x00,0x00,0x1d,0x00,0x03,0x00,0xaa,0x00,0x00,0x00,
-0x17,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,0xab,0x00,0x00,0x00,
-0xaa,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0xac,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0xab,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,
-0xac,0x00,0x00,0x00,0xad,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x2c,0x00,0x06,0x00,0x0a,0x00,0x00,0x00,0xb5,0x00,0x00,0x00,
-0x18,0x00,0x00,0x00,0x77,0x00,0x00,0x00,0x77,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x4d,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,
+0x62,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,
+0x63,0x00,0x00,0x00,0x62,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
+0x64,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x63,0x00,0x00,0x00,
+0x3b,0x00,0x04,0x00,0x64,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x6d,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x17,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
+0x09,0x00,0x00,0x00,0x88,0x00,0x00,0x00,0x02,0x00,0x00,0x00,
+0x2b,0x00,0x04,0x00,0x09,0x00,0x00,0x00,0x89,0x00,0x00,0x00,
+0x08,0x01,0x00,0x00,0x1d,0x00,0x03,0x00,0xa7,0x00,0x00,0x00,
+0x17,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,0xa8,0x00,0x00,0x00,
+0xa7,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0xa9,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0xa8,0x00,0x00,0x00,0x3b,0x00,0x04,0x00,
+0xa9,0x00,0x00,0x00,0xaa,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
+0x2c,0x00,0x06,0x00,0x0a,0x00,0x00,0x00,0xb2,0x00,0x00,0x00,
+0x18,0x00,0x00,0x00,0x18,0x00,0x00,0x00,0x18,0x00,0x00,0x00,
0x36,0x00,0x05,0x00,0x02,0x00,0x00,0x00,0x04,0x00,0x00,0x00,
0x00,0x00,0x00,0x00,0x03,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
0x05,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x0e,0x00,0x00,0x00,
@@ -52819,122 +52278,91 @@ unsigned char mul_mat_vec_f16_f32_data[] = {
0x1b,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
0x1f,0x00,0x00,0x00,0x1d,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
0x22,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x22,0x00,0x00,0x00,
-0xf5,0x00,0x07,0x00,0x06,0x00,0x00,0x00,0xb8,0x00,0x00,0x00,
-0x21,0x00,0x00,0x00,0x05,0x00,0x00,0x00,0x8a,0x00,0x00,0x00,
+0xf5,0x00,0x07,0x00,0x06,0x00,0x00,0x00,0xb5,0x00,0x00,0x00,
+0x21,0x00,0x00,0x00,0x05,0x00,0x00,0x00,0x87,0x00,0x00,0x00,
0x23,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x2b,0x00,0x00,0x00,
0x2c,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,
0x2c,0x00,0x00,0x00,0x87,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
0x2f,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
0xb1,0x00,0x05,0x00,0x30,0x00,0x00,0x00,0x31,0x00,0x00,0x00,
-0xb8,0x00,0x00,0x00,0x2f,0x00,0x00,0x00,0xf6,0x00,0x04,0x00,
+0xb5,0x00,0x00,0x00,0x2f,0x00,0x00,0x00,0xf6,0x00,0x04,0x00,
0x24,0x00,0x00,0x00,0x23,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
0xfa,0x00,0x04,0x00,0x31,0x00,0x00,0x00,0x23,0x00,0x00,0x00,
0x24,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x23,0x00,0x00,0x00,
-0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x34,0x00,0x00,0x00,
-0xb8,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x84,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x37,0x00,0x00,0x00,0x35,0x00,0x00,0x00,
-0x16,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x38,0x00,0x00,0x00,0x34,0x00,0x00,0x00,0x37,0x00,0x00,0x00,
-0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x3d,0x00,0x00,0x00,
-0x11,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,0x3d,0x00,0x00,0x00,
-0x38,0x00,0x00,0x00,0x87,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x40,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
-0x8b,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x43,0x00,0x00,0x00,
-0x38,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x87,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x43,0x00,0x00,0x00,
-0x35,0x00,0x00,0x00,0x82,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x49,0x00,0x00,0x00,0x38,0x00,0x00,0x00,0x43,0x00,0x00,0x00,
-0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,0x55,0x00,0x00,0x00,
-0x51,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0x40,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x4d,0x00,0x00,0x00,0x56,0x00,0x00,0x00,
-0x55,0x00,0x00,0x00,0x73,0x00,0x04,0x00,0x17,0x00,0x00,0x00,
-0x57,0x00,0x00,0x00,0x56,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,0x40,0x00,0x00,0x00,
-0x59,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
-0x5b,0x00,0x00,0x00,0x51,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0x5a,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4d,0x00,0x00,0x00,
-0x5c,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
-0x17,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,0x5c,0x00,0x00,0x00,
-0x41,0x00,0x05,0x00,0x2b,0x00,0x00,0x00,0x67,0x00,0x00,0x00,
-0x2a,0x00,0x00,0x00,0x59,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x68,0x00,0x00,0x00,0x67,0x00,0x00,0x00,
-0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x6a,0x00,0x00,0x00,
-0x68,0x00,0x00,0x00,0x49,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x6c,0x00,0x00,0x00,0x6a,0x00,0x00,0x00,
-0x44,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x6e,0x00,0x00,0x00,
-0x6f,0x00,0x00,0x00,0x66,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0x6c,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x17,0x00,0x00,0x00,
-0x70,0x00,0x00,0x00,0x6f,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x17,0x00,0x00,0x00,0x73,0x00,0x00,0x00,0x1f,0x00,0x00,0x00,
-0x0c,0x00,0x08,0x00,0x17,0x00,0x00,0x00,0x74,0x00,0x00,0x00,
-0x01,0x00,0x00,0x00,0x32,0x00,0x00,0x00,0x57,0x00,0x00,0x00,
-0x70,0x00,0x00,0x00,0x73,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
-0x1f,0x00,0x00,0x00,0x74,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x81,0x00,0x00,0x00,0x6c,0x00,0x00,0x00,
-0x80,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x6e,0x00,0x00,0x00,
-0x82,0x00,0x00,0x00,0x66,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0x81,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x17,0x00,0x00,0x00,
-0x83,0x00,0x00,0x00,0x82,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x17,0x00,0x00,0x00,0x86,0x00,0x00,0x00,0x1f,0x00,0x00,0x00,
-0x0c,0x00,0x08,0x00,0x17,0x00,0x00,0x00,0x87,0x00,0x00,0x00,
-0x01,0x00,0x00,0x00,0x32,0x00,0x00,0x00,0x5d,0x00,0x00,0x00,
-0x83,0x00,0x00,0x00,0x86,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
-0x1f,0x00,0x00,0x00,0x87,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0x8a,0x00,0x00,0x00,0xb8,0x00,0x00,0x00,
-0x35,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x22,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0x24,0x00,0x00,0x00,0xe0,0x00,0x04,0x00,
-0x8b,0x00,0x00,0x00,0x8b,0x00,0x00,0x00,0x8c,0x00,0x00,0x00,
-0xf9,0x00,0x02,0x00,0x8e,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x8e,0x00,0x00,0x00,0xf5,0x00,0x07,0x00,0x06,0x00,0x00,0x00,
-0xb9,0x00,0x00,0x00,0x80,0x00,0x00,0x00,0x24,0x00,0x00,0x00,
-0xa5,0x00,0x00,0x00,0x91,0x00,0x00,0x00,0xad,0x00,0x05,0x00,
-0x30,0x00,0x00,0x00,0x94,0x00,0x00,0x00,0xb9,0x00,0x00,0x00,
-0x21,0x00,0x00,0x00,0xf6,0x00,0x04,0x00,0x90,0x00,0x00,0x00,
-0x91,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0xfa,0x00,0x04,0x00,
-0x94,0x00,0x00,0x00,0x8f,0x00,0x00,0x00,0x90,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0x8f,0x00,0x00,0x00,0xb1,0x00,0x05,0x00,
-0x30,0x00,0x00,0x00,0x97,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
-0xb9,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,0x99,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0xfa,0x00,0x04,0x00,0x97,0x00,0x00,0x00,
-0x98,0x00,0x00,0x00,0x99,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x98,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
-0x9d,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0xb9,0x00,0x00,0x00,
-0x41,0x00,0x05,0x00,0x1e,0x00,0x00,0x00,0x9e,0x00,0x00,0x00,
-0x1b,0x00,0x00,0x00,0x9d,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
-0x17,0x00,0x00,0x00,0x9f,0x00,0x00,0x00,0x9e,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x17,0x00,0x00,0x00,0xa1,0x00,0x00,0x00,
-0x1f,0x00,0x00,0x00,0x81,0x00,0x05,0x00,0x17,0x00,0x00,0x00,
-0xa2,0x00,0x00,0x00,0xa1,0x00,0x00,0x00,0x9f,0x00,0x00,0x00,
-0x3e,0x00,0x03,0x00,0x1f,0x00,0x00,0x00,0xa2,0x00,0x00,0x00,
-0xf9,0x00,0x02,0x00,0x99,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x99,0x00,0x00,0x00,0xe0,0x00,0x04,0x00,0x8b,0x00,0x00,0x00,
-0x8b,0x00,0x00,0x00,0x8c,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
-0x91,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x91,0x00,0x00,0x00,
-0xc3,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0xa5,0x00,0x00,0x00,
-0xb9,0x00,0x00,0x00,0x59,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
-0x8e,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x90,0x00,0x00,0x00,
-0xaa,0x00,0x05,0x00,0x30,0x00,0x00,0x00,0xa7,0x00,0x00,0x00,
-0x16,0x00,0x00,0x00,0x21,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,
-0xa9,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xfa,0x00,0x04,0x00,
-0xa7,0x00,0x00,0x00,0xa8,0x00,0x00,0x00,0xa9,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0xa8,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
-0x2b,0x00,0x00,0x00,0xae,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,
-0x35,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
-0xaf,0x00,0x00,0x00,0xae,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
-0x06,0x00,0x00,0x00,0xb1,0x00,0x00,0x00,0xaf,0x00,0x00,0x00,
-0x11,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x1e,0x00,0x00,0x00,
-0xb2,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x17,0x00,0x00,0x00,0xb3,0x00,0x00,0x00,
-0xb2,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x6e,0x00,0x00,0x00,
-0xb4,0x00,0x00,0x00,0xad,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0xb1,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0xb4,0x00,0x00,0x00,
-0xb3,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0xa9,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0xa9,0x00,0x00,0x00,0xfd,0x00,0x01,0x00,
-0x38,0x00,0x01,0x00,
+0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x37,0x00,0x00,0x00,
+0x35,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x80,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x38,0x00,0x00,0x00,0xb5,0x00,0x00,0x00,
+0x37,0x00,0x00,0x00,0x84,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x3d,0x00,0x00,0x00,0x11,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,
+0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,
+0x3d,0x00,0x00,0x00,0x38,0x00,0x00,0x00,0x87,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x40,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,
+0x2e,0x00,0x00,0x00,0x8b,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x43,0x00,0x00,0x00,0x38,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
+0x87,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x44,0x00,0x00,0x00,
+0x43,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x82,0x00,0x05,0x00,
+0x06,0x00,0x00,0x00,0x49,0x00,0x00,0x00,0x38,0x00,0x00,0x00,
+0x43,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x54,0x00,0x00,0x00,
+0x55,0x00,0x00,0x00,0x51,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
+0x40,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x4d,0x00,0x00,0x00,
+0x56,0x00,0x00,0x00,0x55,0x00,0x00,0x00,0x73,0x00,0x04,0x00,
+0x17,0x00,0x00,0x00,0x57,0x00,0x00,0x00,0x56,0x00,0x00,0x00,
+0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x59,0x00,0x00,0x00,
+0x40,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x54,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,0x51,0x00,0x00,0x00,
+0x21,0x00,0x00,0x00,0x59,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
+0x4d,0x00,0x00,0x00,0x5b,0x00,0x00,0x00,0x5a,0x00,0x00,0x00,
+0x73,0x00,0x04,0x00,0x17,0x00,0x00,0x00,0x5c,0x00,0x00,0x00,
+0x5b,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x2b,0x00,0x00,0x00,
+0x66,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
+0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x67,0x00,0x00,0x00,
+0x66,0x00,0x00,0x00,0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,
+0x69,0x00,0x00,0x00,0x67,0x00,0x00,0x00,0x49,0x00,0x00,0x00,
+0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x6b,0x00,0x00,0x00,
+0x69,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x6d,0x00,0x00,0x00,0x6e,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
+0x21,0x00,0x00,0x00,0x6b,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
+0x17,0x00,0x00,0x00,0x6f,0x00,0x00,0x00,0x6e,0x00,0x00,0x00,
+0x3d,0x00,0x04,0x00,0x17,0x00,0x00,0x00,0x72,0x00,0x00,0x00,
+0x1f,0x00,0x00,0x00,0x0c,0x00,0x08,0x00,0x17,0x00,0x00,0x00,
+0x73,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x32,0x00,0x00,0x00,
+0x57,0x00,0x00,0x00,0x6f,0x00,0x00,0x00,0x72,0x00,0x00,0x00,
+0x3e,0x00,0x03,0x00,0x1f,0x00,0x00,0x00,0x73,0x00,0x00,0x00,
+0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x7e,0x00,0x00,0x00,
+0x6b,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x6d,0x00,0x00,0x00,0x7f,0x00,0x00,0x00,0x65,0x00,0x00,0x00,
+0x21,0x00,0x00,0x00,0x7e,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
+0x17,0x00,0x00,0x00,0x80,0x00,0x00,0x00,0x7f,0x00,0x00,0x00,
+0x3d,0x00,0x04,0x00,0x17,0x00,0x00,0x00,0x83,0x00,0x00,0x00,
+0x1f,0x00,0x00,0x00,0x0c,0x00,0x08,0x00,0x17,0x00,0x00,0x00,
+0x84,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x32,0x00,0x00,0x00,
+0x5c,0x00,0x00,0x00,0x80,0x00,0x00,0x00,0x83,0x00,0x00,0x00,
+0x3e,0x00,0x03,0x00,0x1f,0x00,0x00,0x00,0x84,0x00,0x00,0x00,
+0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0x87,0x00,0x00,0x00,
+0xb5,0x00,0x00,0x00,0x35,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
+0x22,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x24,0x00,0x00,0x00,
+0xe0,0x00,0x04,0x00,0x88,0x00,0x00,0x00,0x88,0x00,0x00,0x00,
+0x89,0x00,0x00,0x00,0xaa,0x00,0x05,0x00,0x30,0x00,0x00,0x00,
+0xa4,0x00,0x00,0x00,0x16,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
+0xf7,0x00,0x03,0x00,0xa6,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0xfa,0x00,0x04,0x00,0xa4,0x00,0x00,0x00,0xa5,0x00,0x00,0x00,
+0xa6,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0xa5,0x00,0x00,0x00,
+0x41,0x00,0x05,0x00,0x2b,0x00,0x00,0x00,0xab,0x00,0x00,0x00,
+0x2a,0x00,0x00,0x00,0x35,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
+0x06,0x00,0x00,0x00,0xac,0x00,0x00,0x00,0xab,0x00,0x00,0x00,
+0x80,0x00,0x05,0x00,0x06,0x00,0x00,0x00,0xae,0x00,0x00,0x00,
+0xac,0x00,0x00,0x00,0x11,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
+0x1e,0x00,0x00,0x00,0xaf,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,
+0x21,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x17,0x00,0x00,0x00,
+0xb0,0x00,0x00,0x00,0xaf,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x6d,0x00,0x00,0x00,0xb1,0x00,0x00,0x00,0xaa,0x00,0x00,0x00,
+0x21,0x00,0x00,0x00,0xae,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
+0xb1,0x00,0x00,0x00,0xb0,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
+0xa6,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0xa6,0x00,0x00,0x00,
+0xfd,0x00,0x01,0x00,0x38,0x00,0x01,0x00,
};
-const uint64_t mul_mat_vec_f16_f32_len = 2788;
+const uint64_t mul_mat_vec_f16_f32_len = 2372;
unsigned char mul_mat_vec_nc_f16_f32_data[] = {
0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp
index 1d93ec6bb..bccc40bf5 100644
--- a/ggml-vulkan.cpp
+++ b/ggml-vulkan.cpp
@@ -817,7 +817,7 @@ static void ggml_vk_load_shaders() {
// mulmat
std::initializer_list warptile_l = { 128, 128, 128, 16, vk_device.subgroup_size * 2, 64, 2, 4, 4, vk_device.subgroup_size };
std::initializer_list warptile_m = { 128, 64, 64, 16, vk_device.subgroup_size, 32, 2, 4, 2, vk_device.subgroup_size };
- std::initializer_list warptile_s = { vk_device.subgroup_size, 32, 32, 8, 32, 32, 2, 2, 2, vk_device.subgroup_size };
+ std::initializer_list warptile_s = { vk_device.subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, vk_device.subgroup_size };
std::array l_wg_denoms = {128, 128, 1 };
std::array m_wg_denoms = { 64, 64, 1 };
@@ -2873,7 +2873,8 @@ static void ggml_vk_op_f32(vk_context * ctx, const ggml_tensor * src0, const ggm
if (op == GGML_OP_CPY) {
GGML_ASSERT(!transfer_src0);
GGML_ASSERT(!transfer_src1);
- d_sz = dst->ne[1] * dst->nb[1];
+ x_sz = ggml_nbytes(src0);
+ d_sz = ggml_nbytes(dst);
if (extra->offset + d_sz >= d_D->size) {
d_sz = VK_WHOLE_SIZE;
@@ -4556,8 +4557,15 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml
}
ggml_vk_preallocate_buffers();
+ int last_node = cgraph->n_nodes - 1;
+
+ // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
+ while (last_node > 0 && cgraph->nodes[last_node]->backend != GGML_BACKEND_GPU) {
+ last_node -= 1;
+ }
+
for (int i = 0; i < cgraph->n_nodes; i++) {
- ggml_vk_build_graph(cgraph->nodes[i], i == cgraph->n_nodes - 1);
+ ggml_vk_build_graph(cgraph->nodes[i], i == last_node);
}
ggml_compute_params params = {};
diff --git a/ggml.c b/ggml.c
index a7a9ea319..ee994c875 100644
--- a/ggml.c
+++ b/ggml.c
@@ -5349,7 +5349,7 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
int s0,
int p0,
int d0) {
- struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
+ struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
struct ggml_tensor * result =
ggml_mul_mat(ctx,
@@ -5427,16 +5427,15 @@ struct ggml_tensor * ggml_conv_depthwise_2d(
int p1,
int d0,
int d1) {
+
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
- s0, s1, p0, p1, d0, d1, true); // [N * IC, OH, OW, KH * KW]
-
- struct ggml_tensor * result =
- ggml_mul_mat(ctx,
- ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1), // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
- ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3])); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
+ s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
+ struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
+ new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
+ struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
return result;
@@ -5457,7 +5456,8 @@ struct ggml_tensor * ggml_im2col(
int p1,
int d0,
int d1,
- bool is_2D) {
+ bool is_2D,
+ enum ggml_type dst_type) {
if(is_2D) {
GGML_ASSERT(a->ne[2] == b->ne[2]);
@@ -5481,7 +5481,7 @@ struct ggml_tensor * ggml_im2col(
is_2D ? b->ne[3] : 1,
};
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
+ struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
ggml_set_op_params(result, params, sizeof(params));
@@ -5506,7 +5506,7 @@ struct ggml_tensor * ggml_conv_2d(
int p1,
int d0,
int d1) {
- struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
+ struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW]
struct ggml_tensor * result =
ggml_mul_mat(ctx,
@@ -5632,12 +5632,13 @@ struct ggml_tensor * ggml_pool_2d(
is_node = true;
}
+ struct ggml_tensor * result;
const int64_t ne[3] = {
ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
a->ne[2],
};
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
+ result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
ggml_set_op_params(result, params, sizeof(params));
@@ -5645,7 +5646,6 @@ struct ggml_tensor * ggml_pool_2d(
result->op = GGML_OP_POOL_2D;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
-
return result;
}
@@ -12493,6 +12493,92 @@ static void ggml_compute_forward_conv_transpose_1d(
}
}
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst: result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int64_t t0 = ggml_perf_time_us();
+ UNUSED(t0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t N = is_2D ? ne13 : ne12;
+ const int64_t IC = is_2D ? ne12 : ne11;
+ const int64_t IH = is_2D ? ne11 : 1;
+ const int64_t IW = ne10;
+
+ const int64_t KH = is_2D ? ne01 : 1;
+ const int64_t KW = ne00;
+
+ const int64_t OH = is_2D ? ne2 : 1;
+ const int64_t OW = ne1;
+
+ int ofs0 = is_2D ? nb13 : nb12;
+ int ofs1 = is_2D ? nb12 : nb11;
+
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ if (params->type == GGML_TASK_INIT) {
+ return;
+ }
+
+ if (params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+ {
+ float * const wdata = (float *) dst->data;
+
+ for (int64_t in = 0; in < N; in++) {
+ for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
+ for (int64_t iow = 0; iow < OW; iow++) {
+ for (int64_t iic = ith; iic < IC; iic += nth) {
+
+ // micro kernel
+ float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+ const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+
+ for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+ } else {
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+
// src0: kernel [OC, IC, KH, KW]
// src1: image [N, IC, IH, IW]
// dst: result [N, OH, OW, IC*KH*KW]
@@ -12583,14 +12669,14 @@ static void ggml_compute_forward_im2col(
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
- switch (src0->type) {
+ switch (dst->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_im2col_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
- GGML_ASSERT(false);
+ ggml_compute_forward_im2col_f32(params, src0, src1, dst);
} break;
default:
{
@@ -12781,8 +12867,8 @@ static void ggml_compute_forward_pool_2d(
const struct ggml_compute_params * params,
const struct ggml_tensor * src,
struct ggml_tensor * dst) {
- assert(src->type == GGML_TYPE_F32);
- assert(params->ith == 0);
+ GGML_ASSERT(src->type == GGML_TYPE_F32);
+ GGML_ASSERT(params->ith == 0);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
@@ -16985,12 +17071,16 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
struct ggml_cplan cplan;
memset(&cplan, 0, sizeof(struct ggml_cplan));
+ int max_tasks = 1;
+
// thread scheduling for the different operations + work buffer size estimation
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
const int n_tasks = ggml_get_n_tasks(node, n_threads);
+ max_tasks = MAX(max_tasks, n_tasks);
+
size_t cur = 0;
switch (node->op) {
@@ -17157,7 +17247,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
work_size += CACHE_LINE_SIZE*(n_threads - 1);
}
- cplan.n_threads = n_threads;
+ cplan.n_threads = MIN(max_tasks, n_threads);
cplan.work_size = work_size;
cplan.work_data = NULL;
@@ -20473,6 +20563,14 @@ int ggml_cpu_has_vulkan(void) {
#endif
}
+int ggml_cpu_has_kompute(void) {
+#if defined(GGML_USE_KOMPUTE)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
int ggml_cpu_has_sycl(void) {
#if defined(GGML_USE_SYCL)
return 1;
@@ -20482,7 +20580,8 @@ int ggml_cpu_has_sycl(void) {
}
int ggml_cpu_has_gpublas(void) {
- return ggml_cpu_has_cublas() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_sycl();
+ return ggml_cpu_has_cublas() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() ||
+ ggml_cpu_has_sycl();
}
int ggml_cpu_has_sse3(void) {
diff --git a/ggml.h b/ggml.h
index bf782e6ad..e0a4799f3 100644
--- a/ggml.h
+++ b/ggml.h
@@ -1495,7 +1495,8 @@ extern "C" {
int p1,
int d0,
int d1,
- bool is_2D);
+ bool is_2D,
+ enum ggml_type dst_type);
GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
struct ggml_context * ctx,
@@ -2266,6 +2267,7 @@ extern "C" {
GGML_API int ggml_cpu_has_cublas (void);
GGML_API int ggml_cpu_has_clblast (void);
GGML_API int ggml_cpu_has_vulkan (void);
+ GGML_API int ggml_cpu_has_kompute (void);
GGML_API int ggml_cpu_has_gpublas (void);
GGML_API int ggml_cpu_has_sse3 (void);
GGML_API int ggml_cpu_has_ssse3 (void);
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py
index d0861fde4..6b1b82bf3 100644
--- a/ggml_vk_generate_shaders.py
+++ b/ggml_vk_generate_shaders.py
@@ -19,8 +19,8 @@ shader_int8_ext = """
# Type-specific defines
shader_f16_defines = """
-#define QUANT_K 32
-#define QUANT_R 2
+#define QUANT_K 1
+#define QUANT_R 1
#define A_TYPE float16_t
"""
diff --git a/llama.cpp b/llama.cpp
index 7b9a5c079..bb23689fa 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2713,10 +2713,10 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
- case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XSS - 2.0625 bpw";
+ case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XXS - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
case LLAMA_FTYPE_MOSTLY_Q3_K_XS:return "Q3_K - Extra small";
- case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XSS - 3.0625 bpw";
+ case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw";
default: return "unknown, may not work";
}
@@ -6878,11 +6878,6 @@ static int llama_decode_internal(
n_threads = std::min(4, n_threads);
}
- const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1;
- if ((ggml_cpu_has_cublas() || ggml_cpu_has_vulkan()) && fully_offloaded) {
- n_threads = 1;
- }
-
#ifdef GGML_USE_MPI
const int64_t n_layer = hparams.n_layer;
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
diff --git a/scripts/install-oneapi.bat b/scripts/install-oneapi.bat
new file mode 100644
index 000000000..e99bef14a
--- /dev/null
+++ b/scripts/install-oneapi.bat
@@ -0,0 +1,19 @@
+:: MIT license
+:: Copyright (C) 2024 Intel Corporation
+:: SPDX-License-Identifier: MIT
+
+
+set URL=%1
+set COMPONENTS=%2
+
+curl.exe --output %TEMP%\webimage.exe --url %URL% --retry 5 --retry-delay 5
+start /b /wait %TEMP%\webimage.exe -s -x -f webimage_extracted --log extract.log
+del %TEMP%\webimage.exe
+if "%COMPONENTS%"=="" (
+ webimage_extracted\bootstrapper.exe -s --action install --eula=accept -p=NEED_VS2017_INTEGRATION=0 -p=NEED_VS2019_INTEGRATION=0 -p=NEED_VS2022_INTEGRATION=0 --log-dir=.
+) else (
+ webimage_extracted\bootstrapper.exe -s --action install --components=%COMPONENTS% --eula=accept -p=NEED_VS2017_INTEGRATION=0 -p=NEED_VS2019_INTEGRATION=0 -p=NEED_VS2022_INTEGRATION=0 --log-dir=.
+)
+set installer_exit_code=%ERRORLEVEL%
+rd /s/q "webimage_extracted"
+exit /b %installer_exit_code%
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 1d29070b6..eb06123d2 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -227,6 +227,14 @@ static std::string var_to_str(ggml_type type) {
return ggml_type_name(type);
}
+static std::string var_to_str(ggml_op_pool pool) {
+ switch (pool) {
+ case GGML_OP_POOL_AVG: return "avg";
+ case GGML_OP_POOL_MAX: return "max";
+ default: return std::to_string(pool);
+ }
+}
+
#define VARS_TO_STR1(a) VAR_TO_STR(a)
#define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b)
#define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c)
@@ -238,6 +246,7 @@ static std::string var_to_str(ggml_type type) {
#define VARS_TO_STR9(a, b, c, d, e, f, g, h, i) VAR_TO_STR(a) + "," + VARS_TO_STR8(b, c, d, e, f, g, h, i)
#define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
#define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
+#define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)
#ifdef GGML_USE_SYCL
static bool inline _isinf(float f) {
@@ -1162,10 +1171,45 @@ struct test_alibi : public test_case {
}
};
+// GGML_OP_POOL2D
+struct test_pool2d : public test_case {
+ enum ggml_op_pool pool_type;
+ const ggml_type type_input;
+ const std::array ne_input;
+ // kernel size
+ const int k0;
+ const int k1;
+ // stride
+ const int s0;
+ const int s1;
+ // padding
+ const int p0;
+ const int p1;
+
+ std::string vars() override {
+ return VARS_TO_STR9(pool_type, type_input, ne_input, k0, k1, s0, s1, p0, p1);
+ }
+
+ test_pool2d(ggml_op_pool pool_type = GGML_OP_POOL_AVG,
+ ggml_type type_input = GGML_TYPE_F32,
+ std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
+ int k0 = 3, int k1 = 3,
+ int s0 = 1, int s1 = 1,
+ int p0 = 1, int p1 = 1)
+ : pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), k1(k1), s0(s0), s1(s1), p0(p0), p1(p1) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
+ ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1);
+ return out;
+ }
+};
+
// GGML_OP_IM2COL
struct test_im2col : public test_case {
const ggml_type type_input;
const ggml_type type_kernel;
+ const ggml_type dst_type;
const std::array ne_input;
const std::array ne_kernel;
// stride
@@ -1181,22 +1225,22 @@ struct test_im2col : public test_case {
const bool is_2D;
std::string vars() override {
- return VARS_TO_STR11(type_input, type_kernel, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);
+ return VARS_TO_STR12(type_input, type_kernel, dst_type, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);
}
- test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16,
+ test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32,
std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
std::array ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
int s0 = 1, int s1 = 1,
int p0 = 1, int p1 = 1,
int d0 = 1, int d1 = 1,
bool is_2D = true)
- : type_input(type_input), type_kernel(type_kernel), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}
+ : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
- ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D);
+ ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, dst_type);
return out;
}
};
@@ -1912,6 +1956,27 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
}
+ for (ggml_type type_input : {GGML_TYPE_F32}) {
+ for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
+ for (int k0 : {1, 3}) {
+ for (int k1 : {1, 3}) {
+ for (int s0 : {1, 2}) {
+ for (int s1 : {1, 2}) {
+ for (int p0 : {0, 1}) {
+ for (int p1 : {0, 1}) {
+ test_cases.emplace_back(new test_pool2d(pool_type, type_input, {10, 10, 3, 1}, k0, k1, s0, s1, p0, p1));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
+ test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
+
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));
@@ -2049,7 +2114,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
test_cases.emplace_back(new test_alibi());
- test_cases.emplace_back(new test_im2col());
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
test_cases.emplace_back(new test_concat(GGML_TYPE_I32));