launcher for rope

This commit is contained in:
Concedo 2023-07-20 17:45:50 +08:00
parent 39dc1a46c4
commit e85557f798

View file

@ -764,6 +764,10 @@ def show_new_gui():
context_var = ctk.IntVar() context_var = ctk.IntVar()
customrope_var = ctk.IntVar()
customrope_scale = ctk.StringVar(value="1.0")
customrope_base = ctk.StringVar(value="10000")
model_var = ctk.StringVar() model_var = ctk.StringVar()
lora_var = ctk.StringVar() lora_var = ctk.StringVar()
lora_base_var = ctk.StringVar() lora_base_var = ctk.StringVar()
@ -904,6 +908,19 @@ def show_new_gui():
# context size # context size
makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, 4, 20, set=2) makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, 4, 20, set=2)
customrope_scale_entry, customrope_scale_label = makelabelentry(tokens_tab, "RoPE Scale:", customrope_scale)
customrope_base_entry, customrope_base_label = makelabelentry(tokens_tab, "RoPE Base:", customrope_base)
def togglerope(a,b,c):
items = [customrope_scale_label, customrope_scale_entry,customrope_base_label, customrope_base_entry]
for idx, item in enumerate(items):
if customrope_var.get() == 1:
item.grid(row=23 + int(idx/2), column=idx%2, padx=8, stick="nw")
else:
item.grid_forget()
makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope)
togglerope(1,1,1)
# Model Tab # Model Tab
model_tab = tabcontent["Model"] model_tab = tabcontent["Model"]
@ -996,6 +1013,9 @@ def show_new_gui():
args.mirostat = [int(mirostat_var.get()), float(mirostat_tau.get()), float(mirostat_eta.get())] if usemirostat.get()==1 else None args.mirostat = [int(mirostat_var.get()), float(mirostat_tau.get()), float(mirostat_eta.get())] if usemirostat.get()==1 else None
args.contextsize = int(contextsize_text[context_var.get()]) args.contextsize = int(contextsize_text[context_var.get()])
if customrope_var.get()==1:
args.ropeconfig = [float(customrope_scale.get()),float(customrope_base.get())]
args.model_param = None if model_var.get() == "" else model_var.get() args.model_param = None if model_var.get() == "" else model_var.get()
args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()]) args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()])
@ -1046,6 +1066,15 @@ def show_new_gui():
if dict["contextsize"]: if dict["contextsize"]:
context_var.set(contextsize_text.index(str(dict["contextsize"]))) context_var.set(contextsize_text.index(str(dict["contextsize"])))
if dict["ropeconfig"] and len(dict["ropeconfig"])>1:
if dict["ropeconfig"][0]>0:
customrope_var.set(1)
customrope_scale.set(str(dict["ropeconfig"][0]))
customrope_base.set(str(dict["ropeconfig"][1]))
else:
customrope_var.set(0)
if dict["blasbatchsize"]: if dict["blasbatchsize"]:
blas_size_var.set(blasbatchsize_values.index(str(dict["blasbatchsize"]))) blas_size_var.set(blasbatchsize_values.index(str(dict["blasbatchsize"])))
if dict["forceversion"]: if dict["forceversion"]: