resolve linter, test errors

This commit is contained in:
HimariO 2024-11-29 22:18:15 +08:00
parent fac034530f
commit cbd08b4204
11 changed files with 168 additions and 147 deletions

View file

@ -29,6 +29,7 @@
{ "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
{ "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } },
{ "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } },
{ "name": "vulkan", "hidden": true, "cacheVariables": { "GGML_VULKAN": "ON" } },
{ {
"name": "arm64-windows-msvc", "hidden": true, "name": "arm64-windows-msvc", "hidden": true,
@ -57,44 +58,28 @@
} }
}, },
{ "name": "arm64-windows-llvm-debug" , "inherits": [ "base", "arm64-windows-llvm", "debug" ] }, { "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] },
{ "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] }, { "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] },
{ "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] }, { "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] },
{ "name": "arm64-apple-clang-debug" , "inherits": [ "base", "arm64-apple-clang", "debug" ] }, { "name": "arm64-apple-clang-debug", "inherits": [ "base", "arm64-apple-clang", "debug" ] },
{ "name": "arm64-apple-clang-release" , "inherits": [ "base", "arm64-apple-clang", "reldbg" ] }, { "name": "arm64-apple-clang-release", "inherits": [ "base", "arm64-apple-clang", "reldbg" ] },
{ "name": "arm64-apple-clang+static-release" , "inherits": [ "base", "arm64-apple-clang", "reldbg", "static" ] }, { "name": "arm64-apple-clang+static-release", "inherits": [ "base", "arm64-apple-clang", "reldbg", "static" ] },
{ "name": "arm64-windows-msvc-debug" , "inherits": [ "base", "arm64-windows-msvc", "debug" ] }, { "name": "arm64-windows-msvc-debug", "inherits": [ "base", "arm64-windows-msvc", "debug" ] },
{ "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg" ] }, { "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg" ] },
{ "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg", "static" ] }, { "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg", "static" ] },
{ "name": "x64-windows-msvc-debug" , "inherits": [ "base", "debug" ] }, { "name": "x64-windows-msvc-debug", "inherits": [ "base", "debug" ] },
{ "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] }, { "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] },
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
{ "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] }, { "name": "x64-windows-sycl-debug", "inherits": [ "sycl-base", "debug" ] },
{ "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] }, { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] },
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }, { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] },
{ "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }, { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] },
{
"name": "x86-cuda-linux", { "name": "x64-windows-vulkan-debug", "inherits": [ "base", "vulkan", "debug" ] },
"description": "", { "name": "x64-windows-vulkan-release", "inherits": [ "base", "vulkan", "release" ] }
"displayName": "",
"inherits": [
"base",
"debug"
],
"cacheVariables": {
"GGML_CUDA": "1",
"CUDA_PATH": "/usr/local/cuda",
"CUDAToolkit_ROOT": "/usr/local/cuda",
"CUDAToolkit_INCLUDE_DIR": "/usr/local/cuda/include/",
"CUDAToolkit_LIBRARY_DIR": "/usr/local/cuda/lib64",
"CUDA_NVCC_FLAGS": "-g -G",
"CMAKE_CUDA_FLAGS_DEBUG": "-g -G",
"CMAKE_CUDA_FLAGS": "-maxrregcount=40"
}
}
] ]
} }

View file

@ -2590,12 +2590,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
int* positions_data = (int*)malloc(ggml_nbytes(positions)); int* positions_data = (int*)malloc(ggml_nbytes(positions));
int ptr = 0; int ptr = 0;
for (size_t y = 0; y < ph; y+=2) for (int y = 0; y < ph; y+=2)
{ {
for (size_t x = 0; x < pw; x+=2) for (int x = 0; x < pw; x+=2)
{ {
for (size_t dy = 0; dy < 2; dy++) { for (int dy = 0; dy < 2; dy++) {
for (size_t dx = 0; dx < 2; dx++) { for (int dx = 0; dx < 2; dx++) {
positions_data[ptr] = y + dy; positions_data[ptr] = y + dy;
positions_data[num_patches + ptr] = x + dx; positions_data[num_patches + ptr] = x + dx;
positions_data[num_patches * 2 + ptr] = y + dy; positions_data[num_patches * 2 + ptr] = y + dy;
@ -2820,20 +2820,15 @@ bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
} }
bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img; clip_image_f32 clip_img;
clip_img.buf.resize(h * w * 3); clip_img.buf.resize(h * w * 3);
for (size_t i = 0; i < h*w*3; i++) for (int i = 0; i < h*w*3; i++)
{ {
clip_img.buf[i] = img[i]; clip_img.buf[i] = img[i];
} }
clip_img.nx = w; clip_img.nx = w;
clip_img.ny = h; clip_img.ny = h;
// ctx->vision_model.hparams.image_size = h;
clip_image_encode(ctx, n_threads, &clip_img, vec); clip_image_encode(ctx, n_threads, &clip_img, vec);
return true; return true;
} }
void tmp_clip_set_layers (struct clip_ctx * ctx, int layers) {
ctx->vision_model.hparams.n_layer = layers;
}

View file

@ -91,8 +91,7 @@ CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
CLIP_API bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
CLIP_API void tmp_clip_set_layers (struct clip_ctx * ctx, int layers);
#ifdef __cplusplus #ifdef __cplusplus
} }

View file

@ -1,9 +1,8 @@
import argparse import argparse
import glob from typing import Dict
import os
from typing import Any, Dict
import torch import torch
import numpy as np
from gguf import * from gguf import *
from transformers import ( from transformers import (
Qwen2VLForConditionalGeneration, Qwen2VLForConditionalGeneration,
@ -44,7 +43,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
else: # bias else: # bias
c3 = ten.shape[0] c3 = ten.shape[0]
assert c3 % 3 == 0 assert c3 % 3 == 0
c = c3//3 c = c3 // 3
wq = ten[:c] wq = ten[:c]
wk = ten[c: c * 2] wk = ten[c: c * 2]
wv = ten[c * 2:] wv = ten[c * 2:]
@ -95,10 +94,8 @@ def main(args):
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=dtype, device_map="cpu" model_name, torch_dtype=dtype, device_map="cpu"
) )
cfg: Qwen2VLConfig = qwen2vl.config cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
vcfg = cfg.vision_config vcfg = cfg.vision_config
rope_cfg = cfg.rope_scaling
fname_out = "qwen2vl-vision.gguf" fname_out = "qwen2vl-vision.gguf"
fout = GGUFWriter(path=fname_out, arch="clip") fout = GGUFWriter(path=fname_out, arch="clip")
@ -125,13 +122,13 @@ def main(args):
fout.add_tensor(name, data) fout.add_tensor(name, data)
fout.add_uint32("clip.vision.patch_size", vcfg.patch_size) fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
fout.add_uint32("clip.vision.image_size", 14*40) # some reasonable size that is divable by (14*2) fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2)
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size) fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth) fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # BUG: not sure what this does fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # not sure what this does, put 0 here as a placeholder
fout.add_name(model_name) fout.add_name(model_name)
""" """
HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig, HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig,
@ -139,9 +136,8 @@ def main(args):
""" """
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name) processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name)
# breakpoint() fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue]
fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue]
fout.add_array("clip.vision.image_std", processor.image_processor.image_std)
fout.write_header_to_file() fout.write_header_to_file()
fout.write_kv_data_to_file() fout.write_kv_data_to_file()

View file

@ -26,9 +26,9 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
auto img_tokens = image_embed->n_image_pos; auto img_tokens = image_embed->n_image_pos;
llama_pos mrope_pos[img_tokens * 4]; llama_pos mrope_pos[img_tokens * 4];
for (size_t y = 0; y < ph; y++) for (int y = 0; y < ph; y++)
{ {
for (size_t x = 0; x < pw; x++) for (int x = 0; x < pw; x++)
{ {
int i = y * pw + x; int i = y * pw + x;
mrope_pos[i] = *st_pos_id; mrope_pos[i] = *st_pos_id;
@ -270,7 +270,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
LOG("\n"); LOG("\n");
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams); struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
if (!smpl) { if (!smpl) {
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__); LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
exit(1); exit(1);
@ -422,10 +422,7 @@ static void tmp_dump_img_embed(struct llava_context * ctx_llava, common_params *
int ne = n_embd * 4; int ne = n_embd * 4;
float vals[56 * 56 * 3]; float vals[56 * 56 * 3];
float embd[ne]; float embd[ne];
// for (int i = 0; i < 3*56*56; i++)
// {
// vals[i] = 0.1;
// }
for (int i = 0; i < 56*56; i++) for (int i = 0; i < 56*56; i++)
{ {
for (int c = 0; c < 3; c++) for (int c = 0; c < 3; c++)
@ -433,7 +430,7 @@ static void tmp_dump_img_embed(struct llava_context * ctx_llava, common_params *
} }
// auto param = &ctx_llava->ctx_clip->vision_model.hparams; // auto param = &ctx_llava->ctx_clip->vision_model.hparams;
tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd); clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd);
std::ofstream outFile("img_embed.bin", std::ios::binary); std::ofstream outFile("img_embed.bin", std::ios::binary);
if (outFile.is_open()) { if (outFile.is_open()) {

View file

@ -238,8 +238,8 @@
#define GGML_EXIT_ABORTED 1 #define GGML_EXIT_ABORTED 1
#define GGML_ROPE_TYPE_NEOX 2 #define GGML_ROPE_TYPE_NEOX 2
#define GGML_ROPE_TYPE_MROPE 4 #define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 12 #define GGML_ROPE_TYPE_VISION 24
#define GGUF_MAGIC "GGUF" #define GGUF_MAGIC "GGUF"

View file

@ -9205,6 +9205,61 @@ static void ggml_rope_cache_init(
} }
} }
static void ggml_mrope_cache_init(
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
float theta_t = theta_base_t;
float theta_h = theta_base_h;
float theta_w = theta_base_w;
float theta_e = theta_base_e; // extra position id for vision encoder
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
GGML_ASSERT(sect_dims <= ne0);
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
int sector = (i0 / 2) % sect_dims;
if (indep_sects) {
if (sector == 0) {
theta_t = theta_base_t;
}
else if (sector == sections[0]) {
theta_h = theta_base_h;;
}
else if (sector == sections[1]) {
theta_w = theta_base_w;
}
else if (sector == sections[2]) {
theta_e = theta_base_e;
}
}
float theta = theta_t;
if (sector >= sections[0] && sector < sec_w) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
}
rope_yarn(
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
);
cache[i0 + 1] *= sin_sign;
theta_t *= theta_scale;
theta_w *= theta_scale;
theta_h *= theta_scale;
theta_e *= theta_scale;
}
}
static void ggml_compute_forward_rope_f32( static void ggml_compute_forward_rope_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst, struct ggml_tensor * dst,

View file

@ -3575,12 +3575,6 @@ struct ggml_tensor * ggml_mrope_ext(
GGML_ASSERT(c->ne[0] >= n_dims / 2); GGML_ASSERT(c->ne[0] >= n_dims / 2);
} }
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_tensor * result = ggml_dup_tensor(ctx, a); struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
@ -3595,7 +3589,6 @@ struct ggml_tensor * ggml_mrope_ext(
ggml_set_op_params(result, params, sizeof(params)); ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE; result->op = GGML_OP_ROPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a; result->src[0] = a;
result->src[1] = b; result->src[1] = b;
result->src[2] = c; result->src[2] = c;

View file

@ -2436,7 +2436,7 @@ struct llama_hparams {
float rope_freq_scale_train; float rope_freq_scale_train;
uint32_t n_ctx_orig_yarn; uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul; float rope_yarn_log_mul;
std::array<uint32_t, 4> rope_mrope_sections; std::array<int, 4> rope_mrope_sections;
// for State Space Models // for State Space Models
uint32_t ssm_d_conv = 0; uint32_t ssm_d_conv = 0;
@ -12540,7 +12540,8 @@ struct llm_build_context {
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
int * sections = (int *)hparams.rope_mrope_sections.data(); int sections[4];
std::copy(hparams.rope_mrope_sections.begin(), hparams.rope_mrope_sections.end(), sections);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL; struct ggml_tensor * inpSA = inpL;