Fix import of llama2.c models that don't share weights between embedding layers

This commit is contained in:
ochafik 2023-08-22 01:56:58 +01:00
parent 930523c8e1
commit 0f7cb95352

View file

@ -49,10 +49,10 @@ typedef struct {
// float* freq_cis_real; // (seq_len, dim/2) // float* freq_cis_real; // (seq_len, dim/2)
// float* freq_cis_imag; // (seq_len, dim/2) // float* freq_cis_imag; // (seq_len, dim/2)
// (optional) classifier weights for the logits, on the last layer // (optional) classifier weights for the logits, on the last layer
//float* wcls; float* wcls;
} TransformerWeights; } TransformerWeights;
void malloc_weights(TransformerWeights* w, Config* p) { void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) {
// we calloc instead of malloc to keep valgrind happy // we calloc instead of malloc to keep valgrind happy
w->token_embedding_table = new float[p->vocab_size * p->dim](); w->token_embedding_table = new float[p->vocab_size * p->dim]();
printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
@ -86,9 +86,16 @@ void malloc_weights(TransformerWeights* w, Config* p) {
w->rms_final_weight = new float[p->dim](); w->rms_final_weight = new float[p->dim]();
printf("[%s:AK] Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim); printf("[%s:AK] Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim);
if (shared_weights) {
w->wcls = NULL;
} else {
w->wcls = new float[p->vocab_size * p->dim]();
printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
}
} }
int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) { int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shared_weights) {
if (fread(w->token_embedding_table, sizeof(float), p->vocab_size * p->dim, f) != static_cast<size_t>(p->vocab_size * p->dim)) return 1; if (fread(w->token_embedding_table, sizeof(float), p->vocab_size * p->dim, f) != static_cast<size_t>(p->vocab_size * p->dim)) return 1;
if (fread(w->rms_att_weight, sizeof(float), p->n_layers * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim)) return 1; if (fread(w->rms_att_weight, sizeof(float), p->n_layers * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim)) return 1;
if (fread(w->wq, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1; if (fread(w->wq, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1;
@ -100,6 +107,22 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) {
if (fread(w->w2, sizeof(float), p->n_layers * p->hidden_dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->hidden_dim * p->dim)) return 1; if (fread(w->w2, sizeof(float), p->n_layers * p->hidden_dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->hidden_dim * p->dim)) return 1;
if (fread(w->w3, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->hidden_dim)) return 1; if (fread(w->w3, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->hidden_dim)) return 1;
if (fread(w->rms_final_weight, sizeof(float), p->dim, f) != static_cast<size_t>(p->dim)) return 1; if (fread(w->rms_final_weight, sizeof(float), p->dim, f) != static_cast<size_t>(p->dim)) return 1;
// Skip freq_cis_real & freq_cis_imag
int head_size = p->dim / p->n_heads;
fseek(f, p->seq_len * head_size * sizeof(float), SEEK_CUR);
if (!shared_weights && fread(w->wcls, sizeof(float), p->vocab_size * p->dim, f) != static_cast<size_t>(p->vocab_size * p->dim)) return 1;
// Check we didn't forget to read anything
auto curr = ftell(f);
fseek(f, 0, SEEK_END);
auto end = ftell(f);
if (curr != end) {
printf("Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", curr, end);
return 1;
}
return 0; return 0;
} }
@ -115,6 +138,7 @@ void free_weights(TransformerWeights* w) {
delete w->w2; delete w->w2;
delete w->w3; delete w->w3;
delete w->rms_final_weight; delete w->rms_final_weight;
if (w->wcls) delete w->wcls;
} }
void print_sample_weights(TransformerWeights *w){ void print_sample_weights(TransformerWeights *w){
@ -131,6 +155,7 @@ void print_sample_weights(TransformerWeights *w){
printf("%f\n", w->w2[0]); printf("%f\n", w->w2[0]);
printf("%f\n", w->w3[0]); printf("%f\n", w->w3[0]);
printf("%f\n", w->rms_att_weight[0]); printf("%f\n", w->rms_att_weight[0]);
if (w->wcls) printf("%f\n", w->wcls[0]);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -617,7 +642,7 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
// // w->token_embedding_table -> model->tok_embeddings // // w->token_embedding_table -> model->tok_embeddings
// // float* -> struct ggml_tensor // // float* -> struct ggml_tensor
// stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table); // stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table);
// stuff_karpathy_weights_into_gg(model->output, w->token_embedding_table); // stuff_karpathy_weights_into_gg(model->output, w->wcls ? w->wcls : w->token_embedding_table);
// //
// stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight); // stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight);
// //print_row(model->norm, 0); // //print_row(model->norm, 0);
@ -791,9 +816,12 @@ int main(int argc, char ** argv) {
if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; } if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; }
// read in the config header // read in the config header
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; } if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
auto shared_weights = config.vocab_size > 0;
config.vocab_size = abs(config.vocab_size);
// read in the Transformer weights // read in the Transformer weights
malloc_weights(&weights, &config); malloc_weights(&weights, &config, shared_weights);
if(checkpoint_init_weights(&weights, &config, file)) { return 1; } if(checkpoint_init_weights(&weights, &config, file, shared_weights)) { return 1; }
fclose(file); fclose(file);
} }