Model Context Protocol in OCaml
at tmp 22 kB view raw
1open Mcp 2open Mcp_sdk 3 4(* MCP Server module for handling JSON-RPC communication *) 5 6(** Server types *) 7type transport_type = 8 | Stdio (* Read/write to stdin/stdout *) 9 | Http (* HTTP server - to be implemented *) 10 11type t = { 12 server: Mcp_sdk.server; 13 transport: transport_type; 14 mutable running: bool; 15} 16 17(** Process a single message *) 18let process_message server message = 19 try 20 Log.debug "Parsing message as JSONRPC message..."; 21 match JSONRPCMessage.t_of_yojson message with 22 | JSONRPCMessage.Request req -> begin 23 Log.debug (Printf.sprintf "Received request with method: %s" req.method_); 24 match req.method_ with 25 | "initialize" -> begin 26 Log.debug "Processing initialize request"; 27 let result = match req.params with 28 | Some params -> begin 29 Log.debug "Parsing initialize request params..."; 30 let req_params = Initialize.Request.t_of_yojson params in 31 Log.debug (Printf.sprintf "Client info: %s v%s" 32 req_params.client_info.name 33 req_params.client_info.version); 34 Log.debug (Printf.sprintf "Client protocol version: %s" req_params.protocol_version); 35 36 (* Check protocol version compatibility *) 37 if req_params.protocol_version <> server.protocol_version then begin 38 Log.debug (Printf.sprintf "Protocol version mismatch: client=%s server=%s" 39 req_params.protocol_version server.protocol_version); 40 end; 41 42 Initialize.Result.create 43 ~capabilities:server.capabilities 44 ~server_info:Implementation.{ name = server.name; version = server.version } 45 ~protocol_version:server.protocol_version 46 ?instructions:(Some "MCP Server") (* TODO: Allow customization *) 47 () 48 end 49 | None -> begin 50 Log.error "Missing params for initialize request"; 51 Initialize.Result.create 52 ~capabilities:server.capabilities 53 ~server_info:Implementation.{ name = server.name; version = server.version } 54 ~protocol_version:server.protocol_version 55 () 56 end 57 in 58 Some (create_response ~id:req.id ~result:(Initialize.Result.yojson_of_t result)) 59 end 60 61 | "tools/list" -> begin 62 Log.debug "Processing tools/list request"; 63 let tools_json = List.map Mcp_sdk.Tool.to_json server.tools in 64 let result = `Assoc [("tools", `List tools_json)] in 65 Some (create_response ~id:req.id ~result) 66 end 67 68 | "tools/call" -> begin 69 Log.debug "Processing tools/call request"; 70 match req.params with 71 | Some (`Assoc params) -> begin 72 let name = match List.assoc_opt "name" params with 73 | Some (`String name) -> begin 74 Log.debug (Printf.sprintf "Tool name: %s" name); 75 name 76 end 77 | _ -> begin 78 Log.error "Missing or invalid 'name' parameter in tool call"; 79 failwith "Missing or invalid 'name' parameter" 80 end 81 in 82 let args = match List.assoc_opt "arguments" params with 83 | Some args -> begin 84 Log.debug (Printf.sprintf "Tool arguments: %s" (Yojson.Safe.to_string args)); 85 args 86 end 87 | _ -> begin 88 Log.debug "No arguments provided for tool call, using empty object"; 89 `Assoc [] (* Empty arguments is valid *) 90 end 91 in 92 let progress_token = req.progress_token in 93 94 (* Find the tool *) 95 let tool_opt = List.find_opt (fun t -> t.Mcp_sdk.Tool.name = name) server.tools in 96 match tool_opt with 97 | Some tool -> begin 98 Log.debug (Printf.sprintf "Found tool: %s" name); 99 let ctx = Mcp_sdk.Context.create 100 ?request_id:(Some req.id) 101 ~lifespan_context:server.lifespan_context 102 () 103 in 104 ctx.progress_token <- progress_token; 105 106 (* Call the handler *) 107 let result = match tool.handler ctx args with 108 | Ok json -> begin 109 `Assoc [ 110 ("content", `List [Mcp.yojson_of_content (Text (TextContent.{ 111 text = Yojson.Safe.to_string json; 112 annotations = None 113 }))]); 114 ("isError", `Bool false) 115 ] 116 end 117 | Error err -> begin 118 `Assoc [ 119 ("content", `List [Mcp.yojson_of_content (Text (TextContent.{ 120 text = err; 121 annotations = None 122 }))]); 123 ("isError", `Bool true) 124 ] 125 end 126 in 127 Some (create_response ~id:req.id ~result) 128 end 129 | None -> begin 130 Log.error (Printf.sprintf "Tool not found: %s" name); 131 let error_content = TextContent.{ 132 text = Printf.sprintf "Unknown tool: %s" name; 133 annotations = None 134 } in 135 let result = `Assoc [ 136 ("content", `List [Mcp.yojson_of_content (Text error_content)]); 137 ("isError", `Bool true) 138 ] in 139 Some (create_response ~id:req.id ~result) 140 end 141 end 142 | _ -> begin 143 Log.error "Invalid params format for tools/call"; 144 Some (create_error ~id:req.id ~code:ErrorCode.invalid_params ~message:"Invalid params for tools/call" ()) 145 end 146 end 147 148 | "resources/list" -> begin 149 Log.debug "Processing resources/list request"; 150 if server.resources <> [] then begin 151 let resources_json = List.map Mcp_sdk.Resource.to_json server.resources in 152 let result = `Assoc [("resources", `List resources_json)] in 153 Some (create_response ~id:req.id ~result) 154 end else begin 155 Some (create_error ~id:req.id ~code:ErrorCode.method_not_found ~message:"Resources not supported" ()) 156 end 157 end 158 159 | "prompts/list" -> begin 160 Log.debug "Processing prompts/list request"; 161 if server.prompts <> [] then begin 162 let prompts_json = List.map Mcp_sdk.Prompt.to_json server.prompts in 163 let result = `Assoc [("prompts", `List prompts_json)] in 164 Some (create_response ~id:req.id ~result) 165 end else begin 166 Some (create_error ~id:req.id ~code:ErrorCode.method_not_found ~message:"Prompts not supported" ()) 167 end 168 end 169 170 | "prompts/get" -> begin 171 Log.debug "Processing prompts/get request"; 172 if server.prompts <> [] then begin 173 match req.params with 174 | Some (`Assoc params) -> begin 175 (* Extract prompt name *) 176 let name = match List.assoc_opt "name" params with 177 | Some (`String name) -> begin 178 Log.debug (Printf.sprintf "Prompt name: %s" name); 179 name 180 end 181 | _ -> begin 182 Log.error "Missing or invalid 'name' parameter in prompt request"; 183 failwith "Missing or invalid 'name' parameter" 184 end 185 in 186 187 (* Extract arguments if any *) 188 let arguments = match List.assoc_opt "arguments" params with 189 | Some (`Assoc args) -> begin 190 Log.debug (Printf.sprintf "Prompt arguments: %s" (Yojson.Safe.to_string (`Assoc args))); 191 List.map (fun (k, v) -> 192 match v with 193 | `String s -> begin (k, s) end 194 | _ -> begin (k, Yojson.Safe.to_string v) end 195 ) args 196 end 197 | _ -> begin 198 [] 199 end 200 in 201 202 (* Find the prompt *) 203 let prompt_opt = List.find_opt (fun p -> p.Mcp_sdk.Prompt.name = name) server.prompts in 204 match prompt_opt with 205 | Some prompt -> begin 206 Log.debug (Printf.sprintf "Found prompt: %s" name); 207 let ctx = Mcp_sdk.Context.create 208 ?request_id:(Some req.id) 209 ~lifespan_context:server.lifespan_context 210 () 211 in 212 213 (* Call the prompt handler *) 214 match prompt.handler ctx arguments with 215 | Ok messages -> begin 216 Log.debug (Printf.sprintf "Prompt handler returned %d messages" (List.length messages)); 217 218 (* Important: We need to directly use yojson_of_message which preserves MIME types *) 219 let messages_json = List.map Prompt.yojson_of_message messages in 220 221 (* Debug output *) 222 Log.debug (Printf.sprintf "Messages JSON: %s" (Yojson.Safe.to_string (`List messages_json))); 223 224 (* Verify one message if available to check structure *) 225 if List.length messages > 0 then begin 226 let first_msg = List.hd messages in 227 let content_debug = match first_msg.content with 228 | Text t -> begin 229 Printf.sprintf "Text content: %s" t.text 230 end 231 | Image i -> begin 232 Printf.sprintf "Image content (mime: %s)" i.mime_type 233 end 234 | Audio a -> begin 235 Printf.sprintf "Audio content (mime: %s)" a.mime_type 236 end 237 | Resource r -> begin 238 "Resource content" 239 end 240 in 241 Log.debug (Printf.sprintf "First message content type: %s" content_debug); 242 end; 243 244 let result = `Assoc [ 245 ("messages", `List messages_json); 246 ("description", match prompt.description with 247 | Some d -> begin `String d end 248 | None -> begin `Null end) 249 ] in 250 Some (create_response ~id:req.id ~result) 251 end 252 | Error err -> begin 253 Log.error (Printf.sprintf "Error processing prompt: %s" err); 254 Some (create_error ~id:req.id ~code:ErrorCode.internal_error ~message:err ()) 255 end 256 end 257 | None -> begin 258 Log.error (Printf.sprintf "Prompt not found: %s" name); 259 Some (create_error ~id:req.id ~code:ErrorCode.invalid_params ~message:(Printf.sprintf "Prompt not found: %s" name) ()) 260 end 261 end 262 | _ -> begin 263 Log.error "Invalid params format for prompts/get"; 264 Some (create_error ~id:req.id ~code:ErrorCode.invalid_params ~message:"Invalid params format" ()) 265 end 266 end else begin 267 Some (create_error ~id:req.id ~code:ErrorCode.method_not_found ~message:"Prompts not supported" ()) 268 end 269 end 270 271 | "ping" -> begin 272 Log.debug "Processing ping request"; 273 Some (create_response ~id:req.id ~result:(`Assoc [])) 274 end 275 276 | _ -> begin 277 Log.error (Printf.sprintf "Unknown method received: %s" req.method_); 278 Some (create_error ~id:req.id ~code:ErrorCode.method_not_found ~message:("Method not found: " ^ req.method_) ()) 279 end 280 end 281 282 | JSONRPCMessage.Notification notif -> begin 283 Log.debug (Printf.sprintf "Received notification with method: %s" notif.method_); 284 match notif.method_ with 285 | "notifications/initialized" -> begin 286 Log.debug "Client initialization complete - Server is now ready to receive requests"; 287 None 288 end 289 | _ -> begin 290 Log.debug (Printf.sprintf "Ignoring notification: %s" notif.method_); 291 None 292 end 293 end 294 295 | JSONRPCMessage.Response _ -> begin 296 Log.error "Unexpected response message received"; 297 None 298 end 299 300 | JSONRPCMessage.Error _ -> begin 301 Log.error "Unexpected error message received"; 302 None 303 end 304 with 305 | Failure msg -> begin 306 Log.error (Printf.sprintf "JSON error in message processing: %s" msg); 307 None 308 end 309 | exc -> begin 310 Log.error (Printf.sprintf "Exception during message processing: %s" (Printexc.to_string exc)); 311 Log.error (Printf.sprintf "Backtrace: %s" (Printexc.get_backtrace())); 312 None 313 end 314 315(** Read a single message from stdin *) 316let read_stdio_message () = 317 try 318 Log.debug "Reading line from stdin..."; 319 let line = read_line () in 320 if line = "" then begin 321 Log.debug "Empty line received, ignoring"; 322 None 323 end else begin 324 Log.debug (Printf.sprintf "Raw input: %s" (String.sub line 0 (min 100 (String.length line)))); 325 try 326 let json = Yojson.Safe.from_string line in 327 Log.debug "Successfully parsed JSON"; 328 Some json 329 with 330 | Yojson.Json_error msg -> begin 331 Log.error (Printf.sprintf "Error parsing JSON: %s" msg); 332 Log.error (Printf.sprintf "Input was: %s" (String.sub line 0 (min 100 (String.length line)))); 333 None 334 end 335 end 336 with 337 | End_of_file -> begin 338 Log.debug "End of file received on stdin"; 339 None 340 end 341 | Sys_error msg -> begin 342 Log.error (Printf.sprintf "System error while reading: %s" msg); 343 None 344 end 345 | exc -> begin 346 Log.error (Printf.sprintf "Exception while reading: %s" (Printexc.to_string exc)); 347 None 348 end 349 350(** Run stdio server with enhanced error handling *) 351let rec run_stdio_server mcp_server = 352 try begin 353 if not mcp_server.running then begin 354 Log.debug "Server stopped"; 355 () 356 end else begin 357 match read_stdio_message () with 358 | Some json -> begin 359 Log.debug "Processing message..."; 360 try begin 361 match process_message mcp_server.server json with 362 | Some response -> begin 363 let response_json = JSONRPCMessage.yojson_of_t response in 364 let response_str = Yojson.Safe.to_string response_json in 365 Log.debug (Printf.sprintf "Sending response: %s" 366 (String.sub response_str 0 (min 100 (String.length response_str)))); 367 Printf.printf "%s\n" response_str; 368 flush stdout; 369 (* Give client time to process *) 370 Unix.sleepf 0.01; 371 end 372 | None -> begin 373 Log.debug "No response needed" 374 end 375 end with 376 | exc -> begin 377 Log.error (Printf.sprintf "ERROR in message processing: %s" (Printexc.to_string exc)); 378 Log.error (Printf.sprintf "Backtrace: %s" (Printexc.get_backtrace())); 379 (* Try to extract ID and send an error response *) 380 try begin 381 let id_opt = match Yojson.Safe.Util.member "id" json with 382 | `Int i -> Some (`Int i) 383 | `String s -> Some (`String s) 384 | _ -> None 385 in 386 match id_opt with 387 | Some id -> begin 388 let error_resp = create_error ~id ~code:ErrorCode.internal_error ~message:(Printexc.to_string exc) () in 389 let error_json = JSONRPCMessage.yojson_of_t error_resp in 390 let error_str = Yojson.Safe.to_string error_json in 391 Printf.printf "%s\n" error_str; 392 flush stdout; 393 end 394 | None -> begin 395 Log.error "Could not extract request ID to send error response" 396 end 397 end with 398 | e -> begin 399 Log.error (Printf.sprintf "Failed to send error response: %s" (Printexc.to_string e)) 400 end 401 end; 402 run_stdio_server mcp_server 403 end 404 | None -> begin 405 if mcp_server.running then begin 406 (* No message received, but server is still running *) 407 Unix.sleepf 0.1; (* Small sleep to prevent CPU spinning *) 408 run_stdio_server mcp_server 409 end else begin 410 Log.debug "Server stopped during message processing" 411 end 412 end 413 end 414 end with 415 | exc -> begin 416 Log.error (Printf.sprintf "FATAL ERROR in server main loop: %s" (Printexc.to_string exc)); 417 Log.error (Printf.sprintf "Backtrace: %s" (Printexc.get_backtrace())); 418 (* Try to continue anyway *) 419 if mcp_server.running then begin 420 Unix.sleepf 0.1; 421 run_stdio_server mcp_server 422 end 423 end 424 425(** Create an MCP server *) 426let create ~server ~transport () = 427 { server; transport; running = false } 428 429(** HTTP server placeholder (to be fully implemented) *) 430let run_http_server mcp_server port = 431 Log.info (Printf.sprintf "%s HTTP server starting on port %d" mcp_server.server.name port); 432 Log.info "HTTP transport is a placeholder and not fully implemented yet"; 433 434 (* This would be where we'd set up cohttp server *) 435 (* 436 let callback _conn req body = 437 let uri = req |> Cohttp.Request.uri in 438 let meth = req |> Cohttp.Request.meth |> Cohttp.Code.string_of_method in 439 440 (* Handle only POST /jsonrpc endpoint *) 441 match (meth, Uri.path uri) with 442 | "POST", "/jsonrpc" -> 443 (* Read the body *) 444 Cohttp_lwt.Body.to_string body >>= fun body_str -> 445 446 (* Parse JSON *) 447 let json = try Some (Yojson.Safe.from_string body_str) with _ -> None in 448 match json with 449 | Some json_msg -> 450 (* Process the message *) 451 let response_opt = process_message mcp_server.server json_msg in 452 (match response_opt with 453 | Some response -> 454 let response_json = JSONRPCMessage.yojson_of_t response in 455 let response_str = Yojson.Safe.to_string response_json in 456 Cohttp_lwt_unix.Server.respond_string 457 ~status:`OK 458 ~body:response_str 459 ~headers:(Cohttp.Header.init_with "Content-Type" "application/json") 460 () 461 | None -> 462 Cohttp_lwt_unix.Server.respond_string 463 ~status:`OK 464 ~body:"{}" 465 ~headers:(Cohttp.Header.init_with "Content-Type" "application/json") 466 ()) 467 | None -> 468 Cohttp_lwt_unix.Server.respond_string 469 ~status:`Bad_request 470 ~body:"{\"error\":\"Invalid JSON\"}" 471 ~headers:(Cohttp.Header.init_with "Content-Type" "application/json") 472 () 473 | _ -> 474 (* Return 404 for any other routes *) 475 Cohttp_lwt_unix.Server.respond_string 476 ~status:`Not_found 477 ~body:"Not found" 478 () 479 in 480 481 (* Create and start the server *) 482 let server = Cohttp_lwt_unix.Server.create 483 ~mode:(`TCP (`Port port)) 484 (Cohttp_lwt_unix.Server.make ~callback ()) 485 in 486 487 (* Run the server *) 488 Lwt_main.run server 489 *) 490 491 (* For now, just wait until the server is stopped *) 492 while mcp_server.running do 493 Unix.sleep 1 494 done 495 496(** Start the server based on transport type *) 497let start server = 498 server.running <- true; 499 500 (* Run startup hook if provided *) 501 (match server.server.startup_hook with 502 | Some hook -> begin hook () end 503 | None -> begin () end); 504 505 (* Install signal handler *) 506 Sys.(set_signal sigint (Signal_handle (fun _ -> 507 Log.debug "Received interrupt signal, stopping server..."; 508 server.running <- false 509 ))); 510 511 match server.transport with 512 | Stdio -> begin 513 (* Setup stdout and stderr *) 514 set_binary_mode_out stdout false; 515 Log.info (Printf.sprintf "%s server started with stdio transport" server.server.name); 516 517 (* Run the server loop *) 518 run_stdio_server server 519 end 520 | Http -> begin 521 (* HTTP server placeholder *) 522 run_http_server server 8080 523 end 524 525(** Stop the server *) 526let stop server = 527 Log.info "Stopping server..."; 528 server.running <- false; 529 530 (* Run shutdown hook if provided *) 531 match server.server.shutdown_hook with 532 | Some hook -> begin hook () end 533 | None -> begin () end