diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index fedcfe5ae..5f81d256a 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -60,6 +60,19 @@ Feature: llama.cpp server """ Then embeddings are generated + Scenario: OAI Embeddings compatibility with multiple inputs + Given a model tinyllama-2 + Given a prompt: + """ + In which country Paris is located ? + """ + And a prompt: + """ + Is Madrid the capital of Spain ? + """ + When an OAI compatible embeddings computation request for multiple inputs + Then embeddings are generated + Scenario: Tokenize / Detokenize When tokenizing: diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 732c8785e..9c825fdbc 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -1,4 +1,5 @@ import asyncio +import collections import json import os import re @@ -321,7 +322,15 @@ async def step_compute_embedding(context): @step(u'embeddings are generated') def step_assert_embeddings(context): - assert_embeddings(context.embeddings) + if len(context.prompts) == 0: + assert_embeddings(context.embeddings) + else: + assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n" + f"context.prompts={context.prompts}\n" + f"context.embeddings={context.embeddings}") + for embedding in context.embeddings: + context.prompts.pop() + assert_embeddings(embedding) @step(u'an OAI compatible embeddings computation request for') @@ -333,6 +342,15 @@ async def step_oai_compute_embeddings(context): model=context.model) +@step(u'an OAI compatible embeddings computation request for multiple inputs') +@async_run_until_complete +async def step_oai_compute_embeddings_multiple_inputs(context): + context.embeddings = await request_oai_embeddings(context.prompts, + base_url=context.base_url, + user_api_key=context.user_api_key, + model=context.model) + + @step(u'concurrent embedding requests') @async_run_until_complete() async def step_concurrent_embedding_requests(context): @@ -607,10 +625,17 @@ async def request_oai_embeddings(input, else: openai.api_key = user_api_key openai.api_base = f'{base_url}/v1' - embeddings = openai.Embedding.create( + oai_embeddings = openai.Embedding.create( model=model, input=input, ) + + if isinstance(input, collections.abc.Sequence): + embeddings = [] + for an_oai_embeddings in oai_embeddings.data: + embeddings.append(an_oai_embeddings.embedding) + else: + embeddings = oai_embeddings.data.embedding return embeddings