openai: update after merge

typos
This commit is contained in:
Olivier Chafik 2024-04-30 18:29:08 +01:00
parent 7675ac6cf4
commit 312e20b54a
4 changed files with 16 additions and 15 deletions

View file

@ -29,7 +29,7 @@ class Tool(BaseModel):
class ResponseFormat(BaseModel): class ResponseFormat(BaseModel):
type: Literal["json_object"] type: Literal["json_object"]
schema: Optional[Json[Any]] = None # type: ignore schema: Optional[dict[str, Any]] = None # type: ignore
class LlamaCppParams(BaseModel): class LlamaCppParams(BaseModel):
n_predict: Optional[int] = None n_predict: Optional[int] = None
@ -67,7 +67,7 @@ class ChatCompletionRequest(LlamaCppParams):
class Choice(BaseModel): class Choice(BaseModel):
index: int index: int
message: Message message: Message
logprobs: Optional[Json[Any]] = None logprobs: Optional[dict[str, Any]] = None
finish_reason: Union[Literal["stop"], Literal["tool_calls"]] finish_reason: Union[Literal["stop"], Literal["tool_calls"]]
class Usage(BaseModel): class Usage(BaseModel):

View file

@ -8,13 +8,12 @@ from gguf.constants import Keys
class GGUFKeyValues: class GGUFKeyValues:
def __init__(self, model: Path): def __init__(self, model: Path):
reader = GGUFReader(model.as_posix()) self.reader = GGUFReader(model.as_posix())
self.fields = reader.fields
def __getitem__(self, key: str): def __getitem__(self, key: str):
if '{arch}' in key: if '{arch}' in key:
key = key.replace('{arch}', self[Keys.General.ARCHITECTURE]) key = key.replace('{arch}', self[Keys.General.ARCHITECTURE])
return self.fields[key].read() return self.reader.read_field(self.reader.fields[key])
def __contains__(self, key: str): def __contains__(self, key: str):
return key in self.fields return key in self.reader.fields
def keys(self): def keys(self):
return self.fields.keys() return self.reader.fields.keys()

View file

@ -180,7 +180,7 @@ class ChatTemplate(BaseModel):
class ChatHandlerArgs(BaseModel): class ChatHandlerArgs(BaseModel):
chat_template: ChatTemplate chat_template: ChatTemplate
response_schema: Optional[Json[Any]] = None response_schema: Optional[dict[str,Any]] = None
tools: Optional[list[Tool]] = None tools: Optional[list[Tool]] = None
class ChatHandler(ABC): class ChatHandler(ABC):
@ -719,7 +719,7 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
raise ValueError(f"Unsupported tool call style: {tool_style}") raise ValueError(f"Unsupported tool call style: {tool_style}")
# os.environ.get('NO_TS') # os.environ.get('NO_TS')
def _please_respond_with_schema(schema: Json[Any]) -> str: def _please_respond_with_schema(schema: dict[str, Any]) -> str:
sig = json.dumps(schema, indent=2) sig = json.dumps(schema, indent=2)
# _ts_converter = SchemaToTypeScriptConverter() # _ts_converter = SchemaToTypeScriptConverter()
# # _ts_converter.resolve_refs(schema, 'schema') # # _ts_converter.resolve_refs(schema, 'schema')

View file

@ -37,12 +37,8 @@ def main(
): ):
import uvicorn import uvicorn
if endpoint: chat_template = None
sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n") if model:
assert template_hf_model_id_fallback, "template_hf_model_id_fallback is required when using an endpoint"
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
else:
metadata = GGUFKeyValues(Path(model)) metadata = GGUFKeyValues(Path(model))
if not context_length: if not context_length:
@ -58,6 +54,12 @@ def main(
if verbose: if verbose:
sys.stderr.write(f"# CHAT TEMPLATE:\n\n{chat_template}\n\n") sys.stderr.write(f"# CHAT TEMPLATE:\n\n{chat_template}\n\n")
if not chat_template:
sys.stderr.write(f"# WARNING: Unsure which model we're talking to, fetching its chat template from HuggingFace tokenizer of {template_hf_model_id_fallback}\n")
assert template_hf_model_id_fallback or chat_template, "template_hf_model_id_fallback is required when using an endpoint without a model"
chat_template = ChatTemplate.from_huggingface(template_hf_model_id_fallback)
if not endpoint:
if verbose: if verbose:
sys.stderr.write(f"# Starting C++ server with model {model} on {server_host}:{server_port}\n") sys.stderr.write(f"# Starting C++ server with model {model} on {server_host}:{server_port}\n")
cmd = [ cmd = [