examples : factor out plan allocation into a helper function

This commit is contained in:
Georgi Gerganov 2023-07-06 21:08:25 +03:00
parent 1b9994f809
commit a67404e749
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 31 additions and 65 deletions

View file

@ -31,6 +31,17 @@ float frand_normal(struct random_normal_distribution * rnd) {
return ((r < rnd->min) ? (rnd->min) : (r > rnd->max) ? (rnd->max) : r);
}
void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
if (plan.work_size > 0) {
buf.resize(plan.work_size);
plan.work_data = buf.data();
}
ggml_graph_compute(graph, &plan);
}
struct ggml_tensor * randomize_tensor(
struct ggml_tensor * tensor,
int ndims,
@ -1596,15 +1607,7 @@ int main(int argc, char ** argv) {
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
ggml_build_forward_expand(&gf, e);
{
struct ggml_cplan pf = ggml_graph_plan(&gf, /*n_threads*/ 1);
if (pf.work_size > 0) {
work_buffer.resize(pf.work_size);
pf.work_data = work_buffer.data();
}
ggml_graph_compute(&gf, &pf);
}
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
float error_before_opt = ggml_get_f32_1d(e, 0);
@ -1620,15 +1623,7 @@ int main(int argc, char ** argv) {
ggml_opt(ctx0, opt_params_lbfgs, e);
//
ggml_build_forward_expand(&gf, e);
{
struct ggml_cplan pf = ggml_graph_plan(&gf, /*n_threads*/ 1);
if (pf.work_size > 0) {
work_buffer.resize(pf.work_size);
pf.work_data = work_buffer.data();
}
ggml_graph_compute(&gf, &pf);
}
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
float error_after_opt = ggml_get_f32_1d(e, 0);
@ -1681,15 +1676,7 @@ int main(int argc, char ** argv) {
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
ggml_build_forward_expand(&gf, logits);
{
struct ggml_cplan pf = ggml_graph_plan(&gf, /*n_threads*/ 1);
if (pf.work_size > 0) {
work_buffer.resize(pf.work_size);
pf.work_data = work_buffer.data();
}
ggml_graph_compute(&gf, &pf);
}
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
@ -1711,10 +1698,11 @@ int main(int argc, char ** argv) {
}
print_matrix(model.tok_embeddings);
printf("done\n");
// ggml_free(kv_self.ctx);
// ggml_free(model_lora.ctx);
ggml_free(model.ctx);
return 0;
}

View file

@ -60,6 +60,17 @@ float frand_uniform(struct random_uniform_distribution * rnd) {
return rnd->rd(rnd->gen);
}
void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
if (plan.work_size > 0) {
buf.resize(plan.work_size);
plan.work_data = buf.data();
}
ggml_graph_compute(graph, &plan);
}
struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
float scale = 1.0f; // xavier
switch (tensor->n_dims) {
@ -3246,14 +3257,7 @@ int main(int argc, char ** argv) {
*gb = ggml_build_backward(ctx0, gf, true);
}
{
ggml_cplan pf = ggml_graph_plan(gf, params.n_threads);
if (pf.work_size > 0) {
work_buffer.resize(pf.work_size);
pf.work_data = work_buffer.data();
}
ggml_graph_compute(gf, &pf);
}
ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
size_t used_mem_before_opt = ggml_used_mem(ctx0);
@ -3277,14 +3281,7 @@ int main(int argc, char ** argv) {
model.train_samples += n_batch;
model.train_tokens += n_batch * n_tokens;
{
ggml_cplan pf = ggml_graph_plan(gf, params.n_threads);
if (pf.work_size > 0) {
work_buffer.resize(pf.work_size);
pf.work_data = work_buffer.data();
}
ggml_graph_compute(gf, &pf);
}
ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
float error_after_opt = ggml_get_f32_1d(loss, 0);
@ -3371,15 +3368,7 @@ int main(int argc, char ** argv) {
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
ggml_build_forward_expand(&gf, logits);
{
ggml_cplan pf = ggml_graph_plan(&gf, params.n_threads);
if (pf.work_size > 0) {
work_buffer.resize(pf.work_size);
pf.work_data = work_buffer.data();
}
ggml_graph_compute(&gf, &pf);
}
ggml_graph_compute_helper(work_buffer, &gf, params.n_threads);
//struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
//struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);

13
ggml.h
View file

@ -65,18 +65,7 @@
// ggml_set_f32(a, 3.0f);
// ggml_set_f32(b, 4.0f);
//
// struct ggml_cplan pf = ggml_graph_compute_make_plan(&gf, n_threads);
//
// if (pf.work_size > 0) {
// pf.work_data = malloc(pf.work_size);
// GGML_ASSERT(pf.work_data);
// }
//
// ggml_graph_compute(&gf, &pf);
//
// if (pf.work_data) {
// free(pf.work_data);
// }
// ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
//
// printf("f = %f\n", ggml_get_f32_1d(f, 0));
//