This commit is contained in:
Johannes Gäßler 2025-02-10 11:26:56 +01:00 committed by GitHub
commit c67beb169e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 1294 additions and 339 deletions

View file

@ -2196,3 +2196,19 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
return result;
}
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
const int64_t ne_datapoint = llama_n_ctx(ctx);
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
ggml_opt_dataset_t result = ggml_opt_dataset_init(
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
for (int64_t idata = 0; idata < ndata; ++idata) {
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
}
return result;
}

View file

@ -715,3 +715,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
}
//
// training utils
//
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);

View file

@ -53,6 +53,7 @@ else()
add_subdirectory(tokenize)
add_subdirectory(tts)
add_subdirectory(gen-docs)
add_subdirectory(training)
if (NOT GGML_BACKEND_DL)
# these examples use the backends directly and cannot be built with dynamic loading
add_subdirectory(convert-llama2c-to-ggml)

View file

@ -0,0 +1,5 @@
set(TARGET llama-finetune)
add_executable(${TARGET} finetune.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -0,0 +1,17 @@
# llama.cpp/examples/training
This directory contains examples related to language model training using llama.cpp/GGML.
So far finetuning is technically functional (for FP32 models and limited hardware setups) but the code is very much WIP.
Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory.
**For CPU training, compile llama.cpp without any additional backends such as CUDA.**
**For CUDA training, use the maximum number of GPU layers.**
Proof of concept:
``` sh
export model_name=llama_3.2-1b && export quantization=f32
./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
```
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.

View file

@ -0,0 +1,97 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <vector>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
int main(int argc, char ** argv) {
common_params params;
params.logits_all = true;
params.escape = false;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
return 1;
}
if (params.use_mmap) {
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
params.use_mmap = false;
}
if (params.cache_type_k == GGML_TYPE_F16) {
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_k = GGML_TYPE_F32;
}
if (params.cache_type_v == GGML_TYPE_F16) {
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_v = GGML_TYPE_F32;
}
common_init();
llama_backend_init();
llama_numa_init(params.numa);
// load the model and apply lora adapter, if any
common_init_result llama_init = common_init_from_params(params);
llama_model_ptr & model = llama_init.model;
llama_context_ptr & ctx = llama_init.context;
if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__);
return 1;
}
// print system information
{
LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
}
constexpr float val_split = 0.05f;
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
optimizer_params.adamw.alpha = 1e-7f; // learning rate
struct llama_opt_params lopt_params {
/*n_ctx_train =*/ 0,
/*param_filter =*/ llama_opt_param_filter_all,
/*param_filter_ud =*/ nullptr,
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
/*get_opt_pars_ud =*/ &optimizer_params,
};
llama_opt_init(ctx.get(), model.get(), lopt_params);
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
ggml_opt_result_t result_train = ggml_opt_result_init();
ggml_opt_result_t result_eval = ggml_opt_result_init();
for (int epoch = 0; epoch < 2; ++epoch) {
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
fprintf(stderr, "\n");
ggml_opt_result_reset(result_train);
ggml_opt_result_reset(result_eval);
}
ggml_opt_result_free(result_train);
ggml_opt_result_free(result_eval);
llama_model_save_to_file(model.get(), "finetuned-model.gguf");
llama_backend_free();
return 0;
}

View file

@ -37,13 +37,16 @@ extern "C" {
// ====== Dataset ======
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
int64_t ne_datapoint, // number of elements per datapoint
int64_t ne_label, // number of elements per label
int64_t ndata, // total number of datapoints/labels
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
enum ggml_type type_data, // the type for the internal data tensor
enum ggml_type type_label, // the type for the internal labels tensor
int64_t ne_datapoint, // number of elements per datapoint
int64_t ne_label, // number of elements per label
int64_t ndata, // total number of datapoints/labels
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
// get underlying tensors that store the data
GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
@ -56,13 +59,19 @@ extern "C" {
struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
int64_t ibatch);
GGML_API void ggml_opt_dataset_get_batch_host(
ggml_opt_dataset_t dataset,
void * data_batch,
size_t nb_data_batch,
void * labels_batch,
int64_t ibatch);
// ====== Model / Context ======
enum ggml_opt_build_type {
GGML_OPT_BUILD_TYPE_FORWARD,
GGML_OPT_BUILD_TYPE_GRAD,
GGML_OPT_BUILD_TYPE_OPT,
GGML_OPT_BUILD_TYPE_FORWARD = 10,
GGML_OPT_BUILD_TYPE_GRAD = 20,
GGML_OPT_BUILD_TYPE_OPT = 30,
};
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@ -81,20 +90,22 @@ extern "C" {
// userdata can be used to pass arbitrary data
typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
// returns the default optimizer params (constant)
// returns the default optimizer params (constant, hard-coded values)
// userdata is not used
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
// casts userdata to ggml_opt_optimizer_params and returns it
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
// parameters for initializing a new optimization context
struct ggml_opt_params {
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
// the forward graph is defined by inputs and outputs
// those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
struct ggml_tensor * inputs;
struct ggml_tensor * outputs;
// by default the forward graph needs to be reconstructed for each eval
// if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
struct ggml_context * ctx_compute;
struct ggml_tensor * inputs;
struct ggml_tensor * outputs;
enum ggml_opt_loss_type loss_type;
enum ggml_opt_build_type build_type;
@ -107,12 +118,9 @@ extern "C" {
// get parameters for an optimization context with defaults set where possible
// parameters for which no sensible defaults exist are supplied as arguments to this function
GGML_API ggml_opt_params ggml_opt_default_params(
ggml_backend_sched_t backend_sched,
struct ggml_context * ctx_compute,
struct ggml_tensor * inputs,
struct ggml_tensor * outputs,
enum ggml_opt_loss_type loss_type);
GGML_API struct ggml_opt_params ggml_opt_default_params(
ggml_backend_sched_t backend_sched,
enum ggml_opt_loss_type loss_type);
GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
@ -121,6 +129,7 @@ extern "C" {
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
// get underlying tensors that store data
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
@ -128,11 +137,12 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
// get the gradient accumulator for a node from the forward graph
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
// ====== Optimization Result ======
GGML_API ggml_opt_result_t ggml_opt_result_init();
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
@ -144,11 +154,20 @@ extern "C" {
// ====== Computation ======
// do forward pass, increment result if not NULL
GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
// if not using static graphs, this function must be called prior to ggml_opt_alloc
GGML_API void ggml_opt_prepare_alloc(
ggml_opt_context_t opt_ctx,
struct ggml_context * ctx_compute,
struct ggml_cgraph * gf,
struct ggml_tensor * inputs,
struct ggml_tensor * outputs);
// do forward pass, increment result if not NULL, do backward pass
GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
// allocate the next graph for evaluation, either forward or forward + backward
// must be called exactly once prior to calling ggml_opt_eval
GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
// do forward pass, increment result if not NULL, do backward pass if allocated
GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
// ############################################################################
// ## The high-level functions start here. They do not depend on any private ##
@ -200,9 +219,9 @@ extern "C" {
// fit model defined by inputs and outputs to dataset
GGML_API void ggml_opt_fit(
ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
enum ggml_opt_loss_type loss_type, // loss to minimize
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)

View file

@ -763,7 +763,7 @@ extern "C" {
// Tensor flags
GGML_API void ggml_set_input(struct ggml_tensor * tensor);
GGML_API void ggml_set_output(struct ggml_tensor * tensor);
GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
GGML_API void ggml_set_param(struct ggml_tensor * tensor);
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
//
@ -933,7 +933,7 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_repeat_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
// concat a and b along dim
// used in stable-diffusion
@ -2054,15 +2054,14 @@ extern "C" {
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API void ggml_build_backward_expand(
struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
struct ggml_context * ctx_compute, // context for gradient computation
struct ggml_cgraph * cgraph,
bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
struct ggml_context * ctx, // context for gradient computation
struct ggml_cgraph * cgraph,
struct ggml_tensor ** grad_accs);
// graph allocation in a context
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads);
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);

View file

@ -1107,7 +1107,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
const int node_backend_id = tensor_backend_id(node);
assert(node_backend_id != -1); // all nodes should be assigned by now
assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
// check if we should start a new split based on the sources of the current node
bool need_new_split = false;

View file

@ -28,16 +28,19 @@ struct ggml_opt_dataset {
};
struct ggml_opt_context {
ggml_backend_sched_t backend_sched = nullptr;
ggml_cgraph * allocated_graph = nullptr;
ggml_cgraph * allocated_graph_copy = nullptr;
struct ggml_context * ctx_static = nullptr;
struct ggml_context * ctx_static_cpu = nullptr;
struct ggml_context * ctx_compute = nullptr;
struct ggml_context * ctx_copy = nullptr;
ggml_backend_buffer_t buf_static = nullptr;
ggml_backend_buffer_t buf_static_cpu = nullptr;
std::mt19937 rng;
ggml_backend_sched_t backend_sched = nullptr;
ggml_cgraph * allocated_graph = nullptr;
ggml_cgraph * allocated_graph_copy = nullptr;
struct ggml_context * ctx_static = nullptr;
struct ggml_context * ctx_cpu = nullptr;
struct ggml_context * ctx_compute = nullptr;
struct ggml_context * ctx_copy = nullptr;
ggml_backend_buffer_t buf_static = nullptr;
ggml_backend_buffer_t buf_cpu = nullptr;
std::mt19937 rng;
enum ggml_opt_loss_type loss_type;
enum ggml_opt_build_type build_type;
enum ggml_opt_build_type build_type_alloc;
struct ggml_tensor * inputs = nullptr;
struct ggml_tensor * outputs = nullptr;
@ -50,6 +53,11 @@ struct ggml_opt_context {
struct ggml_cgraph * gf = nullptr;
struct ggml_cgraph * gb_grad = nullptr;
struct ggml_cgraph * gb_opt = nullptr;
bool static_graphs = false;
bool eval_ready = false;
std::vector<struct ggml_tensor *> grad_accs;
std::vector<struct ggml_tensor *> grad_m;
std::vector<struct ggml_tensor *> grad_v;
int64_t iter = 1;
int32_t opt_period = 1;
@ -73,7 +81,13 @@ struct ggml_opt_result {
// ====== Dataset ======
ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
ggml_opt_dataset_t ggml_opt_dataset_init(
enum ggml_type type_data,
enum ggml_type type_label,
int64_t ne_datapoint,
int64_t ne_label,
int64_t ndata,
int64_t ndata_shard) {
GGML_ASSERT(ne_datapoint > 0);
GGML_ASSERT(ne_label >= 0);
GGML_ASSERT(ndata > 0);
@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label,
result->ctx = ggml_init(params);
}
result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata);
result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
if (ne_label > 0) {
result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata);
result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
} else {
result->labels = nullptr;
@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
delete dataset;
}
int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) {
return dataset->ndata;
}
struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
return dataset->data;
}
@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch));
GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
GGML_ASSERT( data_batch->type == dataset->data->type);
GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
const size_t nb_data_batch = ggml_nbytes(data_batch);
GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
}
}
void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data;
char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data;
memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
if (!labels_batch) {
continue;
}
const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels;
char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels;
memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
}
}
// ====== Model / Context ======
struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
return result;
}
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
return *((struct ggml_opt_optimizer_params *) userdata);
}
struct ggml_opt_params ggml_opt_default_params(
ggml_backend_sched_t backend_sched,
struct ggml_context * ctx_compute,
struct ggml_tensor * inputs,
struct ggml_tensor * outputs,
enum ggml_opt_loss_type loss_type) {
return {
/*backend_sched =*/ backend_sched,
/*ctx_compute =*/ ctx_compute,
/*inputs =*/ inputs,
/*logits =*/ outputs,
/*ctx_compute =*/ nullptr,
/*inputs =*/ nullptr,
/*logits =*/ nullptr,
/*loss_type =*/ loss_type,
/*build_type =*/ GGML_OPT_BUILD_TYPE_OPT,
/*opt_period =*/ 1,
@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
return dst;
}
static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
GGML_ASSERT(graph);
if (opt_ctx->allocated_graph == graph) {
return;
}
static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
{
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE,
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ggml_free(opt_ctx->ctx_copy);
opt_ctx->ctx_copy = ggml_init(params);
}
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
opt_ctx->allocated_graph = graph;
}
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
ggml_opt_context_t result = new struct ggml_opt_context;
result->backend_sched = params.backend_sched;
result->ctx_compute = params.ctx_compute;
result->inputs = params.inputs;
result->outputs = params.outputs;
result->opt_period = params.opt_period;
result->get_opt_pars = params.get_opt_pars;
result->get_opt_pars_ud = params.get_opt_pars_ud;
GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
GGML_ASSERT(result->opt_period >= 1);
const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD ||
(params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
ggml_set_input(result->inputs);
ggml_set_output(result->outputs);
result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
ggml_build_forward_expand(result->gf, result->outputs);
ggml_set_input(opt_ctx->inputs);
ggml_set_output(opt_ctx->outputs);
int n_param = 0;
for (int i = 0; i < result->gf->n_nodes; ++i) {
if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
const struct ggml_tensor * node = opt_ctx->gf->nodes[i];
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
n_param++;
}
GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
}
{
if (!opt_ctx->ctx_static) {
// The static context is used for:
// - gradients (1 tensor per param if using gradient accumulation)
// - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
// - optimizer momenta (2 tensors per param)
// - labels
// - loss + its gradient (up to 5 tensors)
// - pred
// - ncorrect (2 tensors).
const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead();
// - labels (if using static graphs)
// - loss (if using static graphs, up to 5 tensors)
// - pred (if using static graphs)
// - ncorrect (if using static graphs, 2 tensors).
constexpr size_t n_loss = 1;
const size_t tensors_per_param = (accumulate ? 1 : 0) +
(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
struct ggml_init_params params = {
/*.mem_size =*/ size_meta,
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
result->ctx_static = ggml_init(params);
opt_ctx->ctx_static = ggml_init(params);
}
GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
{
// The static cpu context is used for:
// - optimizer parameters (1 for the entire context)
// The cpu context is allocated statically if using static graphs, dynamically otherwise.
// It is used for:
// - optimizer parameters (1 shared for all optimizer invocations)
const size_t size_meta = 1 * ggml_tensor_overhead();
struct ggml_init_params params = {
/*.mem_size =*/ size_meta,
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
result->ctx_static_cpu = ggml_init(params);
ggml_free(opt_ctx->ctx_cpu);
opt_ctx->ctx_cpu = ggml_init(params);
ggml_backend_buffer_free(opt_ctx->buf_cpu);
opt_ctx->buf_cpu = nullptr;
}
struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
switch (params.loss_type) {
switch (opt_ctx->loss_type) {
case GGML_OPT_LOSS_TYPE_MEAN: {
result->loss = ggml_sum(result->ctx_static, result->outputs);
ggml_set_name(result->loss, "loss_sum");
const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
result->loss = ggml_scale(result->ctx_static, result->loss, scale);
ggml_set_name(result->loss, "loss_mean");
result->loss_per_datapoint = true;
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
ggml_set_name(opt_ctx->loss, "loss_sum");
const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
ggml_set_name(opt_ctx->loss, "loss_mean");
opt_ctx->loss_per_datapoint = true;
break;
}
case GGML_OPT_LOSS_TYPE_SUM: {
result->loss = ggml_sum(result->ctx_static, result->outputs);
ggml_set_name(result->loss, "loss_sum");
result->loss_per_datapoint = false;
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
ggml_set_name(opt_ctx->loss, "loss_sum");
opt_ctx->loss_per_datapoint = false;
break;
}
case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
ggml_set_input(result->labels);
ggml_set_name(result->labels, "labels");
result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
ggml_set_name(result->loss, "loss_cross_entropy");
if (result->opt_period > 1) {
result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
ggml_set_name(result->loss, "loss_cross_entropy_scaled");
opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
ggml_set_input(opt_ctx->labels);
ggml_set_name(opt_ctx->labels, "labels");
opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
if (opt_ctx->opt_period > 1) {
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
}
result->loss_per_datapoint = true;
opt_ctx->loss_per_datapoint = true;
break;
}
case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
ggml_set_input(result->labels);
ggml_set_name(result->labels, "labels");
result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels);
ggml_set_name(result->loss, "loss_error");
result->loss = ggml_sqr(result->ctx_static, result->loss);
ggml_set_name(result->loss, "loss_squared_error");
result->loss = ggml_sum(result->ctx_static, result->loss);
ggml_set_name(result->loss, "loss_sum_squared_error");
const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
result->loss = ggml_scale(result->ctx_static, result->loss, scale);
ggml_set_name(result->loss, "loss_mean_squared_error");
result->loss_per_datapoint = true;
opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
ggml_set_input(opt_ctx->labels);
ggml_set_name(opt_ctx->labels, "labels");
opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
ggml_set_name(opt_ctx->loss, "loss_error");
opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss);
ggml_set_name(opt_ctx->loss, "loss_squared_error");
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss);
ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
opt_ctx->loss_per_datapoint = true;
break;
}
}
ggml_set_output(result->loss);
ggml_set_loss(result->loss);
ggml_build_forward_expand(result->gf, result->loss);
ggml_set_output(opt_ctx->loss);
ggml_set_loss(opt_ctx->loss);
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
result->pred = ggml_argmax(result->ctx_static, result->outputs);
ggml_set_name(result->pred, "pred");
ggml_set_output(result->pred);
ggml_build_forward_expand(result->gf, result->pred);
if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs);
ggml_set_name(opt_ctx->pred, "pred");
ggml_set_output(opt_ctx->pred);
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
if (result->labels) {
result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels));
ggml_set_name(result->ncorrect, "ncorrect");
ggml_set_output(result->ncorrect);
ggml_build_forward_expand(result->gf, result->ncorrect);
} else {
result->ncorrect = nullptr;
opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels));
ggml_set_name(opt_ctx->ncorrect, "ncorrect");
ggml_set_output(opt_ctx->ncorrect);
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
}
if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
return result;
if (opt_ctx->buf_static) {
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
return;
}
} else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) {
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
return;
}
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf);
ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
if (opt_ctx->grad_accs.empty()) {
GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD);
if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
ggml_graph_reset(result->gb_grad);
return result;
}
const int n_nodes = opt_ctx->gf->n_nodes;
opt_ctx->grad_accs.resize(n_nodes);
for (int i = 0; i < n_nodes; ++i) {
ggml_tensor * node = opt_ctx->gf->nodes[i];
if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
} else {
opt_ctx->grad_accs[i] = nullptr;
}
}
GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT);
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad);
result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7);
ggml_set_input(result->adamw_params);
ggml_set_name(result->adamw_params, "adamw_params");
for (int i = result->gf->n_nodes-1; i >= 0; --i) {
struct ggml_tensor * node = result->gb_opt->nodes[i];
struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node);
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node);
struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node);
struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
ggml_build_forward_expand(result->gb_opt, opt_step);
if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
opt_ctx->grad_m.resize(n_nodes);
opt_ctx->grad_v.resize(n_nodes);
for (int i = 0; i < n_nodes; ++i) {
ggml_tensor * node = opt_ctx->gf->nodes[i];
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
} else {
opt_ctx->grad_m[i] = nullptr;
opt_ctx->grad_v[i] = nullptr;
}
}
}
}
result->buf_static = ggml_backend_alloc_ctx_tensors(
result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
if (opt_ctx->buf_static) {
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) {
return;
}
} else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) {
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
ggml_graph_reset(opt_ctx->gb_grad);
}
ggml_graph_reset(result->gb_opt);
GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT);
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
ggml_set_input(opt_ctx->adamw_params);
ggml_set_name(opt_ctx->adamw_params, "adamw_params");
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
struct ggml_tensor * m = opt_ctx->grad_m[i];
struct ggml_tensor * v = opt_ctx->grad_v[i];
struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
}
}
if (!opt_ctx->buf_static) {
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
ggml_graph_reset(opt_ctx->gb_opt);
}
opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type());
}
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
ggml_opt_context_t result = new struct ggml_opt_context;
result->backend_sched = params.backend_sched;
result->ctx_compute = params.ctx_compute;
result->loss_type = params.loss_type;
result->build_type = params.build_type;
result->build_type_alloc = params.build_type;
result->inputs = params.inputs;
result->outputs = params.outputs;
result->opt_period = params.opt_period;
result->get_opt_pars = params.get_opt_pars;
result->get_opt_pars_ud = params.get_opt_pars_ud;
GGML_ASSERT(result->opt_period >= 1);
result->static_graphs = result->ctx_compute;
if (!result->static_graphs) {
GGML_ASSERT(!result->inputs);
GGML_ASSERT(!result->outputs);
return result;
}
GGML_ASSERT(result->inputs);
GGML_ASSERT(result->outputs);
result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
ggml_build_forward_expand(result->gf, result->outputs);
ggml_opt_build(result);
return result;
}
@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
return;
}
ggml_backend_buffer_free(opt_ctx->buf_static);
ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
ggml_backend_buffer_free(opt_ctx->buf_cpu);
ggml_free(opt_ctx->ctx_static);
ggml_free(opt_ctx->ctx_static_cpu);
ggml_free(opt_ctx->ctx_cpu);
delete opt_ctx;
}
@ -582,8 +679,80 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl
// ====== Computation ======
static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) {
if (graph != opt_ctx->gf) {
void ggml_opt_prepare_alloc(
ggml_opt_context_t opt_ctx,
struct ggml_context * ctx_compute,
struct ggml_cgraph * gf,
struct ggml_tensor * inputs,
struct ggml_tensor * outputs) {
GGML_ASSERT(!opt_ctx->static_graphs);
opt_ctx->ctx_compute = ctx_compute;
opt_ctx->gf = gf;
opt_ctx->inputs = inputs;
opt_ctx->outputs = outputs;
}
void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
GGML_ASSERT(!opt_ctx->eval_ready);
if (backward) {
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD;
} else {
opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD;
}
if (!opt_ctx->static_graphs) {
ggml_opt_build(opt_ctx);
}
struct ggml_cgraph * graph = nullptr;
switch (opt_ctx->build_type) {
case GGML_OPT_BUILD_TYPE_FORWARD: {
graph = opt_ctx->gf;
} break;
case GGML_OPT_BUILD_TYPE_GRAD: {
graph = opt_ctx->gb_grad;
} break;
case GGML_OPT_BUILD_TYPE_OPT: {
graph = opt_ctx->gb_opt;
} break;
}
GGML_ASSERT(graph);
if (opt_ctx->allocated_graph == graph) {
opt_ctx->eval_ready = true;
return;
}
ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
if (opt_ctx->static_graphs) {
ggml_init_params params = {
/*.mem_size =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ggml_free(opt_ctx->ctx_copy);
opt_ctx->ctx_copy = ggml_init(params);
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
} else {
opt_ctx->allocated_graph_copy = graph;
}
ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
opt_ctx->allocated_graph = graph;
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
ggml_graph_reset(opt_ctx->gb_grad);
}
opt_ctx->eval_ready = true;
}
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
GGML_ASSERT(opt_ctx->eval_ready);
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
@ -609,9 +778,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
adamw_par_data[6] = beta2h;
}
ggml_opt_alloc_graph(opt_ctx, graph);
ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
if (!opt_ctx->static_graphs) {
opt_ctx->gf = nullptr;
opt_ctx->gb_grad = nullptr;
opt_ctx->gb_opt = nullptr;
opt_ctx->allocated_graph = nullptr;
opt_ctx->allocated_graph_copy = nullptr;
}
opt_ctx->eval_ready = false;
if (!result) {
return;
@ -635,12 +814,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
result->loss.push_back(loss);
GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
std::vector<int32_t> pred(ndata);
ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
result->pred.insert(result->pred.end(), pred.begin(), pred.end());
if (opt_ctx->pred) {
GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
std::vector<int32_t> pred(ndata);
ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
result->pred.insert(result->pred.end(), pred.begin(), pred.end());
}
if (!opt_ctx->labels || result->ncorrect < 0) {
if (!opt_ctx->ncorrect || result->ncorrect < 0) {
result->ncorrect = -1;
return;
}
@ -652,26 +833,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
result->ncorrect += ncorrect;
}
void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
}
void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
if (opt_ctx->opt_period == 1) {
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
return;
}
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
if (opt_i_next == 0) {
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
} else {
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
}
opt_ctx->opt_i = opt_i_next;
}
// ====== High-Level Functions ======
void ggml_opt_epoch(
@ -700,16 +861,18 @@ void ggml_opt_epoch(
int64_t ibatch = 0;
int64_t t_loop_start = ggml_time_us();
for (; ibatch < ibatch_split; ++ibatch) {
ggml_opt_alloc(opt_ctx, /*backward =*/ true);
ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
ggml_opt_forward_backward(opt_ctx, result_train);
ggml_opt_eval(opt_ctx, result_train);
if (callback_train) {
callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
}
}
t_loop_start = ggml_time_us();
for (; ibatch < nbatches; ++ibatch) {
ggml_opt_alloc(opt_ctx, /*backward =*/ false);
ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
ggml_opt_forward(opt_ctx, result_eval);
ggml_opt_eval(opt_ctx, result_eval);
if (callback_eval) {
callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
}
@ -726,13 +889,26 @@ void ggml_opt_epoch_callback_progress_bar(
int64_t t_start_us) {
fprintf(stderr, "%s[", train ? "train: " : "val: ");
constexpr int64_t bar_length = 25;
// The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
constexpr int64_t bar_length = 8;
const int64_t ibatch8 = 8 * ibatch;
for (int64_t j = 0; j < bar_length; ++j) {
const int64_t ibatch_j = ibatch_max * j/bar_length;
if (ibatch_j < ibatch) {
fprintf(stderr, "=");
} else if (ibatch_max * (j - 1)/bar_length < ibatch) {
fprintf(stderr, ">");
if (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
fprintf(stderr, "\u2588"); // full block
} else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
fprintf(stderr, "\u2589"); // 7/8 filled
} else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
fprintf(stderr, "\u258A"); // 6/8 filled
} else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
fprintf(stderr, "\u258B"); // 5/8 filled
} else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
fprintf(stderr, "\u258C"); // 4/8 filled
} else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
fprintf(stderr, "\u258D"); // 3/8 filled
} else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
fprintf(stderr, "\u258E"); // 2/8 filled
} else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
fprintf(stderr, "\u258F"); // 1/8 filled
} else {
fprintf(stderr, " ");
}
@ -764,8 +940,8 @@ void ggml_opt_epoch_callback_progress_bar(
const int64_t t_eta_m = t_eta_s / 60;
t_eta_s -= t_eta_m * 60;
fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
if (ibatch == ibatch_max) {
@ -806,7 +982,10 @@ void ggml_opt_fit(
int64_t epoch = 1;
ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type);
params.ctx_compute = ctx_compute;
params.inputs = inputs;
params.outputs = outputs;
params.opt_period = opt_period;
params.get_opt_pars = get_opt_pars;
params.get_opt_pars_ud = &epoch;

View file

@ -5792,10 +5792,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
}
void ggml_build_backward_expand(
struct ggml_context * ctx_static,
struct ggml_context * ctx_compute,
struct ggml_cgraph * cgraph,
bool accumulate) {
struct ggml_context * ctx,
struct ggml_cgraph * cgraph,
struct ggml_tensor ** grad_accs) {
GGML_ASSERT(cgraph->n_nodes > 0);
GGML_ASSERT(cgraph->grads);
GGML_ASSERT(cgraph->grad_accs);
@ -5868,21 +5867,24 @@ void ggml_build_backward_expand(
GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
GGML_ASSERT(igrad != GGML_HASHSET_FULL);
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
cgraph->grads[igrad] = cgraph->grad_accs[igrad];
ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
GGML_ASSERT(ihash != GGML_HASHSET_FULL);
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
if (grad_accs && grad_accs[i]) {
cgraph->grad_accs[ihash] = grad_accs[i];
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
} else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
// loss tensors always need a gradient accumulator
cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
}
grads_needed[igrad] = true;
grads_needed[ihash] = true;
}
for (int i = n_nodes_f - 1; i >= 0; --i) {
// inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
// use allocator to automatically make inplace operations
ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
ggml_compute_backward(ctx, cgraph, i, grads_needed);
}
free(grads_needed);
@ -6028,8 +6030,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
}
}
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
ggml_graph_cpy(cgraph, result);
return result;
}
@ -6357,8 +6359,8 @@ void ggml_set_output(struct ggml_tensor * tensor) {
tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
}
void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
GGML_UNUSED(ctx); // TODO: remove this parameter
void ggml_set_param(struct ggml_tensor * tensor) {
GGML_ASSERT(tensor->op == GGML_OP_NONE);
tensor->flags |= GGML_TENSOR_FLAG_PARAM;
}

View file

@ -4,6 +4,7 @@
#include "ggml.h"
#include "ggml-cpu.h"
#include "ggml-backend.h"
#include "ggml-opt.h"
#include <stddef.h>
#include <stdint.h>
@ -429,6 +430,10 @@ extern "C" {
size_t n_paths,
struct llama_model_params params);
LLAMA_API void llama_model_save_to_file(
const struct llama_model * model,
const char * path_model);
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
"use llama_model_free instead");
@ -848,7 +853,7 @@ extern "C" {
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// Processes a batch of tokens with the encoder part of the encoder-decoder model.
// Stores the encoder output internally for later use by the decoder cross-attention layers.
// 0 - success
// < 0 - error. the KV cache state is restored to the state before this call
@ -856,7 +861,7 @@ extern "C" {
struct llama_context * ctx,
struct llama_batch batch);
// Positive return values does not mean a fatal error, but rather a warning.
// A positive return value does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// < 0 - error. the KV cache state is restored to the state before this call
@ -1328,6 +1333,37 @@ extern "C" {
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
//
// training
//
// function that returns whether or not a given tensor is a trainable parameter
typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
// always returns true
LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
struct llama_opt_params {
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
llama_opt_param_filter param_filter; // callback for determining which tensors are trainable parameters
void * param_filter_ud; // userdata for determining which tensors are trainable parameters
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
};
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
LLAMA_API void llama_opt_epoch(
struct llama_context * lctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
#ifdef __cplusplus
}
#endif

View file

@ -20,6 +20,7 @@ add_library(llama
llama-kv-cache.cpp
llama-mmap.cpp
llama-model-loader.cpp
llama-model-saver.cpp
llama-model.cpp
llama-quant.cpp
llama-sampling.cpp

View file

@ -583,6 +583,7 @@ void llama_output_reorder(struct llama_context & ctx) {
//
void llama_free(struct llama_context * ctx) {
ggml_opt_free(ctx->opt_ctx);
delete ctx;
}

View file

@ -8,6 +8,7 @@
#include "llama-adapter.h"
#include "ggml-cpp.h"
#include "ggml-opt.h"
#include <map>
#include <unordered_map>
@ -107,6 +108,9 @@ struct llama_context {
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
// training
ggml_opt_context_t opt_ctx = nullptr;
};
// TODO: make these methods of llama_context

View file

@ -301,12 +301,12 @@ namespace GGUFMeta {
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
switch (arr_info.gt) {
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
case GGUF_TYPE_INT32: GGML_ASSERT(
(std::is_same<T, int32_t>::value) ||
(std::is_same<T, uint32_t>::value)); break;
case GGUF_TYPE_UINT32:
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
(std::is_same<T, uint32_t>::value)); break;
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
default:
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
}
result.resize(arr_info.length);
@ -330,12 +330,12 @@ namespace GGUFMeta {
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
switch (arr_info.gt) {
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
case GGUF_TYPE_INT32: GGML_ASSERT(
(std::is_same<T, int32_t>::value) ||
(std::is_same<T, uint32_t>::value)); break;
case GGUF_TYPE_UINT32:
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
(std::is_same<T, uint32_t>::value)); break;
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
default:
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
}
if (arr_info.length > N_MAX) {

281
src/llama-model-saver.cpp Normal file
View file

@ -0,0 +1,281 @@
#include "llama-model-saver.h"
#include "gguf.h"
#include "llama.h"
#include "llama-hparams.h"
#include "llama-model.h"
#include "llama-vocab.h"
#include <string>
llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) {
gguf_ctx = gguf_init_empty();
}
llama_model_saver::~llama_model_saver() {
gguf_free(gguf_ctx);
}
void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) {
gguf_set_val_u32(gguf_ctx, llm_kv(key).c_str(), value);
}
void llama_model_saver::add_kv(const enum llm_kv key, const int32_t value) {
gguf_set_val_i32(gguf_ctx, llm_kv(key).c_str(), value);
}
void llama_model_saver::add_kv(const enum llm_kv key, const float value) {
gguf_set_val_f32(gguf_ctx, llm_kv(key).c_str(), value);
}
void llama_model_saver::add_kv(const enum llm_kv key, const bool value) {
gguf_set_val_bool(gguf_ctx, llm_kv(key).c_str(), value);
}
void llama_model_saver::add_kv(const enum llm_kv key, const char * value) {
gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), value);
}
[[noreturn]]
void llama_model_saver::add_kv(const enum llm_kv key, const char value) {
GGML_UNUSED(key);
GGML_UNUSED(value);
GGML_ABORT("fatal error"); // this should never be called, only needed to make the template below compile
}
template <typename Container>
void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) {
const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size();
GGML_ASSERT(n_values <= value.size());
if (n_values == 0) {
return;
}
if (per_layer) {
bool all_values_the_same = true;
for (size_t i = 1; i < n_values; ++i) {
if (value[i] != value[0]) {
all_values_the_same = false;
break;
}
}
if (all_values_the_same) {
add_kv(key, value[0]);
return;
}
}
if (std::is_same<typename Container::value_type, uint8_t>::value) {
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT8, value.data(), n_values);
} else if (std::is_same<typename Container::value_type, int8_t>::value) {
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values);
} else if (std::is_same<typename Container::value_type, uint32_t>::value) {
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values);
} else if (std::is_same<typename Container::value_type, int32_t>::value) {
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values);
} else if (std::is_same<typename Container::value_type, float>::value) {
gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_FLOAT32, value.data(), n_values);
} else if (std::is_same<Container, std::string>::value) {
gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), reinterpret_cast<const char *>(value.data()));
} else {
GGML_ABORT("fatal error");
}
}
void llama_model_saver::add_kv(const enum llm_kv key, const std::vector<std::string> & value) {
std::vector<const char *> tmp(value.size());
for (size_t i = 0; i < value.size(); ++i) {
tmp[i] = value[i].c_str();
}
gguf_set_arr_str(gguf_ctx, llm_kv(key).c_str(), tmp.data(), tmp.size());
}
void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) {
if (!tensor) {
return;
}
if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) {
GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME
return;
}
gguf_add_tensor(gguf_ctx, tensor);
}
void llama_model_saver::add_kv_from_model() {
const llama_hparams & hparams = model.hparams;
const llama_vocab & vocab = model.vocab;
const int32_t n_vocab = vocab.n_tokens();
std::vector<std::string> tokens(n_vocab);
std::vector<float> scores(n_vocab);
std::vector<int32_t> token_types(n_vocab);
for (int32_t id = 0; id < n_vocab; ++id) {
const llama_vocab::token_data & token_data = vocab.get_token_data(id);
tokens[id] = token_data.text;
scores[id] = token_data.score;
switch(token_data.attr) {
case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break;
case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break;
case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break;
case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break;
case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break;
case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break;
case LLAMA_TOKEN_ATTR_UNDEFINED:
default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break;
}
}
// add_kv(LLM_KV_GENERAL_TYPE, ???);
add_kv(LLM_KV_GENERAL_ARCHITECTURE, model.arch_name());
// add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???);
// add_kv(LLM_KV_GENERAL_ALIGNMENT, ???);
add_kv(LLM_KV_GENERAL_NAME, model.name);
// add_kv(LLM_KV_GENERAL_AUTHOR, ???);
// add_kv(LLM_KV_GENERAL_VERSION, ???);
// add_kv(LLM_KV_GENERAL_URL, ???);
// add_kv(LLM_KV_GENERAL_DESCRIPTION, ???);
// add_kv(LLM_KV_GENERAL_LICENSE, ???);
// add_kv(LLM_KV_GENERAL_SOURCE_URL, ???);
// add_kv(LLM_KV_GENERAL_SOURCE_HF_REPO, ???);
add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens());
add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer);
add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true);
add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res);
// add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???);
add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert);
add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used);
add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type));
add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id);
add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping);
add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping);
add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm);
add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers);
add_kv(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true);
add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true);
add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k);
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v);
add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train;
add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train);
// add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name
add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train));
add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor);
add_kv(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor);
add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn);
add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned);
add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
// TODO: implement split file support
// add_kv(LLM_KV_SPLIT_NO, ???);
// add_kv(LLM_KV_SPLIT_COUNT, ???);
// add_kv(LLM_KV_SPLIT_TENSORS_COUNT, ???);
add_kv(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms);
add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model());
add_kv(LLM_KV_TOKENIZER_PRE, vocab.get_tokenizer_pre());
add_kv(LLM_KV_TOKENIZER_LIST, tokens);
add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE, token_types);
add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, vocab.n_token_types());
add_kv(LLM_KV_TOKENIZER_SCORES, scores);
add_kv(LLM_KV_TOKENIZER_MERGES, vocab.get_bpe_merges());
// FIXME llama_token is type i32 but when reading in a GGUF file u32 is expected, not an issue for writing though
add_kv(LLM_KV_TOKENIZER_BOS_ID, uint32_t(vocab.token_bos()));
add_kv(LLM_KV_TOKENIZER_EOS_ID, uint32_t(vocab.token_eos()));
add_kv(LLM_KV_TOKENIZER_EOT_ID, uint32_t(vocab.token_eot()));
add_kv(LLM_KV_TOKENIZER_EOM_ID, uint32_t(vocab.token_eom()));
add_kv(LLM_KV_TOKENIZER_UNK_ID, uint32_t(vocab.token_unk()));
add_kv(LLM_KV_TOKENIZER_SEP_ID, uint32_t(vocab.token_sep()));
add_kv(LLM_KV_TOKENIZER_PAD_ID, uint32_t(vocab.token_pad()));
// add_kv(LLM_KV_TOKENIZER_CLS_ID, uint32_t(vocab.token_bos())); // deprecated
// add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());
// add_kv(LLM_KV_TOKENIZER_HF_JSON, ???);
// add_kv(LLM_KV_TOKENIZER_RWKV, ???);
add_kv(LLM_KV_TOKENIZER_FIM_PRE_ID, uint32_t(vocab.token_fim_pre()));
add_kv(LLM_KV_TOKENIZER_FIM_SUF_ID, uint32_t(vocab.token_fim_suf()));
add_kv(LLM_KV_TOKENIZER_FIM_MID_ID, uint32_t(vocab.token_fim_mid()));
add_kv(LLM_KV_TOKENIZER_FIM_PAD_ID, uint32_t(vocab.token_fim_pad()));
add_kv(LLM_KV_TOKENIZER_FIM_REP_ID, uint32_t(vocab.token_fim_rep()));
add_kv(LLM_KV_TOKENIZER_FIM_SEP_ID, uint32_t(vocab.token_fim_sep()));
// TODO: implement LoRA support
// add_kv(LLM_KV_ADAPTER_TYPE, ???);
// add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???);
// deprecated
// add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???);
// add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???);
// add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???);
}
void llama_model_saver::add_tensors_from_model() {
if (std::string(model.output->name) != std::string(model.tok_embd->name)) {
add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output
}
add_tensor(model.type_embd);
add_tensor(model.pos_embd);
add_tensor(model.tok_norm);
add_tensor(model.tok_norm_b);
add_tensor(model.output_norm);
add_tensor(model.output_norm_b);
add_tensor(model.output);
add_tensor(model.output_b);
add_tensor(model.output_norm_enc);
add_tensor(model.cls);
add_tensor(model.cls_b);
add_tensor(model.cls_out);
add_tensor(model.cls_out_b);
for (const struct llama_layer & layer : model.layers) {
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
add_tensor(reinterpret_cast<const struct ggml_tensor * const *>(&layer)[i]);
}
}
}
void llama_model_saver::save(const std::string & path_model) {
gguf_write_to_file(gguf_ctx, path_model.c_str(), false);
}

37
src/llama-model-saver.h Normal file
View file

@ -0,0 +1,37 @@
#pragma once
#include "llama.h"
#include "llama-arch.h"
#include <vector>
struct llama_model_saver {
struct gguf_context * gguf_ctx = nullptr;
const struct llama_model & model;
const struct LLM_KV llm_kv;
llama_model_saver(const struct llama_model & model);
~llama_model_saver();
void add_kv(enum llm_kv key, uint32_t value);
void add_kv(enum llm_kv key, int32_t value);
void add_kv(enum llm_kv key, float value);
void add_kv(enum llm_kv key, bool value);
void add_kv(enum llm_kv key, const char * value);
[[noreturn]]
void add_kv(enum llm_kv key, char value); // needed to make the template below compile
template <typename Container>
void add_kv(enum llm_kv key, const Container & value, bool per_layer = false);
void add_kv(enum llm_kv key, const std::vector<std::string> & value);
void add_tensor(const struct ggml_tensor * tensor);
void add_kv_from_model();
void add_tensors_from_model();
void save(const std::string & path_model);
};

View file

@ -98,6 +98,10 @@ static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_
{ LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" },
};
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type) {
return LLAMA_ROPE_SCALING_TYPES.at(rope_scaling_type);
}
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
if (kv.second == name) {
@ -3586,7 +3590,7 @@ uint64_t llama_model::n_elements() const {
}
void llama_model::print_info() const {
const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train);
auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) {
bool is_var = false;
@ -3645,7 +3649,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);

