llama/ggml: add LLM training support

more compact progress bar

refactor: llama_prepare_sbatch/ubatch

llama_save_model_to_file

gqa_mode arg for repeat_back

llama_opt_param_filter

ggml_graph_dup force_grads

refactor ggml_opt, fix test-opt
This commit is contained in:
Johannes Gäßler 2024-11-17 14:58:51 +01:00
parent a5203b4465
commit c25557362a
26 changed files with 1294 additions and 339 deletions

View file

@ -37,13 +37,16 @@ extern "C" {
// ====== Dataset ======
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
int64_t ne_datapoint, // number of elements per datapoint
int64_t ne_label, // number of elements per label
int64_t ndata, // total number of datapoints/labels
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
enum ggml_type type_data, // the type for the internal data tensor
enum ggml_type type_label, // the type for the internal labels tensor
int64_t ne_datapoint, // number of elements per datapoint
int64_t ne_label, // number of elements per label
int64_t ndata, // total number of datapoints/labels
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
// get underlying tensors that store the data
GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
@ -56,13 +59,19 @@ extern "C" {
struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
int64_t ibatch);
GGML_API void ggml_opt_dataset_get_batch_host(
ggml_opt_dataset_t dataset,
void * data_batch,
size_t nb_data_batch,
void * labels_batch,
int64_t ibatch);
// ====== Model / Context ======
enum ggml_opt_build_type {
GGML_OPT_BUILD_TYPE_FORWARD,
GGML_OPT_BUILD_TYPE_GRAD,
GGML_OPT_BUILD_TYPE_OPT,
GGML_OPT_BUILD_TYPE_FORWARD = 10,
GGML_OPT_BUILD_TYPE_GRAD = 20,
GGML_OPT_BUILD_TYPE_OPT = 30,
};
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@ -81,20 +90,22 @@ extern "C" {
// userdata can be used to pass arbitrary data
typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
// returns the default optimizer params (constant)
// returns the default optimizer params (constant, hard-coded values)
// userdata is not used
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
// casts userdata to ggml_opt_optimizer_params and returns it
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
// parameters for initializing a new optimization context
struct ggml_opt_params {
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
// the forward graph is defined by inputs and outputs
// those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
struct ggml_tensor * inputs;
struct ggml_tensor * outputs;
// by default the forward graph needs to be reconstructed for each eval
// if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
struct ggml_context * ctx_compute;
struct ggml_tensor * inputs;
struct ggml_tensor * outputs;
enum ggml_opt_loss_type loss_type;
enum ggml_opt_build_type build_type;
@ -107,12 +118,9 @@ extern "C" {
// get parameters for an optimization context with defaults set where possible
// parameters for which no sensible defaults exist are supplied as arguments to this function
GGML_API ggml_opt_params ggml_opt_default_params(
ggml_backend_sched_t backend_sched,
struct ggml_context * ctx_compute,
struct ggml_tensor * inputs,
struct ggml_tensor * outputs,
enum ggml_opt_loss_type loss_type);
GGML_API struct ggml_opt_params ggml_opt_default_params(
ggml_backend_sched_t backend_sched,
enum ggml_opt_loss_type loss_type);
GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
@ -121,6 +129,7 @@ extern "C" {
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
// get underlying tensors that store data
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
@ -128,11 +137,12 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
// get the gradient accumulator for a node from the forward graph
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
// ====== Optimization Result ======
GGML_API ggml_opt_result_t ggml_opt_result_init();
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
@ -144,11 +154,20 @@ extern "C" {
// ====== Computation ======
// do forward pass, increment result if not NULL
GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
// if not using static graphs, this function must be called prior to ggml_opt_alloc
GGML_API void ggml_opt_prepare_alloc(
ggml_opt_context_t opt_ctx,
struct ggml_context * ctx_compute,
struct ggml_cgraph * gf,
struct ggml_tensor * inputs,
struct ggml_tensor * outputs);
// do forward pass, increment result if not NULL, do backward pass
GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
// allocate the next graph for evaluation, either forward or forward + backward
// must be called exactly once prior to calling ggml_opt_eval
GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
// do forward pass, increment result if not NULL, do backward pass if allocated
GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
// ############################################################################
// ## The high-level functions start here. They do not depend on any private ##
@ -200,9 +219,9 @@ extern "C" {
// fit model defined by inputs and outputs to dataset
GGML_API void ggml_opt_fit(
ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
enum ggml_opt_loss_type loss_type, // loss to minimize
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)