mtl : confirmed get_rows_q4_0 is working correctly
This commit is contained in:
parent
a8fd9dc128
commit
794704e409
4 changed files with 60 additions and 4 deletions
|
@ -5,6 +5,8 @@
|
|||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector> // tmp
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
|
@ -37,6 +39,17 @@ int main(int argc, char ** argv) {
|
|||
// this allocates all Metal resources and memory buffers
|
||||
auto * ctx_mtl = llama_mtl_init(ctx_data, ctx_eval, ctx_work, &gf);
|
||||
|
||||
// TODO: tmp to match the input used when creating the cgraph
|
||||
{
|
||||
const int n_ctx = 128;
|
||||
const int n_batch = 32;
|
||||
|
||||
const std::vector<int> tmp(n_batch, 1); // BOS
|
||||
|
||||
struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "embd");
|
||||
memcpy(input->data, tmp.data(), tmp.size() * sizeof(int));
|
||||
}
|
||||
|
||||
// the actual inference happens here
|
||||
llama_mtl_eval(ctx_mtl, &gf);
|
||||
|
||||
|
|
|
@ -357,6 +357,8 @@ int llama_mtl_eval(
|
|||
|
||||
// extract results from the GPU
|
||||
{
|
||||
fprintf(stderr, "%s: extract results from the GPU\n", __func__);
|
||||
|
||||
if (encoder != nil) {
|
||||
[encoder endEncoding];
|
||||
encoder = nil;
|
||||
|
@ -367,6 +369,8 @@ int llama_mtl_eval(
|
|||
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, out, &offs_src0);
|
||||
id<MTLBuffer> id_dst = ctx->out;
|
||||
|
||||
printf("XXXXX n = %d\n", ggml_nelements(out));
|
||||
|
||||
id<MTLBlitCommandEncoder> encoder_blit = [command_buffer blitCommandEncoder];
|
||||
[encoder_blit copyFromBuffer:id_src sourceOffset:offs_src0 toBuffer:id_dst destinationOffset:0 size:ggml_nbytes(out)];
|
||||
[encoder_blit endEncoding];
|
||||
|
@ -383,5 +387,24 @@ int llama_mtl_eval(
|
|||
// TODO
|
||||
const float * logits = ctx->out.contents;
|
||||
|
||||
{
|
||||
struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check");
|
||||
float * data = (float *) ctx->out.contents;
|
||||
printf("data: ");
|
||||
int n = t->ne[0];
|
||||
if (n > 10) {
|
||||
n = 10;
|
||||
}
|
||||
for (int i = 0; i < n; i++) {
|
||||
printf("%f ", data[i]);
|
||||
}
|
||||
printf("\n");
|
||||
double sum = 0.0;
|
||||
for (int i = 0; i < ggml_nelements(t); i++) {
|
||||
sum += data[i];
|
||||
}
|
||||
printf("sum: %f\n", sum);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -67,7 +67,6 @@ kernel void kernel_soft_max(
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: not tested
|
||||
kernel void kernel_get_rows_q4_0(
|
||||
device const void * src0,
|
||||
device const int * src1,
|
||||
|
|
27
llama.cpp
27
llama.cpp
|
@ -1252,6 +1252,7 @@ static bool llama_eval_internal(
|
|||
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
||||
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
|
||||
ggml_set_name(inpL, "mtl-check");
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
@ -1269,9 +1270,9 @@ static bool llama_eval_internal(
|
|||
}
|
||||
|
||||
// TODO: TMP !!!!
|
||||
if (il == 0) {
|
||||
ggml_set_name(cur, "mtl-check");
|
||||
}
|
||||
//if (il == 0) {
|
||||
// ggml_set_name(cur, "mtl-check");
|
||||
//}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
|
@ -1437,6 +1438,26 @@ static bool llama_eval_internal(
|
|||
// lets export a smaller graph to get things rolling -- baby steps first
|
||||
ggml_build_forward_expand(&gf_export, ggml_get_tensor(ctx0, "mtl-check"));
|
||||
|
||||
// print
|
||||
{
|
||||
auto print_t = [&](struct ggml_tensor * t) {
|
||||
float * data = (float *)t->data;
|
||||
printf("data: ");
|
||||
for (int i = 0; i < std::min((int) t->ne[0], 10); i++) {
|
||||
printf("%f ", data[i]);
|
||||
}
|
||||
printf("\n");
|
||||
double sum = 0.0;
|
||||
for (int i = 0; i < ggml_nelements(t); i++) {
|
||||
sum += data[i];
|
||||
}
|
||||
printf("sum: %f\n", sum);
|
||||
};
|
||||
|
||||
ggml_graph_compute(ctx0, &gf_export);
|
||||
print_t(ggml_get_tensor(ctx0, "mtl-check"));
|
||||
}
|
||||
|
||||
if (cgraph_fname) {
|
||||
//ggml_graph_export(&gf, cgraph_fname);
|
||||
ggml_graph_export(&gf_export, cgraph_fname);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue