mtl : confirmed get_rows_q4_0 is working correctly

This commit is contained in:
Georgi Gerganov 2023-05-30 18:41:21 +03:00
parent a8fd9dc128
commit 794704e409
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 60 additions and 4 deletions

View file

@ -5,6 +5,8 @@
#include <cstring>
#include <cstdlib>
#include <vector> // tmp
int main(int argc, char ** argv) {
ggml_time_init();
@ -37,6 +39,17 @@ int main(int argc, char ** argv) {
// this allocates all Metal resources and memory buffers
auto * ctx_mtl = llama_mtl_init(ctx_data, ctx_eval, ctx_work, &gf);
// TODO: tmp to match the input used when creating the cgraph
{
const int n_ctx = 128;
const int n_batch = 32;
const std::vector<int> tmp(n_batch, 1); // BOS
struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "embd");
memcpy(input->data, tmp.data(), tmp.size() * sizeof(int));
}
// the actual inference happens here
llama_mtl_eval(ctx_mtl, &gf);

View file

@ -357,6 +357,8 @@ int llama_mtl_eval(
// extract results from the GPU
{
fprintf(stderr, "%s: extract results from the GPU\n", __func__);
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
@ -367,6 +369,8 @@ int llama_mtl_eval(
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, out, &offs_src0);
id<MTLBuffer> id_dst = ctx->out;
printf("XXXXX n = %d\n", ggml_nelements(out));
id<MTLBlitCommandEncoder> encoder_blit = [command_buffer blitCommandEncoder];
[encoder_blit copyFromBuffer:id_src sourceOffset:offs_src0 toBuffer:id_dst destinationOffset:0 size:ggml_nbytes(out)];
[encoder_blit endEncoding];
@ -383,5 +387,24 @@ int llama_mtl_eval(
// TODO
const float * logits = ctx->out.contents;
{
struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check");
float * data = (float *) ctx->out.contents;
printf("data: ");
int n = t->ne[0];
if (n > 10) {
n = 10;
}
for (int i = 0; i < n; i++) {
printf("%f ", data[i]);
}
printf("\n");
double sum = 0.0;
for (int i = 0; i < ggml_nelements(t); i++) {
sum += data[i];
}
printf("sum: %f\n", sum);
}
return 0;
}

View file

@ -67,7 +67,6 @@ kernel void kernel_soft_max(
}
}
// TODO: not tested
kernel void kernel_get_rows_q4_0(
device const void * src0,
device const int * src1,

View file

@ -1252,6 +1252,7 @@ static bool llama_eval_internal(
memcpy(embd->data, tokens, N*ggml_element_size(embd));
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
ggml_set_name(inpL, "mtl-check");
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
@ -1269,9 +1270,9 @@ static bool llama_eval_internal(
}
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(cur, "mtl-check");
}
//if (il == 0) {
// ggml_set_name(cur, "mtl-check");
//}
// self-attention
{
@ -1437,6 +1438,26 @@ static bool llama_eval_internal(
// lets export a smaller graph to get things rolling -- baby steps first
ggml_build_forward_expand(&gf_export, ggml_get_tensor(ctx0, "mtl-check"));
// print
{
auto print_t = [&](struct ggml_tensor * t) {
float * data = (float *)t->data;
printf("data: ");
for (int i = 0; i < std::min((int) t->ne[0], 10); i++) {
printf("%f ", data[i]);
}
printf("\n");
double sum = 0.0;
for (int i = 0; i < ggml_nelements(t); i++) {
sum += data[i];
}
printf("sum: %f\n", sum);
};
ggml_graph_compute(ctx0, &gf_export);
print_t(ggml_get_tensor(ctx0, "mtl-check"));
}
if (cgraph_fname) {
//ggml_graph_export(&gf, cgraph_fname);
ggml_graph_export(&gf_export, cgraph_fname);