add llama sampler, shuffle samples and constrain sampling to tokens occurring in train data
This commit is contained in:
parent
ec881156f6
commit
e063135d0b
1 changed files with 219 additions and 16 deletions
|
@ -7,6 +7,7 @@
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <cstdarg>
|
#include <cstdarg>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
|
||||||
struct random_normal_distribution {
|
struct random_normal_distribution {
|
||||||
|
@ -42,6 +43,10 @@ float fclamp(const float v, const float min, const float max) {
|
||||||
return ((v < min) ? (min) : (v > max) ? (max) : v);
|
return ((v < min) ? (min) : (v > max) ? (max) : v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float frand() {
|
||||||
|
return (float)rand()/(float)RAND_MAX;
|
||||||
|
}
|
||||||
|
|
||||||
float frand_normal(struct random_normal_distribution * rnd) {
|
float frand_normal(struct random_normal_distribution * rnd) {
|
||||||
return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
|
return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
|
||||||
}
|
}
|
||||||
|
@ -162,6 +167,17 @@ uint32_t get_n_ff(const struct my_llama_hparams* hparams) {
|
||||||
return n_ff;
|
return n_ff;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void print_params(struct my_llama_hparams * params) {
|
||||||
|
printf("%s: n_vocab: %d\n", __func__, params->n_vocab);
|
||||||
|
printf("%s: n_ctx: %d\n", __func__, params->n_ctx);
|
||||||
|
printf("%s: n_embd: %d\n", __func__, params->n_embd);
|
||||||
|
printf("%s: n_mult: %d\n", __func__, params->n_mult);
|
||||||
|
printf("%s: n_head: %d\n", __func__, params->n_head);
|
||||||
|
printf("%s: n_ff: %d\n", __func__, get_n_ff(params));
|
||||||
|
printf("%s: n_layer: %d\n", __func__, params->n_layer);
|
||||||
|
printf("%s: n_rot: %d\n", __func__, params->n_rot);
|
||||||
|
}
|
||||||
|
|
||||||
struct my_llama_layer {
|
struct my_llama_layer {
|
||||||
// normalization
|
// normalization
|
||||||
struct ggml_tensor * attention_norm;
|
struct ggml_tensor * attention_norm;
|
||||||
|
@ -989,18 +1005,17 @@ void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_example_targets(const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) {
|
void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) {
|
||||||
int n_tokens = tokens_input->ne[0];
|
int n_tokens = tokens_input->ne[0];
|
||||||
int n_vocab = targets->ne[0];
|
int n_vocab = targets->ne[0];
|
||||||
|
|
||||||
int n_examples = (n_train_data / (size_t) n_tokens);
|
int sample = train_samples[example_id % n_train_samples];
|
||||||
int begin = (example_id % n_examples) * n_tokens;
|
GGML_ASSERT(sample+n_tokens-1 < n_train_data);
|
||||||
GGML_ASSERT(begin+n_tokens-1 < n_train_data);
|
|
||||||
|
|
||||||
ggml_set_f32(targets, -1.0f);
|
ggml_set_f32(targets, -1.0f/n_vocab);
|
||||||
ggml_set_i32_1d(tokens_input, 0, llama_token_bos());
|
ggml_set_i32_1d(tokens_input, 0, llama_token_bos());
|
||||||
for (int i=1; i<n_tokens+1; ++i) {
|
for (int i=1; i<n_tokens+1; ++i) {
|
||||||
int token = clamp(train_data[begin+i-1], 0, n_vocab-1);
|
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
|
||||||
ggml_set_f32_1d(targets, (i-1)*n_vocab + token, +1.0f);
|
ggml_set_f32_1d(targets, (i-1)*n_vocab + token, +1.0f);
|
||||||
if (i<n_tokens) {
|
if (i<n_tokens) {
|
||||||
ggml_set_i32_1d(tokens_input, i, token);
|
ggml_set_i32_1d(tokens_input, i, token);
|
||||||
|
@ -1008,7 +1023,7 @@ void get_example_targets(const llama_token * train_data, size_t n_train_data, in
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_example_targets_batch(struct ggml_context * ctx, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) {
|
void get_example_targets_batch(struct ggml_context * ctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) {
|
||||||
GGML_ASSERT(tokens_input->n_dims == 2);
|
GGML_ASSERT(tokens_input->n_dims == 2);
|
||||||
GGML_ASSERT( targets->n_dims == 3);
|
GGML_ASSERT( targets->n_dims == 3);
|
||||||
int n_tokens = tokens_input->ne[0];
|
int n_tokens = tokens_input->ne[0];
|
||||||
|
@ -1028,7 +1043,7 @@ void get_example_targets_batch(struct ggml_context * ctx, const llama_token * tr
|
||||||
targets->nb[1],
|
targets->nb[1],
|
||||||
k*targets->nb[2]);
|
k*targets->nb[2]);
|
||||||
|
|
||||||
get_example_targets(train_data, n_train_data,
|
get_example_targets(train_samples, n_train_samples, train_data, n_train_data,
|
||||||
example_id*n_batch + k, tokens_input_k, targets_k);
|
example_id*n_batch + k, tokens_input_k, targets_k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1171,9 +1186,10 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto
|
||||||
struct llama_file f(filename, "rb");
|
struct llama_file f(filename, "rb");
|
||||||
|
|
||||||
std::vector<char> buf;
|
std::vector<char> buf;
|
||||||
buf.resize(f.size);
|
buf.resize(f.size+1);
|
||||||
|
|
||||||
f.read_raw(buf.data(), f.size);
|
f.read_raw(buf.data(), f.size);
|
||||||
|
buf[f.size] = '\0';
|
||||||
|
|
||||||
out.resize(buf.size());
|
out.resize(buf.size());
|
||||||
|
|
||||||
|
@ -1186,6 +1202,143 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto
|
||||||
return n_tokens;
|
return n_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void shuffle_ints(int * begin, int * end) {
|
||||||
|
if (end <= begin) return;
|
||||||
|
int max=begin[0];
|
||||||
|
for (int i=1; i<end-begin; ++i) {
|
||||||
|
if (begin[i] > max) {
|
||||||
|
max = begin[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<float> vals;
|
||||||
|
vals.resize(max+1);
|
||||||
|
for (int i=0; i<max+1; ++i) {
|
||||||
|
vals[i] = frand();
|
||||||
|
}
|
||||||
|
std::sort(begin, end, [&vals](auto a, auto b){
|
||||||
|
return vals.at(a) < vals.at(b);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
struct my_llama_sampler_params {
|
||||||
|
float temp = 0.0f; // <= 0.0 disabled
|
||||||
|
int top_k = 20; // <= 0 to use vocab size
|
||||||
|
float top_p = 0.95f; // 1.0 = disabled
|
||||||
|
float tfs_z = 1.00f; // 1.0 = disabled
|
||||||
|
float typical_p = 1.00f; // 1.0 = disabled
|
||||||
|
int repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
|
float repeat_penalty = 1.0f; // 1.0 = disabled
|
||||||
|
float alpha_presence = 0.0f; // 0.0 = disabled
|
||||||
|
float alpha_frequency = 0.0f; // 0.0 = disabled
|
||||||
|
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
|
bool penalize_nl = true; // consider newlines as a repeatable token
|
||||||
|
};
|
||||||
|
|
||||||
|
struct my_llama_sampler {
|
||||||
|
struct llama_context * ctx = NULL;
|
||||||
|
my_llama_sampler_params params;
|
||||||
|
|
||||||
|
int n_vocab = 0;
|
||||||
|
int n_ctx = 0;
|
||||||
|
|
||||||
|
float mirostat_mu;
|
||||||
|
|
||||||
|
std::vector<llama_token_data> candidates;
|
||||||
|
llama_token_data_array candidates_p;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
void init_sampler(struct my_llama_sampler * sampler, struct llama_context * ctx) {
|
||||||
|
sampler->ctx = ctx;
|
||||||
|
sampler->n_vocab = llama_n_vocab(sampler->ctx);
|
||||||
|
sampler->n_ctx = llama_n_ctx(sampler->ctx);
|
||||||
|
sampler->mirostat_mu = 2.0f * sampler->params.mirostat_tau;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token sample(struct my_llama_sampler * sampler, float * logits, const llama_token * last_tokens, int n_last_tokens) {
|
||||||
|
GGML_ASSERT(sampler->ctx != NULL);
|
||||||
|
|
||||||
|
struct llama_context * ctx = sampler->ctx;
|
||||||
|
|
||||||
|
sampler->candidates.resize(sampler->n_vocab);
|
||||||
|
for (llama_token token_id = 0; token_id < sampler->n_vocab; ++token_id) {
|
||||||
|
sampler->candidates[token_id].id = token_id;
|
||||||
|
sampler->candidates[token_id].logit = logits[token_id];
|
||||||
|
sampler->candidates[token_id].p = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token_data_array * candidates_p = & sampler->candidates_p;
|
||||||
|
|
||||||
|
candidates_p->data = sampler->candidates.data();
|
||||||
|
candidates_p->size = sampler->candidates.size();
|
||||||
|
candidates_p->sorted = false;
|
||||||
|
|
||||||
|
const auto params = sampler->params;
|
||||||
|
|
||||||
|
// Apply penalties
|
||||||
|
const float nl_logit = logits[llama_token_nl()];
|
||||||
|
|
||||||
|
const int n_last = std::min(std::min(n_last_tokens, params.repeat_last_n), sampler->n_ctx);
|
||||||
|
|
||||||
|
llama_sample_repetition_penalty(
|
||||||
|
ctx,
|
||||||
|
candidates_p,
|
||||||
|
last_tokens + n_last_tokens - n_last,
|
||||||
|
n_last,
|
||||||
|
params.repeat_penalty);
|
||||||
|
llama_sample_frequency_and_presence_penalties(
|
||||||
|
ctx,
|
||||||
|
candidates_p,
|
||||||
|
last_tokens + n_last_tokens - n_last,
|
||||||
|
n_last,
|
||||||
|
params.alpha_frequency,
|
||||||
|
params.alpha_presence);
|
||||||
|
|
||||||
|
if (!params.penalize_nl) {
|
||||||
|
logits[llama_token_nl()] = nl_logit;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token token = 0;
|
||||||
|
if (params.temp <= 0) {
|
||||||
|
// Greedy sampling
|
||||||
|
token = llama_sample_token_greedy(ctx, candidates_p);
|
||||||
|
} else {
|
||||||
|
if (params.mirostat == 1) {
|
||||||
|
int mirostat_m = 100;
|
||||||
|
llama_sample_temperature(ctx, candidates_p, params.temp);
|
||||||
|
token = llama_sample_token_mirostat(ctx, candidates_p, params.mirostat_tau, params.mirostat_eta, mirostat_m, &sampler->mirostat_mu);
|
||||||
|
} else if (params.mirostat == 2) {
|
||||||
|
llama_sample_temperature(ctx, candidates_p, params.temp);
|
||||||
|
token = llama_sample_token_mirostat_v2(ctx, candidates_p, params.mirostat_tau, params.mirostat_eta, &sampler->mirostat_mu);
|
||||||
|
} else {
|
||||||
|
// Temperature sampling
|
||||||
|
llama_sample_top_k (ctx, candidates_p, params.top_k, 1);
|
||||||
|
llama_sample_tail_free (ctx, candidates_p, params.tfs_z, 1);
|
||||||
|
llama_sample_typical (ctx, candidates_p, params.typical_p, 1);
|
||||||
|
|
||||||
|
llama_sample_top_p (ctx, candidates_p, params.top_p, 1);
|
||||||
|
llama_sample_temperature (ctx, candidates_p, params.temp);
|
||||||
|
token = llama_sample_token(ctx, candidates_p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return token;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_logits_masked(struct ggml_tensor * logits, std::vector<bool>& mask, float value) {
|
||||||
|
GGML_ASSERT(logits->ne[0] == mask.size());
|
||||||
|
for (int i2 = 0; i2 < logits->ne[2]; ++i2) {
|
||||||
|
for (int i1 = 0; i1 < logits->ne[1]; ++i1) {
|
||||||
|
for (int i0 = 0; i0 < logits->ne[0]; ++i0) {
|
||||||
|
if (!mask[i0]) continue;
|
||||||
|
float * ptr = (float *) ((char *) logits->data + i2*logits->nb[2] + i1*logits->nb[1] + i0*logits->nb[0]);
|
||||||
|
*ptr = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
const char * default_model = "ggml-vic7b-uncensored-q4_0.bin";
|
const char * default_model = "ggml-vic7b-uncensored-q4_0.bin";
|
||||||
const char * default_train = "shakespeare.txt";
|
const char * default_train = "shakespeare.txt";
|
||||||
|
@ -1220,6 +1373,17 @@ int main(int argc, char ** argv) {
|
||||||
model.hparams.n_layer = 1;
|
model.hparams.n_layer = 1;
|
||||||
model.hparams.n_rot = std::min(16u, model.hparams.n_embd / model.hparams.n_head);
|
model.hparams.n_rot = std::min(16u, model.hparams.n_embd / model.hparams.n_head);
|
||||||
|
|
||||||
|
print_params(&model.hparams);
|
||||||
|
|
||||||
|
std::vector<bool> token_occurs;
|
||||||
|
std::vector<bool> token_notavail;
|
||||||
|
token_occurs.resize(model.hparams.n_vocab, false);
|
||||||
|
token_notavail.resize(model.hparams.n_vocab, true);
|
||||||
|
for (int i=0; i<train_tokens.size(); ++i) {
|
||||||
|
token_occurs[train_tokens[i]] = true;
|
||||||
|
token_notavail[train_tokens[i]] = false;
|
||||||
|
}
|
||||||
|
|
||||||
struct my_llama_kv_cache kv_self;
|
struct my_llama_kv_cache kv_self;
|
||||||
|
|
||||||
int n_batch = 8;
|
int n_batch = 8;
|
||||||
|
@ -1232,11 +1396,15 @@ int main(int argc, char ** argv) {
|
||||||
model.ctx = ggml_init(lcparams);
|
model.ctx = ggml_init(lcparams);
|
||||||
kv_self.ctx = model.ctx;
|
kv_self.ctx = model.ctx;
|
||||||
|
|
||||||
|
my_llama_sampler sampler;
|
||||||
|
|
||||||
printf("%s: init model\n", __func__);
|
printf("%s: init model\n", __func__);
|
||||||
init_model(&model);
|
init_model(&model);
|
||||||
set_param_model(&model);
|
set_param_model(&model);
|
||||||
randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
|
randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
|
||||||
init_kv_cache(&kv_self, &model, n_batch);
|
init_kv_cache(&kv_self, &model, n_batch);
|
||||||
|
init_sampler(&sampler, lctx);
|
||||||
|
|
||||||
|
|
||||||
size_t compute_size = 1024ll*1024ll*1024ll*32ll;
|
size_t compute_size = 1024ll*1024ll*1024ll*32ll;
|
||||||
uint8_t * compute_addr = new uint8_t[compute_size];
|
uint8_t * compute_addr = new uint8_t[compute_size];
|
||||||
|
@ -1245,9 +1413,25 @@ int main(int argc, char ** argv) {
|
||||||
int n_tokens = model.hparams.n_ctx;
|
int n_tokens = model.hparams.n_ctx;
|
||||||
int n_vocab = model.hparams.n_vocab;
|
int n_vocab = model.hparams.n_vocab;
|
||||||
|
|
||||||
|
std::vector<int> train_samples;
|
||||||
|
for (int i=0; i<train_tokens.size()-n_tokens; ++i) {
|
||||||
|
train_samples.push_back(i);
|
||||||
|
}
|
||||||
|
shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size());
|
||||||
|
for (int i=0; i<train_samples.size(); ++i) {
|
||||||
|
GGML_ASSERT(train_samples[i]+n_tokens-1 < train_tokens.size());
|
||||||
|
}
|
||||||
|
|
||||||
printf("%s: begin training\n", __func__);
|
printf("%s: begin training\n", __func__);
|
||||||
|
|
||||||
for (int ex=0; ex<n_examples; ++ex) {
|
for (int ex=0; ex<n_examples; ++ex) {
|
||||||
|
if (ex*n_batch >= train_samples.size()) {
|
||||||
|
shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size());
|
||||||
|
for (int i=0; i<train_samples.size(); ++i) {
|
||||||
|
GGML_ASSERT(train_samples[i]+n_tokens-1 < train_tokens.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ compute_size,
|
/*.mem_size =*/ compute_size,
|
||||||
/*.mem_buffer =*/ compute_addr,
|
/*.mem_buffer =*/ compute_addr,
|
||||||
|
@ -1302,14 +1486,27 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ex % 2 == 0) {
|
if (ex % 2 == 0) {
|
||||||
sample_softmax_batch(ctx0, logits, after_opt_probs, after_opt_best_samples);
|
set_logits_masked(logits, token_notavail, -1e9);
|
||||||
|
for (int i=0; i<n_batch; ++i) {
|
||||||
|
init_sampler(&sampler, lctx);
|
||||||
|
for (int k=0; k<n_tokens; ++k) {
|
||||||
|
int32_t token = sample(&sampler,
|
||||||
|
(float *) ((char *) logits->data + i*logits->nb[2] + k*logits->nb[1]),
|
||||||
|
(llama_token *) ((char *) tokens_input->data + i*tokens_input->nb[1]),
|
||||||
|
k);
|
||||||
|
* ((int32_t *) ((char *) after_opt_best_samples->data + i*after_opt_best_samples->nb[1] + k*after_opt_best_samples->nb[0])) = token;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sample_softmax_batch(ctx0, logits, after_opt_probs, after_opt_best_samples);
|
||||||
// printf("probabilities after optimization:\n");
|
// printf("probabilities after optimization:\n");
|
||||||
// print_matrix(after_opt_probs);
|
// print_matrix(after_opt_probs);
|
||||||
printf("Example:\n---\n");
|
printf("Example:\n---\n");
|
||||||
print_tokens_batch(lctx, tokens_input);
|
print_tokens_batch(lctx, tokens_input);
|
||||||
printf("\n---\n");
|
printf("\n---\n");
|
||||||
|
|
||||||
printf("best samples after optimization:\n---\n");
|
// printf("best samples after optimization:\n---\n");
|
||||||
|
printf("samples after optimization:\n---\n");
|
||||||
print_tokens_batch(lctx, after_opt_best_samples);
|
print_tokens_batch(lctx, after_opt_best_samples);
|
||||||
printf("\n---\n");
|
printf("\n---\n");
|
||||||
}
|
}
|
||||||
|
@ -1321,12 +1518,14 @@ int main(int argc, char ** argv) {
|
||||||
int n_gen = 128;
|
int n_gen = 128;
|
||||||
int sample_ctx = n_tokens - n_tokens/8;
|
int sample_ctx = n_tokens - n_tokens/8;
|
||||||
|
|
||||||
|
init_sampler(&sampler, lctx);
|
||||||
|
|
||||||
printf("Generating %d tokens.\n", n_gen);
|
printf("Generating %d tokens.\n", n_gen);
|
||||||
|
|
||||||
struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens);
|
struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens);
|
||||||
struct ggml_tensor * targets = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
|
struct ggml_tensor * targets = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
|
||||||
|
|
||||||
get_example_targets(train_tokens.data(), train_tokens.size(), 137, tokens_input, targets);
|
get_example_targets(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), 137, tokens_input, targets);
|
||||||
for (int i=sample_ctx; i<n_tokens; ++i) {
|
for (int i=sample_ctx; i<n_tokens; ++i) {
|
||||||
ggml_set_i32_1d(tokens_input, i, n_vocab/2);
|
ggml_set_i32_1d(tokens_input, i, n_vocab/2);
|
||||||
}
|
}
|
||||||
|
@ -1356,9 +1555,13 @@ int main(int argc, char ** argv) {
|
||||||
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
|
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
|
||||||
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
|
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
|
||||||
|
|
||||||
sample_softmax(logits, probs, best_samples);
|
set_logits_masked(logits, token_notavail, -1e9);
|
||||||
|
int token = sample(&sampler,
|
||||||
int token = ggml_get_i32_1d(best_samples, sample_ctx-1);
|
(float *) ((char *) logits->data + (sample_ctx-1)*logits->nb[1]),
|
||||||
|
(llama_token *) tokens_input->data,
|
||||||
|
sample_ctx-1);
|
||||||
|
// sample_softmax(logits, probs, best_samples);
|
||||||
|
//int token = ggml_get_i32_1d(best_samples, sample_ctx-1);
|
||||||
|
|
||||||
// print_row(probs, sample_at);
|
// print_row(probs, sample_at);
|
||||||
print_token(lctx, token);
|
print_token(lctx, token);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue