fix training get_example_targets

predict the next token, not the current token!
This commit is contained in:
xaedes 2023-05-07 01:18:17 +02:00
parent 80223d98fd
commit 73fd66e9e5
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -3,6 +3,11 @@
#include <assert.h>
#include <random>
#undef MIN
#undef MAX
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
float frand() {
return (float)rand()/(float)RAND_MAX;
}
@ -576,15 +581,18 @@ void get_example_targets(int example_id, struct ggml_tensor * tokens_input, stru
int n_vocab = targets->ne[0];
float randomness = 0.0f;
ggml_set_zero(targets);
for (int i=0; i<n_tokens; ++i) {
float x = example_id + i * 3.14159f * 2.0f * 4.0f / n_tokens;
ggml_set_i32_1d(tokens_input, 0, 0);
for (int i=1; i<n_tokens+1; ++i) {
float x = example_id + i * 3.14159f * 2.0f * 1.0f / n_tokens;
float y = sinf(x);//*cosf(x*1.1f+1.0f);
float z = (y+1.0f)*0.5f; // scale to [0..1]
z += (frand()-0.5f)*(randomness/n_tokens);
z += (frand()-0.5f)*(randomness/n_vocab);
z = (z < 0.0f) ? 0.0f : (z > 1.0f) ? 1.0f : z; // clamp to [0..1]
int token = (int)(z*(float)(n_vocab-1));
ggml_set_f32_1d(targets, i*n_vocab + token, +1.0f);
ggml_set_i32_1d(tokens_input, i, token);
int token = MAX(1,MIN(1+(int)(z*(float)(n_vocab-1)), n_vocab-1));
ggml_set_f32_1d(targets, (i-1)*n_vocab + token, +1.0f);
if (i<n_tokens) {
ggml_set_i32_1d(tokens_input, i, token);
}
}
}
@ -601,7 +609,7 @@ void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * tar
int main(int argc, char ** argv) {
struct ggml_init_params lcparams;
lcparams.mem_size = 1024*1024*1024;
lcparams.mem_size = 1024ll*1024ll*1024ll;
lcparams.mem_buffer = NULL;
lcparams.no_alloc = false;
@ -634,7 +642,7 @@ int main(int argc, char ** argv) {
init_kv_cache(&kv_self, &model);
size_t compute_size = 1024*1024*1024;
size_t compute_size = 1024ll*1024ll*1024ll;
uint8_t * compute_addr = new uint8_t[compute_size];
int n_examples = 32;
@ -756,6 +764,7 @@ int main(int argc, char ** argv) {
print_token(token, model.hparams.n_vocab);
lshift_examples(tokens_input, targets, 1);
ggml_set_i32_1d(tokens_input, 0, 0);
ggml_set_i32_1d(tokens_input, sample_ctx-1, token);
// printf("---\n");