File size: 8,898 Bytes
644bdfe
 
702f422
644bdfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a02de1f
7b8fddf
8e8fff2
 
 
 
 
 
 
 
6c1d676
702f422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c1d676
 
 
702f422
6c1d676
 
 
644bdfe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import argparse
import json
import os
from typing import Annotated, List, Literal

import mariadb
from fastmcp import Context, FastMCP
from pydantic import Field

from mcp_server_mariadb_vector.app_context import app_lifespan
from mcp_server_mariadb_vector.embeddings.factory import create_embedding_provider
from mcp_server_mariadb_vector.settings import EmbeddingSettings

mcp = FastMCP(
    "Mariadb Vector",
    lifespan=app_lifespan,
    dependencies=["mariadb", "openai", "pydantic", "pydantic-settings"],
)


embedding_provider = create_embedding_provider(EmbeddingSettings())


@mcp.tool()
def mariadb_create_vector_store(
    ctx: Context,
    vector_store_name: Annotated[
        str,
        Field(description="The name of the vector store to create"),
    ],
    distance_function: Annotated[
        Literal["euclidean", "cosine"],
        Field(description="The distance function to use."),
    ] = "euclidean",
) -> str:
    """Create a vector store in the MariaDB database."""

    embedding_length = embedding_provider.length_of_embedding()

    schema_query = f"""
    CREATE TABLE `{vector_store_name}` (
        id BIGINT UNSIGNED PRIMARY KEY AUTO_INCREMENT,
        document LONGTEXT NOT NULL,
        embedding VECTOR({embedding_length}) NOT NULL,
        metadata JSON NOT NULL,
        VECTOR INDEX (embedding) DISTANCE={distance_function}
    )
    """

    try:
        conn = ctx.request_context.lifespan_context.conn
        with conn.cursor() as cursor:
            cursor.execute(schema_query)
    except mariadb.Error as e:
        return f"Error creating vector store `{vector_store_name}`: {e}"

    return f"Vector store `{vector_store_name}` created successfully."


def is_vector_store(conn, table: str, embedding_length: int) -> bool:
    """
    True if `table` has the right schema, with vectors of the correct length, and a VECTOR index.
    """

    with conn.cursor(dictionary=True) as cur:
        # check columns
        cur.execute(f"SHOW COLUMNS FROM `{table}`")
        rows = {r["Field"]: r for r in cur}

        if set(rows) != {"id", "document", "embedding", "metadata"}:
            return False

        # id
        id_type = rows["id"]["Type"].lower()
        if id_type != "bigint(20) unsigned":
            return False
        if (
            rows["id"]["Null"] != "NO"
            or rows["id"]["Key"] != "PRI"
            or "auto_increment" not in rows["id"]["Extra"].lower()
        ):
            return False

        # document
        if (
            rows["document"]["Type"].lower() != "longtext"
            or rows["document"]["Null"] != "NO"
        ):
            return False

        # embedding
        if (
            rows["embedding"]["Type"].lower() != f"vector({embedding_length})"
            or rows["embedding"]["Null"] != "NO"
        ):
            return False

        # metadata
        if (
            rows["metadata"]["Type"].lower() != "longtext"
            or rows["metadata"]["Null"] != "NO"
        ):
            return False

        # check vector index
        cur.execute(f"""
            SHOW INDEX FROM `{table}`
            WHERE Index_type = 'VECTOR' AND Column_name = 'embedding'
        """)
        if cur.fetchone() is None:
            return False

    return True


@mcp.tool()
def mariadb_list_vector_stores(ctx: Context) -> str:
    """List all vector stores in a MariaDB database."""
    try:
        conn = ctx.request_context.lifespan_context.conn
        with conn.cursor() as cursor:
            cursor.execute("SHOW TABLES")
            tables = [table[0] for table in cursor]
    except mariadb.Error as e:
        return f"Error listing vector stores: {e}"

    embedding_length = embedding_provider.length_of_embedding()
    vector_stores = [
        table for table in tables if is_vector_store(conn, table, embedding_length)
    ]

    return "Vector stores: " + ", ".join(vector_stores)


@mcp.tool()
def mariadb_delete_vector_store(
    ctx: Context,
    vector_store_name: Annotated[
        str, Field(description="The name of the vector store to delete.")
    ],
) -> str:
    """Delete a vector store in the MariaDB database."""

    try:
        conn = ctx.request_context.lifespan_context.conn
        with conn.cursor() as cursor:
            cursor.execute(f"DROP TABLE `{vector_store_name}`")
    except mariadb.Error as e:
        return f"Error deleting vector store `{vector_store_name}`: {e}"

    return f"Vector store `{vector_store_name}` deleted successfully."


@mcp.tool()
def mariadb_insert_documents(
    ctx: Context,
    vector_store_name: Annotated[
        str, Field(description="The name of the vector store to insert documents into.")
    ],
    documents: Annotated[
        List[str], Field(description="The documents to insert into the vector store.")
    ],
    metadata: Annotated[
        List[dict], Field(description="The metadata of the documents to insert.")
    ],
) -> str:
    """Insert a document into a vector store."""

    embeddings = embedding_provider.embed_documents(documents)

    metadata_json = [json.dumps(metadata) for metadata in metadata]

    insert_query = f"""
    INSERT INTO `{vector_store_name}` (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s)
    """
    try:
        conn = ctx.request_context.lifespan_context.conn
        with conn.cursor() as cursor:
            cursor.executemany(
                insert_query, list(zip(documents, embeddings, metadata_json))
            )
    except mariadb.Error as e:
        return f"Error inserting documents`{vector_store_name}`: {e}"

    return f"Documents inserted into `{vector_store_name}` successfully."


@mcp.tool()
def mariadb_search_vector_store(
    ctx: Context,
    query: Annotated[str, Field(description="The query to search for.")],
    vector_store_name: Annotated[
        str, Field(description="The name of the vector store to search.")
    ],
    k: Annotated[int, Field(gt=0, description="The number of results to return.")] = 5,
) -> str:
    """Search a vector store for the most similar documents to a query."""

    embedding = embedding_provider.embed_query(query)

    search_query = f"""
    SELECT 
        document,
        metadata,
        VEC_DISTANCE_EUCLIDEAN(embedding, VEC_FromText(%s)) AS distance
    FROM `{vector_store_name}`
    ORDER BY distance ASC
    LIMIT %s
    """

    try:
        conn = ctx.request_context.lifespan_context.conn
        with conn.cursor(buffered=True) as cursor:
            cursor.execute(
                search_query,
                (str(embedding), k),
            )
            rows = cursor.fetchall()
    except mariadb.Error as e:
        return f"Error searching vector store`{vector_store_name}`: {e}"

    if not rows:
        return "No similar context found."

    return "\n\n".join(
        f"Document: {row[0]}\nMetadata: {json.loads(row[1])}\nDistance: {row[2]}"
        for row in rows
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--transport",
        choices=["stdio", "sse"],
        default="stdio",
    )
    parser.add_argument(
        "--host",
        type=str,
        default="127.0.0.1",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
    )

    args = parser.parse_args()

    if args.transport == "sse":
        mcp.run(transport=args.transport, host=args.host, port=args.port)
    else:
        mcp.run(transport=args.transport)


app = mcp.http_app()  # Using http_app instead of deprecated sse_app

# Debug: Print registered MCP tool names (fix AttributeError)
print("[DEBUG] FastMCP attributes:", dir(mcp))
if hasattr(mcp, "_tools"):
    print("[DEBUG] Registered MCP tools:")
    for tool in mcp._tools:
        print(f" - {tool}")
else:
    print("[DEBUG] No _tools attribute found on mcp.")

# Debug: Print all environment variables
print("[DEBUG] Environment variables:")
for k, v in os.environ.items():
    print(f"  {k}={v}")

# Debug: Print MCP tool manager contents if available
if hasattr(mcp, '_tool_manager'):
    print("[DEBUG] MCP _tool_manager contents:")
    print(getattr(mcp, '_tool_manager', None))
    if hasattr(mcp._tool_manager, '__dict__'):
        for k, v in mcp._tool_manager.__dict__.items():
            print(f"  {k}: {v}")
else:
    print("[DEBUG] No _tool_manager attribute found on mcp.")

# Debug: Print MCP app type
print(f"[DEBUG] MCP app type: {type(app)}")

# Debug: Print all routes in the MCP app (with more details)
print("[DEBUG] Registered routes in MCP app:")
try:
    for route in app.routes:
        print(f" - path: {getattr(route, 'path', str(route))}, methods: {getattr(route, 'methods', '')}, name: {getattr(route, 'name', '')}, endpoint: {getattr(route, 'endpoint', '')}")
except Exception as e:
    print(f"[DEBUG] Could not inspect MCP app routes: {e}")

if __name__ == "__main__":
    main()