View file

@ -80,6 +80,8 @@ enum llm_type {
LLM_TYPE_27B,
};
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
struct llama_layer_posnet {
// resnet
struct ggml_tensor * norm1 = nullptr;

View file

@ -512,7 +512,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
nthread = std::thread::hardware_concurrency();
}
// mmap consistently increases speed Linux, and also increases speed on Windows with
// mmap consistently increases speed on Linux, and also increases speed on Windows with
// hot cache. It may cause a slowdown on macOS, possibly related to free memory.
#if defined(__linux__) || defined(_WIN32)
constexpr bool use_mmap = true;
@ -522,7 +522,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
llama_model_kv_override * kv_overrides = nullptr;
if (params->kv_overrides) {
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
kv_overrides = v->data();
}

View file

@ -1,5 +1,7 @@
#include "llama-vocab.h"
#include "ggml.h"
#include "gguf.h"
#include "llama-impl.h"
#include "llama-model-loader.h"
@ -1204,6 +1206,9 @@ struct fragment_buffer_variant {
struct llama_vocab::impl {
uint32_t n_token_types = 0; // for BERT-style token types
std::string tokenizer_model;
std::string tokenizer_pre;
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
@ -1339,9 +1344,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
// determine vocab type
{
std::string tokenizer_model;
std::string tokenizer_pre;
ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
@ -1436,7 +1438,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
if (precompiled_charsmap_keyidx != -1) {
size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx);
GGML_ASSERT(pc_type == GGUF_TYPE_INT8 || pc_type == GGUF_TYPE_UINT8);
const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
#ifdef IS_BIG_ENDIAN
@ -2728,6 +2733,14 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
pimpl->load(ml, kv);
}
std::string llama_vocab::get_tokenizer_model() const {
return pimpl->tokenizer_model;
}
std::string llama_vocab::get_tokenizer_pre() const {
return pimpl->tokenizer_pre;
}
enum llama_vocab_type llama_vocab::get_type() const {
return pimpl->type;
}
@ -2950,6 +2963,20 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string
return it->second;
}
std::vector<std::string> llama_vocab::get_bpe_merges() const {
std::vector<std::string> result(pimpl->bpe_ranks.size());
for (const auto & pair : pimpl->bpe_ranks) {
result[pair.second] = pair.first.first + " " + pair.first.second;
}
return result;
}
std::vector<char> llama_vocab::get_precompiled_charsmap() const {
return pimpl->precompiled_charsmap;
}
int32_t llama_vocab::tokenize(
const char * text,
int32_t text_len,

View file

@ -21,6 +21,9 @@ struct llama_vocab {
void load(llama_model_loader & ml, const LLM_KV & kv);
std::string get_tokenizer_model() const;
std::string get_tokenizer_pre() const;
enum llama_vocab_type get_type() const;
enum llama_vocab_pre_type get_pre_type() const;
@ -80,6 +83,9 @@ struct llama_vocab {
int max_token_len() const;
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
std::vector<std::string> get_bpe_merges() const;
std::vector<char> get_precompiled_charsmap() const;
int32_t tokenize(
const char * text,

View file

@ -7,6 +7,7 @@
#include "llama-sampling.h"
#include "llama-kv-cache.h"
#include "llama-model-loader.h"
#include "llama-model-saver.h"
#include "llama-model.h"
#include "ggml.h"
@ -1142,18 +1143,18 @@ struct llm_build_context {
ctx0 = ggml_init(params);
lctx.inp_tokens = nullptr;
lctx.inp_embd = nullptr;
lctx.inp_pos = nullptr;
lctx.inp_out_ids = nullptr;
lctx.inp_KQ_mask = nullptr;
lctx.inp_KQ_mask_swa = nullptr;
lctx.inp_K_shift = nullptr;
lctx.inp_mean = nullptr;
lctx.inp_cls = nullptr;
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
lctx.inp_tokens = nullptr;
lctx.inp_embd = nullptr;
lctx.inp_pos = nullptr;
lctx.inp_out_ids = nullptr;
lctx.inp_KQ_mask = nullptr;
lctx.inp_KQ_mask_swa = nullptr;
lctx.inp_K_shift = nullptr;
lctx.inp_mean = nullptr;
lctx.inp_cls = nullptr;
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
lctx.inp_pos_bucket = nullptr;
lctx.inp_embd_enc = nullptr;
lctx.inp_KQ_mask_cross = nullptr;
@ -9544,6 +9545,13 @@ struct llama_model * llama_model_load_from_splits(
return llama_model_load_from_file_impl(splits.front(), splits, params);
}
void llama_model_save_to_file(const struct llama_model * model, const char * path_model) {
llama_model_saver ms(*model);
ms.add_kv_from_model();
ms.add_tensors_from_model();
ms.save(path_model);
}
struct llama_context * llama_init_from_model(
struct llama_model * model,
struct llama_context_params params) {
@ -10124,3 +10132,197 @@ void llama_perf_context_reset(struct llama_context * ctx) {
ctx->t_eval_us = ctx->n_eval = 0;
ctx->t_p_eval_us = ctx->n_p_eval = 0;
}
//
// training
//
bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
GGML_UNUSED(tensor);
GGML_UNUSED(userdata);
return true;
}
static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
if (!tensor || tensor->type != GGML_TYPE_F32) {
return;
}
if (!param_filter(tensor, userdata)) {
return;
}
if (strcmp(tensor->name, "token_embd.weight") == 0) {
return; // FIXME
}
if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
return; // FIXME
}
ggml_set_param(tensor);
}
void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params) {
GGML_ASSERT(!lctx->opt_ctx);
model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : llama_n_ctx(lctx);
const uint32_t n_batch = std::min(llama_n_batch(lctx), model->hparams.n_ctx_train);
const uint32_t n_ubatch = std::min(llama_n_ubatch(lctx), n_batch);
GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
GGML_ASSERT(n_batch % n_ubatch == 0);
ggml_opt_params opt_params = ggml_opt_default_params(lctx->sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
opt_params.opt_period = n_batch / n_ubatch;
opt_params.get_opt_pars = lopt_params.get_opt_pars;
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
lctx->opt_ctx = ggml_opt_init(opt_params);
llama_opt_param_filter param_filter = lopt_params.param_filter;
void * param_filter_ud = lopt_params.param_filter_ud;
// llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
llama_set_param(model->type_embd, param_filter, param_filter_ud);
llama_set_param(model->pos_embd, param_filter, param_filter_ud);
llama_set_param(model->tok_norm, param_filter, param_filter_ud);
llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
llama_set_param(model->output_norm, param_filter, param_filter_ud);
llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
llama_set_param(model->output, param_filter, param_filter_ud);
llama_set_param(model->output_b, param_filter, param_filter_ud);
llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
llama_set_param(model->cls, param_filter, param_filter_ud);
llama_set_param(model->cls_b, param_filter, param_filter_ud);
llama_set_param(model->cls_out, param_filter, param_filter_ud);
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
for (struct llama_layer & layer : model->layers) {
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
}
}
}
static void llama_opt_epoch_iter(
struct llama_context * lctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result,
const std::vector<llama_token> & tokens,
const std::vector<llama_token> & labels_sparse,
llama_batch & batch,
ggml_opt_epoch_callback callback,
const bool train,
const int64_t idata_in_loop,
const int64_t ndata_in_loop,
const int64_t t_loop_start) {
GGML_ASSERT(lctx->opt_ctx);
const uint32_t n_ctx = llama_model_n_ctx_train(&lctx->model);
const uint32_t n_batch = std::min(llama_n_batch(lctx), n_ctx);
const uint32_t n_ubatch = std::min(llama_n_ubatch(lctx), n_batch);
lctx->is_encoding = false;
llama_kv_cache_clear(lctx);
llama_kv_slot_restorer kv_slot_restorer(lctx->kv_self);
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
batch.n_tokens = n_batch;
for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
batch.pos [pos_batch] = pos_ctx + pos_batch;
batch.n_seq_id[pos_batch] = 1;
batch.seq_id [pos_batch][0] = 0;
batch.logits [pos_batch] = true;
}
uint32_t n_outputs = 0;
{
const int err_code = llama_prepare_sbatch(*lctx, batch, n_outputs);
GGML_ASSERT(err_code == 0);
}
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
struct llama_ubatch ubatch;
{
const int err_code = llama_prepare_ubatch(*lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
GGML_ASSERT(err_code == 0);
}
struct ggml_cgraph * gf = llama_build_graph(*lctx, ubatch, false);
struct ggml_context * ctx_compute;
{
const size_t size_gf = ggml_graph_size(gf);
const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
struct ggml_init_params params = {
/*.mem_size =*/ size_meta,
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ctx_compute = ggml_init(params);
}
ggml_opt_prepare_alloc(lctx->opt_ctx, ctx_compute, gf, lctx->inp_tokens, ggml_graph_node(gf, -1));
ggml_opt_alloc(lctx->opt_ctx, train);
llama_set_inputs(*lctx, ubatch);
{
struct ggml_tensor * labels = ggml_opt_labels(lctx->opt_ctx);
GGML_ASSERT(labels->ne[1] == n_ubatch);
ggml_set_zero(labels);
const float onef = 1.0f;
for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
}
}
ggml_opt_eval(lctx->opt_ctx, result);
if (callback) {
callback(train, lctx->opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
}
ggml_free(ctx_compute);
}
}
}
void llama_opt_epoch(
struct llama_context * lctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval) {
const uint32_t n_ctx = llama_n_ctx(lctx);
const uint32_t n_batch = std::min(lctx->cparams.n_batch, n_ctx);
const uint32_t n_ubatch = std::min(lctx->cparams.n_ubatch, n_batch);
const int64_t ndata = ggml_opt_dataset_ndata(dataset);
GGML_ASSERT(idata_split >= 0);
GGML_ASSERT(idata_split <= ndata);
const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
std::vector<llama_token> tokens(n_ctx);
std::vector<llama_token> labels_sparse(n_ctx);
int64_t idata = 0;
int64_t t_loop_start = ggml_time_us();
int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
for (; idata < idata_split; ++idata) {
constexpr bool train = true;
const int64_t idata_in_loop = idata*ubatch_per_ctx;
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
llama_opt_epoch_iter(lctx, dataset, result_train, tokens, labels_sparse, batch,
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
}
t_loop_start = ggml_time_us();
ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
for (; idata < ndata; ++idata) {
constexpr bool train = false;
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
llama_opt_epoch_iter(lctx, dataset, result_eval, tokens, labels_sparse, batch,
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
}
llama_batch_free(batch);
}

View file

@ -810,7 +810,7 @@ struct test_case {
ggml_build_forward_expand(gf, out);
ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx, ctx, gb, false);
ggml_build_backward_expand(ctx, gb, nullptr);
if (expect.size() != 1 || expect[0] != 0.0f) {
GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
@ -996,7 +996,7 @@ struct test_example : public test_case {
// Step 3: return the output tensor.
return out;
}
// In order to also check the gradients for your op, add calls like ggml_set_param(ctx, a)
// In order to also check the gradients for your op, add calls like ggml_set_param(a)
// immediately after you create the tensors.
// This is optional and only makes sense if a backward pass has actually been implemented for the new op.
};
@ -1028,7 +1028,7 @@ struct test_unary : public test_case {
auto ne = ne_a; ne[0] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data());
if (grad_supported) {
ggml_set_param(ctx, a);
ggml_set_param(a);
}
ggml_set_name(a, "a");
@ -1037,7 +1037,7 @@ struct test_unary : public test_case {
} else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
if (grad_supported) {
ggml_set_param(ctx, a);
ggml_set_param(a);
}
ggml_set_name(a, "a");
}
@ -1103,7 +1103,7 @@ struct test_get_rows : public test_case {
const bool grad_supported = ggml_is_matrix(in) && ggml_is_vector(rows);
if (grad_supported) {
ggml_set_param(ctx, in);
ggml_set_param(in);
// rows is a constant input -> no gradients
}
@ -1292,7 +1292,7 @@ struct test_repeat : public test_case {
ggml_set_name(target, "target");
ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, src);
ggml_set_param(src);
ggml_set_name(src, "src");
ggml_tensor * out = ggml_repeat(ctx, src, target);
@ -1376,7 +1376,7 @@ struct test_dup : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, src);
ggml_set_param(src);
ggml_set_name(src, "src");
if (_use_permute) {
@ -1412,7 +1412,7 @@ struct test_set : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
ggml_set_param(ctx, src);
ggml_set_param(src);
ggml_set_name(src, "src");
auto ne_dst = ne;
@ -1420,7 +1420,7 @@ struct test_set : public test_case {
ne_dst[i] *= 2;
}
ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
ggml_set_param(ctx, dst);
ggml_set_param(dst);
ggml_set_name(dst, "dst");
size_t offset = 0;
@ -1464,7 +1464,7 @@ struct test_cpy : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
ggml_set_param(ctx, src);
ggml_set_param(src);
ggml_set_name(src, "src");
if (_src_use_permute) {
@ -1497,7 +1497,7 @@ struct test_cont : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, src);
ggml_set_param(src);
ggml_set_name(src, "src");
src = ggml_transpose(ctx, src);
@ -1543,8 +1543,8 @@ struct test_bin_bcast : public test_case {
// The backward pass supports broadcasting only for GGML_ADD:
const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b);
if (grad_supported) {
ggml_set_param(ctx, a);
ggml_set_param(ctx, b);
ggml_set_param(a);
ggml_set_param(b);
}
ggml_tensor * out = op(ctx, a, b);
@ -1592,11 +1592,11 @@ struct test_add1 : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor_1d(ctx, type, 1);
// ggml_set_param(ctx, b); // TODO: implement
// ggml_set_param(b); // TODO: implement
ggml_set_name(b, "b");
ggml_tensor * out = ggml_add1(ctx, a, b);
@ -1627,7 +1627,7 @@ struct test_scale : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_scale(ctx, a, scale);
@ -1722,7 +1722,7 @@ struct test_rms_norm : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
if (v) {
@ -1951,9 +1951,9 @@ struct test_mul_mat : public test_case {
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(ctx, a);
ggml_set_param(a);
}
ggml_set_param(ctx, b);
ggml_set_param(b);
}
ggml_set_name(a, "a");
ggml_set_name(b, "b");
@ -1967,9 +1967,9 @@ struct test_mul_mat : public test_case {
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(ctx, a);
ggml_set_param(a);
}
ggml_set_param(ctx, b);
ggml_set_param(b);
}
ggml_set_name(a, "a");
ggml_set_name(b, "b");
@ -2118,7 +2118,7 @@ struct test_sqr : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_sqr(ctx, a);
@ -2147,7 +2147,7 @@ struct test_sqrt : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_sqrt(ctx, a);
@ -2187,7 +2187,7 @@ struct test_log : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_log(ctx, a);
@ -2223,7 +2223,7 @@ struct test_sin : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_sin(ctx, a);
@ -2266,7 +2266,7 @@ struct test_cos : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_cos(ctx, a);
@ -2346,7 +2346,7 @@ struct test_diag_mask_inf : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past);
@ -2385,7 +2385,7 @@ struct test_soft_max : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * mask = nullptr;
@ -2467,7 +2467,7 @@ struct test_rope : public test_case {
auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data());
if (forward) {
ggml_set_param(ctx, a);
ggml_set_param(a);
}
ggml_set_name(a, "a");
@ -2476,7 +2476,7 @@ struct test_rope : public test_case {
} else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
if (forward) {
ggml_set_param(ctx, a);
ggml_set_param(a);
}
ggml_set_name(a, "a");
}
@ -2588,7 +2588,7 @@ struct test_pool2d : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
ggml_set_param(ctx, input);
ggml_set_param(input);
ggml_set_name(input, "input");
ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1);
@ -2664,7 +2664,7 @@ struct test_im2col : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
ggml_set_param(ctx, input);
ggml_set_param(input);
ggml_set_name(input, "input");
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
@ -2799,7 +2799,7 @@ struct test_sum : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_sum(ctx, a);
@ -2828,7 +2828,7 @@ struct test_sum_rows : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_sum_rows(ctx, a);
@ -2853,7 +2853,7 @@ struct test_mean : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_mean(ctx, a);
@ -2970,11 +2970,11 @@ struct test_acc : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_param(ctx, a);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
ggml_set_param(ctx, b);
ggml_set_param(b);
ggml_set_name(b, "b");
ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
@ -3206,7 +3206,7 @@ struct test_cross_entropy_loss : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, logits);
ggml_set_param(logits);
ggml_set_name(logits, "logits");
ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
@ -3288,7 +3288,7 @@ struct test_opt_step_adamw : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not.
ggml_set_param(a); // Despite tensor a having gradients the output tensor will not.
ggml_set_name(a, "a");
ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);

View file

@ -57,7 +57,8 @@ static helper_ctx_data helper_get_ctx_data(
enum ggml_opt_loss_type loss_type = GGML_OPT_LOSS_TYPE_SUM) {
std::vector<ggml_opt_dataset_t> datasets(ndata);
for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {
ggml_opt_dataset_t dataset = ggml_opt_dataset_init(ne_datapoint, ne_label, ndata, ndata_shard);
ggml_opt_dataset_t dataset = ggml_opt_dataset_init(
GGML_TYPE_F32, GGML_TYPE_F32, ne_datapoint, ne_label, ndata, ndata_shard);
float * data = ggml_get_data_f32(ggml_opt_dataset_data( dataset));
float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
@ -74,7 +75,8 @@ static helper_ctx_data helper_get_ctx_data(
datasets[ndata_shard-1] = dataset;
}
ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(1, 0, ndata, /*ndata_shard =*/ 1);
ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(
GGML_TYPE_F32, GGML_TYPE_F32, 1, 0, ndata, /*ndata_shard =*/ 1);
float * data = ggml_get_data_f32(ggml_opt_dataset_data(dataset_unsupervised));
@ -113,7 +115,7 @@ static helper_ctx_data helper_get_ctx_data(
struct ggml_tensor * weights = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
ggml_set_name(weights, "weights");
ggml_set_param(ctx_static, weights);
ggml_set_param(weights);
struct ggml_tensor * intermediary = ggml_add(ctx_compute, inputs, weights);
@ -127,8 +129,11 @@ static helper_ctx_data helper_get_ctx_data(
GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
const int32_t opt_period = nbatch_logical / nbatch_physical;
struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
opt_params.opt_period = opt_period;
struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, loss_type);
opt_params.ctx_compute = ctx_compute;
opt_params.inputs = inputs;
opt_params.outputs = outputs;
opt_params.opt_period = opt_period;
if (!optimizer_defaults) {
opt_params.get_opt_pars = helper_get_test_opt_pars;
}
@ -264,8 +269,9 @@ static std::pair<int, int> test_grad(ggml_backend_sched_t backend_sched, ggml_ba
for (int idata = 0; idata < ndata; ++idata) {
const float idataf = idata;
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
ggml_opt_forward_backward(cd.opt_ctx, cd.result);
ggml_opt_eval(cd.opt_ctx, cd.result);
ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float));
}
@ -334,8 +340,9 @@ static std::pair<int, int> test_forward_backward(
} else {
for (int idata = 0; idata < ndata; ++idata) {
const float idataf = idata;
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false);
ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
ggml_opt_forward(cd.opt_ctx, cd.result);
ggml_opt_eval(cd.opt_ctx, cd.result);
ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
}
}
@ -367,7 +374,8 @@ static std::pair<int, int> test_forward_backward(
float w0;
ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float));
for (int i = 0; i < 10; ++i) {
ggml_opt_forward_backward(cd.opt_ctx, nullptr);
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
ggml_opt_eval(cd.opt_ctx, cd.result);
}
ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float));
@ -387,8 +395,9 @@ static std::pair<int, int> test_forward_backward(
} else {
for (int idata = 0; idata < ndata; ++idata) {
const float idataf = idata;
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
ggml_opt_forward_backward(cd.opt_ctx, cd.result);
ggml_opt_eval(cd.opt_ctx, cd.result);
ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
}
}
@ -492,14 +501,16 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
int idata = 0;
for (; idata < idata_split; ++idata) {
const float idataf = idata;
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
ggml_opt_forward_backward(cd.opt_ctx, cd.result);
ggml_opt_eval(cd.opt_ctx, cd.result);
ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
}
for (; idata < ndata; ++idata) {
const float idataf = idata;
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false);
ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
ggml_opt_forward(cd.opt_ctx, cd.result2);
ggml_opt_eval(cd.opt_ctx, cd.result2);
ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
}
}
@ -584,15 +595,17 @@ static std::pair<int, int> test_gradient_accumulation(
if (nbatch_physical == 1) {
for (int idata = 0; idata < ndata; ++idata) {
const float idataf = idata;
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
ggml_backend_tensor_set(cd.inputs, &idataf, 0, 1*sizeof(float));
ggml_opt_forward_backward(cd.opt_ctx, cd.result);
ggml_opt_eval(cd.opt_ctx, cd.result);
ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, 1*sizeof(float));
}
} else if (nbatch_physical == 2) {
for (int idata = 0; idata < ndata; idata += 2) {
const float idataf[2] = {float(idata + 0), float(idata + 1)};
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
ggml_backend_tensor_set(cd.inputs, idataf, 0, 2*sizeof(float));
ggml_opt_forward_backward(cd.opt_ctx, cd.result);
ggml_opt_eval(cd.opt_ctx, cd.result);
grad_history[idata + 0] = 0.0f;
ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata + 1, 0, 1*sizeof(float));
@ -617,7 +630,7 @@ static std::pair<int, int> test_gradient_accumulation(
}
subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0, atol);
subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0, atol);
subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0, atol);
subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0, atol);
} else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {
if (nbatch_physical == 1) {
subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0/ndata, atol);
@ -630,7 +643,7 @@ static std::pair<int, int> test_gradient_accumulation(
}
subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0/ndata, atol);
subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0/ndata, atol);
subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0/ndata, atol);
subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0/ndata, atol);
} else {
GGML_ASSERT(false);
}
@ -692,7 +705,8 @@ static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, g
std::mt19937 gen(12345);
std::normal_distribution<float> nd{0.0f, 0.1f};
ggml_opt_dataset_t dataset = ggml_opt_dataset_init(1, 1, ndata_regression, ndata_regression);
ggml_opt_dataset_t dataset = ggml_opt_dataset_init(
GGML_TYPE_F32, GGML_TYPE_F32, 1, 1, ndata_regression, ndata_regression);
float * data = ggml_get_data_f32(ggml_opt_dataset_data( dataset));
float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
@ -733,15 +747,14 @@ static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, g
struct ggml_tensor * a = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
ggml_set_name(a, "a");
ggml_set_param(ctx_static, a);
ggml_set_param(a);
struct ggml_tensor * b = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
ggml_set_name(b, "b");
ggml_set_param(ctx_static, b);
ggml_set_param(b);
struct ggml_tensor * f = ggml_add(ctx_compute, ggml_mul(ctx_compute, x, a), b);
ggml_set_name(f, "f");
ggml_set_param(ctx_static, f);
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend);
const float a0 = 1.0f;