Fix import of llama2.c models that don't share weights between embedding layers
This commit is contained in:
parent
930523c8e1
commit
0f7cb95352
1 changed files with 34 additions and 6 deletions
|
@ -49,10 +49,10 @@ typedef struct {
|
||||||
// float* freq_cis_real; // (seq_len, dim/2)
|
// float* freq_cis_real; // (seq_len, dim/2)
|
||||||
// float* freq_cis_imag; // (seq_len, dim/2)
|
// float* freq_cis_imag; // (seq_len, dim/2)
|
||||||
// (optional) classifier weights for the logits, on the last layer
|
// (optional) classifier weights for the logits, on the last layer
|
||||||
//float* wcls;
|
float* wcls;
|
||||||
} TransformerWeights;
|
} TransformerWeights;
|
||||||
|
|
||||||
void malloc_weights(TransformerWeights* w, Config* p) {
|
void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) {
|
||||||
// we calloc instead of malloc to keep valgrind happy
|
// we calloc instead of malloc to keep valgrind happy
|
||||||
w->token_embedding_table = new float[p->vocab_size * p->dim]();
|
w->token_embedding_table = new float[p->vocab_size * p->dim]();
|
||||||
printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
|
printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
|
||||||
|
@ -86,9 +86,16 @@ void malloc_weights(TransformerWeights* w, Config* p) {
|
||||||
|
|
||||||
w->rms_final_weight = new float[p->dim]();
|
w->rms_final_weight = new float[p->dim]();
|
||||||
printf("[%s:AK] Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim);
|
printf("[%s:AK] Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim);
|
||||||
|
|
||||||
|
if (shared_weights) {
|
||||||
|
w->wcls = NULL;
|
||||||
|
} else {
|
||||||
|
w->wcls = new float[p->vocab_size * p->dim]();
|
||||||
|
printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) {
|
int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shared_weights) {
|
||||||
if (fread(w->token_embedding_table, sizeof(float), p->vocab_size * p->dim, f) != static_cast<size_t>(p->vocab_size * p->dim)) return 1;
|
if (fread(w->token_embedding_table, sizeof(float), p->vocab_size * p->dim, f) != static_cast<size_t>(p->vocab_size * p->dim)) return 1;
|
||||||
if (fread(w->rms_att_weight, sizeof(float), p->n_layers * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim)) return 1;
|
if (fread(w->rms_att_weight, sizeof(float), p->n_layers * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim)) return 1;
|
||||||
if (fread(w->wq, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1;
|
if (fread(w->wq, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1;
|
||||||
|
@ -100,6 +107,22 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) {
|
||||||
if (fread(w->w2, sizeof(float), p->n_layers * p->hidden_dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->hidden_dim * p->dim)) return 1;
|
if (fread(w->w2, sizeof(float), p->n_layers * p->hidden_dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->hidden_dim * p->dim)) return 1;
|
||||||
if (fread(w->w3, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->hidden_dim)) return 1;
|
if (fread(w->w3, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->hidden_dim)) return 1;
|
||||||
if (fread(w->rms_final_weight, sizeof(float), p->dim, f) != static_cast<size_t>(p->dim)) return 1;
|
if (fread(w->rms_final_weight, sizeof(float), p->dim, f) != static_cast<size_t>(p->dim)) return 1;
|
||||||
|
|
||||||
|
// Skip freq_cis_real & freq_cis_imag
|
||||||
|
int head_size = p->dim / p->n_heads;
|
||||||
|
fseek(f, p->seq_len * head_size * sizeof(float), SEEK_CUR);
|
||||||
|
|
||||||
|
if (!shared_weights && fread(w->wcls, sizeof(float), p->vocab_size * p->dim, f) != static_cast<size_t>(p->vocab_size * p->dim)) return 1;
|
||||||
|
|
||||||
|
// Check we didn't forget to read anything
|
||||||
|
auto curr = ftell(f);
|
||||||
|
fseek(f, 0, SEEK_END);
|
||||||
|
auto end = ftell(f);
|
||||||
|
if (curr != end) {
|
||||||
|
printf("Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", curr, end);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,6 +138,7 @@ void free_weights(TransformerWeights* w) {
|
||||||
delete w->w2;
|
delete w->w2;
|
||||||
delete w->w3;
|
delete w->w3;
|
||||||
delete w->rms_final_weight;
|
delete w->rms_final_weight;
|
||||||
|
if (w->wcls) delete w->wcls;
|
||||||
}
|
}
|
||||||
|
|
||||||
void print_sample_weights(TransformerWeights *w){
|
void print_sample_weights(TransformerWeights *w){
|
||||||
|
@ -131,6 +155,7 @@ void print_sample_weights(TransformerWeights *w){
|
||||||
printf("%f\n", w->w2[0]);
|
printf("%f\n", w->w2[0]);
|
||||||
printf("%f\n", w->w3[0]);
|
printf("%f\n", w->w3[0]);
|
||||||
printf("%f\n", w->rms_att_weight[0]);
|
printf("%f\n", w->rms_att_weight[0]);
|
||||||
|
if (w->wcls) printf("%f\n", w->wcls[0]);
|
||||||
}
|
}
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
@ -617,7 +642,7 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
|
||||||
// // w->token_embedding_table -> model->tok_embeddings
|
// // w->token_embedding_table -> model->tok_embeddings
|
||||||
// // float* -> struct ggml_tensor
|
// // float* -> struct ggml_tensor
|
||||||
// stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table);
|
// stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table);
|
||||||
// stuff_karpathy_weights_into_gg(model->output, w->token_embedding_table);
|
// stuff_karpathy_weights_into_gg(model->output, w->wcls ? w->wcls : w->token_embedding_table);
|
||||||
//
|
//
|
||||||
// stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight);
|
// stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight);
|
||||||
// //print_row(model->norm, 0);
|
// //print_row(model->norm, 0);
|
||||||
|
@ -791,9 +816,12 @@ int main(int argc, char ** argv) {
|
||||||
if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; }
|
if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; }
|
||||||
// read in the config header
|
// read in the config header
|
||||||
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
|
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
|
||||||
|
auto shared_weights = config.vocab_size > 0;
|
||||||
|
config.vocab_size = abs(config.vocab_size);
|
||||||
|
|
||||||
// read in the Transformer weights
|
// read in the Transformer weights
|
||||||
malloc_weights(&weights, &config);
|
malloc_weights(&weights, &config, shared_weights);
|
||||||
if(checkpoint_init_weights(&weights, &config, file)) { return 1; }
|
if(checkpoint_init_weights(&weights, &config, file, shared_weights)) { return 1; }
|
||||||
fclose(file);
|
fclose(file);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue