mtl : simplify implementation
This commit is contained in:
parent
627605732c
commit
03c2d72867
1 changed files with 56 additions and 139 deletions
|
@ -257,7 +257,7 @@ id<MTLBuffer> llama_mtl_get_buffer(struct ggml_mtl_context * ctx, struct ggml_te
|
||||||
|
|
||||||
const bool is_data = (offs_eval < 0) || (offs_data >= 0 && offs_data < offs_eval);
|
const bool is_data = (offs_eval < 0) || (offs_data >= 0 && offs_data < offs_eval);
|
||||||
|
|
||||||
const size_t t_size = ggml_nbytes(t);
|
//const size_t t_size = ggml_nbytes(t);
|
||||||
const size_t t_offs = is_data ? offs_data : offs_eval;
|
const size_t t_offs = is_data ? offs_data : offs_eval;
|
||||||
|
|
||||||
id<MTLBuffer> result;
|
id<MTLBuffer> result;
|
||||||
|
@ -271,7 +271,7 @@ id<MTLBuffer> llama_mtl_get_buffer(struct ggml_mtl_context * ctx, struct ggml_te
|
||||||
}
|
}
|
||||||
|
|
||||||
if (result == nil) {
|
if (result == nil) {
|
||||||
//fprintf(stderr, "%s: error: buffer is nil\n", __func__);
|
fprintf(stderr, "%s: error: buffer is nil\n", __func__);
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -296,9 +296,9 @@ int llama_mtl_eval(
|
||||||
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBuffer];
|
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBuffer];
|
||||||
id<MTLComputeCommandEncoder> encoder = nil;
|
id<MTLComputeCommandEncoder> encoder = nil;
|
||||||
|
|
||||||
size_t offs_src0;
|
size_t offs_src0 = 0;
|
||||||
size_t offs_src1;
|
size_t offs_src1 = 0;
|
||||||
size_t offs_dst;
|
size_t offs_dst = 0;
|
||||||
|
|
||||||
// copy the input data to the GPU
|
// copy the input data to the GPU
|
||||||
{
|
{
|
||||||
|
@ -312,6 +312,48 @@ int llama_mtl_eval(
|
||||||
for (int i = 0; i < gf->n_nodes; ++i) {
|
for (int i = 0; i < gf->n_nodes; ++i) {
|
||||||
//fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
//fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||||
|
|
||||||
|
struct ggml_tensor * src0 = gf->nodes[i]->src0;
|
||||||
|
struct ggml_tensor * src1 = gf->nodes[i]->src1;
|
||||||
|
struct ggml_tensor * dst = gf->nodes[i];
|
||||||
|
|
||||||
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
||||||
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
||||||
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
||||||
|
const int64_t ne03 = src0 ? src0->ne[3] : 0;
|
||||||
|
|
||||||
|
const uint64_t nb00 = src0 ? src0->nb[0] : 0;
|
||||||
|
const uint64_t nb01 = src0 ? src0->nb[1] : 0;
|
||||||
|
const uint64_t nb02 = src0 ? src0->nb[2] : 0;
|
||||||
|
const uint64_t nb03 = src0 ? src0->nb[3] : 0;
|
||||||
|
|
||||||
|
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
||||||
|
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
||||||
|
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
||||||
|
//const int64_t ne13 = src1 ? src1->ne[3] : 0;
|
||||||
|
|
||||||
|
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
||||||
|
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
||||||
|
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
||||||
|
//const uint64_t nb13 = src1 ? src1->nb[3] : 0;
|
||||||
|
|
||||||
|
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
||||||
|
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
||||||
|
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
||||||
|
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
||||||
|
|
||||||
|
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
||||||
|
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
||||||
|
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
||||||
|
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
||||||
|
|
||||||
|
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
||||||
|
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
||||||
|
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
||||||
|
|
||||||
|
id<MTLBuffer> id_src0 = src0 ? llama_mtl_get_buffer(ctx, src0, &offs_src0) : nil;
|
||||||
|
id<MTLBuffer> id_src1 = src1 ? llama_mtl_get_buffer(ctx, src1, &offs_src1) : nil;
|
||||||
|
id<MTLBuffer> id_dst = dst ? llama_mtl_get_buffer(ctx, dst, &offs_dst) : nil;
|
||||||
|
|
||||||
switch (gf->nodes[i]->op) {
|
switch (gf->nodes[i]->op) {
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
case GGML_OP_VIEW:
|
case GGML_OP_VIEW:
|
||||||
|
@ -326,10 +368,6 @@ int llama_mtl_eval(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_add];
|
[encoder setComputePipelineState:ctx->pipeline_add];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
@ -345,14 +383,6 @@ int llama_mtl_eval(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
|
||||||
|
|
||||||
const int64_t ne10 = gf->nodes[i]->src1->ne[0];
|
|
||||||
|
|
||||||
if (ggml_nelements(gf->nodes[i]->src1) == ne10) {
|
if (ggml_nelements(gf->nodes[i]->src1) == ne10) {
|
||||||
// src1 is a row
|
// src1 is a row
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
||||||
|
@ -374,9 +404,6 @@ int llama_mtl_eval(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
const float scale = *(const float *) gf->nodes[i]->src1->data;
|
const float scale = *(const float *) gf->nodes[i]->src1->data;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_scale];
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
||||||
|
@ -394,9 +421,6 @@ int llama_mtl_eval(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_silu];
|
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
@ -411,11 +435,8 @@ int llama_mtl_eval(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_relu];
|
[encoder setComputePipelineState:ctx->pipeline_relu];
|
||||||
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(gf->nodes[i]);
|
const int64_t n = ggml_nelements(gf->nodes[i]);
|
||||||
|
@ -428,18 +449,10 @@ int llama_mtl_eval(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
|
||||||
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
|
|
||||||
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
|
|
||||||
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
|
|
||||||
|
|
||||||
const int nth = 32;
|
const int nth = 32;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||||
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
|
@ -454,15 +467,8 @@ int llama_mtl_eval(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
|
||||||
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
|
|
||||||
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
|
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
||||||
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
|
@ -472,38 +478,6 @@ int llama_mtl_eval(
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
|
||||||
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
|
|
||||||
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
|
|
||||||
|
|
||||||
const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
|
|
||||||
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
|
||||||
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
|
|
||||||
|
|
||||||
const int64_t ne10 = gf->nodes[i]->src1->ne[0];
|
|
||||||
const int64_t ne11 = gf->nodes[i]->src1->ne[1];
|
|
||||||
const int64_t ne12 = gf->nodes[i]->src1->ne[2];
|
|
||||||
|
|
||||||
const uint64_t nb10 = gf->nodes[i]->src1->nb[0];
|
|
||||||
const uint64_t nb11 = gf->nodes[i]->src1->nb[1];
|
|
||||||
const uint64_t nb12 = gf->nodes[i]->src1->nb[2];
|
|
||||||
|
|
||||||
const int64_t ne0 = gf->nodes[i]->ne[0];
|
|
||||||
const int64_t ne1 = gf->nodes[i]->ne[1];
|
|
||||||
const int64_t ne2 = gf->nodes[i]->ne[2];
|
|
||||||
|
|
||||||
const uint64_t nb0 = gf->nodes[i]->nb[0];
|
|
||||||
const uint64_t nb1 = gf->nodes[i]->nb[1];
|
|
||||||
const uint64_t nb2 = gf->nodes[i]->nb[2];
|
|
||||||
|
|
||||||
const enum ggml_type src0t = gf->nodes[i]->src0->type;
|
|
||||||
const enum ggml_type src1t = gf->nodes[i]->src1->type;
|
|
||||||
const enum ggml_type dstt = gf->nodes[i]->type;
|
|
||||||
|
|
||||||
//fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src0t), ne00, ne01, ne02, ggml_is_contiguous(gf->nodes[i]->src0));
|
//fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src0t), ne00, ne01, ne02, ggml_is_contiguous(gf->nodes[i]->src0));
|
||||||
//fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src1t), ne10, ne11, ne12, ggml_is_contiguous(gf->nodes[i]->src1));
|
//fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src1t), ne10, ne11, ne12, ggml_is_contiguous(gf->nodes[i]->src1));
|
||||||
//fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
|
//fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
|
||||||
|
@ -613,10 +587,6 @@ int llama_mtl_eval(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
switch (gf->nodes[i]->src0->type) {
|
switch (gf->nodes[i]->src0->type) {
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||||
default: {
|
default: {
|
||||||
|
@ -642,11 +612,6 @@ int llama_mtl_eval(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
|
||||||
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
|
||||||
const float eps = 1e-6f;
|
const float eps = 1e-6f;
|
||||||
|
|
||||||
const int nth = 256;
|
const int nth = 256;
|
||||||
|
@ -669,30 +634,6 @@ int llama_mtl_eval(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
|
||||||
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
|
|
||||||
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
|
|
||||||
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
|
|
||||||
|
|
||||||
const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
|
|
||||||
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
|
||||||
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
|
|
||||||
const uint64_t nb03 = gf->nodes[i]->src0->nb[3];
|
|
||||||
|
|
||||||
const int64_t ne0 = gf->nodes[i]->ne[0];
|
|
||||||
const int64_t ne1 = gf->nodes[i]->ne[1];
|
|
||||||
const int64_t ne2 = gf->nodes[i]->ne[2];
|
|
||||||
const int64_t ne3 = gf->nodes[i]->ne[3];
|
|
||||||
|
|
||||||
const uint64_t nb0 = gf->nodes[i]->nb[0];
|
|
||||||
const uint64_t nb1 = gf->nodes[i]->nb[1];
|
|
||||||
const uint64_t nb2 = gf->nodes[i]->nb[2];
|
|
||||||
const uint64_t nb3 = gf->nodes[i]->nb[3];
|
|
||||||
|
|
||||||
//const int n_past = ((int32_t *) gf->nodes[i]->src1->data)[0]; // TODO: TMP !!!!!
|
|
||||||
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
|
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
|
||||||
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];
|
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];
|
||||||
|
|
||||||
|
@ -731,32 +672,6 @@ int llama_mtl_eval(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
|
||||||
|
|
||||||
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
|
||||||
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
|
|
||||||
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
|
|
||||||
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
|
|
||||||
|
|
||||||
const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
|
|
||||||
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
|
||||||
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
|
|
||||||
const uint64_t nb03 = gf->nodes[i]->src0->nb[3];
|
|
||||||
|
|
||||||
const int64_t ne0 = gf->nodes[i]->ne[0];
|
|
||||||
const int64_t ne1 = gf->nodes[i]->ne[1];
|
|
||||||
const int64_t ne2 = gf->nodes[i]->ne[2];
|
|
||||||
const int64_t ne3 = gf->nodes[i]->ne[3];
|
|
||||||
|
|
||||||
const uint64_t nb0 = gf->nodes[i]->nb[0];
|
|
||||||
const uint64_t nb1 = gf->nodes[i]->nb[1];
|
|
||||||
const uint64_t nb2 = gf->nodes[i]->nb[2];
|
|
||||||
const uint64_t nb3 = gf->nodes[i]->nb[3];
|
|
||||||
|
|
||||||
const enum ggml_type src0t = gf->nodes[i]->src0->type;
|
|
||||||
const enum ggml_type dstt = gf->nodes[i]->type;
|
|
||||||
|
|
||||||
const int nth = 32;
|
const int nth = 32;
|
||||||
|
|
||||||
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
|
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
|
||||||
|
@ -835,6 +750,7 @@ int llama_mtl_eval(
|
||||||
// TODO
|
// TODO
|
||||||
const float * logits = ctx->out.contents;
|
const float * logits = ctx->out.contents;
|
||||||
|
|
||||||
|
#if 1
|
||||||
printf("logits: ");
|
printf("logits: ");
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
printf("%8.4f ", logits[i]);
|
printf("%8.4f ", logits[i]);
|
||||||
|
@ -851,6 +767,7 @@ int llama_mtl_eval(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
|
printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
|
||||||
|
#endif
|
||||||
|
|
||||||
//{
|
//{
|
||||||
// struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check");
|
// struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check");
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue