Add complete implementation of the classical PCA algorithm with covariance matrix and power iteration with a very simple test file
This commit is contained in:
parent
09ecbcb596
commit
5c1d1177d3
4 changed files with 438 additions and 2 deletions
7
Makefile
7
Makefile
|
@ -38,6 +38,7 @@ BUILD_TARGETS = \
|
|||
llama-tokenize \
|
||||
llama-vdot \
|
||||
llama-cvector-generator \
|
||||
llama-test-vanilla-pca \
|
||||
llama-gen-docs \
|
||||
tests/test-c.o
|
||||
|
||||
|
@ -1479,6 +1480,12 @@ llama-cvector-generator: examples/cvector-generator/cvector-generator.cpp \
|
|||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
# TODO: Move to tests
|
||||
llama-test-vanilla-pca: examples/cvector-generator/mini-tests/test-vanilla-pca.cpp \
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
llama-convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp \
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
|
|
|
@ -2,8 +2,7 @@
|
|||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
#include "pca.hpp"
|
||||
#include "mean.hpp"
|
||||
#include "vanilla_pca.hpp"
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
#include "ggml-cuda.h"
|
||||
|
|
116
examples/cvector-generator/mini-tests/test-vanilla-pca.cpp
Normal file
116
examples/cvector-generator/mini-tests/test-vanilla-pca.cpp
Normal file
|
@ -0,0 +1,116 @@
|
|||
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
#include "../vanilla_pca.hpp"
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
#include "ggml-cuda.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
#include "ggml-metal.h"
|
||||
#endif
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
|
||||
// Function to initialize ggml with optional GPU backend support
|
||||
struct ggml_context *initialize_ggml_context() {
|
||||
#ifdef GGML_USE_CUDA
|
||||
struct ggml_init_params params = { .mem_size = 1024 * 1024, .mem_buffer = NULL, .use_gpu = true };
|
||||
printf("Initializing with GPU backend...\n");
|
||||
#else
|
||||
struct ggml_init_params params = { .mem_size = 1024 * 1024, .mem_buffer = NULL };
|
||||
printf("Initializing with CPU backend...\n");
|
||||
#endif
|
||||
return ggml_init(params);
|
||||
}
|
||||
|
||||
// Helper function to create a tensor from a matrix
|
||||
struct ggml_tensor *create_tensor(struct ggml_context *ctx, float *data, int rows, int cols) {
|
||||
struct ggml_tensor *tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, cols, rows);
|
||||
memcpy(tensor->data, data, ggml_nbytes(tensor));
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Function to run PCA and print results
|
||||
void run_pca_test(struct ggml_context *ctx, float *matrix, int rows, int cols) {
|
||||
struct ggml_tensor *input_tensor = create_tensor(ctx, matrix, rows, cols);
|
||||
|
||||
PCA::pca_params pca_params;
|
||||
pca_params.n_threads = 8;
|
||||
pca_params.n_batch = 20;
|
||||
pca_params.n_iterations = 1000;
|
||||
pca_params.tolerance = 1e-5;
|
||||
|
||||
PCA::pca_result result;
|
||||
PCA::run_single_pca(pca_params, input_tensor, result);
|
||||
|
||||
printf("\nPrincipal components:\n");
|
||||
float *b = (float *)result.principal_component->data;
|
||||
for (int i = 0; i < result.principal_component->ne[0]; i++) {
|
||||
printf("%f ", b[i]);
|
||||
}
|
||||
printf("\nEigenvalue: %f\n", result.explained_variance);
|
||||
}
|
||||
|
||||
int main() {
|
||||
// Initialize ggml context
|
||||
struct ggml_context *ctx = initialize_ggml_context();
|
||||
if (ctx == NULL) {
|
||||
printf("Failed to initialize ggml context\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Define matrices
|
||||
float input_matrix1[16] = {
|
||||
-0.124132, 0.740341, -0.452462, 0.777050,
|
||||
1.045571, -0.342142, -0.926047, -0.512965,
|
||||
0.710109, 0.092479, 0.630075, 1.762937,
|
||||
0.230954, -0.808937, 1.057424, 0.051361
|
||||
};
|
||||
|
||||
float input_matrix2[100] = {
|
||||
440152.493740, 122038.234845, 495176.910111, 34388.521115, 909320.402079, 258779.981600, 662522.284354, 311711.076089, 520068.021178, 546710.279343,
|
||||
184854.455526, 969584.627765, 775132.823361, 939498.941564, 894827.350428, 597899.978811, 921874.235023, 88492.502052, 195982.862419, 45227.288911,
|
||||
325330.330763, 388677.289689, 271349.031774, 828737.509152, 356753.326694, 280934.509687, 542696.083158, 140924.224975, 802196.980754, 74550.643680,
|
||||
986886.936601, 772244.769297, 198715.681534, 5522.117124, 815461.428455, 706857.343848, 729007.168041, 771270.346686, 74044.651734, 358465.728544,
|
||||
115869.059525, 863103.425876, 623298.126828, 330898.024853, 63558.350286, 310982.321716, 325183.322027, 729606.178338, 637557.471355, 887212.742576,
|
||||
472214.925162, 119594.245938, 713244.787223, 760785.048617, 561277.197569, 770967.179955, 493795.596364, 522732.829382, 427541.018359, 25419.126744,
|
||||
107891.426993, 31429.185687, 636410.411264, 314355.981076, 508570.691165, 907566.473926, 249292.229149, 410382.923036, 755551.138543, 228798.165492,
|
||||
76979.909829, 289751.452914, 161221.287254, 929697.652343, 808120.379564, 633403.756510, 871460.590188, 803672.076899, 186570.058886, 892558.998490,
|
||||
539342.241916, 807440.155164, 896091.299923, 318003.474972, 110051.924528, 227935.162542, 427107.788626, 818014.765922, 860730.583256, 6952.130531,
|
||||
510747.302578, 417411.003149, 222107.810471, 119865.367334, 337615.171404, 942909.703913, 323202.932021, 518790.621743, 703018.958895, 363629.602379
|
||||
};
|
||||
|
||||
float input_matrix3[9] = {
|
||||
0.374540, 0.950714, 0.731994,
|
||||
0.598658, 0.156019, 0.155995,
|
||||
0.058084, 0.866176, 0.601115
|
||||
};
|
||||
|
||||
float input_matrix4[9] = {
|
||||
10.000000, 0.000000, 0.000000,
|
||||
0.000000, 5.000000, 0.000000,
|
||||
0.000000, 0.000000, 1.000000
|
||||
};
|
||||
|
||||
// Run PCA for each matrix
|
||||
printf("Testing Matrix 1:\n");
|
||||
run_pca_test(ctx, input_matrix1, 4, 4);
|
||||
|
||||
printf("\nTesting Matrix 2:\n");
|
||||
run_pca_test(ctx, input_matrix2, 10, 10);
|
||||
|
||||
printf("\nTesting Matrix 3:\n");
|
||||
run_pca_test(ctx, input_matrix3, 3, 3);
|
||||
|
||||
printf("\nTesting Matrix 4:\n");
|
||||
run_pca_test(ctx, input_matrix4, 3, 3);
|
||||
|
||||
// Cleanup
|
||||
ggml_free(ctx);
|
||||
return 0;
|
||||
}
|
||||
|
314
examples/cvector-generator/vanilla_pca.hpp
Normal file
314
examples/cvector-generator/vanilla_pca.hpp
Normal file
|
@ -0,0 +1,314 @@
|
|||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
#include "ggml-cuda.h"
|
||||
#endif
|
||||
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
#define DEBUG_POS 5
|
||||
|
||||
static void print_debug_tensor(struct ggml_tensor * t, bool with_data = true) {
|
||||
printf("%s: %s (%s): [%d, %d]\n", __func__, t->name, ggml_type_name(t->type), (int) t->ne[0], (int) t->ne[1]);
|
||||
if (!with_data) return;
|
||||
printf("%s: %s[0] = [", __func__, t->name);
|
||||
for (size_t i = 0; i <= DEBUG_POS; i++) {
|
||||
printf(" %f,", ggml_get_f32_nd(t, i, 0, 0, 0));
|
||||
}
|
||||
printf(" ... ]\n");
|
||||
}
|
||||
|
||||
// begin vanilla pca namespace
|
||||
namespace PCA {
|
||||
|
||||
// input params for PCA computations
|
||||
struct pca_params {
|
||||
int n_threads = 1;
|
||||
int n_batch = 20; // number of iterations do to in one batch. larger the batch, more memory is used
|
||||
int n_iterations = 1000;
|
||||
float tolerance = 1e-7;
|
||||
};
|
||||
|
||||
// result from each iteration
|
||||
struct pca_result {
|
||||
struct ggml_tensor * principal_component; // eigenvectors of the covariance matrix
|
||||
float explained_variance; // eigenvalues of the covariance matrix
|
||||
};
|
||||
|
||||
void compute_covariance(struct pca_params &pca_params,
|
||||
struct ggml_tensor * X,
|
||||
struct ggml_tensor * covariance,
|
||||
struct ggml_backend * backend) {
|
||||
|
||||
// Memory allocation
|
||||
struct ggml_cgraph * gf = NULL;
|
||||
struct ggml_context * ctx = NULL;
|
||||
struct ggml_init_params ctx_params = {
|
||||
ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
|
||||
NULL,
|
||||
true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
|
||||
};
|
||||
ctx = ggml_init(ctx_params);
|
||||
gf = ggml_new_graph(ctx);
|
||||
|
||||
// Step 0: Transpose the input because of row-major
|
||||
X = ggml_cont(ctx, ggml_transpose(ctx, X));
|
||||
|
||||
// Step 1: Compute the mean for each feature
|
||||
struct ggml_tensor * mean = ggml_repeat(ctx, ggml_mean(ctx, X), X); // mean with trick to make it easier to sub
|
||||
struct ggml_tensor * centered_data = ggml_sub(ctx, X, mean);
|
||||
|
||||
// Step 2: Compute the covariance matrix
|
||||
struct ggml_tensor * cov = ggml_mul_mat(ctx, centered_data, centered_data); // C = X * X^T
|
||||
cov = ggml_scale(ctx, cov, 1.0/(X->ne[0]-1));
|
||||
ggml_build_forward_expand(gf, cov);
|
||||
|
||||
// Step 3: Create ggml_gallocr for graph computation
|
||||
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
|
||||
ggml_gallocr_alloc_graph(allocr, gf);
|
||||
|
||||
// Step 4: Check if CPU and compute the result of the graph
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads);
|
||||
}
|
||||
ggml_backend_graph_compute(backend, gf);
|
||||
|
||||
// Step 5: Store covariance matrix in the data pointer
|
||||
struct ggml_tensor * result = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1);
|
||||
float * result_data = (float*) malloc(ggml_nbytes(result));
|
||||
ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result));
|
||||
covariance->data = result_data;
|
||||
|
||||
// Step 6: Free memory
|
||||
ggml_gallocr_free(allocr);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
static void compute_cross_covariance(struct pca_params &pca_params,
|
||||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * cross_covariance,
|
||||
struct ggml_backend * backend) {
|
||||
|
||||
// Memory allocation
|
||||
struct ggml_cgraph * gf = NULL;
|
||||
struct ggml_context * ctx = NULL;
|
||||
struct ggml_init_params ctx_params = {
|
||||
ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
|
||||
NULL,
|
||||
true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
|
||||
};
|
||||
ctx = ggml_init(ctx_params);
|
||||
gf = ggml_new_graph(ctx);
|
||||
|
||||
// Step 1: Compute matrices of cross_covariance
|
||||
struct ggml_tensor * AT = ggml_cont(ctx, ggml_transpose(ctx, A));
|
||||
struct ggml_tensor * BT = ggml_cont(ctx, ggml_transpose(ctx, B));
|
||||
struct ggml_tensor * AT_B = ggml_mul_mat(ctx, AT, BT);
|
||||
struct ggml_tensor * BT_A = ggml_cont(ctx, ggml_transpose(ctx, AT_B));
|
||||
|
||||
// Step 2: Compute the covariance matrix
|
||||
struct ggml_tensor * cross_cov = ggml_add(ctx, AT_B, BT_A);
|
||||
cross_cov = ggml_scale(ctx, cross_cov, 0.5);
|
||||
ggml_build_forward_expand(gf, cross_cov);
|
||||
|
||||
// Step 3: Create ggml_gallocr for graph computation
|
||||
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
|
||||
ggml_gallocr_alloc_graph(allocr, gf);
|
||||
|
||||
// Step 4: Check if CPU and compute the result of the graph
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads);
|
||||
}
|
||||
ggml_backend_graph_compute(backend, gf);
|
||||
|
||||
// Step 5: Store covariance matrix in the data pointer
|
||||
struct ggml_tensor * result = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1);
|
||||
float * result_data = (float*) malloc(ggml_nbytes(result));
|
||||
ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result));
|
||||
cross_covariance->data = result_data;
|
||||
|
||||
// Step 6: Free memory
|
||||
ggml_gallocr_free(allocr);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
// Find the dominant eigenvector of tensor M
|
||||
static void power_iteration(struct pca_params &pca_params,
|
||||
struct ggml_tensor * M,
|
||||
struct pca_result &result,
|
||||
struct ggml_backend * backend) {
|
||||
|
||||
int m = M->ne[1];
|
||||
|
||||
// Initialize random vector
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
|
||||
float * b = (float*) malloc(m * sizeof(float));
|
||||
for (int i = 0; i < m; i++) {
|
||||
b[i] = dist(gen);
|
||||
};
|
||||
float eigenvalue = 0;
|
||||
|
||||
// Iterate
|
||||
int n_rounds = pca_params.n_iterations / pca_params.n_batch;
|
||||
for(int i = 0; i < n_rounds; i++) {
|
||||
|
||||
// Memory allocation
|
||||
struct ggml_cgraph * gf = NULL;
|
||||
struct ggml_context * ctx = NULL;
|
||||
struct ggml_init_params params = {
|
||||
ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
|
||||
NULL,
|
||||
true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
|
||||
};
|
||||
ctx = ggml_init(params);
|
||||
gf = ggml_new_graph(ctx);
|
||||
|
||||
// Fill current eigen vector
|
||||
struct ggml_tensor * e_curr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m);
|
||||
struct ggml_tensor * e_prev = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m);
|
||||
|
||||
ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
||||
|
||||
ggml_backend_tensor_set(e_curr, b, 0, ggml_nbytes(e_curr));
|
||||
ggml_backend_tensor_set(e_prev, b, 0, ggml_nbytes(e_curr));
|
||||
|
||||
struct ggml_tensor * e_next = NULL;
|
||||
struct ggml_tensor * e_norm = NULL;
|
||||
for(int j = 0; j < pca_params.n_batch; j++) {
|
||||
// Compute next candidate vector multiplying M with the current vector
|
||||
e_next = ggml_mul_mat(ctx, M, e_curr);
|
||||
|
||||
// Compute the norm of the new vector (and normalize it)
|
||||
// this will give us the next eigenvector and eigenvalue
|
||||
e_norm = ggml_sqrt_inplace(ctx, ggml_sum_rows(ctx, ggml_sqr(ctx, e_next)));
|
||||
e_curr = ggml_div_inplace(ctx, e_next, e_norm);
|
||||
ggml_format_name(e_norm, "eigenvalue_%d", j);
|
||||
ggml_format_name(e_curr, "eigenvector_%d", j);
|
||||
|
||||
// Update graph
|
||||
ggml_build_forward_expand(gf, e_curr);
|
||||
}
|
||||
|
||||
// Compute the similarity between the current eigenvector and the previous (dot product)
|
||||
struct ggml_tensor * similarity = ggml_mul_mat(ctx, e_curr, e_prev);
|
||||
ggml_build_forward_expand(gf, similarity);
|
||||
|
||||
// Create ggml_gallocr for graph computation
|
||||
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
|
||||
ggml_gallocr_alloc_graph(allocr, gf);
|
||||
|
||||
// Check if CPU and compute the result of the graph
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads);
|
||||
}
|
||||
ggml_status graph_status = ggml_backend_graph_compute(backend, gf);
|
||||
|
||||
// Get graph results (eigenvector and eigenvalue) and store it in b and eigenvalue
|
||||
if(graph_status == GGML_STATUS_SUCCESS){
|
||||
|
||||
// Similarity is the last node in the graph
|
||||
struct ggml_tensor * similarity_tensor = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1);
|
||||
float similarity = (float)((float*) similarity_tensor->data)[0];
|
||||
|
||||
// Eigenvector is the second last node in the graph
|
||||
// struct ggml_tensor * eigenvector_tensor = gf->nodes[gf->n_nodes-2];
|
||||
struct ggml_tensor * eigenvector_tensor = ggml_graph_node(gf,ggml_graph_n_nodes(gf)-2);
|
||||
float * eigenvector_data = (float*) malloc(ggml_nbytes(eigenvector_tensor));
|
||||
ggml_backend_tensor_get(eigenvector_tensor, eigenvector_data, 0, ggml_nbytes(eigenvector_tensor));
|
||||
b = eigenvector_data;
|
||||
|
||||
// Eigenvalue computation is 1 operation before eigenvector computation
|
||||
// struct ggml_tensor * eigenvalue_tensor = gf->nodes[gf->n_nodes-3];
|
||||
struct ggml_tensor * eigenvalue_tensor = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-3);
|
||||
eigenvalue = (float)((float*) eigenvalue_tensor->data)[0];
|
||||
|
||||
// Check if the similarity is close enough to 1, if so we converged and should break
|
||||
if(1 - similarity < pca_params.tolerance)
|
||||
break;
|
||||
}
|
||||
|
||||
// Free memory
|
||||
ggml_gallocr_free(allocr);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
// Store result
|
||||
result.principal_component->data = b;
|
||||
result.explained_variance = eigenvalue;
|
||||
return;
|
||||
}
|
||||
|
||||
static void run_single_pca(struct pca_params &pca_params,
|
||||
struct ggml_tensor * X,
|
||||
struct pca_result &result
|
||||
) {
|
||||
|
||||
ggml_set_name(X, "input_tensor");
|
||||
|
||||
int m = X->ne[1]; // Number of features
|
||||
|
||||
// Step 1. Initialize GGML Backend
|
||||
ggml_backend_t backend = NULL;
|
||||
#ifdef GGML_USE_CUDA
|
||||
fprintf(stderr, "%s: using CUDA backend\n", __func__);
|
||||
backend = ggml_backend_cuda_init(0); // init device 0
|
||||
if (!backend) { fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); }
|
||||
#endif
|
||||
// If there aren't GPU Backends fallback to CPU backend
|
||||
if (!backend) { backend = ggml_backend_cpu_init(); }
|
||||
|
||||
// Compute the context size needed
|
||||
size_t ctx_size = 2 * ggml_tensor_overhead();
|
||||
|
||||
// Step 2. Initialize GGML Context
|
||||
struct ggml_init_params ctx_params {
|
||||
ctx_size, // mem_size
|
||||
NULL, // mem_buffer
|
||||
true, // no_alloc
|
||||
};
|
||||
struct ggml_context * ctx = ggml_init(ctx_params);
|
||||
|
||||
ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
||||
|
||||
// Step 3. Compute the data covariance matrix
|
||||
struct ggml_tensor * covariance = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, m, m);
|
||||
ggml_set_name(covariance, "covariance_tensor");
|
||||
compute_covariance(pca_params, X, covariance, backend);
|
||||
|
||||
// Step 4. Power iteration
|
||||
result.principal_component = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m);
|
||||
power_iteration(pca_params, covariance, result, backend);
|
||||
|
||||
// Free ggml context and backend
|
||||
ggml_free(ctx);
|
||||
ggml_backend_free(backend);
|
||||
}
|
||||
|
||||
|
||||
static void run_pca(
|
||||
struct pca_params & params,
|
||||
const std::vector<struct ggml_tensor *> & v_input, // shape of v_input[0]: [n_samples, n_embd]
|
||||
const std::vector<struct ggml_tensor *> & v_output) {
|
||||
|
||||
for (size_t i = 0; i < v_input.size(); i++) {
|
||||
struct pca_result result;
|
||||
run_single_pca(params, v_input[i], result);
|
||||
ggml_backend_tensor_get(result.principal_component, v_output[i]->data, 0, ggml_nbytes(result.principal_component));
|
||||
}
|
||||
}
|
||||
|
||||
// end namespace
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue