completed top nsigma sampler implementation
This commit is contained in:
		
							parent
							
								
									ddc3c2208a
								
							
						
					
					
						commit
						da038d8715
					
				
					 5 changed files with 112 additions and 79 deletions
				
			
		|  | @ -899,6 +899,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|             params.sampling.min_p = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--top-nsigma"}, "N", | ||||
|         string_format("top-n-sigma sampling (default: %d, -1 = disabled)", params.sampling.top_n_sigma), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sampling.top_n_sigma = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--xtc-probability"}, "N", | ||||
|         string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), | ||||
|  |  | |||
|  | @ -95,7 +95,6 @@ enum common_sampler_type { | |||
|     COMMON_SAMPLER_TYPE_XTC         = 8, | ||||
|     COMMON_SAMPLER_TYPE_INFILL      = 9, | ||||
|     COMMON_SAMPLER_TYPE_PENALTIES   = 10, | ||||
|     COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11 | ||||
| }; | ||||
| 
 | ||||
| // dimensionality reduction methods, used by cvector-generator
 | ||||
|  | @ -129,7 +128,7 @@ struct common_params_sampling { | |||
|     int32_t dry_allowed_length = 2;     // tokens extending repetitions beyond this receive penalty
 | ||||
|     int32_t dry_penalty_last_n = -1;    // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
 | ||||
|     int32_t mirostat           = 0;     // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
 | ||||
|     int32_t top_n_sigma        = 2; | ||||
|     int32_t top_n_sigma        = -1;    // -1 = disabled
 | ||||
|     float   mirostat_tau       = 5.00f; // target entropy
 | ||||
|     float   mirostat_eta       = 0.10f; // learning rate
 | ||||
|     bool    ignore_eos         = false; | ||||
|  | @ -148,7 +147,6 @@ struct common_params_sampling { | |||
|         COMMON_SAMPLER_TYPE_MIN_P, | ||||
|         COMMON_SAMPLER_TYPE_XTC, | ||||
|         COMMON_SAMPLER_TYPE_TEMPERATURE, | ||||
|         COMMON_SAMPLER_TYPE_TOP_N_SIGMA, | ||||
|     }; | ||||
| 
 | ||||
|     std::string grammar; // optional BNF-like grammar to constrain sampling
 | ||||
|  |  | |||
|  | @ -131,11 +131,11 @@ std::string common_params_sampling::print() const { | |||
|     snprintf(result, sizeof(result), | ||||
|             "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" | ||||
|             "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" | ||||
|             "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" | ||||
|             "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", | ||||
|             "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %d, temp = %.3f\n" | ||||
|             "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f,", | ||||
|             penalty_last_n, penalty_repeat, penalty_freq, penalty_present, | ||||
|             dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, | ||||
|             top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, | ||||
|             top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, | ||||
|             mirostat, mirostat_eta, mirostat_tau); | ||||
| 
 | ||||
|     return std::string(result); | ||||
|  | @ -162,49 +162,50 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co | |||
|                 params.logit_bias.data())); | ||||
| 
 | ||||
|     if (params.mirostat == 0) { | ||||
|         for (const auto & cnstr : params.samplers) { | ||||
|             switch (cnstr) { | ||||
|                 case COMMON_SAMPLER_TYPE_DRY: | ||||
|                     { | ||||
|                         std::vector<const char *> c_breakers; | ||||
|                         c_breakers.reserve(params.dry_sequence_breakers.size()); | ||||
|                         for (const auto & str : params.dry_sequence_breakers) { | ||||
|                             c_breakers.push_back(str.c_str()); | ||||
|                         } | ||||
|         if(params.top_n_sigma >= 0) { | ||||
|             llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); | ||||
|             llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma)); | ||||
|         } else { | ||||
|             for (const auto & cnstr : params.samplers) { | ||||
|                 switch (cnstr) { | ||||
|                     case COMMON_SAMPLER_TYPE_DRY: | ||||
|                         { | ||||
|                             std::vector<const char *> c_breakers; | ||||
|                             c_breakers.reserve(params.dry_sequence_breakers.size()); | ||||
|                             for (const auto & str : params.dry_sequence_breakers) { | ||||
|                                 c_breakers.push_back(str.c_str()); | ||||
|                             } | ||||
| 
 | ||||
|                         llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); | ||||
|                     } | ||||
|                     break; | ||||
|                 case COMMON_SAMPLER_TYPE_TOP_K: | ||||
|                     llama_sampler_chain_add(result->chain, llama_sampler_init_top_k      (params.top_k)); | ||||
|                     break; | ||||
|                 case COMMON_SAMPLER_TYPE_TOP_P: | ||||
|                     llama_sampler_chain_add(result->chain, llama_sampler_init_top_p      (params.top_p, params.min_keep)); | ||||
|                     break; | ||||
|                 case COMMON_SAMPLER_TYPE_MIN_P: | ||||
|                     llama_sampler_chain_add(result->chain, llama_sampler_init_min_p      (params.min_p, params.min_keep)); | ||||
|                     break; | ||||
|                 case COMMON_SAMPLER_TYPE_XTC: | ||||
|                     llama_sampler_chain_add(result->chain, llama_sampler_init_xtc        (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); | ||||
|                     break; | ||||
|                 case COMMON_SAMPLER_TYPE_TYPICAL_P: | ||||
|                     llama_sampler_chain_add(result->chain, llama_sampler_init_typical    (params.typ_p, params.min_keep)); | ||||
|                     break; | ||||
|                 case COMMON_SAMPLER_TYPE_TEMPERATURE: | ||||
|                     llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext   (params.temp, params.dynatemp_range, params.dynatemp_exponent)); | ||||
|                     break; | ||||
|                 case COMMON_SAMPLER_TYPE_INFILL: | ||||
|                     llama_sampler_chain_add(result->chain, llama_sampler_init_infill     (model)); | ||||
|                     break; | ||||
|                 case COMMON_SAMPLER_TYPE_PENALTIES: | ||||
|                     llama_sampler_chain_add(result->chain, llama_sampler_init_penalties  (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); | ||||
|                     break; | ||||
|                 case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: | ||||
|                     // llama_sampler_chain_add(result->chain, )
 | ||||
|                     llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma)) | ||||
|                     break; | ||||
|                 default: | ||||
|                     GGML_ASSERT(false && "unknown sampler type"); | ||||
|                             llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); | ||||
|                         } | ||||
|                         break; | ||||
|                     case COMMON_SAMPLER_TYPE_TOP_K: | ||||
|                         llama_sampler_chain_add(result->chain, llama_sampler_init_top_k      (params.top_k)); | ||||
|                         break; | ||||
|                     case COMMON_SAMPLER_TYPE_TOP_P: | ||||
|                         llama_sampler_chain_add(result->chain, llama_sampler_init_top_p      (params.top_p, params.min_keep)); | ||||
|                         break; | ||||
|                     case COMMON_SAMPLER_TYPE_MIN_P: | ||||
|                         llama_sampler_chain_add(result->chain, llama_sampler_init_min_p      (params.min_p, params.min_keep)); | ||||
|                         break; | ||||
|                     case COMMON_SAMPLER_TYPE_XTC: | ||||
|                         llama_sampler_chain_add(result->chain, llama_sampler_init_xtc        (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); | ||||
|                         break; | ||||
|                     case COMMON_SAMPLER_TYPE_TYPICAL_P: | ||||
|                         llama_sampler_chain_add(result->chain, llama_sampler_init_typical    (params.typ_p, params.min_keep)); | ||||
|                         break; | ||||
|                     case COMMON_SAMPLER_TYPE_TEMPERATURE: | ||||
|                         llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext   (params.temp, params.dynatemp_range, params.dynatemp_exponent)); | ||||
|                         break; | ||||
|                     case COMMON_SAMPLER_TYPE_INFILL: | ||||
|                         llama_sampler_chain_add(result->chain, llama_sampler_init_infill     (model)); | ||||
|                         break; | ||||
|                     case COMMON_SAMPLER_TYPE_PENALTIES: | ||||
|                         llama_sampler_chain_add(result->chain, llama_sampler_init_penalties  (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); | ||||
|                         break; | ||||
|                     default: | ||||
|                         GGML_ASSERT(false && "unknown sampler type"); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); | ||||
|  | @ -411,7 +412,6 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { | |||
|         case COMMON_SAMPLER_TYPE_XTC:         return 'x'; | ||||
|         case COMMON_SAMPLER_TYPE_INFILL:      return 'i'; | ||||
|         case COMMON_SAMPLER_TYPE_PENALTIES:   return 'e'; | ||||
|         case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's'; | ||||
|         default : return '?'; | ||||
|     } | ||||
| } | ||||
|  | @ -427,7 +427,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { | |||
|         case COMMON_SAMPLER_TYPE_XTC:         return "xtc"; | ||||
|         case COMMON_SAMPLER_TYPE_INFILL:      return "infill"; | ||||
|         case COMMON_SAMPLER_TYPE_PENALTIES:   return "penalties"; | ||||
|         case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma"; | ||||
|         default : return ""; | ||||
|     } | ||||
| } | ||||
|  | @ -443,7 +442,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect | |||
|         { "xtc",         COMMON_SAMPLER_TYPE_XTC }, | ||||
|         { "infill",      COMMON_SAMPLER_TYPE_INFILL }, | ||||
|         { "penalties",   COMMON_SAMPLER_TYPE_PENALTIES }, | ||||
|         { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, | ||||
|     }; | ||||
| 
 | ||||
|     // since samplers names are written multiple ways
 | ||||
|  | @ -458,9 +456,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect | |||
|         { "typ",         COMMON_SAMPLER_TYPE_TYPICAL_P }, | ||||
|         { "min-p",       COMMON_SAMPLER_TYPE_MIN_P }, | ||||
|         { "temp",        COMMON_SAMPLER_TYPE_TEMPERATURE }, | ||||
|         { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, | ||||
|         { "top-nsigma",  COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, | ||||
|         { "top_nsigma",  COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, | ||||
|     }; | ||||
| 
 | ||||
|     std::vector<common_sampler_type> samplers; | ||||
|  | @ -494,7 +489,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri | |||
|         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC }, | ||||
|         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL),      COMMON_SAMPLER_TYPE_INFILL }, | ||||
|         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES),   COMMON_SAMPLER_TYPE_PENALTIES }, | ||||
|         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA} | ||||
|     }; | ||||
| 
 | ||||
|     std::vector<common_sampler_type> samplers; | ||||
|  |  | |||
|  | @ -1133,6 +1133,9 @@ extern "C" { | |||
|     /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
 | ||||
|     LLAMA_API struct llama_sampler * llama_sampler_init_xtc        (float   p, float   t,     size_t min_keep, uint32_t seed); | ||||
| 
 | ||||
|     /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
 | ||||
|     LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n); | ||||
| 
 | ||||
|     /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
 | ||||
|     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
 | ||||
|     /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
 | ||||
|  |  | |||
|  | @ -301,6 +301,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) | |||
|     cur_p->size = k; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| static uint32_t get_rng_seed(uint32_t seed) { | ||||
|     if (seed == LLAMA_DEFAULT_SEED) { | ||||
|         // use system clock if std::random_device is not a true RNG
 | ||||
|  | @ -1657,35 +1658,65 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * | |||
| 
 | ||||
| static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { | ||||
|     const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; | ||||
|     llama_sampler_top_n_sigma_impl(cur_p, ctx->n); | ||||
|     // 1. Find max logit: M
 | ||||
|     // 2. Find standard deviation of logits: sig
 | ||||
|     // 3. Create a mask where m[i] = 1 if ith logit  >= M - n (sig), else m[i] = 0
 | ||||
|     // 4. Apply mask: ith logit itself if m[i]==1, else ith logit = -inf
 | ||||
|     // 5. p = softmax(l)
 | ||||
| 
 | ||||
|     // find max logit and calculate mean
 | ||||
|     int32_t max = cur_p->data[0].logit; | ||||
|     int32_t logits_sum = 0; | ||||
|     for (size_t i = 0; i < cur_p->size; ++i) { | ||||
|         if(cur_p->data[i].logit > max){ | ||||
|             max = cur_p->data[i].logit; | ||||
|         } | ||||
|         logits_sum += cur_p->data[i].logit; | ||||
|     } | ||||
|     int32_t mean = logits_sum/cur_p->size; | ||||
|      | ||||
|     // calculate standard deviation
 | ||||
|     int32_t acc = 0; | ||||
|     for(size_t i = 0; i < cur_p->size; ++i){ | ||||
|         acc += (cur_p->data[i].logit - mean) * (cur_p->data[i].logit - mean); | ||||
|     } | ||||
|     int32_t std = sqrt(acc/cur_p->size); | ||||
| 
 | ||||
|     //apply mask
 | ||||
|     for(size_t i = 0; i < cur_p->size; ++i){ | ||||
|         if(cur_p->data[i].logit < max - (ctx->n * std)) { | ||||
|             cur_p->data[i].logit = -INFINITY; | ||||
|         } | ||||
|     } | ||||
|     llama_sampler_softmax_impl(cur_p); | ||||
| } | ||||
| 
 | ||||
| // static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
 | ||||
| //     const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
 | ||||
| //     return llama_sampler_init_top_k(ctx->k);
 | ||||
| // }
 | ||||
| static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl){ | ||||
|     const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx; | ||||
|     return llama_sampler_init_top_n_sigma(ctx->n); | ||||
| } | ||||
| 
 | ||||
| // static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
 | ||||
| //     delete (llama_sampler_top_k *) smpl->ctx;
 | ||||
| // }
 | ||||
| static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { | ||||
|     delete (llama_sampler_top_n_sigma *) smpl->ctx; | ||||
| } | ||||
| 
 | ||||
| // static struct llama_sampler_i llama_sampler_top_k_i = {
 | ||||
| //     /* .name   = */ llama_sampler_top_k_name,
 | ||||
| //     /* .accept = */ nullptr,
 | ||||
| //     /* .apply  = */ llama_sampler_top_k_apply,
 | ||||
| //     /* .reset  = */ nullptr,
 | ||||
| //     /* .clone  = */ llama_sampler_top_k_clone,
 | ||||
| //     /* .free   = */ llama_sampler_top_k_free,
 | ||||
| // };
 | ||||
| static struct llama_sampler_i llama_sampler_top_n_sigma_i = { | ||||
|     /* .name   = */ llama_sampler_top_n_sigma_name, | ||||
|     /* .accept = */ nullptr, | ||||
|     /* .apply  = */ llama_sampler_top_n_sigma_apply, | ||||
|     /* .reset  = */ nullptr, | ||||
|     /* .clone  = */ llama_sampler_top_n_sigma_clone, | ||||
|     /* .free   = */ llama_sampler_top_n_sigma_free, | ||||
| }; | ||||
| 
 | ||||
| // struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
 | ||||
| //     return new llama_sampler {
 | ||||
| //         /* .iface = */ &llama_sampler_top_k_i,
 | ||||
| //         /* .ctx   = */ new llama_sampler_top_k {
 | ||||
| //             /* .k = */ k,
 | ||||
| //         },
 | ||||
| //     };
 | ||||
| // }
 | ||||
| struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n) { | ||||
|     return new llama_sampler { | ||||
|         /* .iface = */ &llama_sampler_top_n_sigma_i, | ||||
|         /* .ctx   = */ new llama_sampler_top_n_sigma { | ||||
|             /* .n = */ n, | ||||
|                         }, | ||||
|     }; | ||||
| } | ||||
| 
 | ||||
| // DRY
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue