fix training get_example_targets
predict the next token, not the current token!
This commit is contained in:
parent
80223d98fd
commit
73fd66e9e5
1 changed files with 17 additions and 8 deletions
|
@ -3,6 +3,11 @@
|
|||
#include <assert.h>
|
||||
#include <random>
|
||||
|
||||
#undef MIN
|
||||
#undef MAX
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
float frand() {
|
||||
return (float)rand()/(float)RAND_MAX;
|
||||
}
|
||||
|
@ -576,15 +581,18 @@ void get_example_targets(int example_id, struct ggml_tensor * tokens_input, stru
|
|||
int n_vocab = targets->ne[0];
|
||||
float randomness = 0.0f;
|
||||
ggml_set_zero(targets);
|
||||
for (int i=0; i<n_tokens; ++i) {
|
||||
float x = example_id + i * 3.14159f * 2.0f * 4.0f / n_tokens;
|
||||
ggml_set_i32_1d(tokens_input, 0, 0);
|
||||
for (int i=1; i<n_tokens+1; ++i) {
|
||||
float x = example_id + i * 3.14159f * 2.0f * 1.0f / n_tokens;
|
||||
float y = sinf(x);//*cosf(x*1.1f+1.0f);
|
||||
float z = (y+1.0f)*0.5f; // scale to [0..1]
|
||||
z += (frand()-0.5f)*(randomness/n_tokens);
|
||||
z += (frand()-0.5f)*(randomness/n_vocab);
|
||||
z = (z < 0.0f) ? 0.0f : (z > 1.0f) ? 1.0f : z; // clamp to [0..1]
|
||||
int token = (int)(z*(float)(n_vocab-1));
|
||||
ggml_set_f32_1d(targets, i*n_vocab + token, +1.0f);
|
||||
ggml_set_i32_1d(tokens_input, i, token);
|
||||
int token = MAX(1,MIN(1+(int)(z*(float)(n_vocab-1)), n_vocab-1));
|
||||
ggml_set_f32_1d(targets, (i-1)*n_vocab + token, +1.0f);
|
||||
if (i<n_tokens) {
|
||||
ggml_set_i32_1d(tokens_input, i, token);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -601,7 +609,7 @@ void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * tar
|
|||
|
||||
int main(int argc, char ** argv) {
|
||||
struct ggml_init_params lcparams;
|
||||
lcparams.mem_size = 1024*1024*1024;
|
||||
lcparams.mem_size = 1024ll*1024ll*1024ll;
|
||||
lcparams.mem_buffer = NULL;
|
||||
lcparams.no_alloc = false;
|
||||
|
||||
|
@ -634,7 +642,7 @@ int main(int argc, char ** argv) {
|
|||
init_kv_cache(&kv_self, &model);
|
||||
|
||||
|
||||
size_t compute_size = 1024*1024*1024;
|
||||
size_t compute_size = 1024ll*1024ll*1024ll;
|
||||
uint8_t * compute_addr = new uint8_t[compute_size];
|
||||
|
||||
int n_examples = 32;
|
||||
|
@ -756,6 +764,7 @@ int main(int argc, char ** argv) {
|
|||
print_token(token, model.hparams.n_vocab);
|
||||
|
||||
lshift_examples(tokens_input, targets, 1);
|
||||
ggml_set_i32_1d(tokens_input, 0, 0);
|
||||
ggml_set_i32_1d(tokens_input, sample_ctx-1, token);
|
||||
|
||||
// printf("---\n");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue