From 0e87ae24cd497907ecf5eac33647cecfe070e7bf Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 27 Dec 2024 00:07:58 +0000 Subject: [PATCH] rm trailing spaces --- common/minja.hpp | 4 +-- examples/agent/run.py | 2 +- examples/agent/tools/memory.py | 30 +++++++++---------- examples/server/tests/pytest.ini | 2 +- .../server/tests/unit/test_chat_completion.py | 2 +- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/common/minja.hpp b/common/minja.hpp index c5472a0ae..26f20fdc9 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -1009,7 +1009,7 @@ public: throw std::runtime_error("Filter must be a callable: " + filter_value.dump()); } std::string rendered_body = body->render(context); - + ArgumentsValue filter_args = {{Value(rendered_body)}, {}}; auto result = filter_value.call(context, filter_args); out << result.to_str(); @@ -1181,7 +1181,7 @@ public: case Op::Expansion: case Op::ExpansionDict: throw std::runtime_error("Expansion operator is only supported in function calls and collections"); - + } throw std::runtime_error("Unknown unary operator"); } diff --git a/examples/agent/run.py b/examples/agent/run.py index 1cf94ede1..3330f1b7a 100644 --- a/examples/agent/run.py +++ b/examples/agent/run.py @@ -80,7 +80,7 @@ async def main( api_key = os.environ.get(provider_info['api_key_env']) tool_map, tools = await discover_tools(tool_endpoints or [], verbose) - + if think: tools.append({ 'type': 'function', diff --git a/examples/agent/tools/memory.py b/examples/agent/tools/memory.py index 3a3e87ce9..d3d0e600c 100644 --- a/examples/agent/tools/memory.py +++ b/examples/agent/tools/memory.py @@ -2,33 +2,33 @@ Memory tools that use sqlite-vec as a vector database (combined w/ sqlite-lembed or sqlite-rembed for embeddings). Note: it's best to run this in a silo w/: - + ./examples/agent/serve_tools_inside_docker.sh # Run w/o other tools: - + ## Prerequisites: - + pip install aiosqlite "fastapi[standard]" sqlite-lembed sqlite-rembed sqlite-vec uvicorn - + ## Usage w/ sqlite-rembed: - + ./llama-server --port 8081 -fa -c 0 --embeddings --rope-freq-scale 0.75 \ -hfr nomic-ai/nomic-embed-text-v1.5-GGUF -hff nomic-embed-text-v1.5.Q4_K_M.gguf MEMORY_SQLITE_DB=memory_rembed.db \ EMBEDDINGS_DIMS=768 \ EMBEDDINGS_ENDPOINT=http://localhost:8081/v1/embeddings \ python examples/agent/tools/memory.py - + ## Usage w/ sqlite-lembed: - + MEMORY_SQLITE_DB=memory_lembed.db \ EMBEDDINGS_DIMS=768 \ EMBEDDINGS_MODEL_FILE=~/Library/Caches/llama.cpp/nomic-embed-text-v1.5.Q4_K_M.gguf \ python examples/agent/tools/memory.py ## Test: - + curl -X POST "http://localhost:8000/memorize" -H "Content-Type: application/json" -d '["User is Olivier Chafik", "User is a Software Engineer"]' curl -X POST "http://localhost:8000/search_memory?text=What%20do%20we%20do%3F" ''' @@ -65,7 +65,7 @@ else: async def setup_db(db: aiosqlite.Connection): - + await db.enable_load_extension(True) await db.load_extension(sqlite_vec.loadable_path()) if local: @@ -75,7 +75,7 @@ async def setup_db(db: aiosqlite.Connection): await db.enable_load_extension(False) client_name = 'default' - + if local: await db.execute(f''' INSERT INTO lembed_models(name, model) VALUES ( @@ -88,7 +88,7 @@ async def setup_db(db: aiosqlite.Connection): '{client_name}', rembed_client_options('format', 'llamafile', 'url', ?, 'key', ?) ); ''', (embeddings_endpoint, embeddings_api_key)) - + async def create_vector_index(table_name, text_column, embedding_column): ''' Create an sqlite-vec virtual table w/ an embedding column @@ -145,7 +145,7 @@ async def setup_db(db: aiosqlite.Connection): JOIN {table_name} USING (rowid) ''', (text, top_n) - ) + ) return search await db.execute(''' @@ -155,9 +155,9 @@ async def setup_db(db: aiosqlite.Connection): ) ''') facts_search = await create_vector_index('facts', 'content', 'embedding') - + await db.commit() - + return dict( facts_search=facts_search, ) @@ -185,7 +185,7 @@ async def search_memory(text: str, top_n: int = 10): results = await cursor.fetchall() cols = [c[0] for c in cursor.description] return [dict(zip(cols, row)) for row in results] - + # This main entry point is just here for easy debugging if __name__ == '__main__': diff --git a/examples/server/tests/pytest.ini b/examples/server/tests/pytest.ini index 6510c8d98..6df308df7 100644 --- a/examples/server/tests/pytest.ini +++ b/examples/server/tests/pytest.ini @@ -1,4 +1,4 @@ [pytest] markers = slow: marks tests as slow (deselect with '-m "not slow"') - serial \ No newline at end of file + serial diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 154176d32..f9db84957 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -231,7 +231,7 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool: {"role": "user", "content": "Write an example"}, ], "tool_choice": "required", - "tools": [tool], + "tools": [tool], "parallel_tool_calls": False, }) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"