Merge branch 'master' into concedo_experimental
# Conflicts: # .devops/tools.sh # .gitignore # CMakeLists.txt # Makefile # README.md
This commit is contained in:
commit
4f40c226a0
9 changed files with 224 additions and 65 deletions
|
@ -11,8 +11,13 @@ if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../.git")
|
||||||
if(NOT IS_DIRECTORY "${GIT_DIR}")
|
if(NOT IS_DIRECTORY "${GIT_DIR}")
|
||||||
file(READ ${GIT_DIR} REAL_GIT_DIR_LINK)
|
file(READ ${GIT_DIR} REAL_GIT_DIR_LINK)
|
||||||
string(REGEX REPLACE "gitdir: (.*)\n$" "\\1" REAL_GIT_DIR ${REAL_GIT_DIR_LINK})
|
string(REGEX REPLACE "gitdir: (.*)\n$" "\\1" REAL_GIT_DIR ${REAL_GIT_DIR_LINK})
|
||||||
|
string(FIND "${REAL_GIT_DIR}" "/" SLASH_POS)
|
||||||
|
if (SLASH_POS EQUAL 0)
|
||||||
|
set(GIT_DIR "${REAL_GIT_DIR}")
|
||||||
|
else()
|
||||||
set(GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../${REAL_GIT_DIR}")
|
set(GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../${REAL_GIT_DIR}")
|
||||||
endif()
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
set(GIT_INDEX "${GIT_DIR}/index")
|
set(GIT_INDEX "${GIT_DIR}/index")
|
||||||
else()
|
else()
|
||||||
|
|
|
@ -267,7 +267,7 @@ class Params:
|
||||||
n_ctx = 2048
|
n_ctx = 2048
|
||||||
|
|
||||||
return Params(
|
return Params(
|
||||||
n_vocab = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]),
|
n_vocab = model["tok_embeddings.weight"].shape[0],
|
||||||
n_embd = config["dim"],
|
n_embd = config["dim"],
|
||||||
n_layer = config["n_layers"],
|
n_layer = config["n_layers"],
|
||||||
n_ctx = n_ctx,
|
n_ctx = n_ctx,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
This is a swift clone of `examples/batched`.
|
This is a swift clone of `examples/batched`.
|
||||||
|
|
||||||
$ `make`
|
$ `make`
|
||||||
$ `./swift MODEL_PATH [PROMPT] [PARALLEL]`
|
$ `./batched_swift MODEL_PATH [PROMPT] [PARALLEL]`
|
||||||
|
|
|
@ -5,7 +5,7 @@ import json
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gguf import *
|
from gguf import *
|
||||||
from transformers import CLIPModel, CLIPProcessor
|
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
||||||
|
|
||||||
TEXT = "clip.text"
|
TEXT = "clip.text"
|
||||||
VISION = "clip.vision"
|
VISION = "clip.vision"
|
||||||
|
@ -78,11 +78,19 @@ ap.add_argument("--text-only", action="store_true", required=False,
|
||||||
help="Save a text-only model. It can't be used to encode images")
|
help="Save a text-only model. It can't be used to encode images")
|
||||||
ap.add_argument("--vision-only", action="store_true", required=False,
|
ap.add_argument("--vision-only", action="store_true", required=False,
|
||||||
help="Save a vision-only model. It can't be used to encode texts")
|
help="Save a vision-only model. It can't be used to encode texts")
|
||||||
|
ap.add_argument("--clip_model_is_vision", action="store_true", required=False,
|
||||||
|
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
|
||||||
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
|
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
|
||||||
ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values")
|
ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values")
|
||||||
ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values")
|
ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values")
|
||||||
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
|
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
|
||||||
|
# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
|
||||||
|
default_image_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
default_image_std = [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
ap.add_argument('--image_mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
|
||||||
|
ap.add_argument('--image_std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
|
||||||
|
|
||||||
|
# with proper
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,13 +104,20 @@ if args.use_f32:
|
||||||
# output in the same directory as the model if output_dir is None
|
# output in the same directory as the model if output_dir is None
|
||||||
dir_model = args.model_dir
|
dir_model = args.model_dir
|
||||||
|
|
||||||
|
if args.clip_model_is_vision:
|
||||||
|
vocab = None
|
||||||
|
tokens = None
|
||||||
|
else:
|
||||||
with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
|
with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
|
||||||
vocab = json.load(f)
|
vocab = json.load(f)
|
||||||
tokens = [key for key in vocab]
|
tokens = [key for key in vocab]
|
||||||
|
|
||||||
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
|
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
if args.clip_model_is_vision:
|
||||||
|
v_hparams = config
|
||||||
|
t_hparams = None
|
||||||
|
else:
|
||||||
v_hparams = config["vision_config"]
|
v_hparams = config["vision_config"]
|
||||||
t_hparams = config["text_config"]
|
t_hparams = config["text_config"]
|
||||||
|
|
||||||
|
@ -117,7 +132,10 @@ ftype = 1
|
||||||
if args.use_f32:
|
if args.use_f32:
|
||||||
ftype = 0
|
ftype = 0
|
||||||
|
|
||||||
|
if args.clip_model_is_vision:
|
||||||
|
model = CLIPVisionModel.from_pretrained(dir_model)
|
||||||
|
processor = None
|
||||||
|
else:
|
||||||
model = CLIPModel.from_pretrained(dir_model)
|
model = CLIPModel.from_pretrained(dir_model)
|
||||||
processor = CLIPProcessor.from_pretrained(dir_model)
|
processor = CLIPProcessor.from_pretrained(dir_model)
|
||||||
|
|
||||||
|
@ -128,13 +146,13 @@ has_llava_projector = False
|
||||||
if args.text_only:
|
if args.text_only:
|
||||||
fname_middle = "text-"
|
fname_middle = "text-"
|
||||||
has_vision_encoder = False
|
has_vision_encoder = False
|
||||||
elif args.vision_only:
|
|
||||||
fname_middle = "vision-"
|
|
||||||
has_text_encoder = False
|
|
||||||
elif args.llava_projector is not None:
|
elif args.llava_projector is not None:
|
||||||
fname_middle = "mmproj-"
|
fname_middle = "mmproj-"
|
||||||
has_text_encoder = False
|
has_text_encoder = False
|
||||||
has_llava_projector = True
|
has_llava_projector = True
|
||||||
|
elif args.vision_only:
|
||||||
|
fname_middle = "vision-"
|
||||||
|
has_text_encoder = False
|
||||||
else:
|
else:
|
||||||
fname_middle = ""
|
fname_middle = ""
|
||||||
|
|
||||||
|
@ -182,8 +200,12 @@ if has_vision_encoder:
|
||||||
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
|
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
|
||||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
|
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
|
||||||
|
|
||||||
image_mean = processor.image_processor.image_mean if args.image_mean is None else args.image_mean
|
if processor is not None:
|
||||||
image_std = processor.image_processor.image_std if args.image_std is None else args.image_std
|
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
|
||||||
|
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
|
||||||
|
else:
|
||||||
|
image_mean = args.image_mean if args.image_mean is not None else default_image_mean
|
||||||
|
image_std = args.image_std if args.image_std is not None else default_image_std
|
||||||
fout.add_array("clip.vision.image_mean", image_mean)
|
fout.add_array("clip.vision.image_mean", image_mean)
|
||||||
fout.add_array("clip.vision.image_std", image_std)
|
fout.add_array("clip.vision.image_std", image_std)
|
||||||
|
|
||||||
|
|
|
@ -101,6 +101,12 @@ static void sigint_handler(int signo) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) {
|
||||||
|
(void) level;
|
||||||
|
(void) user_data;
|
||||||
|
LOG_TEE("%s", text);
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
g_params = ¶ms;
|
g_params = ¶ms;
|
||||||
|
@ -114,6 +120,7 @@ int main(int argc, char ** argv) {
|
||||||
log_set_target(log_filename_generator("main", "log"));
|
log_set_target(log_filename_generator("main", "log"));
|
||||||
LOG_TEE("Log start\n");
|
LOG_TEE("Log start\n");
|
||||||
log_dump_cmdline(argc, argv);
|
log_dump_cmdline(argc, argv);
|
||||||
|
llama_log_set(llama_log_callback_logTee, nullptr);
|
||||||
#endif // LOG_DISABLE_LOGS
|
#endif // LOG_DISABLE_LOGS
|
||||||
|
|
||||||
// TODO: Dump params ?
|
// TODO: Dump params ?
|
||||||
|
|
|
@ -11,10 +11,10 @@ app = Flask(__name__)
|
||||||
slot_id = -1
|
slot_id = -1
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
|
parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
|
||||||
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
|
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')
|
||||||
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
|
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: 'USER: ')", default="USER: ")
|
||||||
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
|
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: 'ASSISTANT: ')", default="ASSISTANT: ")
|
||||||
parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ")
|
parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: 'ASSISTANT's RULE: ')", default="ASSISTANT's RULE: ")
|
||||||
parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
|
parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
|
||||||
parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
|
parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
|
||||||
parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
|
parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
|
||||||
|
@ -34,19 +34,19 @@ def is_present(json, key):
|
||||||
|
|
||||||
#convert chat to prompt
|
#convert chat to prompt
|
||||||
def convert_chat(messages):
|
def convert_chat(messages):
|
||||||
prompt = "" + args.chat_prompt.replace("\\n", "\n")
|
|
||||||
|
|
||||||
system_n = args.system_name.replace("\\n", "\n")
|
system_n = args.system_name
|
||||||
user_n = args.user_name.replace("\\n", "\n")
|
user_n = args.user_name
|
||||||
ai_n = args.ai_name.replace("\\n", "\n")
|
ai_n = args.ai_name
|
||||||
stop = args.stop.replace("\\n", "\n")
|
stop = args.stop
|
||||||
|
|
||||||
|
prompt = "" + args.chat_prompt + stop
|
||||||
|
|
||||||
for line in messages:
|
for line in messages:
|
||||||
if (line["role"] == "system"):
|
if (line["role"] == "system"):
|
||||||
prompt += f"{system_n}{line['content']}"
|
prompt += f"{system_n}{line['content']}{stop}"
|
||||||
if (line["role"] == "user"):
|
if (line["role"] == "user"):
|
||||||
prompt += f"{user_n}{line['content']}"
|
prompt += f"{user_n}{line['content']}{stop}"
|
||||||
if (line["role"] == "assistant"):
|
if (line["role"] == "assistant"):
|
||||||
prompt += f"{ai_n}{line['content']}{stop}"
|
prompt += f"{ai_n}{line['content']}{stop}"
|
||||||
prompt += ai_n.rstrip()
|
prompt += ai_n.rstrip()
|
||||||
|
@ -130,7 +130,7 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
slot_id = data["slot_id"]
|
slot_id = data.get("slot_id")
|
||||||
if (chat):
|
if (chat):
|
||||||
if (start):
|
if (start):
|
||||||
resData["choices"][0]["delta"] = {
|
resData["choices"][0]["delta"] = {
|
||||||
|
@ -150,11 +150,13 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
|
||||||
return resData
|
return resData
|
||||||
|
|
||||||
|
|
||||||
@app.route('/chat/completions', methods=['POST'])
|
@app.route('/chat/completions', methods=['POST', 'OPTIONS'])
|
||||||
@app.route('/v1/chat/completions', methods=['POST'])
|
@app.route('/v1/chat/completions', methods=['POST', 'OPTIONS'])
|
||||||
def chat_completions():
|
def chat_completions():
|
||||||
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
|
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
|
||||||
return Response(status=403)
|
return Response(status=403)
|
||||||
|
if request.method == 'OPTIONS':
|
||||||
|
return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
|
||||||
body = request.get_json()
|
body = request.get_json()
|
||||||
stream = False
|
stream = False
|
||||||
tokenize = False
|
tokenize = False
|
||||||
|
@ -177,20 +179,22 @@ def chat_completions():
|
||||||
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
|
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
|
||||||
time_now = int(time.time())
|
time_now = int(time.time())
|
||||||
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
|
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
|
||||||
yield 'data: {}\n'.format(json.dumps(resData))
|
yield 'data: {}\n\n'.format(json.dumps(resData))
|
||||||
for line in data.iter_lines():
|
for line in data.iter_lines():
|
||||||
if line:
|
if line:
|
||||||
decoded_line = line.decode('utf-8')
|
decoded_line = line.decode('utf-8')
|
||||||
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
|
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
|
||||||
yield 'data: {}\n'.format(json.dumps(resData))
|
yield 'data: {}\n\n'.format(json.dumps(resData))
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
|
||||||
|
|
||||||
|
|
||||||
@app.route('/completions', methods=['POST'])
|
@app.route('/completions', methods=['POST', 'OPTIONS'])
|
||||||
@app.route('/v1/completions', methods=['POST'])
|
@app.route('/v1/completions', methods=['POST', 'OPTIONS'])
|
||||||
def completion():
|
def completion():
|
||||||
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
|
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
|
||||||
return Response(status=403)
|
return Response(status=403)
|
||||||
|
if request.method == 'OPTIONS':
|
||||||
|
return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
|
||||||
body = request.get_json()
|
body = request.get_json()
|
||||||
stream = False
|
stream = False
|
||||||
tokenize = False
|
tokenize = False
|
||||||
|
@ -216,8 +220,8 @@ def completion():
|
||||||
if line:
|
if line:
|
||||||
decoded_line = line.decode('utf-8')
|
decoded_line = line.decode('utf-8')
|
||||||
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
|
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
|
||||||
yield 'data: {}\n'.format(json.dumps(resData))
|
yield 'data: {}\n\n'.format(json.dumps(resData))
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(args.host, port=args.port)
|
app.run(args.host, port=args.port)
|
||||||
|
|
|
@ -156,15 +156,23 @@ struct task_server {
|
||||||
json data;
|
json data;
|
||||||
bool infill_mode = false;
|
bool infill_mode = false;
|
||||||
bool embedding_mode = false;
|
bool embedding_mode = false;
|
||||||
|
int multitask_id = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct task_result {
|
struct task_result {
|
||||||
int id;
|
int id;
|
||||||
|
int multitask_id = -1;
|
||||||
bool stop;
|
bool stop;
|
||||||
bool error;
|
bool error;
|
||||||
json result_json;
|
json result_json;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct task_multi {
|
||||||
|
int id;
|
||||||
|
std::set<int> subtasks_remaining{};
|
||||||
|
std::vector<task_result> results{};
|
||||||
|
};
|
||||||
|
|
||||||
// TODO: can become bool if we can't find use of more states
|
// TODO: can become bool if we can't find use of more states
|
||||||
enum slot_state
|
enum slot_state
|
||||||
{
|
{
|
||||||
|
@ -407,6 +415,9 @@ struct llama_client_slot
|
||||||
double t_prompt_processing; // ms
|
double t_prompt_processing; // ms
|
||||||
double t_token_generation; // ms
|
double t_token_generation; // ms
|
||||||
|
|
||||||
|
// multitasks
|
||||||
|
int multitask_id = -1;
|
||||||
|
|
||||||
void reset() {
|
void reset() {
|
||||||
num_prompt_tokens = 0;
|
num_prompt_tokens = 0;
|
||||||
generated_text = "";
|
generated_text = "";
|
||||||
|
@ -530,7 +541,8 @@ struct llama_server_context
|
||||||
|
|
||||||
std::vector<task_server> queue_tasks;
|
std::vector<task_server> queue_tasks;
|
||||||
std::vector<task_result> queue_results;
|
std::vector<task_result> queue_results;
|
||||||
std::mutex mutex_tasks;
|
std::vector<task_multi> queue_multitasks;
|
||||||
|
std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
|
||||||
std::mutex mutex_results;
|
std::mutex mutex_results;
|
||||||
|
|
||||||
~llama_server_context()
|
~llama_server_context()
|
||||||
|
@ -1113,17 +1125,40 @@ struct llama_server_context
|
||||||
return slot.images.size() > 0;
|
return slot.images.size() > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_error(int id, std::string error)
|
void send_error(task_server& task, std::string error)
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = id;
|
res.id = task.id;
|
||||||
|
res.multitask_id = task.multitask_id;
|
||||||
res.stop = false;
|
res.stop = false;
|
||||||
res.error = true;
|
res.error = true;
|
||||||
res.result_json = { { "content", error } };
|
res.result_json = { { "content", error } };
|
||||||
queue_results.push_back(res);
|
queue_results.push_back(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void add_multi_task(int id, std::vector<int>& sub_ids)
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||||
|
task_multi multi;
|
||||||
|
multi.id = id;
|
||||||
|
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
|
||||||
|
queue_multitasks.push_back(multi);
|
||||||
|
}
|
||||||
|
|
||||||
|
void update_multi_task(int multitask_id, int subtask_id, task_result& result)
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||||
|
for (auto& multitask : queue_multitasks)
|
||||||
|
{
|
||||||
|
if (multitask.id == multitask_id)
|
||||||
|
{
|
||||||
|
multitask.subtasks_remaining.erase(subtask_id);
|
||||||
|
multitask.results.push_back(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
json get_model_props()
|
json get_model_props()
|
||||||
{
|
{
|
||||||
return get_formated_generation(slots[0]);
|
return get_formated_generation(slots[0]);
|
||||||
|
@ -1168,6 +1203,7 @@ struct llama_server_context
|
||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = slot.task_id;
|
res.id = slot.task_id;
|
||||||
|
res.multitask_id = slot.multitask_id;
|
||||||
res.error = false;
|
res.error = false;
|
||||||
res.stop = false;
|
res.stop = false;
|
||||||
|
|
||||||
|
@ -1207,6 +1243,7 @@ struct llama_server_context
|
||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = slot.task_id;
|
res.id = slot.task_id;
|
||||||
|
res.multitask_id = slot.multitask_id;
|
||||||
res.error = false;
|
res.error = false;
|
||||||
res.stop = true;
|
res.stop = true;
|
||||||
|
|
||||||
|
@ -1252,6 +1289,12 @@ struct llama_server_context
|
||||||
res.result_json["model"] = slot.oaicompat_model;
|
res.result_json["model"] = slot.oaicompat_model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parent multitask, if any, needs to be updated
|
||||||
|
if (slot.multitask_id != -1)
|
||||||
|
{
|
||||||
|
update_multi_task(slot.multitask_id, slot.task_id, res);
|
||||||
|
}
|
||||||
|
|
||||||
queue_results.push_back(res);
|
queue_results.push_back(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1260,6 +1303,7 @@ struct llama_server_context
|
||||||
std::lock_guard<std::mutex> lock(mutex_results);
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
task_result res;
|
task_result res;
|
||||||
res.id = slot.task_id;
|
res.id = slot.task_id;
|
||||||
|
res.multitask_id = slot.multitask_id;
|
||||||
res.error = false;
|
res.error = false;
|
||||||
res.stop = true;
|
res.stop = true;
|
||||||
|
|
||||||
|
@ -1286,9 +1330,9 @@ struct llama_server_context
|
||||||
queue_results.push_back(res);
|
queue_results.push_back(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
int request_completion(json data, bool infill, bool embedding)
|
int request_completion(json data, bool infill, bool embedding, int multitask_id)
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||||
task_server task;
|
task_server task;
|
||||||
task.id = id_gen++;
|
task.id = id_gen++;
|
||||||
task.target_id = 0;
|
task.target_id = 0;
|
||||||
|
@ -1296,6 +1340,16 @@ struct llama_server_context
|
||||||
task.infill_mode = infill;
|
task.infill_mode = infill;
|
||||||
task.embedding_mode = embedding;
|
task.embedding_mode = embedding;
|
||||||
task.type = COMPLETION_TASK;
|
task.type = COMPLETION_TASK;
|
||||||
|
task.multitask_id = multitask_id;
|
||||||
|
|
||||||
|
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
||||||
|
if (task.data.at("prompt").size() > 1)
|
||||||
|
{
|
||||||
|
lock.unlock(); // entering new func scope
|
||||||
|
return split_multiprompt_task(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
// otherwise, it's a single-prompt task, we actually queue it
|
||||||
queue_tasks.push_back(task);
|
queue_tasks.push_back(task);
|
||||||
return task.id;
|
return task.id;
|
||||||
}
|
}
|
||||||
|
@ -1314,8 +1368,17 @@ struct llama_server_context
|
||||||
|
|
||||||
for (int i = 0; i < (int) queue_results.size(); i++)
|
for (int i = 0; i < (int) queue_results.size(); i++)
|
||||||
{
|
{
|
||||||
|
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
|
||||||
|
if (queue_results[i].multitask_id == task_id)
|
||||||
|
{
|
||||||
|
update_multi_task(task_id, queue_results[i].id, queue_results[i]);
|
||||||
|
queue_results.erase(queue_results.begin() + i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (queue_results[i].id == task_id)
|
if (queue_results[i].id == task_id)
|
||||||
{
|
{
|
||||||
|
assert(queue_results[i].multitask_id == -1);
|
||||||
task_result res = queue_results[i];
|
task_result res = queue_results[i];
|
||||||
queue_results.erase(queue_results.begin() + i);
|
queue_results.erase(queue_results.begin() + i);
|
||||||
return res;
|
return res;
|
||||||
|
@ -1405,6 +1468,27 @@ struct llama_server_context
|
||||||
queue_tasks.push_back(task);
|
queue_tasks.push_back(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int split_multiprompt_task(task_server& multiprompt_task)
|
||||||
|
{
|
||||||
|
auto prompt_count = multiprompt_task.data.at("prompt").size();
|
||||||
|
assert(prompt_count > 1);
|
||||||
|
|
||||||
|
int multitask_id = id_gen++;
|
||||||
|
std::vector<int> subtask_ids(prompt_count);
|
||||||
|
for (int i = 0; i < prompt_count; i++)
|
||||||
|
{
|
||||||
|
json subtask_data = multiprompt_task.data;
|
||||||
|
subtask_data["prompt"] = subtask_data["prompt"][i];
|
||||||
|
|
||||||
|
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
||||||
|
subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// queue up the multitask so we can track its subtask progression
|
||||||
|
add_multi_task(multitask_id, subtask_ids);
|
||||||
|
return multitask_id;
|
||||||
|
}
|
||||||
|
|
||||||
void process_tasks()
|
void process_tasks()
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_tasks);
|
std::lock_guard<std::mutex> lock(mutex_tasks);
|
||||||
|
@ -1420,7 +1504,7 @@ struct llama_server_context
|
||||||
{
|
{
|
||||||
LOG_TEE("slot unavailable\n");
|
LOG_TEE("slot unavailable\n");
|
||||||
// send error result
|
// send error result
|
||||||
send_error(task.id, "slot unavailable");
|
send_error(task, "slot unavailable");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1434,11 +1518,12 @@ struct llama_server_context
|
||||||
slot->infill = task.infill_mode;
|
slot->infill = task.infill_mode;
|
||||||
slot->embedding = task.embedding_mode;
|
slot->embedding = task.embedding_mode;
|
||||||
slot->task_id = task.id;
|
slot->task_id = task.id;
|
||||||
|
slot->multitask_id = task.multitask_id;
|
||||||
|
|
||||||
if (!launch_slot_with_data(slot, task.data))
|
if (!launch_slot_with_data(slot, task.data))
|
||||||
{
|
{
|
||||||
// send error result
|
// send error result
|
||||||
send_error(task.id, "internal_error");
|
send_error(task, "internal_error");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
@ -1454,6 +1539,38 @@ struct llama_server_context
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
|
||||||
|
auto queue_iterator = queue_multitasks.begin();
|
||||||
|
while (queue_iterator != queue_multitasks.end())
|
||||||
|
{
|
||||||
|
if (queue_iterator->subtasks_remaining.empty())
|
||||||
|
{
|
||||||
|
// all subtasks done == multitask is done
|
||||||
|
task_result aggregate_result;
|
||||||
|
aggregate_result.id = queue_iterator->id;
|
||||||
|
aggregate_result.stop = true;
|
||||||
|
aggregate_result.error = false;
|
||||||
|
|
||||||
|
// collect json results into one json result
|
||||||
|
std::vector<json> result_jsons;
|
||||||
|
for (auto& subres : queue_iterator->results)
|
||||||
|
{
|
||||||
|
result_jsons.push_back(subres.result_json);
|
||||||
|
aggregate_result.error = aggregate_result.error && subres.error;
|
||||||
|
}
|
||||||
|
aggregate_result.result_json = json{ "results", result_jsons };
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_results);
|
||||||
|
queue_results.push_back(aggregate_result);
|
||||||
|
|
||||||
|
queue_iterator = queue_multitasks.erase(queue_iterator);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
++queue_iterator;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool update_slots() {
|
bool update_slots() {
|
||||||
|
@ -1845,6 +1962,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||||
printf(" -spf FNAME, --system-prompt-file FNAME\n");
|
printf(" -spf FNAME, --system-prompt-file FNAME\n");
|
||||||
printf(" Set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
|
printf(" Set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
|
||||||
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
|
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
|
||||||
|
printf(" --log-disable disables logging to a file.\n");
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2199,6 +2317,11 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
}
|
}
|
||||||
params.mmproj = argv[i];
|
params.mmproj = argv[i];
|
||||||
}
|
}
|
||||||
|
else if (arg == "--log-disable")
|
||||||
|
{
|
||||||
|
log_set_target(stdout);
|
||||||
|
LOG_INFO("logging to file is disabled.", {});
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
|
@ -2597,7 +2720,7 @@ int main(int argc, char **argv)
|
||||||
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
const int task_id = llama.request_completion(data, false, false);
|
const int task_id = llama.request_completion(data, false, false, -1);
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
task_result result = llama.next_result(task_id);
|
task_result result = llama.next_result(task_id);
|
||||||
|
@ -2686,7 +2809,7 @@ int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
||||||
|
|
||||||
const int task_id = llama.request_completion(data, false, false);
|
const int task_id = llama.request_completion(data, false, false, -1);
|
||||||
|
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
|
@ -2755,7 +2878,7 @@ int main(int argc, char **argv)
|
||||||
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
const int task_id = llama.request_completion(data, true, false);
|
const int task_id = llama.request_completion(data, true, false, -1);
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
task_result result = llama.next_result(task_id);
|
task_result result = llama.next_result(task_id);
|
||||||
|
@ -2859,7 +2982,7 @@ int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
prompt = "";
|
prompt = "";
|
||||||
}
|
}
|
||||||
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true);
|
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true, -1);
|
||||||
task_result result = llama.next_result(task_id);
|
task_result result = llama.next_result(task_id);
|
||||||
return res.set_content(result.result_json.dump(), "application/json");
|
return res.set_content(result.result_json.dump(), "application/json");
|
||||||
});
|
});
|
||||||
|
|
|
@ -1,21 +1,19 @@
|
||||||
|
#include "ggml.h"
|
||||||
#include "ggml-opencl.h"
|
#include "ggml-opencl.h"
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <limits>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <limits>
|
|
||||||
|
|
||||||
#define CL_TARGET_OPENCL_VERSION 110
|
#define CL_TARGET_OPENCL_VERSION 110
|
||||||
#include <clblast.h>
|
#include <clblast.h>
|
||||||
#include <clblast_c.h>
|
#include <clblast_c.h>
|
||||||
|
|
||||||
#include <stdlib.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <string.h>
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -47,7 +47,6 @@
|
||||||
#endif
|
#endif
|
||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
#include <io.h>
|
#include <io.h>
|
||||||
#include <stdio.h> // for _fseeki64
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
@ -7290,6 +7289,7 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
|
||||||
// Replace the data in candidates with the new_candidates data
|
// Replace the data in candidates with the new_candidates data
|
||||||
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
|
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
|
||||||
candidates->size = new_candidates.size();
|
candidates->size = new_candidates.size();
|
||||||
|
candidates->sorted = false;
|
||||||
|
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue