Kaballas's picture
ccc
702f422
import argparse
import json
import os
from typing import Annotated, List, Literal
import mariadb
from fastmcp import Context, FastMCP
from pydantic import Field
from mcp_server_mariadb_vector.app_context import app_lifespan
from mcp_server_mariadb_vector.embeddings.factory import create_embedding_provider
from mcp_server_mariadb_vector.settings import EmbeddingSettings
mcp = FastMCP(
"Mariadb Vector",
lifespan=app_lifespan,
dependencies=["mariadb", "openai", "pydantic", "pydantic-settings"],
)
embedding_provider = create_embedding_provider(EmbeddingSettings())
@mcp.tool()
def mariadb_create_vector_store(
ctx: Context,
vector_store_name: Annotated[
str,
Field(description="The name of the vector store to create"),
],
distance_function: Annotated[
Literal["euclidean", "cosine"],
Field(description="The distance function to use."),
] = "euclidean",
) -> str:
"""Create a vector store in the MariaDB database."""
embedding_length = embedding_provider.length_of_embedding()
schema_query = f"""
CREATE TABLE `{vector_store_name}` (
id BIGINT UNSIGNED PRIMARY KEY AUTO_INCREMENT,
document LONGTEXT NOT NULL,
embedding VECTOR({embedding_length}) NOT NULL,
metadata JSON NOT NULL,
VECTOR INDEX (embedding) DISTANCE={distance_function}
)
"""
try:
conn = ctx.request_context.lifespan_context.conn
with conn.cursor() as cursor:
cursor.execute(schema_query)
except mariadb.Error as e:
return f"Error creating vector store `{vector_store_name}`: {e}"
return f"Vector store `{vector_store_name}` created successfully."
def is_vector_store(conn, table: str, embedding_length: int) -> bool:
"""
True if `table` has the right schema, with vectors of the correct length, and a VECTOR index.
"""
with conn.cursor(dictionary=True) as cur:
# check columns
cur.execute(f"SHOW COLUMNS FROM `{table}`")
rows = {r["Field"]: r for r in cur}
if set(rows) != {"id", "document", "embedding", "metadata"}:
return False
# id
id_type = rows["id"]["Type"].lower()
if id_type != "bigint(20) unsigned":
return False
if (
rows["id"]["Null"] != "NO"
or rows["id"]["Key"] != "PRI"
or "auto_increment" not in rows["id"]["Extra"].lower()
):
return False
# document
if (
rows["document"]["Type"].lower() != "longtext"
or rows["document"]["Null"] != "NO"
):
return False
# embedding
if (
rows["embedding"]["Type"].lower() != f"vector({embedding_length})"
or rows["embedding"]["Null"] != "NO"
):
return False
# metadata
if (
rows["metadata"]["Type"].lower() != "longtext"
or rows["metadata"]["Null"] != "NO"
):
return False
# check vector index
cur.execute(f"""
SHOW INDEX FROM `{table}`
WHERE Index_type = 'VECTOR' AND Column_name = 'embedding'
""")
if cur.fetchone() is None:
return False
return True
@mcp.tool()
def mariadb_list_vector_stores(ctx: Context) -> str:
"""List all vector stores in a MariaDB database."""
try:
conn = ctx.request_context.lifespan_context.conn
with conn.cursor() as cursor:
cursor.execute("SHOW TABLES")
tables = [table[0] for table in cursor]
except mariadb.Error as e:
return f"Error listing vector stores: {e}"
embedding_length = embedding_provider.length_of_embedding()
vector_stores = [
table for table in tables if is_vector_store(conn, table, embedding_length)
]
return "Vector stores: " + ", ".join(vector_stores)
@mcp.tool()
def mariadb_delete_vector_store(
ctx: Context,
vector_store_name: Annotated[
str, Field(description="The name of the vector store to delete.")
],
) -> str:
"""Delete a vector store in the MariaDB database."""
try:
conn = ctx.request_context.lifespan_context.conn
with conn.cursor() as cursor:
cursor.execute(f"DROP TABLE `{vector_store_name}`")
except mariadb.Error as e:
return f"Error deleting vector store `{vector_store_name}`: {e}"
return f"Vector store `{vector_store_name}` deleted successfully."
@mcp.tool()
def mariadb_insert_documents(
ctx: Context,
vector_store_name: Annotated[
str, Field(description="The name of the vector store to insert documents into.")
],
documents: Annotated[
List[str], Field(description="The documents to insert into the vector store.")
],
metadata: Annotated[
List[dict], Field(description="The metadata of the documents to insert.")
],
) -> str:
"""Insert a document into a vector store."""
embeddings = embedding_provider.embed_documents(documents)
metadata_json = [json.dumps(metadata) for metadata in metadata]
insert_query = f"""
INSERT INTO `{vector_store_name}` (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s)
"""
try:
conn = ctx.request_context.lifespan_context.conn
with conn.cursor() as cursor:
cursor.executemany(
insert_query, list(zip(documents, embeddings, metadata_json))
)
except mariadb.Error as e:
return f"Error inserting documents`{vector_store_name}`: {e}"
return f"Documents inserted into `{vector_store_name}` successfully."
@mcp.tool()
def mariadb_search_vector_store(
ctx: Context,
query: Annotated[str, Field(description="The query to search for.")],
vector_store_name: Annotated[
str, Field(description="The name of the vector store to search.")
],
k: Annotated[int, Field(gt=0, description="The number of results to return.")] = 5,
) -> str:
"""Search a vector store for the most similar documents to a query."""
embedding = embedding_provider.embed_query(query)
search_query = f"""
SELECT
document,
metadata,
VEC_DISTANCE_EUCLIDEAN(embedding, VEC_FromText(%s)) AS distance
FROM `{vector_store_name}`
ORDER BY distance ASC
LIMIT %s
"""
try:
conn = ctx.request_context.lifespan_context.conn
with conn.cursor(buffered=True) as cursor:
cursor.execute(
search_query,
(str(embedding), k),
)
rows = cursor.fetchall()
except mariadb.Error as e:
return f"Error searching vector store`{vector_store_name}`: {e}"
if not rows:
return "No similar context found."
return "\n\n".join(
f"Document: {row[0]}\nMetadata: {json.loads(row[1])}\nDistance: {row[2]}"
for row in rows
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transport",
choices=["stdio", "sse"],
default="stdio",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
)
parser.add_argument(
"--port",
type=int,
default=8000,
)
args = parser.parse_args()
if args.transport == "sse":
mcp.run(transport=args.transport, host=args.host, port=args.port)
else:
mcp.run(transport=args.transport)
app = mcp.http_app() # Using http_app instead of deprecated sse_app
# Debug: Print registered MCP tool names (fix AttributeError)
print("[DEBUG] FastMCP attributes:", dir(mcp))
if hasattr(mcp, "_tools"):
print("[DEBUG] Registered MCP tools:")
for tool in mcp._tools:
print(f" - {tool}")
else:
print("[DEBUG] No _tools attribute found on mcp.")
# Debug: Print all environment variables
print("[DEBUG] Environment variables:")
for k, v in os.environ.items():
print(f" {k}={v}")
# Debug: Print MCP tool manager contents if available
if hasattr(mcp, '_tool_manager'):
print("[DEBUG] MCP _tool_manager contents:")
print(getattr(mcp, '_tool_manager', None))
if hasattr(mcp._tool_manager, '__dict__'):
for k, v in mcp._tool_manager.__dict__.items():
print(f" {k}: {v}")
else:
print("[DEBUG] No _tool_manager attribute found on mcp.")
# Debug: Print MCP app type
print(f"[DEBUG] MCP app type: {type(app)}")
# Debug: Print all routes in the MCP app (with more details)
print("[DEBUG] Registered routes in MCP app:")
try:
for route in app.routes:
print(f" - path: {getattr(route, 'path', str(route))}, methods: {getattr(route, 'methods', '')}, name: {getattr(route, 'name', '')}, endpoint: {getattr(route, 'endpoint', '')}")
except Exception as e:
print(f"[DEBUG] Could not inspect MCP app routes: {e}")
if __name__ == "__main__":
main()