move common train functions into common/train.[h|cpp]

This commit is contained in:
xaedes 2023-09-16 14:58:34 +02:00
parent 00b656f6db
commit 9f4b1bf88d
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
7 changed files with 1279 additions and 1993 deletions

View file

@ -485,6 +485,9 @@ console.o: common/console.cpp common/console.h
grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
$(CXX) $(CXXFLAGS) -c $< -o $@
train.o: common/train.cpp common/train.h
$(CXX) $(CXXFLAGS) -c $< -o $@
libllama.so: llama.o ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
@ -532,7 +535,7 @@ embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-te
gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o common.o $(OBJS)
train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o common.o train.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp ggml.o llama.o $(OBJS)
@ -541,13 +544,13 @@ convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggm
llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o $(OBJS)
baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o train.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
finetune: examples/finetune/finetune.cpp build-info.h ggml.o llama.o common.o $(OBJS)
finetune: examples/finetune/finetune.cpp build-info.h ggml.o llama.o common.o train.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)

View file

@ -9,6 +9,8 @@ add_library(${TARGET} OBJECT
console.cpp
grammar-parser.h
grammar-parser.cpp
train.h
train.cpp
)
if (BUILD_SHARED_LIBS)

914
common/train.cpp Normal file
View file

@ -0,0 +1,914 @@
#include "train.h"
#include "common.h"
#include <random>
#include <sstream>
#include <functional>
struct random_normal_distribution {
std::mt19937 gen;
std::normal_distribution<float> rd;
float min;
float max;
};
struct random_uniform_distribution {
std::mt19937 gen;
std::uniform_real_distribution<float> rd;
};
struct random_normal_distribution * init_random_normal_distribution(int seed, float mean, float std, float min, float max) {
struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution));
rnd->gen = std::mt19937(seed);
rnd->rd = std::normal_distribution<float>{mean, std};
rnd->min = min;
rnd->max = max;
return rnd;
}
struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max) {
struct random_uniform_distribution * rnd = (struct random_uniform_distribution *) malloc(sizeof(struct random_uniform_distribution));
rnd->gen = std::mt19937(seed);
rnd->rd = std::uniform_real_distribution<float>{min, max};
return rnd;
}
void free_random_normal_distribution (struct random_normal_distribution * rnd) {
free(rnd);
}
void free_random_uniform_distribution(struct random_uniform_distribution * rnd) {
free(rnd);
}
struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
float scale = 1.0f; // xavier
switch (tensor->n_dims) {
case 1:
scale /= sqrtf((float) tensor->ne[0]);
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
*dst = scale * frand_normal(rnd);
}
break;
case 2:
scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
*dst = scale * frand_normal(rnd);
}
}
break;
case 3:
scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
*dst = scale * frand_normal(rnd);
}
}
}
break;
case 4:
scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
*dst = scale * frand_normal(rnd);
}
}
}
}
break;
default:
GGML_ASSERT(!"Unsupported tensor->n_dims");
};
return tensor;
}
struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) {
switch (tensor->n_dims) {
case 1:
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
*dst = frand_uniform(rnd);
}
break;
case 2:
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
*dst = frand_uniform(rnd);
}
}
break;
case 3:
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
*dst = frand_uniform(rnd);
}
}
}
break;
case 4:
for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
*dst = frand_uniform(rnd);
}
}
}
}
break;
default:
GGML_ASSERT(!"Unsupported tensor->n_dims");
};
return tensor;
}
float frand() {
return (float)rand()/(float)RAND_MAX;
}
float frand_normal(struct random_normal_distribution * rnd) {
return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
}
float frand_uniform(struct random_uniform_distribution * rnd) {
return rnd->rd(rnd->gen);
}
int clamp(const int v, const int min, const int max) {
return ((v < min) ? (min) : (v > max) ? (max) : v);
}
float fclamp(const float v, const float min, const float max) {
return ((v < min) ? (min) : (v > max) ? (max) : v);
}
void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
GGML_ASSERT(tensor->n_dims == 1);
GGML_ASSERT(tensor->ne[0] == ne0);
}
void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
GGML_ASSERT(tensor->n_dims == 2);
GGML_ASSERT(tensor->ne[0] == ne0);
GGML_ASSERT(tensor->ne[1] == ne1);
}
void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
GGML_ASSERT(tensor->n_dims == 3);
GGML_ASSERT(tensor->ne[0] == ne0);
GGML_ASSERT(tensor->ne[1] == ne1);
GGML_ASSERT(tensor->ne[2] == ne2);
}
void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
GGML_ASSERT(tensor->n_dims == 4);
GGML_ASSERT(tensor->ne[0] == ne0);
GGML_ASSERT(tensor->ne[1] == ne1);
GGML_ASSERT(tensor->ne[2] == ne2);
GGML_ASSERT(tensor->ne[3] == ne3);
}
int64_t get_example_targets_batch(
struct llama_context * lctx,
struct ggml_tensor * tokens_input,
struct ggml_tensor * target_probs,
int64_t example_id,
const size_t * samples_begin,
const size_t * samples_size,
size_t samples_count,
const llama_token * train_data,
size_t n_train_data,
bool separate_with_eos,
bool separate_with_bos,
bool fill_with_next_samples) {
GGML_ASSERT(tokens_input->n_dims == 2);
GGML_ASSERT(target_probs->n_dims == 3);
int64_t n_vocab = target_probs->ne[0];
int64_t n_tokens = tokens_input->ne[0];
int64_t n_batch = tokens_input->ne[1];
GGML_ASSERT(n_vocab == target_probs->ne[0]);
GGML_ASSERT(n_tokens == target_probs->ne[1]);
GGML_ASSERT(n_batch == target_probs->ne[2]);
int64_t used_samples = 0;
ggml_set_f32(target_probs, 0.0f);
llama_token bos = llama_token_bos(lctx);
llama_token eos = llama_token_eos(lctx);
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
for (int k=0; k<n_batch; ++k) {
// printf("%s: batch %d\n", __func__, k);
size_t sample_offs = 0;
size_t sample_idx = (example_id + used_samples) % samples_count;
size_t sample_begin = samples_begin[sample_idx];
size_t sample_size = samples_size[sample_idx];
++used_samples;
// printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
GGML_ASSERT(sample_begin+sample_size-1 < n_train_data);
ggml_set_i32_nd(tokens_input, 0, k, 0, 0, bos);
bool sample_separation_eos = !separate_with_eos;
bool sample_separation_bos = !separate_with_bos;
for (int64_t i=0; i<n_tokens; ++i) {
llama_token token = eos;
if (sample_offs >= sample_size && fill_with_next_samples) {
if (!sample_separation_eos) {
// insert eos token to separate samples
sample_separation_eos = true;
} else if (!sample_separation_bos) {
// insert bos token to separate samples
sample_separation_bos = true;
token = bos;
} else {
// sample separation is done, continue with next sample
sample_separation_eos = !separate_with_eos;
sample_separation_bos = !separate_with_bos;
sample_offs = 0;
sample_idx = (example_id + used_samples) % samples_count;
sample_begin = samples_begin[sample_idx];
sample_size = samples_size[sample_idx];
++used_samples;
}
}
// note: no else-if here
if (sample_offs < sample_size) {
token = clamp(train_data[sample_begin+sample_offs], 0, (llama_token) (n_vocab - 1));
++sample_offs;
}
ggml_set_f32_nd(target_probs, token, (int) i, (int) k, 0, +1.0f);
if (i+1<n_tokens) {
ggml_set_i32_nd(tokens_input, (int) (i + 1), (int) k, 0, 0, token);
}
}
}
return used_samples;
}
void mt19937_set_state(std::mt19937& rng, const std::string& rng_state) {
std::stringstream s_rng_state;
s_rng_state.imbue(std::locale::classic());
s_rng_state.exceptions(std::stringstream::failbit);
s_rng_state.str(rng_state);
s_rng_state >> rng;
}
std::string mt19937_get_state(const std::mt19937& rng) {
std::stringstream s_rng_state;
s_rng_state.imbue(std::locale::classic());
s_rng_state << rng;
return s_rng_state.str();
}
std::string mt19937_seed_to_state(unsigned seed) {
std::mt19937 rng(seed);
return mt19937_get_state(rng);
}
std::string shuffle_samples(
const std::string & rng_state,
size_t * shuffled_begins,
size_t * shuffled_sizes,
const size_t * begins,
const size_t * sizes,
size_t count) {
if (count == 0) return rng_state;
std::mt19937 rng;
mt19937_set_state(rng, rng_state);
// sort indices by random value for each index
std::vector<size_t> idcs;
{
std::vector<unsigned> rnd;
idcs.resize(count);
rnd.resize(count);
for (unsigned i=0; i<count; ++i) {
idcs[i] = i;
rnd[i] = rng();
}
std::sort(idcs.begin(), idcs.end(), [&rnd](size_t a, size_t b){
// stable sort for reproducibility
return (rnd[a] == rnd[b]) ? (a < b) : (rnd[a] < rnd[b]);
});
}
// reorder begins and sizes by sorted indices
for (unsigned i=0; i<count; ++i) {
shuffled_begins[i] = begins[idcs[i]];
}
for (unsigned i=0; i<count; ++i) {
shuffled_sizes[i] = sizes[idcs[i]];
}
return mt19937_get_state(rng);
}
size_t hash_combine(size_t h1, size_t h2) {
return h1 ^ (h2 << 1);
}
size_t compute_samples_hash(const char* fn, const size_t* samples_begin, const size_t* samples_size, size_t sample_count) {
std::hash<std::string> h_string;
std::hash<unsigned long long> h_ull;
size_t h = h_string(std::string(fn));
h = hash_combine(h, h_ull((unsigned long long) sample_count));
for (size_t i=0; i< sample_count; ++i) {
h = hash_combine(h, h_ull((unsigned long long) samples_begin[i]));
h = hash_combine(h, h_ull((unsigned long long) samples_size[i]));
}
return h;
}
std::string replace_str(const char * s, const char * needle, const char * replacement) {
std::string str = s;
size_t pos = str.find(needle);
if (pos != std::string::npos) {
str.replace(pos, strlen(needle), replacement);
}
return str;
}
void print_duration(double fmillis) {
if (fmillis < 1000.0f) {
printf("%.1fms", (float) fmillis);
return;
}
const int64_t one_sec = 1000;
const int64_t one_min = one_sec * 60;
const int64_t one_hour = one_min * 60;
const int64_t one_day = one_hour * 24;
int64_t millis = (int64_t) fmillis;
int64_t days = millis/one_day;
int64_t hours = (millis - days*one_day)/one_hour;
int64_t minutes = (millis - days*one_day - hours*one_hour)/one_min;
int64_t seconds = (millis - days*one_day - hours*one_hour - minutes*one_min)/one_sec;
// to print int64_t either cast to (long long int) or use macro PRId64 from <inttypes.h>
if (days > 0) {
printf("%lldd ", (long long int) days);
}
printf("%02lld:%02lld:%02lld", (long long int) hours, (long long int) minutes, (long long int) seconds);
}
float cosine_decay(int64_t step, int64_t decay_steps, float minimum) {
if (step > decay_steps) {
step = decay_steps;
}
const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
const float decay = (1 - minimum)*cosine_decay + minimum;
return decay;
}
float cosine_decay_restart(int64_t step, int64_t decay_steps, float minimum, float restart_step_mult) {
while (step > decay_steps) {
step -= decay_steps;
decay_steps = (int64_t) (restart_step_mult * decay_steps);
}
return cosine_decay(step, decay_steps, minimum);
}
float learning_schedule(
int64_t step,
int64_t warmup_steps,
int64_t cos_decay_steps,
float learning_rate,
float overall_minimum,
float cos_decay_minimum,
float cos_decay_restart_step_mult,
bool enable_restart) {
float result =
(step < warmup_steps)
? (float) step / (float) warmup_steps
: enable_restart
? cosine_decay_restart(
step - warmup_steps,
cos_decay_steps,
cos_decay_minimum,
cos_decay_restart_step_mult)
: cosine_decay(
step,
cos_decay_steps,
cos_decay_minimum);
float min = overall_minimum / learning_rate;
result = min + result * (1.0f - min);
return result;
}
static bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) {
GGML_ASSERT(a != NULL);
GGML_ASSERT(b != NULL);
GGML_ASSERT(a->type == b->type);
GGML_ASSERT(ggml_are_same_shape(a, b));
GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b));
return true;
}
void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) {
if (dst == NULL) {
return;
}
struct ggml_tensor * t = ggml_get_tensor(ctx, name);
GGML_ASSERT(are_same_layout(dst, t));
memcpy(dst->data, t->data, ggml_nbytes(t));
if (strlen(ggml_get_name(dst)) == 0) {
ggml_set_name(dst, name);
}
}
// gguf constants
static const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type";
static const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam";
static const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs";
static const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version";
static const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count";
static const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count";
static const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count";
static const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized";
static const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss";
static const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss";
static const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count";
static const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count";
static const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss";
static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step";
static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j";
static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k";
static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end";
static const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count";
static const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments";
static const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments";
static const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
{ \
const std::string skey(key); \
const int kid = gguf_find_key(ctx, skey.c_str()); \
if (kid >= 0) { \
enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
if (ktype != (type)) { \
die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
} \
(dst) = func(ctx, kid); \
} else if (req) { \
die_fmt("key not found in model: %s", skey.c_str()); \
} \
}
void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
// NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
uint32_t file_version;
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION);
GGML_ASSERT(file_version == 0);
GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
uint64_t nx;
GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT);
opt->nx = (size_t) nx;
// don't call ggml_opt_init until optimizer type and optimizer specific parameters are know
std::string opt_type;
GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
opt->params.type = GGML_OPT_ADAM;
GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT);
ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
copy_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
copy_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
} else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
opt->params.type = GGML_OPT_LBFGS;
GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP);
GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J);
GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K);
GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END);
GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT);
ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
copy_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
copy_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
copy_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
copy_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
copy_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
copy_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
copy_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
copy_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
copy_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
copy_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
} else {
throw std::runtime_error("unknown optimizer type\n");
}
}
void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) {
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0);
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past);
gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx);
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
switch (opt->params.type) {
case GGML_OPT_ADAM:
{
gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best);
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev);
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement);
ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
if (opt->adam.pf) {
ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
}
gguf_add_tensor(fctx, opt->adam.m);
gguf_add_tensor(fctx, opt->adam.v);
if (opt->adam.pf) {
gguf_add_tensor(fctx, opt->adam.pf);
}
} break;
case GGML_OPT_LBFGS:
{
gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best);
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step);
gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j);
gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k);
gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end);
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement);
ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
if (opt->lbfgs.pf) {
ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
}
ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
gguf_add_tensor(fctx, opt->lbfgs.x);
gguf_add_tensor(fctx, opt->lbfgs.xp);
gguf_add_tensor(fctx, opt->lbfgs.g);
gguf_add_tensor(fctx, opt->lbfgs.gp);
gguf_add_tensor(fctx, opt->lbfgs.d);
if (opt->lbfgs.pf) {
gguf_add_tensor(fctx, opt->lbfgs.pf);
}
gguf_add_tensor(fctx, opt->lbfgs.lmal);
gguf_add_tensor(fctx, opt->lbfgs.lmys);
gguf_add_tensor(fctx, opt->lbfgs.lms);
gguf_add_tensor(fctx, opt->lbfgs.lmy);
} break;
}
}
struct llama_file {
// use FILE * so we don't have to re-open the file to mmap
FILE * fp;
size_t size;
llama_file(const char * fname, const char * mode) {
fp = std::fopen(fname, mode);
if (fp == NULL) {
size = 0;
} else {
seek(0, SEEK_END);
size = tell();
seek(0, SEEK_SET);
}
}
size_t tell() const {
#ifdef _WIN32
__int64 ret = _ftelli64(fp);
#else
long ret = std::ftell(fp);
#endif
GGML_ASSERT(ret != -1); // this really shouldn't fail
return (size_t) ret;
}
void seek(size_t offset, int whence) {
#ifdef _WIN32
int ret = _fseeki64(fp, (__int64) offset, whence);
#else
int ret = std::fseek(fp, (long) offset, whence);
#endif
GGML_ASSERT(ret == 0); // same
}
void read_raw(void * ptr, size_t size) {
if (size == 0) {
return;
}
errno = 0;
std::size_t ret = std::fread(ptr, size, 1, fp);
if (ferror(fp)) {
die_fmt("read error: %s", strerror(errno));
}
if (ret != 1) {
die_fmt("unexpectedly reached end of file");
}
}
std::uint32_t read_u32() {
std::uint32_t ret;
read_raw(&ret, sizeof(ret));
return ret;
}
std::string read_string(std::uint32_t len) {
std::vector<char> chars(len);
read_raw(chars.data(), len);
return std::string(chars.data(), len);
}
void write_raw(const void * ptr, size_t size) {
if (size == 0) {
return;
}
errno = 0;
size_t ret = std::fwrite(ptr, size, 1, fp);
if (ret != 1) {
die_fmt("write error: %s", strerror(errno));
}
}
void write_u32(std::uint32_t val) {
write_raw(&val, sizeof(val));
}
~llama_file() {
if (fp) {
std::fclose(fp);
}
}
};
static size_t utf8_len(char src) {
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
return lookup[highbits];
}
// mark each byte with its utf8 unit number.
// returns the number of utf8 characters.
// e.g. when bytes == '\x61\xD0\xB0\x62',
// then utf8_units will become [0,0,1,0]
// utf8_nunits will become [1,2,2,1] and 3 is returned.
// bytes where utf8_units is zero, are the begin of an utf8 character.
static size_t mark_utf8_units(const char* bytes, int * utf8_units, int * utf8_nunits, size_t count) {
size_t offs = 0;
size_t count_utf8 = 0;
while(offs < count) {
int len = (int) utf8_len(bytes[offs]);
for (int i=0; i<len; ++i) {
utf8_units[offs+i] = i;
utf8_nunits[offs+i] = len;
}
offs += len;
++count_utf8;
}
return count_utf8;
}
size_t tokenize_file(
struct llama_context * lctx,
const char * filename,
const std::string & sample_start,
bool include_sample_start,
bool overlapping_samples,
unsigned context_length,
std::vector<llama_token> & out_tokens,
std::vector<size_t> & out_samples_begin,
std::vector<size_t> & out_samples_size) {
struct llama_file f(filename, "rb");
if (f.size == 0) {
out_tokens.clear();
out_samples_begin.clear();
out_samples_size.clear();
printf("%s: warning: empty or not existing training data file '%s'\n",
__func__, filename);
return out_tokens.size();
}
// account for possible leading whitespace that will be added by tokenizer
// e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]
const int n_max_tokens_overhead = 1;
std::vector<char> buf;
buf.resize(f.size+1);
f.read_raw(buf.data(), f.size);
buf[f.size] = '\0';
std::vector<int> utf8_units;
std::vector<int> utf8_nunits;
utf8_units.resize(buf.size());
utf8_nunits.resize(buf.size());
size_t n_utf8_chars = mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size());
if (sample_start.size() == 0) {
// tokenize all data at once
out_tokens.resize(buf.size() + n_max_tokens_overhead);
int n_tokens = llama_tokenize(lctx, buf.data(), out_tokens.data(), (int) out_tokens.size(), false);
if (n_tokens < 0) {
out_tokens.resize(-n_tokens);
n_tokens = llama_tokenize(lctx, buf.data(), out_tokens.data(), (int) out_tokens.size(), false);
}
if (n_tokens >= 0) {
out_tokens.resize(n_tokens);
}
// generate sample starts at all token positions
out_samples_begin.clear();
out_samples_begin.push_back(0);
out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size()));
size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0;
for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) {
out_samples_begin.push_back(sample_begin);
out_samples_size.push_back(context_length);
}
} else {
// split data into samples and tokenize each sample
std::string data_str(buf.data(), buf.size()-1);
out_samples_begin.clear();
out_samples_size.clear();
out_tokens.clear();
// find all positions of pattern sample_start
size_t sample_begin = data_str.find(sample_start, 0);
while (sample_begin != std::string::npos) {
out_samples_begin.push_back(sample_begin);
const size_t search_start = sample_begin + sample_start.size();
sample_begin = data_str.find(sample_start, search_start);
}
if (out_samples_begin.size() == 0) {
printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n",
__func__, sample_start.c_str());
out_samples_begin.push_back(0);
}
out_samples_size.resize(out_samples_begin.size(), 0);
std::vector<char> buf_sample;
std::vector<llama_token> tok_sample;
const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size());
size_t found_too_big_sample = 0;
size_t found_too_small_sample = 0;
size_t found_empty_sample = 0;
size_t found_min_sample_size = SIZE_MAX;
size_t found_max_sample_size = 0;
size_t max_token_text_size = 0;
int n_vocab = llama_n_vocab(lctx);
for (llama_token token=0; token < n_vocab; ++token) {
max_token_text_size = std::max(
max_token_text_size,
strlen(llama_token_get_text(lctx, token)));
}
// upper bound of context byte length.
// strings with this byte length should always tokenize to at least context_length tokens.
size_t context_byte_len = max_token_text_size*context_length;
for (unsigned i=0; i<out_samples_begin.size(); ++i) {
// determine sample begin and end from pattern positions
size_t sample_begin = out_samples_begin[i] + sample_begin_offset;
size_t sample_end = overlapping_samples
? std::min(
data_str.size(),
sample_begin + context_byte_len)
: (i+1 < out_samples_begin.size()
? out_samples_begin[i+1]
: data_str.size());
if (utf8_units[sample_end] > 0) {
// sample end is in the middle of an utf8 character.
// advance sample_end to the begin of the next utf8 character.
sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];
}
size_t sample_size = sample_end - sample_begin;
if (sample_size == 0) {
++found_empty_sample;
}
if (sample_size > 0) {
// llama_tokenize expects zero terminated string,
// copy sample into buffer and zero terminate it.
buf_sample.resize(sample_size+1);
memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size);
buf_sample[sample_size] = '\0';
// printf("sample: '%s'\n", buf_sample.data());
// tokenize the sample
tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
int n_tokens = llama_tokenize(lctx,
buf_sample.data(),
tok_sample.data(),
(int) tok_sample.size(), false);
if (n_tokens < 0) {
tok_sample.resize(-n_tokens);
n_tokens = llama_tokenize(lctx,
buf_sample.data(),
tok_sample.data(),
(int) tok_sample.size(), false);
GGML_ASSERT(n_tokens >= 0);
}
GGML_ASSERT(n_tokens <= (int) tok_sample.size());
if ((size_t) n_tokens > context_length) {
++found_too_big_sample;
} else if ((size_t) n_tokens < context_length) {
++found_too_small_sample;
}
found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens);
found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens);
// write out tokens, start and size of sample
// overwrite the string start position with the token start position
out_samples_begin[i] = out_tokens.size();
out_samples_size[i] = (size_t) n_tokens;
out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens);
} else {
out_samples_begin[i] = out_tokens.size();
out_samples_size[i] = 0;
}
}
if (found_too_big_sample > 0) {
printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n",
__func__, found_too_big_sample, found_max_sample_size, context_length);
}
if (found_too_small_sample > 0) {
printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n",
__func__, found_too_small_sample, found_min_sample_size, context_length);
}
if (found_empty_sample) {
printf("%s: warning: found %zu empty samples.\n",
__func__, found_empty_sample);
}
}
printf("%s: total number of samples: %zu\n",
__func__, out_samples_begin.size());
GGML_ASSERT(out_samples_begin.size() == out_samples_size.size());
return out_tokens.size();
}

113
common/train.h Normal file
View file

@ -0,0 +1,113 @@
// Various helper functions and utilities for training
#pragma once
#include <string>
#include <random>
#include <vector>
#include "ggml.h"
#include "llama.h"
struct random_normal_distribution;
struct random_uniform_distribution;
struct random_normal_distribution * init_random_normal_distribution (int seed, float mean, float std, float min, float max);
struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max);
void free_random_normal_distribution (struct random_normal_distribution * rnd);
void free_random_uniform_distribution(struct random_uniform_distribution * rnd);
struct ggml_tensor * randomize_tensor_normal (struct ggml_tensor * tensor, struct random_normal_distribution * rnd);
struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd);
float frand();
float frand_normal (struct random_normal_distribution * rnd);
float frand_uniform(struct random_uniform_distribution * rnd);
int clamp (const int v, const int min, const int max);
float fclamp(const float v, const float min, const float max);
void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0);
void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1);
void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2);
void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3);
size_t tokenize_file(
struct llama_context * lctx,
const char * filename,
const std::string & sample_start,
bool include_sample_start,
bool overlapping_samples,
unsigned context_length,
std::vector<llama_token> & out_tokens,
std::vector<size_t> & out_samples_begin,
std::vector<size_t> & out_samples_size);
int64_t get_example_targets_batch(
struct llama_context * lctx,
struct ggml_tensor * tokens_input,
struct ggml_tensor * target_probs,
int64_t example_id,
const size_t * samples_begin,
const size_t * samples_size,
size_t samples_count,
const llama_token * train_data,
size_t n_train_data,
bool separate_with_eos,
bool separate_with_bos,
bool fill_with_next_samples);
typedef std::string mt19937_state;
void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state);
mt19937_state mt19937_get_state(const std::mt19937& rng);
mt19937_state mt19937_seed_to_state(unsigned seed);
mt19937_state shuffle_samples(
const mt19937_state & rng_state,
size_t * shuffled_begins,
size_t * shuffled_sizes,
const size_t * begins,
const size_t * sizes,
size_t count);
size_t hash_combine(size_t h1, size_t h2);
size_t compute_samples_hash(
const char* fn,
const size_t* samples_begin,
const size_t* samples_size,
size_t sample_count);
std::string replace_str(const char * s, const char * needle, const char * replacement);
void print_duration(double milliseconds);
float cosine_decay(
int64_t step,
int64_t decay_steps,
float minimum);
float cosine_decay_restart(
int64_t step,
int64_t decay_steps,
float minimum,
float restart_step_mult);
float learning_schedule(
int64_t step,
int64_t warmup_steps,
int64_t decay_steps,
float learning_rate,
float overall_minimum,
float cos_decay_minimum,
float cos_decay_restart_step_mult,
bool enable_restart);
void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name);
void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt);
void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt);

View file

@ -1,4 +1,5 @@
#include "ggml.h"
#include "train.h"
#include <vector>
#include <cassert>
#include <random>
@ -14,29 +15,6 @@ static const float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS;
static const float rms_norm_eps = 5e-6f;
#endif
float frand() {
return (float)rand()/(float)RAND_MAX;
}
struct random_normal_distribution {
std::mt19937 gen;
std::normal_distribution<float> nd;
float min;
float max;
};
void init_random_normal_distribution(struct random_normal_distribution * rnd, int seed, float mean, float std, float min, float max) {
rnd->gen = std::mt19937(seed);
rnd->nd = std::normal_distribution<float>{mean, std};
rnd->min = min;
rnd->max = max;
}
float frand_normal(struct random_normal_distribution * rnd) {
const float r = rnd->nd(rnd->gen);
return ((r < rnd->min) ? (rnd->min) : (r > rnd->max) ? (rnd->max) : r);
}
void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
@ -95,56 +73,6 @@ struct ggml_tensor * randomize_tensor(
return tensor;
}
struct ggml_tensor * randomize_tensor_normal(
struct ggml_tensor * tensor,
int ndims,
const int64_t ne[],
struct random_normal_distribution * rnd) {
float scale = 1.0; // xavier
switch (ndims) {
case 1:
scale /= sqrtf(ne[0]);
for (int i0 = 0; i0 < ne[0]; i0++) {
((float *)tensor->data)[i0] = scale * frand_normal(rnd);
}
break;
case 2:
scale /= sqrtf(ne[0]+ne[1]);
for (int i1 = 0; i1 < ne[1]; i1++) {
for (int i0 = 0; i0 < ne[0]; i0++) {
((float *)tensor->data)[i1*ne[0] + i0] = scale * frand_normal(rnd);
}
}
break;
case 3:
scale /= sqrtf(ne[0]+ne[1]);
for (int i2 = 0; i2 < ne[2]; i2++) {
for (int i1 = 0; i1 < ne[1]; i1++) {
for (int i0 = 0; i0 < ne[0]; i0++) {
((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd);
}
}
}
break;
case 4:
scale /= sqrtf(ne[0]+ne[1]);
for (int i3 = 0; i3 < ne[3]; i3++) {
for (int i2 = 0; i2 < ne[2]; i2++) {
for (int i1 = 0; i1 < ne[1]; i1++) {
for (int i0 = 0; i0 < ne[0]; i0++) {
((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd);
}
}
}
}
break;
default:
assert(false);
};
return tensor;
}
struct llama_hparams {
uint32_t n_vocab = 32000;
uint32_t n_ctx = 512; // this is provided as user input?
@ -402,27 +330,29 @@ void randomize_model(struct llama_model * model, int seed, float mean, float std
const uint32_t n_layer = hparams.n_layer;
struct random_normal_distribution rnd;
init_random_normal_distribution(&rnd, seed, mean, std, min, max);
randomize_tensor_normal(model->tok_embeddings, model->tok_embeddings->n_dims, model->tok_embeddings->ne, &rnd);
randomize_tensor_normal(model->norm, model->norm->n_dims, model->norm->ne, &rnd);
randomize_tensor_normal(model->output, model->output->n_dims, model->output->ne, &rnd);
struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
randomize_tensor_normal(model->tok_embeddings , rnd);
randomize_tensor_normal(model->norm , rnd);
randomize_tensor_normal(model->output , rnd);
for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = model->layers[i];
randomize_tensor_normal(layer.attention_norm, layer.attention_norm->n_dims, layer.attention_norm->ne, &rnd);
randomize_tensor_normal(layer.attention_norm, rnd);
randomize_tensor_normal(layer.wq, layer.wq->n_dims, layer.wq->ne, &rnd);
randomize_tensor_normal(layer.wk, layer.wk->n_dims, layer.wk->ne, &rnd);
randomize_tensor_normal(layer.wv, layer.wv->n_dims, layer.wv->ne, &rnd);
randomize_tensor_normal(layer.wo, layer.wo->n_dims, layer.wo->ne, &rnd);
randomize_tensor_normal(layer.wq, rnd);
randomize_tensor_normal(layer.wk, rnd);
randomize_tensor_normal(layer.wv, rnd);
randomize_tensor_normal(layer.wo, rnd);
randomize_tensor_normal(layer.ffn_norm, layer.ffn_norm->n_dims, layer.ffn_norm->ne, &rnd);
randomize_tensor_normal(layer.ffn_norm, rnd);
randomize_tensor_normal(layer.w1, layer.w1->n_dims, layer.w1->ne, &rnd);
randomize_tensor_normal(layer.w2, layer.w2->n_dims, layer.w2->ne, &rnd);
randomize_tensor_normal(layer.w3, layer.w3->n_dims, layer.w3->ne, &rnd);
randomize_tensor_normal(layer.w1, rnd);
randomize_tensor_normal(layer.w2, rnd);
randomize_tensor_normal(layer.w3, rnd);
}
free_random_normal_distribution(rnd);
}
@ -431,32 +361,34 @@ void randomize_model_lora(struct llama_model_lora * model, int seed, float mean,
const uint32_t n_layer = hparams.n_layer;
struct random_normal_distribution rnd;
init_random_normal_distribution(&rnd, seed, mean, std, min, max);
randomize_tensor_normal(model->tok_embeddings, model->tok_embeddings->n_dims, model->tok_embeddings->ne, &rnd);
randomize_tensor_normal(model->norm, model->norm->n_dims, model->norm->ne, &rnd);
randomize_tensor_normal(model->outputa, model->outputa->n_dims, model->outputa->ne, &rnd);
randomize_tensor_normal(model->outputb, model->outputb->n_dims, model->outputb->ne, &rnd);
struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
randomize_tensor_normal(model->tok_embeddings, rnd);
randomize_tensor_normal(model->norm , rnd);
randomize_tensor_normal(model->outputa , rnd);
randomize_tensor_normal(model->outputb , rnd);
for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = model->layers[i];
randomize_tensor_normal(layer.attention_norm, layer.attention_norm->n_dims, layer.attention_norm->ne, &rnd);
randomize_tensor_normal(layer.attention_norm, rnd);
randomize_tensor_normal(layer.wqa, layer.wqa->n_dims, layer.wqa->ne, &rnd);
randomize_tensor_normal(layer.wqb, layer.wqb->n_dims, layer.wqb->ne, &rnd);
randomize_tensor_normal(layer.wka, layer.wka->n_dims, layer.wka->ne, &rnd);
randomize_tensor_normal(layer.wkb, layer.wkb->n_dims, layer.wkb->ne, &rnd);
randomize_tensor_normal(layer.wva, layer.wva->n_dims, layer.wva->ne, &rnd);
randomize_tensor_normal(layer.wvb, layer.wvb->n_dims, layer.wvb->ne, &rnd);
randomize_tensor_normal(layer.woa, layer.woa->n_dims, layer.woa->ne, &rnd);
randomize_tensor_normal(layer.wob, layer.wob->n_dims, layer.wob->ne, &rnd);
randomize_tensor_normal(layer.wqa, rnd);
randomize_tensor_normal(layer.wqb, rnd);
randomize_tensor_normal(layer.wka, rnd);
randomize_tensor_normal(layer.wkb, rnd);
randomize_tensor_normal(layer.wva, rnd);
randomize_tensor_normal(layer.wvb, rnd);
randomize_tensor_normal(layer.woa, rnd);
randomize_tensor_normal(layer.wob, rnd);
randomize_tensor_normal(layer.ffn_norm, layer.ffn_norm->n_dims, layer.ffn_norm->ne, &rnd);
randomize_tensor_normal(layer.ffn_norm, rnd);
randomize_tensor_normal(layer.w1, layer.w1->n_dims, layer.w1->ne, &rnd);
randomize_tensor_normal(layer.w2, layer.w2->n_dims, layer.w2->ne, &rnd);
randomize_tensor_normal(layer.w3, layer.w3->n_dims, layer.w3->ne, &rnd);
randomize_tensor_normal(layer.w1, rnd);
randomize_tensor_normal(layer.w2, rnd);
randomize_tensor_normal(layer.w3, rnd);
}
free_random_normal_distribution(rnd);
}
bool init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model, int n_batch) {
@ -756,32 +688,6 @@ struct ggml_tensor * forward(
return inpL;
}
void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
GGML_ASSERT(tensor->n_dims == 1);
GGML_ASSERT(tensor->ne[0] == ne0);
}
void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
GGML_ASSERT(tensor->n_dims == 2);
GGML_ASSERT(tensor->ne[0] == ne0);
GGML_ASSERT(tensor->ne[1] == ne1);
}
void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
GGML_ASSERT(tensor->n_dims == 3);
GGML_ASSERT(tensor->ne[0] == ne0);
GGML_ASSERT(tensor->ne[1] == ne1);
GGML_ASSERT(tensor->ne[2] == ne2);
}
void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
GGML_ASSERT(tensor->n_dims == 4);
GGML_ASSERT(tensor->ne[0] == ne0);
GGML_ASSERT(tensor->ne[1] == ne1);
GGML_ASSERT(tensor->ne[2] == ne2);
GGML_ASSERT(tensor->ne[3] == ne3);
}
struct ggml_tensor * forward_batch(
struct llama_model * model,
struct llama_kv_cache * cache,

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff