minor : fix compiler warnings + indentation style

This commit is contained in:
Georgi Gerganov 2023-05-13 09:55:17 +03:00
parent b9ef08ccab
commit f977243ded
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 63 additions and 56 deletions

View file

@ -134,7 +134,7 @@ struct llama_hparams {
};
uint32_t get_n_ff(const struct llama_hparams* hparams) {
uint32_t n_ff = ((2*(4*hparams->n_embd)/3 + hparams->n_mult - 1)/hparams->n_mult)*hparams->n_mult;
const uint32_t n_ff = ((2*(4*hparams->n_embd)/3 + hparams->n_mult - 1)/hparams->n_mult)*hparams->n_mult;
return n_ff;
}
@ -241,7 +241,7 @@ void init_model(struct llama_model * model) {
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_vocab = hparams.n_vocab;
uint32_t n_ff = get_n_ff(&hparams);
const uint32_t n_ff = get_n_ff(&hparams);
struct ggml_context * ctx = model->ctx;
@ -265,7 +265,7 @@ void init_model(struct llama_model * model) {
layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // (layers_i + ".ffn_norm.weight", {n_embd});
layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); // (layers_i + ".feed_forward.w1.weight", {n_embd, n_ff});
layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd); // (layers_i + ".feed_forward.w2.weight", { n_ff, n_embd});
layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd); // (layers_i + ".feed_forward.w2.weight", { n_ff, n_embd});
layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); // (layers_i + ".feed_forward.w3.weight", {n_embd, n_ff});
}
}
@ -275,18 +275,19 @@ void init_model_lora(struct llama_model_lora * model) {
const auto & hparams = model->hparams;
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_mult = hparams.n_mult;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_vocab = hparams.n_vocab;
const uint32_t n_lora = hparams.n_lora;
uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
const uint32_t n_ff = ((2*(4*n_embd)/3 + n_mult - 1)/n_mult)*n_mult;
struct ggml_context * ctx = model->ctx;
model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); // ("tok_embeddings.weight", {n_embd, n_vocab});
model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // ("norm.weight", {n_embd});
model->outputa = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_vocab); // ("output.weight", {n_embd, n_vocab});
model->outputb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // ("output.weight", {n_embd, n_vocab});
model->outputa = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_vocab); // ("output.weight", {n_embd, n_vocab});
model->outputb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // ("output.weight", {n_embd, n_vocab});
model->layers.resize(n_layer);
for (uint32_t i = 0; i < n_layer; ++i) {
@ -296,26 +297,28 @@ void init_model_lora(struct llama_model_lora * model) {
layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // (layers_i + ".attention_norm.weight", {n_embd});
layer.wqa = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd); // (layers_i + ".attention.wq.weight", {n_embd, n_embd});
layer.wqb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // (layers_i + ".attention.wq.weight", {n_embd, n_embd});
layer.wka = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd); // (layers_i + ".attention.wk.weight", {n_embd, n_embd});
layer.wkb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // (layers_i + ".attention.wk.weight", {n_embd, n_embd});
layer.wva = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd); // (layers_i + ".attention.wv.weight", {n_embd, n_embd});
layer.wvb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // (layers_i + ".attention.wv.weight", {n_embd, n_embd});
layer.woa = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd); // (layers_i + ".attention.wo.weight", {n_embd, n_embd});
layer.wob = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // (layers_i + ".attention.wo.weight", {n_embd, n_embd});
layer.wqa = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd); // (layers_i + ".attention.wq.weight", {n_embd, n_embd});
layer.wqb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // (layers_i + ".attention.wq.weight", {n_embd, n_embd});
layer.wka = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd); // (layers_i + ".attention.wk.weight", {n_embd, n_embd});
layer.wkb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // (layers_i + ".attention.wk.weight", {n_embd, n_embd});
layer.wva = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd); // (layers_i + ".attention.wv.weight", {n_embd, n_embd});
layer.wvb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // (layers_i + ".attention.wv.weight", {n_embd, n_embd});
layer.woa = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd); // (layers_i + ".attention.wo.weight", {n_embd, n_embd});
layer.wob = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // (layers_i + ".attention.wo.weight", {n_embd, n_embd});
layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // (layers_i + ".ffn_norm.weight", {n_embd});
layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); // (layers_i + ".feed_forward.w1.weight", {n_embd, n_ff});
layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd); // (layers_i + ".feed_forward.w2.weight", { n_ff, n_embd});
layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd); // (layers_i + ".feed_forward.w2.weight", { n_ff, n_embd});
layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); // (layers_i + ".feed_forward.w3.weight", {n_embd, n_ff});
}
}
void set_param_model(struct llama_model * model) {
const auto& hparams = model->hparams;
const uint32_t n_layer = hparams.n_layer;
struct ggml_context* ctx = model->ctx;
ggml_set_param(ctx, model->tok_embeddings);
@ -339,7 +342,9 @@ void set_param_model(struct llama_model * model) {
void set_param_model_lora(struct llama_model_lora * model) {
const auto& hparams = model->hparams;
const uint32_t n_layer = hparams.n_layer;
struct ggml_context* ctx = model->ctx;
ggml_set_param(ctx, model->tok_embeddings);
@ -369,11 +374,7 @@ void set_param_model_lora(struct llama_model_lora * model) {
void randomize_model(struct llama_model * model, int seed, float mean, float std, float min, float max) {
const auto & hparams = model->hparams;
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_vocab = hparams.n_vocab;
uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
struct random_normal_distribution rnd;
init_random_normal_distribution(&rnd, seed, mean, std, min, max);
@ -402,11 +403,7 @@ void randomize_model(struct llama_model * model, int seed, float mean, float std
void randomize_model_lora(struct llama_model_lora * model, int seed, float mean, float std, float min, float max) {
const auto & hparams = model->hparams;
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_vocab = hparams.n_vocab;
uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
struct random_normal_distribution rnd;
init_random_normal_distribution(&rnd, seed, mean, std, min, max);
@ -438,9 +435,10 @@ void randomize_model_lora(struct llama_model_lora * model, int seed, float mean,
bool init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model, int n_batch) {
const auto & hparams = model->hparams;
const int n_ctx = hparams.n_ctx;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const uint32_t n_ctx = hparams.n_ctx;
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_layer = hparams.n_layer;
const int64_t n_mem = n_layer*n_ctx*n_batch;
const int64_t n_elements = n_embd*n_mem;
@ -473,9 +471,10 @@ bool init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model, int
bool init_kv_cache_lora(struct llama_kv_cache* cache, struct llama_model_lora * model, int n_batch) {
const auto & hparams = model->hparams;
const int n_ctx = hparams.n_ctx;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const uint32_t n_ctx = hparams.n_ctx;
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_layer = hparams.n_layer;
const int64_t n_mem = n_layer*n_ctx*n_batch;
const int64_t n_elements = n_embd*n_mem;
@ -1062,12 +1061,12 @@ struct ggml_tensor * forward_lora(
struct llama_kv_cache& kv_self = *cache;
const auto & hparams = model->hparams;
const int n_ctx = hparams.n_ctx;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_head = hparams.n_head;
const int n_rot = hparams.n_rot;
const int n_lora = hparams.n_lora;
struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(tokens->data, tokens_input->data, N*ggml_element_size(tokens));
@ -1310,7 +1309,7 @@ void sample_softmax(struct ggml_tensor * logits, struct ggml_tensor * probs, str
assert(logits->ne[1] == best_samples->ne[0]);
assert(logits->ne[0] == probs->ne[0]);
assert(logits->ne[1] == probs->ne[1]);
for (int i=0; i< logits->ne[1]; ++i) {
for (int i = 0; i < logits->ne[1]; ++i) {
float max_logit = ggml_get_f32_1d(logits, i * logits->ne[0]);
ggml_set_i32_1d(best_samples, i, 0);
for (int k = 0; k < logits->ne[0]; ++k) {
@ -1347,18 +1346,18 @@ void sample_softmax_batch(struct ggml_context * ctx, struct ggml_tensor * logits
GGML_ASSERT(n_tokens == probs->ne[1]);
GGML_ASSERT(n_batch == probs->ne[2]);
for (int k=0; k<n_batch; ++k) {
for (int k = 0; k < n_batch; ++k) {
struct ggml_tensor * best_samples_k = ggml_view_1d(ctx,
best_samples,
best_samples->ne[0],
k*best_samples->nb[1]);
struct ggml_tensor * logits_k = ggml_view_2d(ctx,
struct ggml_tensor * logits_k = ggml_view_2d(ctx,
logits,
logits->ne[0],
logits->ne[1],
logits->nb[1],
k*logits->nb[2]);
struct ggml_tensor * probs_k = ggml_view_2d(ctx,
struct ggml_tensor * probs_k = ggml_view_2d(ctx,
probs,
probs->ne[0],
probs->ne[1],
@ -1378,7 +1377,7 @@ void print_row(struct ggml_tensor * probs, int i) {
void print_matrix(struct ggml_tensor * probs) {
assert(probs->n_dims == 2);
for (int i=0; i<probs->ne[1]; ++i) {
for (int i = 0; i < probs->ne[1]; ++i) {
for (int k = 0; k < probs->ne[0]; ++k) {
float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k);
printf(" %.2f", p);
@ -1431,7 +1430,6 @@ void get_example_targets_batch(struct ggml_context * ctx, int example_id, struct
GGML_ASSERT( targets->n_dims == 3);
int n_tokens = tokens_input->ne[0];
int n_batch = tokens_input->ne[1];
int n_vocab = targets->ne[0];
GGML_ASSERT(n_tokens == targets->ne[1]);
GGML_ASSERT(n_batch == targets->ne[2]);
@ -1481,6 +1479,12 @@ struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_t
}
int main(int argc, char ** argv) {
if (argc < 1) {
fprintf(stderr, "usage: %s\n", argv[0]);
return 1;
}
struct ggml_init_params lcparams;
lcparams.mem_size = 1024ll*1024ll*1024ll;
lcparams.mem_buffer = NULL;
@ -1565,7 +1569,6 @@ int main(int argc, char ** argv) {
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);

42
ggml.c
View file

@ -3978,12 +3978,12 @@ inline static float ggml_silu_f32(float x) {
return x/(1.0f + expf(-x));
}
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
const uint16_t * i16 = (const uint16_t *) x;
for (int i = 0; i < n; ++i) {
y[i] = table_silu_f16[i16[i]];
}
}
//inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
// const uint16_t * i16 = (const uint16_t *) x;
// for (int i = 0; i < n; ++i) {
// y[i] = table_silu_f16[i16[i]];
// }
//}
#ifdef GGML_SILU_FP16
inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
@ -4512,9 +4512,9 @@ static inline int ggml_up32(int n) {
return (n + 31) & ~31;
}
static inline int ggml_up64(int n) {
return (n + 63) & ~63;
}
//static inline int ggml_up64(int n) {
// return (n + 63) & ~63;
//}
static inline int ggml_up(int n, int m) {
// assert m is a power of 2
@ -8165,6 +8165,8 @@ static void ggml_compute_forward_add1_f32(
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
#ifdef GGML_USE_ACCELERATE
UNUSED(ggml_vec_add1_f32);
vDSP_vadd(
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) src1->data), 0,
@ -8680,6 +8682,8 @@ static void ggml_compute_forward_mul_f32(
#ifdef GGML_USE_ACCELERATE
UNUSED(ggml_vec_mul_f32);
vDSP_vmul(
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
@ -9831,15 +9835,15 @@ static void ggml_compute_forward_rms_norm_back_f32(
sum_xdz += (ggml_float)(x[i00] * dz[i00]);
}
const ggml_float mean = sum_xx/ne00;
const ggml_float mean_eps = sum_xx/ne00 + eps;
const ggml_float sum_eps = sum_xx + eps*ne00;
const ggml_float mean_xdz = sum_xdz/ne00;
//const float mean = (float)(sum_xx)/ne00;
const float mean_eps = (float)(sum_xx)/ne00 + eps;
const float sum_eps = (float)(sum_xx) + eps*ne00;
//const float mean_xdz = (float)(sum_xdz)/ne00;
// we could cache rms from forward pass to improve performance.
// to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
const ggml_float rms = sqrtf(mean_eps);
const ggml_float rrms = 1.0f / sqrtf(mean_eps);
const ggml_float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
//const float rms = sqrtf(mean_eps);
const float rrms = 1.0f / sqrtf(mean_eps);
//const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
{
// z = rms_norm(x)
@ -9937,10 +9941,10 @@ static void ggml_compute_forward_rms_norm_back_f32(
// dx := scale(dx, rrms)
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
ggml_vec_cpy_f32(ne00, dx, x);
ggml_vec_cpy_f32 (ne00, dx, x);
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
ggml_vec_scale_f32(ne00, dx, -sum_xdz/sum_eps);
ggml_vec_acc_f32(ne00, dx, dz);
ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
ggml_vec_acc_f32 (ne00, dx, dz);
ggml_vec_scale_f32(ne00, dx, rrms);
}
}