remove trailing whitespace

This commit is contained in:
xaedes 2023-05-08 00:04:54 +02:00
parent 7c8768f819
commit 2936dd60a4
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
4 changed files with 120 additions and 120 deletions

View file

@ -381,7 +381,7 @@ void randomize_model(struct llama_model * model, int seed, float mean, float std
randomize_tensor_normal(model->tok_embeddings, model->tok_embeddings->n_dims, model->tok_embeddings->ne, &rnd); randomize_tensor_normal(model->tok_embeddings, model->tok_embeddings->n_dims, model->tok_embeddings->ne, &rnd);
randomize_tensor_normal(model->norm, model->norm->n_dims, model->norm->ne, &rnd); randomize_tensor_normal(model->norm, model->norm->n_dims, model->norm->ne, &rnd);
randomize_tensor_normal(model->output, model->output->n_dims, model->output->ne, &rnd); randomize_tensor_normal(model->output, model->output->n_dims, model->output->ne, &rnd);
for (uint32_t i = 0; i < n_layer; ++i) { for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = model->layers[i]; auto & layer = model->layers[i];
randomize_tensor_normal(layer.attention_norm, layer.attention_norm->n_dims, layer.attention_norm->ne, &rnd); randomize_tensor_normal(layer.attention_norm, layer.attention_norm->n_dims, layer.attention_norm->ne, &rnd);
@ -415,7 +415,7 @@ void randomize_model_lora(struct llama_model_lora * model, int seed, float mean,
randomize_tensor_normal(model->norm, model->norm->n_dims, model->norm->ne, &rnd); randomize_tensor_normal(model->norm, model->norm->n_dims, model->norm->ne, &rnd);
randomize_tensor_normal(model->outputa, model->outputa->n_dims, model->outputa->ne, &rnd); randomize_tensor_normal(model->outputa, model->outputa->n_dims, model->outputa->ne, &rnd);
randomize_tensor_normal(model->outputb, model->outputb->n_dims, model->outputb->ne, &rnd); randomize_tensor_normal(model->outputb, model->outputb->n_dims, model->outputb->ne, &rnd);
for (uint32_t i = 0; i < n_layer; ++i) { for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = model->layers[i]; auto & layer = model->layers[i];
randomize_tensor_normal(layer.attention_norm, layer.attention_norm->n_dims, layer.attention_norm->ne, &rnd); randomize_tensor_normal(layer.attention_norm, layer.attention_norm->n_dims, layer.attention_norm->ne, &rnd);
@ -508,14 +508,14 @@ bool init_kv_cache_lora(struct llama_kv_cache* cache, struct llama_model_lora *
} }
struct ggml_tensor * forward( struct ggml_tensor * forward(
struct llama_model * model, struct llama_model * model,
struct llama_kv_cache * cache, struct llama_kv_cache * cache,
struct ggml_context * ctx0, struct ggml_context * ctx0,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct ggml_tensor * tokens_input, struct ggml_tensor * tokens_input,
const int n_tokens, const int n_tokens,
const int n_past) { const int n_past) {
const int N = n_tokens; const int N = n_tokens;
struct llama_kv_cache& kv_self = *cache; struct llama_kv_cache& kv_self = *cache;
@ -569,11 +569,11 @@ struct ggml_tensor * forward(
// Vcur shape [n_embd, N, 1, 1] // Vcur shape [n_embd, N, 1, 1]
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wv, cur), n_embd, N))); struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wv, cur), n_embd, N)));
// kv_self.k shape [n_embd * n_ctx * n_layer, 1] // kv_self.k shape [n_embd * n_ctx * n_layer, 1]
// kv_self.v shape [n_embd * n_ctx * n_layer, 1] // kv_self.v shape [n_embd * n_ctx * n_layer, 1]
// k shape [n_embd * N, 1] == kv_self.k[:,n_past:n_past+N,il,0] // k shape [n_embd * N, 1] == kv_self.k[:,n_past:n_past+N,il,0]
// v shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0] // v shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0]
/* { /* {
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
@ -597,7 +597,7 @@ struct ggml_tensor * forward(
Qcur, Qcur,
0, 2, 1, 3); 0, 2, 1, 3);
// kv_self.k shape [n_embd * n_ctx * n_layer, 1] // kv_self.k shape [n_embd * n_ctx * n_layer, 1]
// K shape [n_embd/n_head, n_past + N, n_head, 1] // K shape [n_embd/n_head, n_past + N, n_head, 1]
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctx0, ggml_permute(ctx0,
@ -641,7 +641,7 @@ struct ggml_tensor * forward(
// KQV_merged = KQV.permute(0, 2, 1, 3) // KQV_merged = KQV.permute(0, 2, 1, 3)
// KQV_merged shape [n_embd/n_head, n_head, N, 1] // KQV_merged shape [n_embd/n_head, n_head, N, 1]
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// KQV_merged shape // KQV_merged shape
// cur = KQV_merged.contiguous().view(n_embd, N) // cur = KQV_merged.contiguous().view(n_embd, N)
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
@ -734,14 +734,14 @@ struct ggml_tensor * forward(
struct ggml_tensor * forward_lora( struct ggml_tensor * forward_lora(
struct llama_model_lora * model, struct llama_model_lora * model,
struct llama_kv_cache * cache, struct llama_kv_cache * cache,
struct ggml_context * ctx0, struct ggml_context * ctx0,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct ggml_tensor * tokens_input, struct ggml_tensor * tokens_input,
const int n_tokens, const int n_tokens,
const int n_past) { const int n_past) {
const int N = n_tokens; const int N = n_tokens;
struct llama_kv_cache& kv_self = *cache; struct llama_kv_cache& kv_self = *cache;
@ -784,23 +784,23 @@ struct ggml_tensor * forward_lora(
// wk shape [n_embd, n_embd, 1, 1] // wk shape [n_embd, n_embd, 1, 1]
// Qcur shape [n_embd/n_head, n_head, N, 1] // Qcur shape [n_embd/n_head, n_head, N, 1]
// Kcur shape [n_embd/n_head, n_head, N, 1] // Kcur shape [n_embd/n_head, n_head, N, 1]
struct ggml_tensor * Qcur = ggml_rope(ctx0, struct ggml_tensor * Qcur = ggml_rope(ctx0,
ggml_reshape_3d(ctx0, ggml_reshape_3d(ctx0,
ggml_mul_mat(ctx0, ggml_mul_mat(ctx0,
model->layers[il].wqa, model->layers[il].wqa,
ggml_mul_mat(ctx0, ggml_mul_mat(ctx0,
model->layers[il].wqb, model->layers[il].wqb,
cur)), cur)),
n_embd/n_head, n_head, N), n_embd/n_head, n_head, N),
n_past, n_rot, 0); n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, struct ggml_tensor * Kcur = ggml_rope(ctx0,
ggml_reshape_3d(ctx0, ggml_reshape_3d(ctx0,
ggml_mul_mat(ctx0, ggml_mul_mat(ctx0,
model->layers[il].wka, model->layers[il].wka,
ggml_mul_mat(ctx0, ggml_mul_mat(ctx0,
model->layers[il].wkb, model->layers[il].wkb,
cur)), cur)),
n_embd/n_head, n_head, N), n_embd/n_head, n_head, N),
n_past, n_rot, 0); n_past, n_rot, 0);
// store key and value to memory // store key and value to memory
@ -808,21 +808,21 @@ struct ggml_tensor * forward_lora(
// compute the transposed [N, n_embd] V matrix // compute the transposed [N, n_embd] V matrix
// wv shape [n_embd, n_embd, 1, 1] // wv shape [n_embd, n_embd, 1, 1]
// Vcur shape [n_embd, N, 1, 1] // Vcur shape [n_embd, N, 1, 1]
struct ggml_tensor * Vcur = ggml_cont(ctx0, struct ggml_tensor * Vcur = ggml_cont(ctx0,
ggml_transpose(ctx0, ggml_transpose(ctx0,
ggml_reshape_2d(ctx0, ggml_reshape_2d(ctx0,
ggml_mul_mat(ctx0, ggml_mul_mat(ctx0,
model->layers[il].wva, model->layers[il].wva,
ggml_mul_mat(ctx0, ggml_mul_mat(ctx0,
model->layers[il].wvb, model->layers[il].wvb,
cur)), cur)),
n_embd, N))); n_embd, N)));
// kv_self.k shape [n_embd * n_ctx * n_layer, 1] // kv_self.k shape [n_embd * n_ctx * n_layer, 1]
// kv_self.v shape [n_embd * n_ctx * n_layer, 1] // kv_self.v shape [n_embd * n_ctx * n_layer, 1]
// k shape [n_embd * N, 1] == kv_self.k[:,n_past:n_past+N,il,0] // k shape [n_embd * N, 1] == kv_self.k[:,n_past:n_past+N,il,0]
// v shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0] // v shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0]
/* { /* {
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
@ -846,7 +846,7 @@ struct ggml_tensor * forward_lora(
Qcur, Qcur,
0, 2, 1, 3); 0, 2, 1, 3);
// kv_self.k shape [n_embd * n_ctx * n_layer, 1] // kv_self.k shape [n_embd * n_ctx * n_layer, 1]
// K shape [n_embd/n_head, n_past + N, n_head, 1] // K shape [n_embd/n_head, n_past + N, n_head, 1]
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctx0, ggml_permute(ctx0,
@ -890,7 +890,7 @@ struct ggml_tensor * forward_lora(
// KQV_merged = KQV.permute(0, 2, 1, 3) // KQV_merged = KQV.permute(0, 2, 1, 3)
// KQV_merged shape [n_embd/n_head, n_head, N, 1] // KQV_merged shape [n_embd/n_head, n_head, N, 1]
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// KQV_merged shape // KQV_merged shape
// cur = KQV_merged.contiguous().view(n_embd, N) // cur = KQV_merged.contiguous().view(n_embd, N)
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
@ -974,10 +974,10 @@ struct ggml_tensor * forward_lora(
// lm_head // lm_head
// inpL shape [n_vocab,N,1,1] // inpL shape [n_vocab,N,1,1]
inpL = ggml_mul_mat(ctx0, inpL = ggml_mul_mat(ctx0,
model->outputa, model->outputa,
ggml_mul_mat(ctx0, ggml_mul_mat(ctx0,
model->outputb, model->outputb,
inpL)); inpL));
// ggml_set_scratch(ctx0, { 0, 0, nullptr, }); // ggml_set_scratch(ctx0, { 0, 0, nullptr, });
@ -1094,12 +1094,12 @@ struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_te
struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
const float eps = 1e-3; const float eps = 1e-3;
return return
ggml_sum(ctx, ggml_sum(ctx,
ggml_neg(ctx, ggml_neg(ctx,
ggml_sum_rows(ctx, ggml_sum_rows(ctx,
ggml_mul(ctx, ggml_mul(ctx,
ggml_soft_max(ctx, a), ggml_soft_max(ctx, a),
ggml_log(ctx, ggml_log(ctx,
ggml_add1(ctx, ggml_add1(ctx,
ggml_soft_max(ctx, b), ggml_soft_max(ctx, b),
@ -1169,7 +1169,7 @@ int main(int argc, char ** argv) {
*/ */
// key + value cache for the self attention // key + value cache for the self attention
struct llama_kv_cache kv_self; struct llama_kv_cache kv_self;
printf("init_kv_cache\n"); printf("init_kv_cache\n");
kv_self.ctx = model.ctx; kv_self.ctx = model.ctx;
init_kv_cache(&kv_self, &model); init_kv_cache(&kv_self, &model);
@ -1221,17 +1221,17 @@ int main(int argc, char ** argv) {
struct ggml_tensor * logits2 = forward(&model, &kv_self, ctx0, &gf, tokens_input2, n_tokens, n_past); struct ggml_tensor * logits2 = forward(&model, &kv_self, ctx0, &gf, tokens_input2, n_tokens, n_past);
// struct ggml_tensor * logits3 = forward(&model, &kv_self, ctx0, &gf, tokens_input3, n_tokens, n_past); // struct ggml_tensor * logits3 = forward(&model, &kv_self, ctx0, &gf, tokens_input3, n_tokens, n_past);
// struct ggml_tensor * logits4 = forward(&model, &kv_self, ctx0, &gf, tokens_input4, n_tokens, n_past); // struct ggml_tensor * logits4 = forward(&model, &kv_self, ctx0, &gf, tokens_input4, n_tokens, n_past);
// struct ggml_tensor * e = cross_entropy_loss(ctx0, targets1, logits1); // struct ggml_tensor * e = cross_entropy_loss(ctx0, targets1, logits1);
// struct ggml_tensor * e = square_error_loss(ctx0, targets1, logits1); // struct ggml_tensor * e = square_error_loss(ctx0, targets1, logits1);
struct ggml_tensor * e = ggml_add(ctx0, struct ggml_tensor * e = ggml_add(ctx0,
square_error_loss(ctx0, targets1, logits1), square_error_loss(ctx0, targets1, logits1),
square_error_loss(ctx0, targets2, logits2)); square_error_loss(ctx0, targets2, logits2));
// struct ggml_tensor * e = ggml_add(ctx0, // struct ggml_tensor * e = ggml_add(ctx0,
// cross_entropy_loss(ctx0, targets1, logits1), // cross_entropy_loss(ctx0, targets1, logits1),
// cross_entropy_loss(ctx0, targets2, logits2)); // cross_entropy_loss(ctx0, targets2, logits2));
// struct ggml_tensor * e = ggml_add(ctx0, // struct ggml_tensor * e = ggml_add(ctx0,
// ggml_add(ctx0, // ggml_add(ctx0,
// cross_entropy_loss(ctx0, targets1, logits1), // cross_entropy_loss(ctx0, targets1, logits1),
// cross_entropy_loss(ctx0, targets2, logits2)), // cross_entropy_loss(ctx0, targets2, logits2)),
@ -1260,7 +1260,7 @@ int main(int argc, char ** argv) {
opt_params_lbfgs.lbfgs.n_iter = 16; opt_params_lbfgs.lbfgs.n_iter = 16;
// ggml_opt(ctx0, opt_params_adam, e); // ggml_opt(ctx0, opt_params_adam, e);
ggml_opt(ctx0, opt_params_lbfgs, e); ggml_opt(ctx0, opt_params_lbfgs, e);
// //
ggml_build_forward_expand(&gf, e); ggml_build_forward_expand(&gf, e);
ggml_graph_compute(ctx0, &gf); ggml_graph_compute(ctx0, &gf);
@ -1292,7 +1292,7 @@ int main(int argc, char ** argv) {
struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens); struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * targets = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens); struct ggml_tensor * targets = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
get_example_targets(137, tokens_input, targets); get_example_targets(137, tokens_input, targets);
for (int i=sample_ctx; i<n_tokens; ++i) { for (int i=sample_ctx; i<n_tokens; ++i) {
ggml_set_i32_1d(tokens_input, i, n_vocab/2); ggml_set_i32_1d(tokens_input, i, n_vocab/2);
@ -1327,14 +1327,14 @@ int main(int argc, char ** argv) {
// int sample_at = n_tokens-1; // int sample_at = n_tokens-1;
int token = ggml_get_i32_1d(best_samples, sample_ctx-1); int token = ggml_get_i32_1d(best_samples, sample_ctx-1);
// print_row(probs, sample_at); // print_row(probs, sample_at);
print_token(token, n_vocab); print_token(token, n_vocab);
lshift_examples(tokens_input, targets, 1); lshift_examples(tokens_input, targets, 1);
ggml_set_i32_1d(tokens_input, 0, 0); ggml_set_i32_1d(tokens_input, 0, 0);
ggml_set_i32_1d(tokens_input, sample_ctx-1, token); ggml_set_i32_1d(tokens_input, sample_ctx-1, token);
// printf("---\n"); // printf("---\n");
// for (int i=0; i<sample_ctx-1; ++i) { // for (int i=0; i<sample_ctx-1; ++i) {
// print_token(ggml_get_i32_1d(tokens_input, i), model.hparams.n_vocab); // print_token(ggml_get_i32_1d(tokens_input, i), model.hparams.n_vocab);
@ -1350,7 +1350,7 @@ int main(int argc, char ** argv) {
} }
printf("important (dont optimize it away, compiler!) : %d\n", important_sum); printf("important (dont optimize it away, compiler!) : %d\n", important_sum);
} }
print_matrix(model.tok_embeddings); print_matrix(model.tok_embeddings);
printf("done\n"); printf("done\n");

88
ggml.c
View file

@ -7161,7 +7161,7 @@ static void ggml_compute_forward_dup_same_cont(
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]); (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
} }
} }
static void ggml_compute_forward_dup_f16( static void ggml_compute_forward_dup_f16(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
@ -7818,7 +7818,7 @@ static void ggml_compute_forward_add_f32(
vDSP_vadd( vDSP_vadd(
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
ne0); ne0);
#else #else
ggml_vec_add_f32(ne0, ggml_vec_add_f32(ne0,
@ -8177,7 +8177,7 @@ static void ggml_compute_forward_add1_f32(
vDSP_vadd( vDSP_vadd(
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) src1->data), 0, (float *) ((char *) src1->data), 0,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
ne0); ne0);
#else #else
ggml_vec_add1_f32(ne0, ggml_vec_add1_f32(ne0,
@ -8438,17 +8438,17 @@ static void ggml_compute_forward_acc_f32(
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
GGML_ASSERT(opt0->type == GGML_TYPE_I32); GGML_ASSERT(opt0->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_nelements(opt0) == 5); GGML_ASSERT(ggml_nelements(opt0) == 5);
// view src0 and dst with these strides and data offset inbytes during acc // view src0 and dst with these strides and data offset inbytes during acc
// nb0 is implicitely element_size because src0 and dst are contiguous // nb0 is implicitely element_size because src0 and dst are contiguous
size_t nb1 = ((int32_t *) opt0->data)[0]; size_t nb1 = ((int32_t *) opt0->data)[0];
size_t nb2 = ((int32_t *) opt0->data)[1]; size_t nb2 = ((int32_t *) opt0->data)[1];
size_t nb3 = ((int32_t *) opt0->data)[2]; size_t nb3 = ((int32_t *) opt0->data)[2];
size_t offset = ((int32_t *) opt0->data)[3]; size_t offset = ((int32_t *) opt0->data)[3];
bool inplace = (bool) ((int32_t *) opt0->data)[4]; bool inplace = (bool) ((int32_t *) opt0->data)[4];
if (!inplace && (params->type == GGML_TASK_INIT)) { if (!inplace && (params->type == GGML_TASK_INIT)) {
// memcpy needs to be synchronized across threads to avoid race conditions. // memcpy needs to be synchronized across threads to avoid race conditions.
@ -8596,7 +8596,7 @@ static void ggml_compute_forward_sub_f32(
vDSP_vsub( vDSP_vsub(
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
ne0); ne0);
#else #else
ggml_vec_sub_f32(ne0, ggml_vec_sub_f32(ne0,
@ -8692,7 +8692,7 @@ static void ggml_compute_forward_mul_f32(
vDSP_vmul( vDSP_vmul(
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
ne0); ne0);
#else #else
ggml_vec_mul_f32(ne0, ggml_vec_mul_f32(ne0,
@ -8788,7 +8788,7 @@ static void ggml_compute_forward_div_f32(
vDSP_vdiv( vDSP_vdiv(
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
ne0); ne0);
#else #else
ggml_vec_div_f32(ne0, ggml_vec_div_f32(ne0,
@ -9189,9 +9189,9 @@ static void ggml_compute_forward_repeat_f32(
const size_t nb01 = src0->nb[1]; const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2]; const size_t nb02 = src0->nb[2];
const size_t nb03 = src0->nb[3]; const size_t nb03 = src0->nb[3];
// guaranteed to be an integer due to the check in ggml_can_repeat // guaranteed to be an integer due to the check in ggml_can_repeat
const int nr0 = (int)(ne0/ne00); const int nr0 = (int)(ne0/ne00);
const int nr1 = (int)(ne1/ne01); const int nr1 = (int)(ne1/ne01);
const int nr2 = (int)(ne2/ne02); const int nr2 = (int)(ne2/ne02);
const int nr3 = (int)(ne3/ne03); const int nr3 = (int)(ne3/ne03);
@ -9850,12 +9850,12 @@ static void ggml_compute_forward_rms_norm_back_f32(
{ {
// z = rms_norm(x) // z = rms_norm(x)
// //
// rms_norm(src0) = // rms_norm(src0) =
// scale( // scale(
// src0, // src0,
// div( // div(
// 1, // 1,
// sqrt( // sqrt(
// add( // add(
// scale( // scale(
@ -9868,17 +9868,17 @@ static void ggml_compute_forward_rms_norm_back_f32(
// postorder: // postorder:
// ## op args grad // ## op args grad
// 00 param src0 grad[#00] // 00 param src0 grad[#00]
// 01 const 1 // 01 const 1
// 02 sqr (#00) grad[#02] // 02 sqr (#00) grad[#02]
// 03 sum (#02) grad[#03] // 03 sum (#02) grad[#03]
// 04 const 1/N // 04 const 1/N
// 05 scale (#03, #04) grad[#05] // 05 scale (#03, #04) grad[#05]
// 06 const eps // 06 const eps
// 07 add (#05, #06) grad[#07] // 07 add (#05, #06) grad[#07]
// 08 sqrt (#07) grad[#08] // 08 sqrt (#07) grad[#08]
// 09 div (#01,#08) grad[#09] // 09 div (#01,#08) grad[#09]
// 10 scale (#00,#09) grad[#10] // 10 scale (#00,#09) grad[#10]
// //
// backward pass, given grad[#10] // backward pass, given grad[#10]
// #10: scale // #10: scale
// grad[#00] += scale(grad[#10],#09) // grad[#00] += scale(grad[#10],#09)
@ -9893,7 +9893,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
// grad[#03] += scale(grad[#05],#04) // grad[#03] += scale(grad[#05],#04)
// #03: sum // #03: sum
// grad[#02] += repeat(grad[#03], #02) // grad[#02] += repeat(grad[#03], #02)
// #02: // #02:
// grad[#00] += scale(mul(#00, grad[#02]), 2.0) // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
// //
// substitute and simplify: // substitute and simplify:
@ -10716,17 +10716,17 @@ static void ggml_compute_forward_set_f32(
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
GGML_ASSERT(opt0->type == GGML_TYPE_I32); GGML_ASSERT(opt0->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_nelements(opt0) == 5); GGML_ASSERT(ggml_nelements(opt0) == 5);
// view src0 and dst with these strides and data offset inbytes during set // view src0 and dst with these strides and data offset inbytes during set
// nb0 is implicitely element_size because src0 and dst are contiguous // nb0 is implicitely element_size because src0 and dst are contiguous
size_t nb1 = ((int32_t *) opt0->data)[0]; size_t nb1 = ((int32_t *) opt0->data)[0];
size_t nb2 = ((int32_t *) opt0->data)[1]; size_t nb2 = ((int32_t *) opt0->data)[1];
size_t nb3 = ((int32_t *) opt0->data)[2]; size_t nb3 = ((int32_t *) opt0->data)[2];
size_t offset = ((int32_t *) opt0->data)[3]; size_t offset = ((int32_t *) opt0->data)[3];
bool inplace = (bool) ((int32_t *) opt0->data)[4]; bool inplace = (bool) ((int32_t *) opt0->data)[4];
if (!inplace && (params->type == GGML_TASK_INIT)) { if (!inplace && (params->type == GGML_TASK_INIT)) {
// memcpy needs to be synchronized across threads to avoid race conditions. // memcpy needs to be synchronized across threads to avoid race conditions.
@ -13420,7 +13420,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
case GGML_OP_ROPE_BACK: case GGML_OP_ROPE_BACK:
{ {
ggml_compute_forward_rope_back(params, tensor->src0, tensor->src1, tensor); ggml_compute_forward_rope_back(params, tensor->src0, tensor->src1, tensor);
} break; } break;
case GGML_OP_ALIBI: case GGML_OP_ALIBI:
{ {
ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor); ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
@ -13521,7 +13521,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src1->grad->ne[2], src1->grad->ne[2],
src1->grad->ne[3], src1->grad->ne[3],
nb1, nb2, nb3, offset); nb1, nb2, nb3, offset);
src1->grad = src1->grad =
ggml_add_impl(ctx, ggml_add_impl(ctx,
src1->grad, src1->grad,
@ -13664,7 +13664,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// transpose [nc0*nr0,1,1] // transpose [nc0*nr0,1,1]
// reshape [nc0,nr0,1,1] reshape_1d or reshape_2d // reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
// add to src0->grad // add to src0->grad
int64_t ne[4] = {nc0,ncr,nr0,nrr}; int64_t ne[4] = {nc0,ncr,nr0,nrr};
struct ggml_tensor* F00 = tensor->grad; struct ggml_tensor* F00 = tensor->grad;
@ -13846,7 +13846,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3]; const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3];
struct ggml_tensor * tensor_grad_view = NULL; struct ggml_tensor * tensor_grad_view = NULL;
if (src0->grad || src1->grad) { if (src0->grad || src1->grad) {
GGML_ASSERT(src0->type == tensor->type); GGML_ASSERT(src0->type == tensor->type);
GGML_ASSERT(tensor->grad->type == tensor->type); GGML_ASSERT(tensor->grad->type == tensor->type);
@ -13862,10 +13862,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} }
if (src0->grad) { if (src0->grad) {
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_impl(ctx,
src0->grad, src0->grad,
ggml_acc_impl(ctx, ggml_acc_impl(ctx,
tensor->grad, tensor->grad,
ggml_neg(ctx, tensor_grad_view), ggml_neg(ctx, tensor_grad_view),
nb1, nb2, nb3, offset, false), nb1, nb2, nb3, offset, false),
inplace); inplace);
@ -13944,7 +13944,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
nb2 = (nb2 / n0) * ng; nb2 = (nb2 / n0) * ng;
nb3 = (nb3 / n0) * ng; nb3 = (nb3 / n0) * ng;
} }
src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace); src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace);
} }
} break; } break;
@ -14040,18 +14040,18 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama // necessary for llama
if (src0->grad) { if (src0->grad) {
// y = softmax(x) // y = softmax(x)
// //
// Jii = yi - yi*yi // Jii = yi - yi*yi
// Jij = -yi*yj // Jij = -yi*yj
// J = diag(y)-y.*y // J = diag(y)-y.*y
// dx = J * dy // dx = J * dy
// dxk = sum(Jkj * dyk) // dxk = sum(Jkj * dyk)
int64_t ne2[4] = { int64_t ne2[4] = {
tensor->ne[0], tensor->ne[0],
1, 1,
tensor->ne[1]*tensor->ne[2], tensor->ne[1]*tensor->ne[2],
tensor->ne[3] tensor->ne[3]
}; };
struct ggml_tensor * tensor2 = ggml_cont(ctx, struct ggml_tensor * tensor2 = ggml_cont(ctx,
ggml_reshape_4d(ctx, ggml_reshape_4d(ctx,

6
ggml.h
View file

@ -649,7 +649,7 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
// in-place, returns view(a) // in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_scale_inplace( GGML_API struct ggml_tensor * ggml_scale_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
@ -787,7 +787,7 @@ extern "C" {
int64_t ne3, int64_t ne3,
size_t nb1, // row stride in bytes size_t nb1, // row stride in bytes
size_t nb2, // slice stride in bytes size_t nb2, // slice stride in bytes
size_t nb3, size_t nb3,
size_t offset); size_t offset);
GGML_API struct ggml_tensor * ggml_permute( GGML_API struct ggml_tensor * ggml_permute(
@ -862,7 +862,7 @@ extern "C" {
int n_dims, int n_dims,
int mode); int mode);
// in-place, returns view(a) // in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_inplace( GGML_API struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,

View file

@ -156,7 +156,7 @@ struct ggml_tensor * get_random_tensor_int(
float get_element(const struct ggml_tensor * t, int idx) { float get_element(const struct ggml_tensor * t, int idx) {
if (t->type == GGML_TYPE_F32) { if (t->type == GGML_TYPE_F32) {
return ((float *)t->data)[idx]; return ((float *)t->data)[idx];
} else if (t->type == GGML_TYPE_I32) { } else if (t->type == GGML_TYPE_I32) {
return ((int32_t *)t->data)[idx]; return ((int32_t *)t->data)[idx];
} else { } else {
assert(false); assert(false);
@ -591,9 +591,9 @@ int main(int argc, const char ** argv) {
#ifdef GGML_SILU_FP16 #ifdef GGML_SILU_FP16
// due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds. // due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds.
check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY); check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY);
#else #else
check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
#endif #endif
} }
} }
@ -610,7 +610,7 @@ int main(int argc, const char ** argv) {
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0])); struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0]));
check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY); check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
} }
} }
@ -630,7 +630,7 @@ int main(int argc, const char ** argv) {
struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], x[1])); struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], x[1]));
check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
} }
} }
@ -975,10 +975,10 @@ int main(int argc, const char ** argv) {
int64_t ne2[4]; int64_t ne2[4];
const int nargs = 1; const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) for (int ndims = 1; ndims <= 4; ++ndims)
{ {
// ggml_permute will set axes of dimensions below n_dims to 1. // ggml_permute will set axes of dimensions below n_dims to 1.
// to make ggml_permute work correctly on all axes, // to make ggml_permute work correctly on all axes,
// the input tensor needs maximal n_dim of 4. // the input tensor needs maximal n_dim of 4.
for (int i=0; i<ndims; ++i) { for (int i=0; i<ndims; ++i) {
ne2[i] = ne[i]; ne2[i] = ne[i];
@ -1008,10 +1008,10 @@ int main(int argc, const char ** argv) {
int64_t ne2[4]; int64_t ne2[4];
const int nargs = 1; const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) for (int ndims = 1; ndims <= 4; ++ndims)
{ {
// ggml_transpose will set axes of dimensions below n_dims to 1. // ggml_transpose will set axes of dimensions below n_dims to 1.
// to make ggml_transpose work correctly on all axes, // to make ggml_transpose work correctly on all axes,
// the input tensor needs maximal n_dim of 4. // the input tensor needs maximal n_dim of 4.
for (int i=0; i<ndims; ++i) { for (int i=0; i<ndims; ++i) {
ne2[i] = ne[i]; ne2[i] = ne[i];
@ -1038,7 +1038,7 @@ int main(int argc, const char ** argv) {
const int ndims = 2; const int ndims = 2;
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f); x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
x[1] = get_random_tensor_int(ctx0, 1, ne3, 0, ne2[1]); x[1] = get_random_tensor_int(ctx0, 1, ne3, 0, ne2[1]);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1])); struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1]));
@ -1079,7 +1079,7 @@ int main(int argc, const char ** argv) {
// softmax // softmax
{ {
const int nargs = 1; const int nargs = 1;
int64_t ne2[4]; int64_t ne2[4];
get_random_dims(ne2, 4); get_random_dims(ne2, 4);
@ -1121,7 +1121,7 @@ int main(int argc, const char ** argv) {
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode)); struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode));
GGML_PRINT_DEBUG("rope: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); GGML_PRINT_DEBUG("rope: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
check_gradient("rope", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY); check_gradient("rope", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
} }
} }
} }