mtl : optimize rms_norm and soft_max kernels
This commit is contained in:
parent
9665429e94
commit
f0196a7e7a
5 changed files with 166 additions and 47 deletions
|
@ -41,13 +41,15 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// TODO: tmp to match the input used when creating the cgraph
|
// TODO: tmp to match the input used when creating the cgraph
|
||||||
{
|
{
|
||||||
const int n_past = 128;
|
const int n_batch = 1;
|
||||||
const int n_batch = 32;
|
const int n_past = 512 - n_batch;
|
||||||
|
|
||||||
const std::vector<int> tmp(n_batch, 1); // BOS
|
const std::vector<int> tmp(n_batch, 1); // BOS
|
||||||
|
|
||||||
// the actual inference happens here
|
// the actual inference happens here
|
||||||
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
|
for (int i = 0; i < 10; ++i) {
|
||||||
|
llama_mtl_eval(ctx_mtl, &gf, tmp.data(), tmp.size(), n_past);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_mtl_free(ctx_mtl);
|
llama_mtl_free(ctx_mtl);
|
||||||
|
|
|
@ -429,14 +429,17 @@ int llama_mtl_eval(
|
||||||
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
|
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
|
||||||
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
|
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
|
||||||
|
|
||||||
|
const int nth = 32;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||||
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||||
|
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
{
|
{
|
||||||
|
@ -494,10 +497,10 @@ int llama_mtl_eval(
|
||||||
const enum ggml_type src1t = gf->nodes[i]->src1->type;
|
const enum ggml_type src1t = gf->nodes[i]->src1->type;
|
||||||
const enum ggml_type dstt = gf->nodes[i]->type;
|
const enum ggml_type dstt = gf->nodes[i]->type;
|
||||||
|
|
||||||
printf("mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02);
|
fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02);
|
||||||
printf("mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12);
|
fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12);
|
||||||
printf("mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
|
fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
|
||||||
printf("mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt));
|
fprintf(stderr, "mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt));
|
||||||
|
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
GGML_ASSERT(ne02 == ne12);
|
GGML_ASSERT(ne02 == ne12);
|
||||||
|
@ -599,16 +602,19 @@ int llama_mtl_eval(
|
||||||
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
|
||||||
const float eps = 1e-6f;
|
const float eps = 1e-6f;
|
||||||
|
|
||||||
|
const int nth = 32;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||||
|
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
const int64_t nrows = ggml_nrows(gf->nodes[i]->src0);
|
const int64_t nrows = ggml_nrows(gf->nodes[i]->src0);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
{
|
{
|
||||||
|
@ -643,9 +649,9 @@ int llama_mtl_eval(
|
||||||
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
|
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
|
||||||
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];
|
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];
|
||||||
|
|
||||||
printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
|
fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
|
||||||
printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
|
fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
|
||||||
printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
|
fprintf(stderr, "rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_rope];
|
[encoder setComputePipelineState:ctx->pipeline_rope];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
@ -704,11 +710,13 @@ int llama_mtl_eval(
|
||||||
const enum ggml_type src0t = gf->nodes[i]->src0->type;
|
const enum ggml_type src0t = gf->nodes[i]->src0->type;
|
||||||
const enum ggml_type dstt = gf->nodes[i]->type;
|
const enum ggml_type dstt = gf->nodes[i]->type;
|
||||||
|
|
||||||
printf("cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
|
const int nth = 32;
|
||||||
printf("cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
|
|
||||||
printf("cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
|
fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
|
||||||
printf("cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3);
|
fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
|
||||||
printf("cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt));
|
fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
|
||||||
|
fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3);
|
||||||
|
fprintf(stderr, "cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt));
|
||||||
|
|
||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
|
@ -741,7 +749,7 @@ int llama_mtl_eval(
|
||||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||||
|
@ -764,8 +772,6 @@ int llama_mtl_eval(
|
||||||
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, out, &offs_src0);
|
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, out, &offs_src0);
|
||||||
id<MTLBuffer> id_dst = ctx->out;
|
id<MTLBuffer> id_dst = ctx->out;
|
||||||
|
|
||||||
printf("XXXXX n = %d\n", ggml_nelements(out));
|
|
||||||
|
|
||||||
id<MTLBlitCommandEncoder> encoder_blit = [command_buffer blitCommandEncoder];
|
id<MTLBlitCommandEncoder> encoder_blit = [command_buffer blitCommandEncoder];
|
||||||
[encoder_blit copyFromBuffer:id_src sourceOffset:offs_src0 toBuffer:id_dst destinationOffset:0 size:ggml_nbytes(out)];
|
[encoder_blit copyFromBuffer:id_src sourceOffset:offs_src0 toBuffer:id_dst destinationOffset:0 size:ggml_nbytes(out)];
|
||||||
[encoder_blit endEncoding];
|
[encoder_blit endEncoding];
|
||||||
|
@ -776,12 +782,29 @@ int llama_mtl_eval(
|
||||||
|
|
||||||
{
|
{
|
||||||
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
|
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
|
||||||
fprintf(stderr, "%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
|
printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
const float * logits = ctx->out.contents;
|
const float * logits = ctx->out.contents;
|
||||||
|
|
||||||
|
printf("logits: ");
|
||||||
|
for (int i = 0; i < 100; i++) {
|
||||||
|
printf("%8.4f ", logits[i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
double sum = 0.0;
|
||||||
|
int imax = 0;
|
||||||
|
double vmax = -INFINITY;
|
||||||
|
for (int i = 0; i < 32000; i++) {
|
||||||
|
sum += (double) logits[i];
|
||||||
|
if (logits[i] > vmax) {
|
||||||
|
vmax = logits[i];
|
||||||
|
imax = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
|
||||||
|
|
||||||
//{
|
//{
|
||||||
// struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check");
|
// struct ggml_tensor * t = ggml_get_tensor(ctx->ctx_eval, "mtl-check");
|
||||||
// if (t->type == GGML_TYPE_F32) {
|
// if (t->type == GGML_TYPE_F32) {
|
||||||
|
|
|
@ -87,25 +87,80 @@ kernel void kernel_soft_max(
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
uint3 tpig[[thread_position_in_grid]]) {
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
const int64_t i03 = tpig[2];
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
const int64_t i02 = tpig[1];
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
const int64_t i01 = tpig[0];
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = tgpig[2];
|
||||||
|
const int64_t i02 = tgpig[1];
|
||||||
|
const int64_t i01 = tgpig[0];
|
||||||
|
|
||||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
float max = 0.0f;
|
//float max = 0.0f;
|
||||||
for (int i = 0; i < ne00; i++) {
|
//for (int i = 0; i < ne00; i++) {
|
||||||
max = MAX(max, psrc0[i]);
|
// max = MAX(max, psrc0[i]);
|
||||||
|
//}
|
||||||
|
//float sum = 0.0f;
|
||||||
|
//for (int i = 0; i < ne00; i++) {
|
||||||
|
// pdst[i] = exp(psrc0[i] - max);
|
||||||
|
// sum += pdst[i];
|
||||||
|
//}
|
||||||
|
//for (int i = 0; i < ne00; i++) {
|
||||||
|
// pdst[i] /= sum;
|
||||||
|
//}
|
||||||
|
|
||||||
|
// parallel max
|
||||||
|
buf[tpitg[0]] = -INFINITY;
|
||||||
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
|
buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
|
||||||
}
|
}
|
||||||
float sum = 0.0f;
|
|
||||||
for (int i = 0; i < ne00; i++) {
|
// reduce
|
||||||
pdst[i] = exp(psrc0[i] - max);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
sum += pdst[i];
|
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
||||||
|
if (tpitg[0] < i) {
|
||||||
|
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
for (int i = 0; i < ne00; i++) {
|
|
||||||
pdst[i] /= sum;
|
// broadcast
|
||||||
|
if (tpitg[0] == 0) {
|
||||||
|
buf[0] = buf[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
const float max = buf[0];
|
||||||
|
|
||||||
|
// parallel sum
|
||||||
|
buf[tpitg[0]] = 0.0f;
|
||||||
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
|
buf[tpitg[0]] += exp(psrc0[i00] - max);
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
||||||
|
if (tpitg[0] < i) {
|
||||||
|
buf[tpitg[0]] += buf[tpitg[0] + i];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
// broadcast
|
||||||
|
if (tpitg[0] == 0) {
|
||||||
|
buf[0] = buf[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
const float sum = buf[0];
|
||||||
|
|
||||||
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
|
pdst[i00] = exp(psrc0[i00] - max) / sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,19 +204,39 @@ kernel void kernel_rms_norm(
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant float & eps,
|
constant float & eps,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
threadgroup float * sum [[threadgroup(0)]],
|
||||||
device const float * x = (device const float *) ((device const char *) src0 + tpig*nb01);
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
|
||||||
|
|
||||||
float sum = 0.0f;
|
// parallel sum
|
||||||
for (int i00 = 0; i00 < ne00; i00++) {
|
sum[tpitg] = 0.0f;
|
||||||
sum += x[i00] * x[i00];
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
|
sum[tpitg] += x[i00] * x[i00];
|
||||||
}
|
}
|
||||||
|
|
||||||
const float mean = sum/ne00;
|
// reduce
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (uint i = ntg/2; i > 0; i /= 2) {
|
||||||
|
if (tpitg < i) {
|
||||||
|
sum[tpitg] += sum[tpitg + i];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
// broadcast
|
||||||
|
if (tpitg == 0) {
|
||||||
|
sum[0] /= ne00;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
const float mean = sum[0];
|
||||||
const float scale = 1.0f/sqrt(mean + eps);
|
const float scale = 1.0f/sqrt(mean + eps);
|
||||||
|
|
||||||
device float * y = dst + tpig*ne00;
|
device float * y = dst + tgpig*ne00;
|
||||||
for (int i00 = 0; i00 < ne00; i00++) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
y[i00] = x[i00] * scale;
|
y[i00] = x[i00] * scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
4
ggml.c
4
ggml.c
|
@ -14647,8 +14647,8 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
|
void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
|
||||||
assert(cgraph->work == NULL);
|
//assert(cgraph->work == NULL);
|
||||||
assert(cgraph->work_size == 0);
|
//assert(cgraph->work_size == 0);
|
||||||
|
|
||||||
uint64_t size_eval = 0;
|
uint64_t size_eval = 0;
|
||||||
|
|
||||||
|
|
25
llama.cpp
25
llama.cpp
|
@ -1506,6 +1506,25 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
if (cgraph_fname) {
|
if (cgraph_fname) {
|
||||||
ggml_graph_export(&gf, cgraph_fname);
|
ggml_graph_export(&gf, cgraph_fname);
|
||||||
|
|
||||||
|
float * logits = (float *) ggml_get_data(inpL);
|
||||||
|
|
||||||
|
printf("logits: ");
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
printf("%8.4f ", logits[i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
double sum = 0.0;
|
||||||
|
int imax = 0;
|
||||||
|
double vmax = -INFINITY;
|
||||||
|
for (int i = 0; i < 32000; i++) {
|
||||||
|
sum += (double) logits[i];
|
||||||
|
if (logits[i] > vmax) {
|
||||||
|
vmax = logits[i];
|
||||||
|
imax = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_PERF
|
#ifdef GGML_PERF
|
||||||
|
@ -3002,11 +3021,11 @@ int llama_eval(
|
||||||
|
|
||||||
int llama_eval_export(struct llama_context * ctx, const char * fname) {
|
int llama_eval_export(struct llama_context * ctx, const char * fname) {
|
||||||
// these values determine the maximum inference sizes of the exported computation graph
|
// these values determine the maximum inference sizes of the exported computation graph
|
||||||
// TODO: TMP !!!
|
// TODO: need to increase buffers to support the full context
|
||||||
//const int n_ctx = ctx->model.hparams.n_ctx;
|
//const int n_ctx = ctx->model.hparams.n_ctx;
|
||||||
//const int n_batch = 512;
|
//const int n_batch = 512;
|
||||||
const int n_ctx = 128;
|
const int n_batch = 1;
|
||||||
const int n_batch = 32;
|
const int n_ctx = 512 - n_batch;
|
||||||
|
|
||||||
const std::vector<llama_token> tmp(n_batch, llama_token_bos());
|
const std::vector<llama_token> tmp(n_batch, llama_token_bos());
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue