Commit graph

603 commits

Author SHA1 Message Date
xaedes
70c08318af
test flash attention backward pass
need to set loose error bounds to pass.
the finitie differences are close to numeric limits and often return quite different values than the backward pass.
reducing eps further lets the gradients vanish completely.
likewise setting eps to big results in wronger values.
the softmax in the middle of the function is probably the most responsible for the numeric issues using finite differences.
2023-05-29 23:51:40 +02:00
xaedes
38560b6d51
bugfixes for backward pass of flash attention 2023-05-29 23:45:58 +02:00
xaedes
22a7279ffb
implement backward pass of flash attention 2023-05-29 22:00:40 +02:00
xaedes
56895e28f6
get vocabulary for exporting training checkpoint to llama compatible model file 2023-05-29 02:25:18 +02:00
xaedes
4b81c32d5b
add export of training checkpoint to llama compatible model file 2023-05-29 01:27:09 +02:00
xaedes
2da5c8cf24
set default model.type for unknown models with few layers 2023-05-29 01:21:01 +02:00
xaedes
bf4d9b3b81
add llama_get_vocab to get the vocabulary as output parameters 2023-05-29 01:20:26 +02:00
xaedes
89475fb320
slightly improve how cross entropy loss is compute
btw: directly implemented cross entropy loss seems to have way lower magnitudes than when implemented with softmax and log.
probably the input to log gets closer to zero due to float numerics.
maybe the multiplication by (1.0-eps)/sum is more accurate..
2023-05-28 22:40:58 +02:00
xaedes
5f5aa20078
remove trailing whitespace 2023-05-28 22:00:56 +02:00
xaedes
1fbd19abe1
use ggml_cross_entropy_loss in text training example 2023-05-28 22:00:26 +02:00
xaedes
f056a04a80
add tests for cross_entropy_loss backward pass
finite differences regularly results in estimated gradient of zero, despite the backward pass giving non zero gradient.
_probably_ the finite differences fails due to numerical issues
2023-05-28 21:59:17 +02:00
xaedes
71aaf8dedf
add ggml_cross_entropy_loss with backward pass for faster training
cross entropy loss can also be implemented using softmax and log, but as dedicated operation it is faster and especially avoids unnecessary memory overhead.
2023-05-28 21:57:38 +02:00
xaedes
05cb629c8e
replace inefficient repeat backward pass with dedicated repeat_back operation 2023-05-28 18:00:17 +02:00
xaedes
c47df09842
simplify backward pass for SQRT 2023-05-28 17:32:01 +02:00
xaedes
6d40cc3a44
remove trailing whitespace 2023-05-22 20:56:35 +02:00
xaedes
d3acbf644e
simplify code 2023-05-22 20:53:57 +02:00
xaedes
0651679302
save checkpoint only when it was trained 2023-05-22 16:56:28 +02:00
xaedes
cc440bd438
fix bug in get_samples which corrupted training targets 2023-05-22 16:55:52 +02:00
xaedes
b763d6f1f2
remove unused functions 2023-05-22 16:54:21 +02:00
xaedes
42d9b4cfc2
store optimizer state in training checkpoint and add learning schedule
persistent optimizer state allows to resume training without resetting the optimizer
learning schedule consists of linear warmup ramp followed by cosine decay with restarts
2023-05-21 21:36:04 +02:00
xaedes
37c69435f0
print suppressed newline tokens as string "\n"
printing too much actual newlines is suppressed to avoid flooding the console.
2023-05-21 21:17:46 +02:00
xaedes
93eb8f7752
add forward function without using cache, for more performant training
during training on whole samples no cache is required.
removing the cache and simplifying the remaining code results in performance and memory usage improvement.
2023-05-21 21:14:49 +02:00
xaedes
2afd218479
fix bug in llama_sample_token_mirostat_v2
when all candidates are filtered out through mu threshold, the following soft_max operation will fail.
so keep at least one.
2023-05-21 21:12:10 +02:00
xaedes
ec1783c3e0
add ggml_opt_context, so that we can properly resume training
otherwise the optimizer states, tracking statistics about the error function and its derivates,
will reset to zero each time ggml_opt is called, hindering convergence on resumed training.

now the optimizer context and all its memory is stored in a separate struct.
2023-05-21 21:10:16 +02:00
xaedes
1eee9255e7
add missing default parameters for adam optimizer 2023-05-21 15:03:51 +02:00
xaedes
57c2f4f909
fix random weight initialization scale 2023-05-21 12:18:47 +02:00
xaedes
96514971dd
use inplace operations in cross_entropy_loss 2023-05-21 12:17:57 +02:00
xaedes
ef17d99f65
implement AdamW in ggml_opt_adam by adding weight decay parameter (default 0.001f)
also add a schedule parameter (default 1.0f) that can be used to scale alpha and decay according to learning schedule.
setting the decay parameter to zero disables AdamW resulting in normal Adam optimizer.

since the difference between Adam and AdamW is minimal it is not implemented as another optimizer, but integrated into the existing Adam optimizer.
2023-05-20 14:54:57 +02:00
xaedes
f4e9ce7998
enable gradient propagation for inplace add1 and scale operations
those functions backward passes don't need the original src0, so they also work when forward is inplace
2023-05-20 14:49:30 +02:00
xaedes
a6aafdd719
add ggml_add1_inplace to header 2023-05-20 14:47:56 +02:00
xaedes
08a330a136
add cmake target for baby-llama-text 2023-05-19 18:41:26 +02:00
xaedes
332003584e
sample with non-greedy sampling parameters at the end of training 2023-05-19 18:41:06 +02:00
xaedes
e19ead6e3f
print used memory before and after optimization 2023-05-19 18:40:20 +02:00
xaedes
da86a1d736
fix cross entropy loss
- add target probabilities for each sample which is then used in cross entropy loss
2023-05-19 18:39:38 +02:00
xaedes
09b304d015
remove duplicate include 2023-05-19 18:36:05 +02:00
xaedes
37f5b76df1
ggml fixes to support backward pass on inplace operations 2023-05-19 18:35:40 +02:00
xaedes
44d83558bc
use different arguments for input and output checkpoint 2023-05-19 18:34:18 +02:00
xaedes
d8b0666429
initialize rng with srand 2023-05-19 18:29:47 +02:00
xaedes
25fe1c3815
use inplace functions where possible 2023-05-19 14:53:21 +02:00
xaedes
b241b9cb6c
save train trained model to checkpoint and load model to be trained from checkpoint 2023-05-17 13:49:32 +02:00
xaedes
d328472f16
fix get_samples call, add model tensor names, increase model size, start training samples after newline 2023-05-17 12:52:20 +02:00
xaedes
e063135d0b
add llama sampler, shuffle samples and constrain sampling to tokens occurring in train data 2023-05-15 21:12:28 +02:00
xaedes
ec881156f6
improve ggml_out_prod performance
- change iteration order (>15s -> 10s runtime)
- parallelize over one more dimension: over dst matrix rows (10s -> <5s runtime)
2023-05-15 14:42:24 +02:00
xaedes
19fb91899b
better weight initialization improves training convergence at start 2023-05-15 14:19:38 +02:00
xaedes
f3cf7df21f
better weight initialization improves training convergence at start 2023-05-15 14:18:57 +02:00
xaedes
efa4bb78ea
add ggml_out_prod and use it for mul_mat backward pass for improved performance
performance stats report improvement from 37 seconds to 16 seconds runtime during my training tests
2023-05-15 14:17:42 +02:00
xaedes
a703d7a85f
activate threading in baby-llama-text 2023-05-14 21:00:55 +02:00
xaedes
d9b5268728
avoid printing too much newlines in baby-llama-text 2023-05-14 20:57:47 +02:00
xaedes
c054079fb8
improve performance of mul_mat backward pass
avoid transpose by using mul_mat with swapped arguments
2023-05-14 20:56:50 +02:00
xaedes
1f2b76de01
fix bug in ggml_compute_forward_soft_max_back_f32 on DEBUG build 2023-05-14 20:55:24 +02:00