reenabled sched_yield, reduced sampler warning msg to once per session

This commit is contained in:
Concedo 2023-07-18 20:26:18 +08:00
parent 6d32e7fc8b
commit 0a11f50da8
3 changed files with 8 additions and 5 deletions

View file

@ -144,7 +144,7 @@ ifdef LLAMA_CUBLAS
CUBLASLD_FLAGS = -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib CUBLASLD_FLAGS = -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
CUBLAS_OBJS = ggml-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o CUBLAS_OBJS = ggml-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o
NVCC = nvcc NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
ifdef CUDA_DOCKER_ARCH ifdef CUDA_DOCKER_ARCH
NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH) NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
else else
@ -358,7 +358,7 @@ koboldcpp_openblas: ggml_openblas.o ggml_v2_openblas.o ggml_v1.o expose.o common
$(OPENBLAS_BUILD) $(OPENBLAS_BUILD)
koboldcpp_failsafe: ggml_failsafe.o ggml_v2_failsafe.o ggml_v1_failsafe.o expose.o common.o gpttype_adapter_failsafe.o k_quants_failsafe.o $(OBJS) koboldcpp_failsafe: ggml_failsafe.o ggml_v2_failsafe.o ggml_v1_failsafe.o expose.o common.o gpttype_adapter_failsafe.o k_quants_failsafe.o $(OBJS)
$(FAILSAFE_BUILD) $(FAILSAFE_BUILD)
koboldcpp_openblas_noavx2: ggml_openblas_noavx2.o ggml_v2_openblas_noavx2.o ggml_v1_failsafe.o expose.o common.o gpttype_adapter.o k_quants_noavx2.o $(OBJS) koboldcpp_openblas_noavx2: ggml_openblas_noavx2.o ggml_v2_openblas_noavx2.o ggml_v1_failsafe.o expose.o common.o gpttype_adapter_failsafe.o k_quants_noavx2.o $(OBJS)
$(OPENBLAS_NOAVX2_BUILD) $(OPENBLAS_NOAVX2_BUILD)
koboldcpp_clblast: ggml_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o common.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o k_quants.o $(OBJS) koboldcpp_clblast: ggml_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o common.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o k_quants.o $(OBJS)
$(CLBLAST_BUILD) $(CLBLAST_BUILD)

2
ggml.c
View file

@ -16383,7 +16383,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
// wait for other threads to finish // wait for other threads to finish
const int last = node_n; const int last = node_n;
do { do {
//sched_yield(); sched_yield();
node_n = atomic_load(&state->shared->node_n); node_n = atomic_load(&state->shared->node_n);
} while (node_n == last); } while (node_n == last);
} }

View file

@ -242,8 +242,10 @@ def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_
for i, sampler in enumerate(sampler_order): for i, sampler in enumerate(sampler_order):
inputs.sampler_order[i] = sampler inputs.sampler_order[i] = sampler
inputs.sampler_len = len(sampler_order) inputs.sampler_len = len(sampler_order)
if inputs.sampler_len>0 and (inputs.sampler_order[0]!=6 or inputs.sampler_order[inputs.sampler_len-1]!=5): global showsamplerwarning
print("\n(Warning!!! Poor sampler_order detected! You will have reduced quality. Recommended values are [6,0,1,3,4,2,5])") if showsamplerwarning and inputs.sampler_len>0 and (inputs.sampler_order[0]!=6 or inputs.sampler_order[inputs.sampler_len-1]!=5):
print("\n(Note: Sub-optimal sampler_order detected. You may have reduced quality. Recommended sampler values are [6,0,1,3,4,2,5]. This message will only show once per session.)")
showsamplerwarning = False
except TypeError as e: except TypeError as e:
print("ERROR: sampler_order must be a list of integers: " + str(e)) print("ERROR: sampler_order must be a list of integers: " + str(e))
inputs.seed = seed inputs.seed = seed
@ -277,6 +279,7 @@ modelbusy = False
defaultport = 5001 defaultport = 5001
KcppVersion = "1.36" KcppVersion = "1.36"
showdebug = True showdebug = True
showsamplerwarning = True
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
sys_version = "" sys_version = ""