remember to copy back the last_eigenvector
This commit is contained in:
parent
1a088fb0a5
commit
163916864c
1 changed files with 15 additions and 11 deletions
|
@ -62,9 +62,6 @@ struct pca_model {
|
|||
struct ggml_tensor * dev_square;
|
||||
struct ggml_tensor * dev_eigenvector;
|
||||
|
||||
// tensors to store output data on host
|
||||
struct ggml_tensor * host_eigenvector;
|
||||
|
||||
pca_model(struct ggml_tensor * t_input) {
|
||||
#ifdef GGML_USE_CUDA
|
||||
fprintf(stderr, "%s: using CUDA backend\n", __func__);
|
||||
|
@ -129,17 +126,16 @@ struct pca_model {
|
|||
}
|
||||
|
||||
// init host context
|
||||
struct ggml_init_params host_params = {
|
||||
/*.mem_size =*/ (n_embd * sizeof(float) + ggml_tensor_overhead()) * 2u,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
ctx_host = ggml_init(host_params);
|
||||
host_eigenvector = ggml_new_tensor_1d(ctx_host, GGML_TYPE_F32, n_embd);
|
||||
//struct ggml_init_params host_params = {
|
||||
// /*.mem_size =*/ (n_embd * sizeof(float) + ggml_tensor_overhead()) * 2u,
|
||||
// /*.mem_buffer =*/ NULL,
|
||||
// /*.no_alloc =*/ false,
|
||||
//};
|
||||
//ctx_host = ggml_init(host_params);
|
||||
//host_eigenvector = ggml_new_tensor_1d(ctx_host, GGML_TYPE_F32, n_embd);
|
||||
}
|
||||
|
||||
~pca_model() {
|
||||
ggml_free(ctx_host);
|
||||
ggml_free(ctx);
|
||||
ggml_backend_buffer_free(buffer);
|
||||
ggml_backend_free(backend);
|
||||
|
@ -299,6 +295,14 @@ static void power_iteration(
|
|||
ggml_backend_tensor_set(model.dev_square, tmp_buf.data(), 0, tmp_buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
// copy last eigen vector and store as input for next iteration
|
||||
GGML_ASSERT(last_eigenvector != NULL);
|
||||
std::vector<uint8_t> tmp_buf(ggml_nbytes(last_eigenvector));
|
||||
ggml_backend_tensor_get(last_eigenvector, tmp_buf.data(), 0, tmp_buf.size());
|
||||
ggml_backend_tensor_set(model.dev_eigenvector, tmp_buf.data(), 0, tmp_buf.size());
|
||||
}
|
||||
|
||||
printf("%s: layer %d/%d, iteration: %d / total: %d (batch = %d) ...\n",
|
||||
__func__, params.i_layer+1, params.n_layers, iter, n_iters, params.n_batch);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue