minor updates on debug util, bug fixs
This commit is contained in:
parent
12f17f754d
commit
3ba7664de9
3 changed files with 114 additions and 70 deletions
|
@ -666,17 +666,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
ctx0, inp,
|
ctx0, inp,
|
||||||
hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
|
hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
|
||||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
|
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
|
||||||
// inp = ggml_reshape_2d(
|
|
||||||
// ctx0, inp,
|
|
||||||
// hidden_size * 4, (patches_w / 2) * batch_size * (patches_h / 2));
|
|
||||||
inp = ggml_reshape_3d(
|
inp = ggml_reshape_3d(
|
||||||
ctx0, inp,
|
ctx0, inp,
|
||||||
hidden_size, patches_w * patches_h, batch_size);
|
hidden_size, patches_w * patches_h, batch_size);
|
||||||
|
|
||||||
// ggml_build_forward_expand(gf, inp);
|
|
||||||
// ggml_free(ctx0);
|
|
||||||
|
|
||||||
// return gf;
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
|
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
|
||||||
|
@ -830,11 +822,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
embeddings = cur;
|
embeddings = cur;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_build_forward_expand(gf, embeddings);
|
|
||||||
// ggml_free(ctx0);
|
|
||||||
|
|
||||||
// return gf;
|
|
||||||
|
|
||||||
// post-layernorm
|
// post-layernorm
|
||||||
if (ctx->has_post_norm) {
|
if (ctx->has_post_norm) {
|
||||||
|
@ -1100,11 +1087,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
|
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||||
|
|
||||||
// // First LayerNorm
|
|
||||||
// embeddings = ggml_norm(ctx0, embeddings, eps);
|
|
||||||
// embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w),
|
|
||||||
// model.mm_1_b);
|
|
||||||
|
|
||||||
// GELU activation
|
// GELU activation
|
||||||
embeddings = ggml_gelu(ctx0, embeddings);
|
embeddings = ggml_gelu(ctx0, embeddings);
|
||||||
|
@ -1112,11 +1094,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
// Second linear layer
|
// Second linear layer
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||||
|
|
||||||
// // Second LayerNorm
|
|
||||||
// embeddings = ggml_norm(ctx0, embeddings, eps);
|
|
||||||
// embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w),
|
|
||||||
// model.mm_4_b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// build the graph
|
// build the graph
|
||||||
|
|
|
@ -8,6 +8,14 @@
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CUDA
|
||||||
|
#include "ggml-cuda.h"
|
||||||
|
#endif
|
||||||
|
#ifdef NDEBUG
|
||||||
|
#include "ggml-alloc.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
@ -352,72 +360,127 @@ static void llava_free(struct llava_context * ctx_llava) {
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
|
|
||||||
static void tmp_test_rope() {
|
static void debug_test_mrope_2d() {
|
||||||
|
// 1. Initialize backend
|
||||||
int n_threads = 1;
|
ggml_backend_t backend = NULL;
|
||||||
static size_t buf_size = 512u*1024*1024;
|
std::string backend_name = "";
|
||||||
static void * buf = malloc(buf_size);
|
#ifdef GGML_USE_CUDA
|
||||||
|
fprintf(stderr, "%s: using CUDA backend\n", __func__);
|
||||||
|
backend = ggml_backend_cuda_init(0); // init device 0
|
||||||
|
backend_name = "cuda";
|
||||||
|
if (!backend) {
|
||||||
|
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
// if there aren't GPU Backends fallback to CPU backend
|
||||||
|
if (!backend) {
|
||||||
|
backend = ggml_backend_cpu_init();
|
||||||
|
backend_name = "cpu";
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_init_params init_params = {
|
// Calculate the size needed to allocate
|
||||||
/*.mem_size =*/ buf_size,
|
size_t ctx_size = 0;
|
||||||
/*.mem_buffer =*/ buf,
|
ctx_size += 2 * ggml_tensor_overhead(); // tensors
|
||||||
/*.no_alloc =*/ false,
|
// no need to allocate anything else!
|
||||||
|
|
||||||
|
// 2. Allocate `ggml_context` to store tensor data
|
||||||
|
struct ggml_init_params params = {
|
||||||
|
/*.mem_size =*/ ctx_size,
|
||||||
|
/*.mem_buffer =*/ NULL,
|
||||||
|
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors()
|
||||||
};
|
};
|
||||||
|
struct ggml_context * ctx = ggml_init(params);
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(init_params);
|
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 128, 12, 30);
|
||||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
|
||||||
|
|
||||||
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 128, 12, 30);
|
|
||||||
ggml_set_name(inp_raw, "inp_raw");
|
ggml_set_name(inp_raw, "inp_raw");
|
||||||
ggml_set_input(inp_raw);
|
ggml_set_input(inp_raw);
|
||||||
|
|
||||||
|
struct ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 30 * 4);
|
||||||
|
ggml_set_name(pos, "pos");
|
||||||
|
ggml_set_input(pos);
|
||||||
|
|
||||||
std::vector<float> dummy_q;
|
std::vector<float> dummy_q;
|
||||||
dummy_q.resize(128 * 12 * 30);
|
dummy_q.resize(128 * 12 * 30);
|
||||||
std::fill(dummy_q.begin(), dummy_q.end(), 0.1);
|
std::fill(dummy_q.begin(), dummy_q.end(), 0.1);
|
||||||
memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw));
|
// memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw));
|
||||||
|
|
||||||
struct ggml_tensor * pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 30);
|
|
||||||
ggml_set_name(pos, "pos");
|
|
||||||
ggml_set_input(pos);
|
|
||||||
|
|
||||||
std::vector<int> pos_id;
|
std::vector<int> pos_id;
|
||||||
pos_id.resize(30);
|
pos_id.resize(30 * 4);
|
||||||
for (int i = 0; i < 30; i ++) pos_id[i] = i;
|
for (int i = 0; i < 30; i ++) {
|
||||||
memcpy(pos->data, pos_id.data(), (30) * ggml_element_size(pos));
|
pos_id[i] = i;
|
||||||
|
pos_id[i + 30] = i + 10;
|
||||||
auto encode = ggml_rope_ext(
|
pos_id[i + 60] = i + 20;
|
||||||
ctx0, inp_raw, pos, nullptr,
|
pos_id[i + 90] = i + 30;
|
||||||
128, LLAMA_ROPE_TYPE_NEOX, 32768, 1000000, 1,
|
}
|
||||||
0, 1, 32, 1);
|
int sections[4] = {32, 32, 0, 0};
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, encode);
|
// 4. Allocate a `ggml_backend_buffer` to store all tensors
|
||||||
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
|
ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
||||||
|
|
||||||
std::vector<float> embd;
|
// 5. Copy tensor data from main memory (RAM) to backend buffer
|
||||||
embd.resize(128 * 12 * 30);
|
ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw));
|
||||||
memcpy(
|
ggml_backend_tensor_set(pos, pos_id.data(), 0, ggml_nbytes(pos));
|
||||||
embd.data(),
|
|
||||||
(float *) ggml_get_data(encode),
|
|
||||||
sizeof(float) * 128 * 12 * 30);
|
|
||||||
ggml_free(ctx0);
|
|
||||||
|
|
||||||
|
// 6. Create a `ggml_cgraph` for mul_mat operation
|
||||||
|
struct ggml_cgraph * gf = NULL;
|
||||||
|
struct ggml_context * ctx_cgraph = NULL;
|
||||||
|
|
||||||
|
// create a temporally context to build the graph
|
||||||
|
struct ggml_init_params params0 = {
|
||||||
|
/*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
|
||||||
|
/*.mem_buffer =*/ NULL,
|
||||||
|
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
|
||||||
|
};
|
||||||
|
ctx_cgraph = ggml_init(params0);
|
||||||
|
gf = ggml_new_graph(ctx_cgraph);
|
||||||
|
|
||||||
// Open a binary file for writing
|
struct ggml_tensor * result0 = ggml_rope_multi(
|
||||||
std::ofstream outFile("rope.bin", std::ios::binary);
|
ctx_cgraph, inp_raw, pos, nullptr,
|
||||||
// Check if file is open
|
128/2, sections, LLAMA_ROPE_TYPE_VISION, 32768, 1000000, 1,
|
||||||
|
0, 1, 32, 1);
|
||||||
|
|
||||||
|
// Add "result" tensor and all of its dependencies to the cgraph
|
||||||
|
ggml_build_forward_expand(gf, result0);
|
||||||
|
|
||||||
|
// 7. Create a `ggml_gallocr` for cgraph computation
|
||||||
|
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
|
||||||
|
ggml_gallocr_alloc_graph(allocr, gf);
|
||||||
|
|
||||||
|
// 9. Run the computation
|
||||||
|
int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading
|
||||||
|
if (ggml_backend_is_cpu(backend)) {
|
||||||
|
ggml_backend_cpu_set_n_threads(backend, n_threads);
|
||||||
|
}
|
||||||
|
ggml_backend_graph_compute(backend, gf);
|
||||||
|
|
||||||
|
// 10. Retrieve results (output tensors)
|
||||||
|
// in this example, output tensor is always the last tensor in the graph
|
||||||
|
struct ggml_tensor * result = result0;
|
||||||
|
// struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1];
|
||||||
|
float * result_data = (float *)malloc(ggml_nbytes(result));
|
||||||
|
// because the tensor data is stored in device buffer, we need to copy it back to RAM
|
||||||
|
ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result));
|
||||||
|
const std::string bin_file = "mrope_2d_" + backend_name +".bin";
|
||||||
|
std::ofstream outFile(bin_file, std::ios::binary);
|
||||||
|
|
||||||
if (outFile.is_open()) {
|
if (outFile.is_open()) {
|
||||||
// Write the vector to the file
|
outFile.write(reinterpret_cast<const char*>(result_data), ggml_nbytes(result));
|
||||||
outFile.write(reinterpret_cast<const char*>(embd.data()), embd.size() * sizeof(int));
|
|
||||||
|
|
||||||
// Close the file
|
|
||||||
outFile.close();
|
outFile.close();
|
||||||
std::cout << "Data successfully written to output.bin" << std::endl;
|
std::cout << "Data successfully written to " + bin_file << std::endl;
|
||||||
} else {
|
} else {
|
||||||
std::cerr << "Error opening file!" << std::endl;
|
std::cerr << "Error opening file!" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
free(result_data);
|
||||||
|
// 11. Free memory and exit
|
||||||
|
ggml_free(ctx_cgraph);
|
||||||
|
ggml_gallocr_free(allocr);
|
||||||
|
ggml_free(ctx);
|
||||||
|
ggml_backend_buffer_free(buffer);
|
||||||
|
ggml_backend_free(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void tmp_dump_img_embed(struct llava_context * ctx_llava) {
|
static void debug_dump_img_embed(struct llava_context * ctx_llava) {
|
||||||
int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama));
|
int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama));
|
||||||
int ne = n_embd * 4;
|
int ne = n_embd * 4;
|
||||||
float vals[56 * 56 * 3];
|
float vals[56 * 56 * 3];
|
||||||
|
@ -485,7 +548,8 @@ int main(int argc, char ** argv) {
|
||||||
} else if (params.image[0].empty()) {
|
} else if (params.image[0].empty()) {
|
||||||
auto ctx_llava = llava_init_context(¶ms, model);
|
auto ctx_llava = llava_init_context(¶ms, model);
|
||||||
|
|
||||||
tmp_dump_img_embed(ctx_llava);
|
debug_test_mrope_2d();
|
||||||
|
debug_dump_img_embed(ctx_llava);
|
||||||
|
|
||||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||||
ctx_llava->model = NULL;
|
ctx_llava->model = NULL;
|
||||||
|
|
|
@ -9216,6 +9216,7 @@ static void ggml_mrope_cache_init(
|
||||||
float theta_e = theta_base_e; // extra position id for vision encoder
|
float theta_e = theta_base_e; // extra position id for vision encoder
|
||||||
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
||||||
int sec_w = sections[1] + sections[0];
|
int sec_w = sections[1] + sections[0];
|
||||||
|
int sec_e = sections[2] + sec_w;
|
||||||
GGML_ASSERT(sect_dims <= ne0);
|
GGML_ASSERT(sect_dims <= ne0);
|
||||||
|
|
||||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||||
|
@ -9223,16 +9224,18 @@ static void ggml_mrope_cache_init(
|
||||||
|
|
||||||
int sector = (i0 / 2) % sect_dims;
|
int sector = (i0 / 2) % sect_dims;
|
||||||
if (indep_sects) {
|
if (indep_sects) {
|
||||||
|
// compute theta independently for each dim sections
|
||||||
|
// (i.e. reset corresponding theta when `i0` go from one section to another)
|
||||||
if (sector == 0) {
|
if (sector == 0) {
|
||||||
theta_t = theta_base_t;
|
theta_t = theta_base_t;
|
||||||
}
|
}
|
||||||
else if (sector == sections[0]) {
|
else if (sector == sections[0]) {
|
||||||
theta_h = theta_base_h;;
|
theta_h = theta_base_h;;
|
||||||
}
|
}
|
||||||
else if (sector == sections[1]) {
|
else if (sector == sec_w) {
|
||||||
theta_w = theta_base_w;
|
theta_w = theta_base_w;
|
||||||
}
|
}
|
||||||
else if (sector == sections[2]) {
|
else if (sector == sec_e) {
|
||||||
theta_e = theta_base_e;
|
theta_e = theta_base_e;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue