fix consume_common_train_arg

This commit is contained in:
xaedes 2023-09-16 19:08:51 +02:00
parent bef1e97875
commit 7aa9ea7f20
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 5 additions and 3 deletions

View file

@ -1115,7 +1115,11 @@ void print_common_train_usage(int /*argc*/, char ** argv, const struct train_par
bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param) {
int& i = *idx;
char * arg = argv[i];
std::string arg = argv[i];
const std::string arg_prefix = "--";
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
if (arg == "--train-data") {
if (++i >= argc) {
*invalid_param = true;

View file

@ -514,8 +514,6 @@ static void init_lora(const struct my_llama_model * model, struct my_llama_lora
ggml_allocr_free(alloc);
}
static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, float std, float min, float max) {
const uint32_t n_layer = lora->layers.size();