Apply code-format changes
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
f6be4dc661
commit
2175aebdb1
2 changed files with 10 additions and 8 deletions
|
@ -3484,6 +3484,9 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
||||||
class Rwkv7Model(Rwkv6Model):
|
class Rwkv7Model(Rwkv6Model):
|
||||||
model_arch = gguf.MODEL_ARCH.RWKV7
|
model_arch = gguf.MODEL_ARCH.RWKV7
|
||||||
|
|
||||||
|
def calc_lora_rank(self, hidden_size, exponent, multiplier):
|
||||||
|
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
block_count = self.hparams["num_hidden_layers"]
|
||||||
head_size = self.hparams["head_size"]
|
head_size = self.hparams["head_size"]
|
||||||
|
@ -3492,11 +3495,10 @@ class Rwkv7Model(Rwkv6Model):
|
||||||
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)
|
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)
|
||||||
|
|
||||||
# ICLR: In-Context-Learning-Rate
|
# ICLR: In-Context-Learning-Rate
|
||||||
calc_lora_rank = lambda exponent, multiplier: max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
|
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||||
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else calc_lora_rank(0.5, 1.8)
|
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
|
||||||
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else calc_lora_rank(0.5, 1.8)
|
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
|
||||||
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else calc_lora_rank(0.5, 1.3)
|
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
|
||||||
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else calc_lora_rank(0.8, 0.6)
|
|
||||||
|
|
||||||
# RWKV isn't context limited
|
# RWKV isn't context limited
|
||||||
self.gguf_writer.add_context_length(1048576)
|
self.gguf_writer.add_context_length(1048576)
|
||||||
|
|
|
@ -1397,7 +1397,7 @@ kernel void kernel_rwkv_wkv6_f32(
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
const uint head_size = 64; // rwkv6
|
const uint head_size = 64; // rwkv6
|
||||||
const uint batch_id = tgpig.x / H;
|
const uint batch_id = tgpig.x / H;
|
||||||
const uint head_id = tgpig.x % H;
|
const uint head_id = tgpig.x % H;
|
||||||
|
@ -1438,7 +1438,7 @@ kernel void kernel_rwkv_wkv6_f32(
|
||||||
|
|
||||||
const float v_val = v[t];
|
const float v_val = v[t];
|
||||||
float y = 0.0;
|
float y = 0.0;
|
||||||
|
|
||||||
#pragma unroll(64)
|
#pragma unroll(64)
|
||||||
for (uint j = 0; j < head_size; j += 4) {
|
for (uint j = 0; j < head_size; j += 4) {
|
||||||
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||||
|
@ -1484,7 +1484,7 @@ kernel void kernel_rwkv_wkv7_f32(
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
const uint head_size = 64;
|
const uint head_size = 64;
|
||||||
const uint batch_id = tgpig.x / H;
|
const uint batch_id = tgpig.x / H;
|
||||||
const uint head_id = tgpig.x % H;
|
const uint head_id = tgpig.x % H;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue