Replace vector with C-style array and length in llama_split_layers_weighted
This commit is contained in:
parent
364b707130
commit
fda60ead35
3 changed files with 5 additions and 5 deletions
|
@ -243,7 +243,7 @@ int main(int argc, char ** argv) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_split_layers_weighted(ctx, params.mpi_layer_split);
|
llama_split_layers_weighted(ctx, params.mpi_layer_split.data(), params.mpi_layer_split.size());
|
||||||
|
|
||||||
std::string path_session = params.path_prompt_cache;
|
std::string path_session = params.path_prompt_cache;
|
||||||
std::vector<llama_token> session_tokens;
|
std::vector<llama_token> session_tokens;
|
||||||
|
|
|
@ -13087,12 +13087,12 @@ struct llama_context * llama_new_context_with_model(
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_split_layers_weighted(struct llama_context * ctx, std::vector<float> device_weights) {
|
void llama_split_layers_weighted(struct llama_context * ctx, float device_weights[], size_t num_weights) {
|
||||||
#ifdef GGML_USE_MPI
|
#ifdef GGML_USE_MPI
|
||||||
if (ggml_mpi_rank(ctx->ctx_mpi) == 0 && ggml_mpi_size(ctx->ctx_mpi) != device_weights.size()) {
|
if (ggml_mpi_rank(ctx->ctx_mpi) == 0 && ggml_mpi_size(ctx->ctx_mpi) != num_weights) {
|
||||||
GGML_ASSERT(false && "Must have same number of split percentages as devices");
|
GGML_ASSERT(false && "Must have same number of split percentages as devices");
|
||||||
}
|
}
|
||||||
uint16_t** ranges = ggml_mpi_split_range(ctx->ctx_mpi, 0, ctx->model.hparams.n_layer - 1, device_weights.data());
|
uint16_t** ranges = ggml_mpi_split_range(ctx->ctx_mpi, 0, ctx->model.hparams.n_layer - 1, device_weights);
|
||||||
ggml_mpi_scatter_layers(ctx->ctx_mpi, ranges);
|
ggml_mpi_scatter_layers(ctx->ctx_mpi, ranges);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
2
llama.h
2
llama.h
|
@ -358,7 +358,7 @@ extern "C" {
|
||||||
const char * path_model,
|
const char * path_model,
|
||||||
struct llama_model_params params);
|
struct llama_model_params params);
|
||||||
|
|
||||||
LLAMA_API void llama_split_layers_weighted(struct llama_context * ctx, std::vector<float> device_weights);
|
LLAMA_API void llama_split_layers_weighted(struct llama_context * ctx, float device_weights[], size_t num_weights);
|
||||||
|
|
||||||
LLAMA_API void llama_free_model(struct llama_model * model);
|
LLAMA_API void llama_free_model(struct llama_model * model);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue