From cbd08b4204ec0bf720be628f65150d3704681667 Mon Sep 17 00:00:00 2001 From: HimariO Date: Fri, 29 Nov 2024 22:18:15 +0800 Subject: [PATCH] resolve linter, test errors --- CMakePresets.json | 177 +++++++++++++---------------- convert_hf_to_gguf.py | 2 +- examples/llava/clip.cpp | 17 +-- examples/llava/clip.h | 3 +- examples/llava/qwen2_vl_surgery.py | 30 +++-- examples/llava/qwen2vl-cli.cpp | 13 +-- ggml/include/ggml.h | 4 +- ggml/src/ggml-cpu/ggml-cpu.c | 55 +++++++++ ggml/src/ggml.c | 7 -- gguf-py/gguf/gguf_writer.py | 2 +- src/llama.cpp | 5 +- 11 files changed, 168 insertions(+), 147 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index e354b61f0..4d3f546f7 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -1,100 +1,85 @@ { "version": 4, "configurePresets": [ - { - "name": "base", - "hidden": true, - "generator": "Ninja", - "binaryDir": "${sourceDir}/build-${presetName}", - "cacheVariables": { - "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", - "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." - } - }, - { - "name": "sycl-base", - "hidden": true, - "generator": "Ninja", - "binaryDir": "${sourceDir}/build-${presetName}", - "cacheVariables": { - "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", - "CMAKE_CXX_COMPILER": "icx", - "CMAKE_C_COMPILER": "cl", - "GGML_SYCL": "ON", - "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." - } - }, - { "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } }, - { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, - { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, - { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, - { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, - - { - "name": "arm64-windows-msvc", "hidden": true, - "architecture": { "value": "arm64", "strategy": "external" }, - "toolset": { "value": "host=x64", "strategy": "external" }, - "cacheVariables": { - "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-msvc.cmake" - } - }, - - { - "name": "arm64-windows-llvm", "hidden": true, - "architecture": { "value": "arm64", "strategy": "external" }, - "toolset": { "value": "host=x64", "strategy": "external" }, - "cacheVariables": { - "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-llvm.cmake" - } - }, - - { - "name": "arm64-apple-clang", "hidden": true, - "architecture": { "value": "arm64", "strategy": "external" }, - "toolset": { "value": "host=x64", "strategy": "external" }, - "cacheVariables": { - "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake" - } - }, - - { "name": "arm64-windows-llvm-debug" , "inherits": [ "base", "arm64-windows-llvm", "debug" ] }, - { "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] }, - { "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] }, - - { "name": "arm64-apple-clang-debug" , "inherits": [ "base", "arm64-apple-clang", "debug" ] }, - { "name": "arm64-apple-clang-release" , "inherits": [ "base", "arm64-apple-clang", "reldbg" ] }, - { "name": "arm64-apple-clang+static-release" , "inherits": [ "base", "arm64-apple-clang", "reldbg", "static" ] }, - - { "name": "arm64-windows-msvc-debug" , "inherits": [ "base", "arm64-windows-msvc", "debug" ] }, - { "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg" ] }, - { "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg", "static" ] }, - - { "name": "x64-windows-msvc-debug" , "inherits": [ "base", "debug" ] }, - { "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] }, - { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, - - { "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] }, - { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] }, - { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }, - { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }, - { - "name": "x86-cuda-linux", - "description": "", - "displayName": "", - "inherits": [ - "base", - "debug" - ], - "cacheVariables": { - "GGML_CUDA": "1", - "CUDA_PATH": "/usr/local/cuda", - "CUDAToolkit_ROOT": "/usr/local/cuda", - "CUDAToolkit_INCLUDE_DIR": "/usr/local/cuda/include/", - "CUDAToolkit_LIBRARY_DIR": "/usr/local/cuda/lib64", - "CUDA_NVCC_FLAGS": "-g -G", - "CMAKE_CUDA_FLAGS_DEBUG": "-g -G", - "CMAKE_CUDA_FLAGS": "-maxrregcount=40" - } - } + { + "name": "base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/build-${presetName}", + "cacheVariables": { + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." + } + }, + { + "name": "sycl-base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/build-${presetName}", + "cacheVariables": { + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "CMAKE_CXX_COMPILER": "icx", + "CMAKE_C_COMPILER": "cl", + "GGML_SYCL": "ON", + "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." + } + }, + { "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } }, + { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, + { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, + { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, + { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, + { "name": "vulkan", "hidden": true, "cacheVariables": { "GGML_VULKAN": "ON" } }, + + { + "name": "arm64-windows-msvc", "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x64", "strategy": "external" }, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-msvc.cmake" + } + }, + + { + "name": "arm64-windows-llvm", "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x64", "strategy": "external" }, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-llvm.cmake" + } + }, + + { + "name": "arm64-apple-clang", "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x64", "strategy": "external" }, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake" + } + }, + + { "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] }, + { "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] }, + { "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] }, + + { "name": "arm64-apple-clang-debug", "inherits": [ "base", "arm64-apple-clang", "debug" ] }, + { "name": "arm64-apple-clang-release", "inherits": [ "base", "arm64-apple-clang", "reldbg" ] }, + { "name": "arm64-apple-clang+static-release", "inherits": [ "base", "arm64-apple-clang", "reldbg", "static" ] }, + + { "name": "arm64-windows-msvc-debug", "inherits": [ "base", "arm64-windows-msvc", "debug" ] }, + { "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg" ] }, + { "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg", "static" ] }, + + { "name": "x64-windows-msvc-debug", "inherits": [ "base", "debug" ] }, + { "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] }, + { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, + + { "name": "x64-windows-sycl-debug", "inherits": [ "sycl-base", "debug" ] }, + { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] }, + { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }, + { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }, + + { "name": "x64-windows-vulkan-debug", "inherits": [ "base", "vulkan", "debug" ] }, + { "name": "x64-windows-vulkan-release", "inherits": [ "base", "vulkan", "release" ] } ] -} + } \ No newline at end of file diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index df14a7988..5ce828ddc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1991,7 +1991,7 @@ class Qwen2VLModel(Model): self._set_vocab_sentencepiece() except FileNotFoundError: self._set_vocab_gpt2() - + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: for name, data in super().get_tensors(): if name.startswith("visual."): diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index c61a4d415..050b04ce2 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -2590,12 +2590,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima int* positions_data = (int*)malloc(ggml_nbytes(positions)); int ptr = 0; - for (size_t y = 0; y < ph; y+=2) + for (int y = 0; y < ph; y+=2) { - for (size_t x = 0; x < pw; x+=2) + for (int x = 0; x < pw; x+=2) { - for (size_t dy = 0; dy < 2; dy++) { - for (size_t dx = 0; dx < 2; dx++) { + for (int dy = 0; dy < 2; dy++) { + for (int dx = 0; dx < 2; dx++) { positions_data[ptr] = y + dy; positions_data[num_patches + ptr] = x + dx; positions_data[num_patches * 2 + ptr] = y + dy; @@ -2820,20 +2820,15 @@ bool clip_is_qwen2vl(const struct clip_ctx * ctx) { } -bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { +bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { clip_image_f32 clip_img; clip_img.buf.resize(h * w * 3); - for (size_t i = 0; i < h*w*3; i++) + for (int i = 0; i < h*w*3; i++) { clip_img.buf[i] = img[i]; } clip_img.nx = w; clip_img.ny = h; - // ctx->vision_model.hparams.image_size = h; clip_image_encode(ctx, n_threads, &clip_img, vec); return true; } - -void tmp_clip_set_layers (struct clip_ctx * ctx, int layers) { - ctx->vision_model.hparams.n_layer = layers; -} \ No newline at end of file diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 750a0438e..1603edd26 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -91,8 +91,7 @@ CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); -CLIP_API bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); -CLIP_API void tmp_clip_set_layers (struct clip_ctx * ctx, int layers); +CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); #ifdef __cplusplus } diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py index c71bc973f..56d933fde 100644 --- a/examples/llava/qwen2_vl_surgery.py +++ b/examples/llava/qwen2_vl_surgery.py @@ -1,12 +1,11 @@ import argparse -import glob -import os -from typing import Any, Dict +from typing import Dict import torch +import numpy as np from gguf import * from transformers import ( - Qwen2VLForConditionalGeneration, + Qwen2VLForConditionalGeneration, Qwen2VLProcessor, AutoProcessor, Qwen2VLConfig @@ -44,7 +43,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: else: # bias c3 = ten.shape[0] assert c3 % 3 == 0 - c = c3//3 + c = c3 // 3 wq = ten[:c] wk = ten[c: c * 2] wv = ten[c * 2:] @@ -68,7 +67,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...] else: tensor_map[to_gguf_name(f"vision_model.{name}")] = ten - + for new_name, ten in tensor_map.items(): if ten.ndim <= 1 or new_name.endswith("_norm.weight"): tensor_map[new_name] = ten.astype(np.float32) @@ -89,16 +88,14 @@ def main(args): ftype = 1 else: raise ValueError() - + model_name = args.model_name print("model_name: ", model_name) qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( model_name, torch_dtype=dtype, device_map="cpu" ) - cfg: Qwen2VLConfig = qwen2vl.config + cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] vcfg = cfg.vision_config - rope_cfg = cfg.rope_scaling - fname_out = "qwen2vl-vision.gguf" fout = GGUFWriter(path=fname_out, arch="clip") @@ -125,23 +122,22 @@ def main(args): fout.add_tensor(name, data) fout.add_uint32("clip.vision.patch_size", vcfg.patch_size) - fout.add_uint32("clip.vision.image_size", 14*40) # some reasonable size that is divable by (14*2) + fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth) - fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # BUG: not sure what this does + fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # not sure what this does, put 0 here as a placeholder fout.add_name(model_name) """ - HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig, + HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig, it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`. """ processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name) - # breakpoint() - fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) - fout.add_array("clip.vision.image_std", processor.image_processor.image_std) + fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue] + fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue] fout.write_header_to_file() fout.write_kv_data_to_file() @@ -154,4 +150,4 @@ if __name__ == "__main__": parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct") parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32") args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 73f94d8fa..4a1c12cbb 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -26,9 +26,9 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla auto img_tokens = image_embed->n_image_pos; llama_pos mrope_pos[img_tokens * 4]; - for (size_t y = 0; y < ph; y++) + for (int y = 0; y < ph; y++) { - for (size_t x = 0; x < pw; x++) + for (int x = 0; x < pw; x++) { int i = y * pw + x; mrope_pos[i] = *st_pos_id; @@ -270,7 +270,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ LOG("\n"); - struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams); + struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling); if (!smpl) { LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__); exit(1); @@ -422,10 +422,7 @@ static void tmp_dump_img_embed(struct llava_context * ctx_llava, common_params * int ne = n_embd * 4; float vals[56 * 56 * 3]; float embd[ne]; - // for (int i = 0; i < 3*56*56; i++) - // { - // vals[i] = 0.1; - // } + for (int i = 0; i < 56*56; i++) { for (int c = 0; c < 3; c++) @@ -433,7 +430,7 @@ static void tmp_dump_img_embed(struct llava_context * ctx_llava, common_params * } // auto param = &ctx_llava->ctx_clip->vision_model.hparams; - tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd); + clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd); std::ofstream outFile("img_embed.bin", std::ios::binary); if (outFile.is_open()) { diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 56c6f2c05..e1c620e15 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -238,8 +238,8 @@ #define GGML_EXIT_ABORTED 1 #define GGML_ROPE_TYPE_NEOX 2 -#define GGML_ROPE_TYPE_MROPE 4 -#define GGML_ROPE_TYPE_VISION 12 +#define GGML_ROPE_TYPE_MROPE 8 +#define GGML_ROPE_TYPE_VISION 24 #define GGUF_MAGIC "GGUF" diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 1cba6e96e..fb9fcff67 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -9205,6 +9205,61 @@ static void ggml_rope_cache_init( } } +static void ggml_mrope_cache_init( + float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects, + float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, + float * cache, float sin_sign, float theta_scale) { + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py + float theta_t = theta_base_t; + float theta_h = theta_base_h; + float theta_w = theta_base_w; + float theta_e = theta_base_e; // extra position id for vision encoder + int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + int sec_w = sections[1] + sections[0]; + GGML_ASSERT(sect_dims <= ne0); + + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; + + int sector = (i0 / 2) % sect_dims; + if (indep_sects) { + if (sector == 0) { + theta_t = theta_base_t; + } + else if (sector == sections[0]) { + theta_h = theta_base_h;; + } + else if (sector == sections[1]) { + theta_w = theta_base_w; + } + else if (sector == sections[2]) { + theta_e = theta_base_e; + } + } + + float theta = theta_t; + if (sector >= sections[0] && sector < sec_w) { + theta = theta_h; + } + else if (sector >= sec_w && sector < sec_w + sections[2]) { + theta = theta_w; + } + else if (sector >= sec_w + sections[2]) { + theta = theta_e; + } + + rope_yarn( + theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] + ); + cache[i0 + 1] *= sin_sign; + + theta_t *= theta_scale; + theta_w *= theta_scale; + theta_h *= theta_scale; + theta_e *= theta_scale; + } +} + static void ggml_compute_forward_rope_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 6289aee5a..c3726163b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3575,12 +3575,6 @@ struct ggml_tensor * ggml_mrope_ext( GGML_ASSERT(c->ne[0] >= n_dims / 2); } - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; @@ -3595,7 +3589,6 @@ struct ggml_tensor * ggml_mrope_ext( ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; result->src[2] = c; diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 3b1d7e9e9..65a64e10d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -750,7 +750,7 @@ class GGUFWriter: def add_rope_dimension_count(self, count: int) -> None: self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count) - + def add_rope_dimension_sections(self, dims: Sequence[int]) -> None: self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims) diff --git a/src/llama.cpp b/src/llama.cpp index d35e2cf27..15052006b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2436,7 +2436,7 @@ struct llama_hparams { float rope_freq_scale_train; uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul; - std::array rope_mrope_sections; + std::array rope_mrope_sections; // for State Space Models uint32_t ssm_d_conv = 0; @@ -12540,7 +12540,8 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - int * sections = (int *)hparams.rope_mrope_sections.data(); + int sections[4]; + std::copy(hparams.rope_mrope_sections.begin(), hparams.rope_mrope_sections.end(), sections); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL;