resolve linter, test errors

This commit is contained in:
HimariO 2024-11-29 22:18:15 +08:00
parent fac034530f
commit cbd08b4204
11 changed files with 168 additions and 147 deletions

View file

@ -1,100 +1,85 @@
{
"version": 4,
"configurePresets": [
{
"name": "base",
"hidden": true,
"generator": "Ninja",
"binaryDir": "${sourceDir}/build-${presetName}",
"cacheVariables": {
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
"CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.."
}
},
{
"name": "sycl-base",
"hidden": true,
"generator": "Ninja",
"binaryDir": "${sourceDir}/build-${presetName}",
"cacheVariables": {
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
"CMAKE_CXX_COMPILER": "icx",
"CMAKE_C_COMPILER": "cl",
"GGML_SYCL": "ON",
"CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.."
}
},
{ "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } },
{ "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } },
{ "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
{ "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } },
{ "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } },
{
"name": "arm64-windows-msvc", "hidden": true,
"architecture": { "value": "arm64", "strategy": "external" },
"toolset": { "value": "host=x64", "strategy": "external" },
"cacheVariables": {
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-msvc.cmake"
}
},
{
"name": "arm64-windows-llvm", "hidden": true,
"architecture": { "value": "arm64", "strategy": "external" },
"toolset": { "value": "host=x64", "strategy": "external" },
"cacheVariables": {
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-llvm.cmake"
}
},
{
"name": "arm64-apple-clang", "hidden": true,
"architecture": { "value": "arm64", "strategy": "external" },
"toolset": { "value": "host=x64", "strategy": "external" },
"cacheVariables": {
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake"
}
},
{ "name": "arm64-windows-llvm-debug" , "inherits": [ "base", "arm64-windows-llvm", "debug" ] },
{ "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] },
{ "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] },
{ "name": "arm64-apple-clang-debug" , "inherits": [ "base", "arm64-apple-clang", "debug" ] },
{ "name": "arm64-apple-clang-release" , "inherits": [ "base", "arm64-apple-clang", "reldbg" ] },
{ "name": "arm64-apple-clang+static-release" , "inherits": [ "base", "arm64-apple-clang", "reldbg", "static" ] },
{ "name": "arm64-windows-msvc-debug" , "inherits": [ "base", "arm64-windows-msvc", "debug" ] },
{ "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg" ] },
{ "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg", "static" ] },
{ "name": "x64-windows-msvc-debug" , "inherits": [ "base", "debug" ] },
{ "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] },
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
{ "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] },
{ "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] },
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] },
{ "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] },
{
"name": "x86-cuda-linux",
"description": "",
"displayName": "",
"inherits": [
"base",
"debug"
],
"cacheVariables": {
"GGML_CUDA": "1",
"CUDA_PATH": "/usr/local/cuda",
"CUDAToolkit_ROOT": "/usr/local/cuda",
"CUDAToolkit_INCLUDE_DIR": "/usr/local/cuda/include/",
"CUDAToolkit_LIBRARY_DIR": "/usr/local/cuda/lib64",
"CUDA_NVCC_FLAGS": "-g -G",
"CMAKE_CUDA_FLAGS_DEBUG": "-g -G",
"CMAKE_CUDA_FLAGS": "-maxrregcount=40"
}
}
{
"name": "base",
"hidden": true,
"generator": "Ninja",
"binaryDir": "${sourceDir}/build-${presetName}",
"cacheVariables": {
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
"CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.."
}
},
{
"name": "sycl-base",
"hidden": true,
"generator": "Ninja",
"binaryDir": "${sourceDir}/build-${presetName}",
"cacheVariables": {
"CMAKE_EXPORT_COMPILE_COMMANDS": "ON",
"CMAKE_CXX_COMPILER": "icx",
"CMAKE_C_COMPILER": "cl",
"GGML_SYCL": "ON",
"CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.."
}
},
{ "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } },
{ "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } },
{ "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
{ "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } },
{ "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } },
{ "name": "vulkan", "hidden": true, "cacheVariables": { "GGML_VULKAN": "ON" } },
{
"name": "arm64-windows-msvc", "hidden": true,
"architecture": { "value": "arm64", "strategy": "external" },
"toolset": { "value": "host=x64", "strategy": "external" },
"cacheVariables": {
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-msvc.cmake"
}
},
{
"name": "arm64-windows-llvm", "hidden": true,
"architecture": { "value": "arm64", "strategy": "external" },
"toolset": { "value": "host=x64", "strategy": "external" },
"cacheVariables": {
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-llvm.cmake"
}
},
{
"name": "arm64-apple-clang", "hidden": true,
"architecture": { "value": "arm64", "strategy": "external" },
"toolset": { "value": "host=x64", "strategy": "external" },
"cacheVariables": {
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake"
}
},
{ "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] },
{ "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] },
{ "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] },
{ "name": "arm64-apple-clang-debug", "inherits": [ "base", "arm64-apple-clang", "debug" ] },
{ "name": "arm64-apple-clang-release", "inherits": [ "base", "arm64-apple-clang", "reldbg" ] },
{ "name": "arm64-apple-clang+static-release", "inherits": [ "base", "arm64-apple-clang", "reldbg", "static" ] },
{ "name": "arm64-windows-msvc-debug", "inherits": [ "base", "arm64-windows-msvc", "debug" ] },
{ "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg" ] },
{ "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg", "static" ] },
{ "name": "x64-windows-msvc-debug", "inherits": [ "base", "debug" ] },
{ "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] },
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
{ "name": "x64-windows-sycl-debug", "inherits": [ "sycl-base", "debug" ] },
{ "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] },
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] },
{ "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] },
{ "name": "x64-windows-vulkan-debug", "inherits": [ "base", "vulkan", "debug" ] },
{ "name": "x64-windows-vulkan-release", "inherits": [ "base", "vulkan", "release" ] }
]
}
}

View file

@ -1991,7 +1991,7 @@ class Qwen2VLModel(Model):
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for name, data in super().get_tensors():
if name.startswith("visual."):

View file

@ -2590,12 +2590,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
int* positions_data = (int*)malloc(ggml_nbytes(positions));
int ptr = 0;
for (size_t y = 0; y < ph; y+=2)
for (int y = 0; y < ph; y+=2)
{
for (size_t x = 0; x < pw; x+=2)
for (int x = 0; x < pw; x+=2)
{
for (size_t dy = 0; dy < 2; dy++) {
for (size_t dx = 0; dx < 2; dx++) {
for (int dy = 0; dy < 2; dy++) {
for (int dx = 0; dx < 2; dx++) {
positions_data[ptr] = y + dy;
positions_data[num_patches + ptr] = x + dx;
positions_data[num_patches * 2 + ptr] = y + dy;
@ -2820,20 +2820,15 @@ bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
}
bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;
clip_img.buf.resize(h * w * 3);
for (size_t i = 0; i < h*w*3; i++)
for (int i = 0; i < h*w*3; i++)
{
clip_img.buf[i] = img[i];
}
clip_img.nx = w;
clip_img.ny = h;
// ctx->vision_model.hparams.image_size = h;
clip_image_encode(ctx, n_threads, &clip_img, vec);
return true;
}
void tmp_clip_set_layers (struct clip_ctx * ctx, int layers) {
ctx->vision_model.hparams.n_layer = layers;
}

View file

@ -91,8 +91,7 @@ CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
CLIP_API bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
CLIP_API void tmp_clip_set_layers (struct clip_ctx * ctx, int layers);
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
#ifdef __cplusplus
}

View file

@ -1,12 +1,11 @@
import argparse
import glob
import os
from typing import Any, Dict
from typing import Dict
import torch
import numpy as np
from gguf import *
from transformers import (
Qwen2VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
Qwen2VLProcessor,
AutoProcessor,
Qwen2VLConfig
@ -44,7 +43,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
else: # bias
c3 = ten.shape[0]
assert c3 % 3 == 0
c = c3//3
c = c3 // 3
wq = ten[:c]
wk = ten[c: c * 2]
wv = ten[c * 2:]
@ -68,7 +67,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
else:
tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
for new_name, ten in tensor_map.items():
if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
tensor_map[new_name] = ten.astype(np.float32)
@ -89,16 +88,14 @@ def main(args):
ftype = 1
else:
raise ValueError()
model_name = args.model_name
print("model_name: ", model_name)
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=dtype, device_map="cpu"
)
cfg: Qwen2VLConfig = qwen2vl.config
cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
vcfg = cfg.vision_config
rope_cfg = cfg.rope_scaling
fname_out = "qwen2vl-vision.gguf"
fout = GGUFWriter(path=fname_out, arch="clip")
@ -125,23 +122,22 @@ def main(args):
fout.add_tensor(name, data)
fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
fout.add_uint32("clip.vision.image_size", 14*40) # some reasonable size that is divable by (14*2)
fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2)
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # BUG: not sure what this does
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # not sure what this does, put 0 here as a placeholder
fout.add_name(model_name)
"""
HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig,
HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig,
it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`.
"""
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name)
# breakpoint()
fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean)
fout.add_array("clip.vision.image_std", processor.image_processor.image_std)
fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue]
fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue]
fout.write_header_to_file()
fout.write_kv_data_to_file()
@ -154,4 +150,4 @@ if __name__ == "__main__":
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
args = parser.parse_args()
main(args)
main(args)

View file

@ -26,9 +26,9 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
auto img_tokens = image_embed->n_image_pos;
llama_pos mrope_pos[img_tokens * 4];
for (size_t y = 0; y < ph; y++)
for (int y = 0; y < ph; y++)
{
for (size_t x = 0; x < pw; x++)
for (int x = 0; x < pw; x++)
{
int i = y * pw + x;
mrope_pos[i] = *st_pos_id;
@ -270,7 +270,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
LOG("\n");
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams);
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
if (!smpl) {
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
@ -422,10 +422,7 @@ static void tmp_dump_img_embed(struct llava_context * ctx_llava, common_params *
int ne = n_embd * 4;
float vals[56 * 56 * 3];
float embd[ne];
// for (int i = 0; i < 3*56*56; i++)
// {
// vals[i] = 0.1;
// }
for (int i = 0; i < 56*56; i++)
{
for (int c = 0; c < 3; c++)
@ -433,7 +430,7 @@ static void tmp_dump_img_embed(struct llava_context * ctx_llava, common_params *
}
// auto param = &ctx_llava->ctx_clip->vision_model.hparams;
tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd);
clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd);
std::ofstream outFile("img_embed.bin", std::ios::binary);
if (outFile.is_open()) {

View file

@ -238,8 +238,8 @@
#define GGML_EXIT_ABORTED 1
#define GGML_ROPE_TYPE_NEOX 2
#define GGML_ROPE_TYPE_MROPE 4
#define GGML_ROPE_TYPE_VISION 12
#define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24
#define GGUF_MAGIC "GGUF"

View file

@ -9205,6 +9205,61 @@ static void ggml_rope_cache_init(
}
}
static void ggml_mrope_cache_init(
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
float theta_t = theta_base_t;
float theta_h = theta_base_h;
float theta_w = theta_base_w;
float theta_e = theta_base_e; // extra position id for vision encoder
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
GGML_ASSERT(sect_dims <= ne0);
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
int sector = (i0 / 2) % sect_dims;
if (indep_sects) {
if (sector == 0) {
theta_t = theta_base_t;
}
else if (sector == sections[0]) {
theta_h = theta_base_h;;
}
else if (sector == sections[1]) {
theta_w = theta_base_w;
}
else if (sector == sections[2]) {
theta_e = theta_base_e;
}
}
float theta = theta_t;
if (sector >= sections[0] && sector < sec_w) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
}
rope_yarn(
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
);
cache[i0 + 1] *= sin_sign;
theta_t *= theta_scale;
theta_w *= theta_scale;
theta_h *= theta_scale;
theta_e *= theta_scale;
}
}
static void ggml_compute_forward_rope_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,

View file

@ -3575,12 +3575,6 @@ struct ggml_tensor * ggml_mrope_ext(
GGML_ASSERT(c->ne[0] >= n_dims / 2);
}
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
@ -3595,7 +3589,6 @@ struct ggml_tensor * ggml_mrope_ext(
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;

View file

@ -750,7 +750,7 @@ class GGUFWriter:
def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)

View file

@ -2436,7 +2436,7 @@ struct llama_hparams {
float rope_freq_scale_train;
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul;
std::array<uint32_t, 4> rope_mrope_sections;
std::array<int, 4> rope_mrope_sections;
// for State Space Models
uint32_t ssm_d_conv = 0;
@ -12540,7 +12540,8 @@ struct llm_build_context {
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
int * sections = (int *)hparams.rope_mrope_sections.data();
int sections[4];
std::copy(hparams.rope_mrope_sections.begin(), hparams.rope_mrope_sections.end(), sections);
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;