Fix for windows model unloading not releasing memory (#569)

* Add in model processes as a separate process so it can be killed when unloading to release memory on windows

* Fix from Henky
This commit is contained in:
ebolam 2023-12-19 02:55:41 -05:00 committed by GitHub
parent 4c274dc2fd
commit 6948da5a0d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 26 deletions

View file

@ -9,7 +9,7 @@ import torch
import requests import requests
import numpy as np import numpy as np
from typing import List, Optional, Union from typing import List, Optional, Union
import os import os, time
from . import koboldcpp from . import koboldcpp
import utils import utils
@ -20,11 +20,9 @@ from modeling.inference_model import (
InferenceModel, InferenceModel,
) )
model_backend_name = "koboldcpp" #specific instead of ggml model_backend_name = "KoboldCPP" #specific instead of ggml
model_backend_type = "ggml" #This should be a generic name in case multiple model backends are compatible (think Hugging Face Custom and Basic Hugging Face) model_backend_type = "ggml" #This should be a generic name in case multiple model backends are compatible (think Hugging Face Custom and Basic Hugging Face)
kcpp_backend_loaded = False
class KoboldCppException(Exception): class KoboldCppException(Exception):
"""To be used for errors on cpp side of KoboldCpp.""" """To be used for errors on cpp side of KoboldCpp."""
@ -35,6 +33,7 @@ class KcppArgsObject:
class model_backend(InferenceModel): class model_backend(InferenceModel):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.kcpp_backend_loaded = False
def is_valid(self, model_name, model_path, menu_path): def is_valid(self, model_name, model_path, menu_path):
@ -257,26 +256,31 @@ class model_backend(InferenceModel):
def unload(self): def unload(self):
print("Attemping to unload library") print("Attemping to unload library")
koboldcpp.unload_libs() self.process.terminate()
global kcpp_backend_loaded
kcpp_backend_loaded = False
pass
def _load(self, save_model: bool, initial_load: bool) -> None: def _load(self, save_model: bool, initial_load: bool) -> None:
global kcpp_backend_loaded
self.tokenizer = self._get_tokenizer("gpt2") self.tokenizer = self._get_tokenizer("gpt2")
if not kcpp_backend_loaded: kcppargs = KcppArgsObject(model=self.kcpp_filename, model_param=self.kcpp_filename,
kcppargs = KcppArgsObject(model=self.kcpp_filename, model_param=self.kcpp_filename, port=5001, port_param=5001, host='', launch=False, lora=None, threads=self.kcpp_threads, blasthreads=self.kcpp_threads,
port=5001, port_param=5001, host='', launch=False, lora=None, threads=self.kcpp_threads, blasthreads=self.kcpp_threads, psutil_set_threads=False, highpriority=False, contextsize=self.kcpp_ctxsize,
highpriority=False, contextsize=self.kcpp_ctxsize, blasbatchsize=self.kcpp_blasbatchsize, ropeconfig=[self.kcpp_ropescale, self.kcpp_ropebase], blasbatchsize=self.kcpp_blasbatchsize, ropeconfig=[self.kcpp_ropescale, self.kcpp_ropebase], stream=False, smartcontext=self.kcpp_smartcontext,
smartcontext=self.kcpp_smartcontext, bantokens=None, forceversion=0, nommap=self.kcpp_nommap, unbantokens=False, bantokens=None, usemirostat=None, forceversion=0, nommap=self.kcpp_nommap,
usemlock=False, noavx2=self.kcpp_noavx2, debugmode=self.kcpp_debugmode, skiplauncher=True, hordeconfig=None, noblas=self.kcpp_noblas, usemlock=False, noavx2=self.kcpp_noavx2, debugmode=self.kcpp_debugmode, skiplauncher=True, hordeconfig=None, noblas=self.kcpp_noblas,
useclblast=self.kcpp_useclblast, usecublas=self.kcpp_usecublas, gpulayers=self.kcpp_gpulayers, tensor_split=self.kcpp_tensor_split, config=None, useclblast=self.kcpp_useclblast, usecublas=self.kcpp_usecublas, gpulayers=self.kcpp_gpulayers, tensor_split=self.kcpp_tensor_split, config=None,
onready='', multiuser=False, foreground=False) onready='', multiuser=False, foreground=False, preloadstory=None, noshift=False, remotetunnel=False)
koboldcpp.main(kcppargs,False) #initialize library without enabling Lite http server #koboldcpp.main(kcppargs,False) #initialize library without enabling Lite http server
kcpp_backend_loaded = True (self.output_queue, self.input_queue, self.process) = koboldcpp.start_in_seperate_process(kcppargs)
pass while True:
data = self.output_queue.get()
if data['command'] == 'load status':
utils.koboldai_vars.total_layers = data['data']['total']
utils.koboldai_vars.loaded_layers = data['data']['loaded']
elif data['command'] == 'complete':
break
time.sleep(0.02)
def _save_settings(self): def _save_settings(self):
pass pass
@ -297,16 +301,31 @@ class model_backend(InferenceModel):
# Store context in memory to use it for comparison with generated content # Store context in memory to use it for comparison with generated content
utils.koboldai_vars.lastctx = decoded_prompt utils.koboldai_vars.lastctx = decoded_prompt
genresult = koboldcpp.generate(decoded_prompt,max_new,utils.koboldai_vars.max_length, self.input_queue.put({'command': 'generate', 'data': [(decoded_prompt,max_new,utils.koboldai_vars.max_length,
gen_settings.temp,int(gen_settings.top_k),gen_settings.top_a,gen_settings.top_p, gen_settings.temp,int(gen_settings.top_k),gen_settings.top_a,gen_settings.top_p,
gen_settings.typical,gen_settings.tfs,gen_settings.rep_pen,gen_settings.rep_pen_range, gen_settings.typical,gen_settings.tfs,gen_settings.rep_pen,gen_settings.rep_pen_range),
sampler_order=gen_settings.sampler_order,use_default_badwordsids=utils.koboldai_vars.use_default_badwordsids) {"sampler_order": gen_settings.sampler_order, "use_default_badwordsids": utils.koboldai_vars.use_default_badwordsids}
]})
#genresult = koboldcpp.generate(decoded_prompt,max_new,utils.koboldai_vars.max_length,
#gen_settings.temp,int(gen_settings.top_k),gen_settings.top_a,gen_settings.top_p,
#gen_settings.typical,gen_settings.tfs,gen_settings.rep_pen,gen_settings.rep_pen_range,
#sampler_order=gen_settings.sampler_order,use_default_badwordsids=utils.koboldai_vars.use_default_badwordsids)
genresult = []
while True:
data = self.output_queue.get()
print(data)
if data['command'] == 'generated text':
genresult.append(data['data'])
if self.output_queue.empty():
break
time.sleep(0.02)
outputs = [genresult]
return GenerationResult( return GenerationResult(
model=self, model=self,
out_batches=np.array( out_batches=np.array(
[self.tokenizer.encode(x) for x in outputs] [self.tokenizer.encode(x) for x in genresult]
), ),
prompt=prompt_tokens, prompt=prompt_tokens,
is_whole_generation=True, is_whole_generation=True,

View file

@ -13,6 +13,7 @@ import os
import argparse import argparse
import json, sys, http.server, time, asyncio, socket, threading import json, sys, http.server, time, asyncio, socket, threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import multiprocessing
sampler_order_max = 7 sampler_order_max = 7
stop_token_max = 16 stop_token_max = 16
@ -2330,6 +2331,25 @@ def main(launch_args,start_server=True):
else: else:
print(f"Server was not started, main function complete. Idling.") print(f"Server was not started, main function complete. Idling.")
def run_in_queue(launch_args, input_queue, output_queue):
main(launch_args, start_server=False)
output_queue.put({'command': 'complete'})
while True:
if not input_queue.empty():
while not input_queue.empty():
data = input_queue.get()
if data['command'] == 'generate':
(args, kwargs) = data['data']
output_queue.put({'command': 'generated text', 'data': generate(*args, **kwargs)})
time.sleep(0.2)
def start_in_seperate_process(launch_args):
input_queue = multiprocessing.Queue()
output_queue = multiprocessing.Queue()
p = multiprocessing.Process(target=run_in_queue, args=(launch_args, input_queue, output_queue))
p.start()
return (output_queue, input_queue, p)
if __name__ == '__main__': if __name__ == '__main__':
print("***\nWelcome to KoboldCpp - Version " + KcppVersion) # just update version manually print("***\nWelcome to KoboldCpp - Version " + KcppVersion) # just update version manually
# print("Python version: " + sys.version) # print("Python version: " + sys.version)