|
import argparse |
|
import json |
|
import os |
|
from typing import Annotated, List, Literal |
|
|
|
import mariadb |
|
from fastmcp import Context, FastMCP |
|
from pydantic import Field |
|
|
|
from mcp_server_mariadb_vector.app_context import app_lifespan |
|
from mcp_server_mariadb_vector.embeddings.factory import create_embedding_provider |
|
from mcp_server_mariadb_vector.settings import EmbeddingSettings |
|
|
|
mcp = FastMCP( |
|
"Mariadb Vector", |
|
lifespan=app_lifespan, |
|
dependencies=["mariadb", "openai", "pydantic", "pydantic-settings"], |
|
) |
|
|
|
|
|
embedding_provider = create_embedding_provider(EmbeddingSettings()) |
|
|
|
|
|
@mcp.tool() |
|
def mariadb_create_vector_store( |
|
ctx: Context, |
|
vector_store_name: Annotated[ |
|
str, |
|
Field(description="The name of the vector store to create"), |
|
], |
|
distance_function: Annotated[ |
|
Literal["euclidean", "cosine"], |
|
Field(description="The distance function to use."), |
|
] = "euclidean", |
|
) -> str: |
|
"""Create a vector store in the MariaDB database.""" |
|
|
|
embedding_length = embedding_provider.length_of_embedding() |
|
|
|
schema_query = f""" |
|
CREATE TABLE `{vector_store_name}` ( |
|
id BIGINT UNSIGNED PRIMARY KEY AUTO_INCREMENT, |
|
document LONGTEXT NOT NULL, |
|
embedding VECTOR({embedding_length}) NOT NULL, |
|
metadata JSON NOT NULL, |
|
VECTOR INDEX (embedding) DISTANCE={distance_function} |
|
) |
|
""" |
|
|
|
try: |
|
conn = ctx.request_context.lifespan_context.conn |
|
with conn.cursor() as cursor: |
|
cursor.execute(schema_query) |
|
except mariadb.Error as e: |
|
return f"Error creating vector store `{vector_store_name}`: {e}" |
|
|
|
return f"Vector store `{vector_store_name}` created successfully." |
|
|
|
|
|
def is_vector_store(conn, table: str, embedding_length: int) -> bool: |
|
""" |
|
True if `table` has the right schema, with vectors of the correct length, and a VECTOR index. |
|
""" |
|
|
|
with conn.cursor(dictionary=True) as cur: |
|
|
|
cur.execute(f"SHOW COLUMNS FROM `{table}`") |
|
rows = {r["Field"]: r for r in cur} |
|
|
|
if set(rows) != {"id", "document", "embedding", "metadata"}: |
|
return False |
|
|
|
|
|
id_type = rows["id"]["Type"].lower() |
|
if id_type != "bigint(20) unsigned": |
|
return False |
|
if ( |
|
rows["id"]["Null"] != "NO" |
|
or rows["id"]["Key"] != "PRI" |
|
or "auto_increment" not in rows["id"]["Extra"].lower() |
|
): |
|
return False |
|
|
|
|
|
if ( |
|
rows["document"]["Type"].lower() != "longtext" |
|
or rows["document"]["Null"] != "NO" |
|
): |
|
return False |
|
|
|
|
|
if ( |
|
rows["embedding"]["Type"].lower() != f"vector({embedding_length})" |
|
or rows["embedding"]["Null"] != "NO" |
|
): |
|
return False |
|
|
|
|
|
if ( |
|
rows["metadata"]["Type"].lower() != "longtext" |
|
or rows["metadata"]["Null"] != "NO" |
|
): |
|
return False |
|
|
|
|
|
cur.execute(f""" |
|
SHOW INDEX FROM `{table}` |
|
WHERE Index_type = 'VECTOR' AND Column_name = 'embedding' |
|
""") |
|
if cur.fetchone() is None: |
|
return False |
|
|
|
return True |
|
|
|
|
|
@mcp.tool() |
|
def mariadb_list_vector_stores(ctx: Context) -> str: |
|
"""List all vector stores in a MariaDB database.""" |
|
try: |
|
conn = ctx.request_context.lifespan_context.conn |
|
with conn.cursor() as cursor: |
|
cursor.execute("SHOW TABLES") |
|
tables = [table[0] for table in cursor] |
|
except mariadb.Error as e: |
|
return f"Error listing vector stores: {e}" |
|
|
|
embedding_length = embedding_provider.length_of_embedding() |
|
vector_stores = [ |
|
table for table in tables if is_vector_store(conn, table, embedding_length) |
|
] |
|
|
|
return "Vector stores: " + ", ".join(vector_stores) |
|
|
|
|
|
@mcp.tool() |
|
def mariadb_delete_vector_store( |
|
ctx: Context, |
|
vector_store_name: Annotated[ |
|
str, Field(description="The name of the vector store to delete.") |
|
], |
|
) -> str: |
|
"""Delete a vector store in the MariaDB database.""" |
|
|
|
try: |
|
conn = ctx.request_context.lifespan_context.conn |
|
with conn.cursor() as cursor: |
|
cursor.execute(f"DROP TABLE `{vector_store_name}`") |
|
except mariadb.Error as e: |
|
return f"Error deleting vector store `{vector_store_name}`: {e}" |
|
|
|
return f"Vector store `{vector_store_name}` deleted successfully." |
|
|
|
|
|
@mcp.tool() |
|
def mariadb_insert_documents( |
|
ctx: Context, |
|
vector_store_name: Annotated[ |
|
str, Field(description="The name of the vector store to insert documents into.") |
|
], |
|
documents: Annotated[ |
|
List[str], Field(description="The documents to insert into the vector store.") |
|
], |
|
metadata: Annotated[ |
|
List[dict], Field(description="The metadata of the documents to insert.") |
|
], |
|
) -> str: |
|
"""Insert a document into a vector store.""" |
|
|
|
embeddings = embedding_provider.embed_documents(documents) |
|
|
|
metadata_json = [json.dumps(metadata) for metadata in metadata] |
|
|
|
insert_query = f""" |
|
INSERT INTO `{vector_store_name}` (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s) |
|
""" |
|
try: |
|
conn = ctx.request_context.lifespan_context.conn |
|
with conn.cursor() as cursor: |
|
cursor.executemany( |
|
insert_query, list(zip(documents, embeddings, metadata_json)) |
|
) |
|
except mariadb.Error as e: |
|
return f"Error inserting documents`{vector_store_name}`: {e}" |
|
|
|
return f"Documents inserted into `{vector_store_name}` successfully." |
|
|
|
|
|
@mcp.tool() |
|
def mariadb_search_vector_store( |
|
ctx: Context, |
|
query: Annotated[str, Field(description="The query to search for.")], |
|
vector_store_name: Annotated[ |
|
str, Field(description="The name of the vector store to search.") |
|
], |
|
k: Annotated[int, Field(gt=0, description="The number of results to return.")] = 5, |
|
) -> str: |
|
"""Search a vector store for the most similar documents to a query.""" |
|
|
|
embedding = embedding_provider.embed_query(query) |
|
|
|
search_query = f""" |
|
SELECT |
|
document, |
|
metadata, |
|
VEC_DISTANCE_EUCLIDEAN(embedding, VEC_FromText(%s)) AS distance |
|
FROM `{vector_store_name}` |
|
ORDER BY distance ASC |
|
LIMIT %s |
|
""" |
|
|
|
try: |
|
conn = ctx.request_context.lifespan_context.conn |
|
with conn.cursor(buffered=True) as cursor: |
|
cursor.execute( |
|
search_query, |
|
(str(embedding), k), |
|
) |
|
rows = cursor.fetchall() |
|
except mariadb.Error as e: |
|
return f"Error searching vector store`{vector_store_name}`: {e}" |
|
|
|
if not rows: |
|
return "No similar context found." |
|
|
|
return "\n\n".join( |
|
f"Document: {row[0]}\nMetadata: {json.loads(row[1])}\nDistance: {row[2]}" |
|
for row in rows |
|
) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--transport", |
|
choices=["stdio", "sse"], |
|
default="stdio", |
|
) |
|
parser.add_argument( |
|
"--host", |
|
type=str, |
|
default="127.0.0.1", |
|
) |
|
parser.add_argument( |
|
"--port", |
|
type=int, |
|
default=8000, |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
if args.transport == "sse": |
|
mcp.run(transport=args.transport, host=args.host, port=args.port) |
|
else: |
|
mcp.run(transport=args.transport) |
|
|
|
|
|
app = mcp.http_app() |
|
|
|
|
|
print("[DEBUG] FastMCP attributes:", dir(mcp)) |
|
if hasattr(mcp, "_tools"): |
|
print("[DEBUG] Registered MCP tools:") |
|
for tool in mcp._tools: |
|
print(f" - {tool}") |
|
else: |
|
print("[DEBUG] No _tools attribute found on mcp.") |
|
|
|
|
|
print("[DEBUG] Environment variables:") |
|
for k, v in os.environ.items(): |
|
print(f" {k}={v}") |
|
|
|
|
|
if hasattr(mcp, '_tool_manager'): |
|
print("[DEBUG] MCP _tool_manager contents:") |
|
print(getattr(mcp, '_tool_manager', None)) |
|
if hasattr(mcp._tool_manager, '__dict__'): |
|
for k, v in mcp._tool_manager.__dict__.items(): |
|
print(f" {k}: {v}") |
|
else: |
|
print("[DEBUG] No _tool_manager attribute found on mcp.") |
|
|
|
|
|
print(f"[DEBUG] MCP app type: {type(app)}") |
|
|
|
|
|
print("[DEBUG] Registered routes in MCP app:") |
|
try: |
|
for route in app.routes: |
|
print(f" - path: {getattr(route, 'path', str(route))}, methods: {getattr(route, 'methods', '')}, name: {getattr(route, 'name', '')}, endpoint: {getattr(route, 'endpoint', '')}") |
|
except Exception as e: |
|
print(f"[DEBUG] Could not inspect MCP app routes: {e}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|