let src = Logs.Src.create "requests.cache" ~doc:"HTTP cache with cacheio" module Log = (val Logs.src_log src : Logs.LOG) type cached_response = { status : Cohttp.Code.status_code; headers : Cohttp.Header.t; body : string; } type t = { sw : Eio.Switch.t; enabled : bool; cache_get_requests : bool; cache_range_requests : bool; cacheio : Cacheio.t option; memory_cache : (string, cached_response * float) Hashtbl.t; } let create ~sw ~enabled ?(cache_get_requests=true) ?(cache_range_requests=true) ~cache_dir () = let cacheio = match cache_dir with | Some dir when enabled -> (try Some (Cacheio.create ~base_dir:dir) with e -> Log.warn (fun m -> m "Failed to create cacheio backend: %s. Using memory cache only." (Printexc.to_string e)); None) | _ -> None in { sw; enabled; cache_get_requests; cache_range_requests; cacheio; memory_cache = Hashtbl.create 100 } let make_cache_key ~method_ ~url ~headers = let method_str = match method_ with | `GET -> "GET" | `HEAD -> "HEAD" | _ -> "OTHER" in let url_str = Uri.to_string url in let range_str = match Cohttp.Header.get headers "range" with | Some r -> "_range:" ^ r | None -> "" in Printf.sprintf "%s_%s%s" method_str url_str range_str let is_cacheable ~method_ ~status ~headers = match method_ with | `GET | `HEAD -> let code = Cohttp.Code.code_of_status status in if code >= 200 && code < 300 then match Cohttp.Header.get headers "cache-control" with | Some cc -> let cc_lower = String.lowercase_ascii cc in let rec contains s sub pos = if pos + String.length sub > String.length s then false else if String.sub s pos (String.length sub) = sub then true else contains s sub (pos + 1) in not (contains cc_lower "no-store" 0 || contains cc_lower "no-cache" 0 || contains cc_lower "private" 0) | None -> true else code = 301 || code = 308 | _ -> false let parse_max_age headers = match Cohttp.Header.get headers "cache-control" with | Some cc -> let parts = String.split_on_char ',' cc |> List.map String.trim in List.find_map (fun part -> let prefix = "max-age=" in if String.starts_with ~prefix part then let value = String.sub part (String.length prefix) (String.length part - String.length prefix) in try Some (float_of_string value) with _ -> None else None ) parts | None -> None let serialize_metadata ~status ~headers = let status_code = Cohttp.Code.code_of_status status in let headers_assoc = Cohttp.Header.to_list headers in let json = `Assoc [ ("status_code", `Int status_code); ("headers", `Assoc (List.map (fun (k, v) -> (k, `String v)) headers_assoc)); ] in Yojson.Basic.to_string json let deserialize_metadata json_str = try let open Yojson.Basic.Util in let json = Yojson.Basic.from_string json_str in let status_code = json |> member "status_code" |> to_int in let status = Cohttp.Code.status_of_code status_code in let headers_json = json |> member "headers" |> to_assoc in let headers = headers_json |> List.map (fun (k, v) -> (k, to_string v)) |> Cohttp.Header.of_list in Some (status, headers) with _ -> None let get t ~method_ ~url ~headers = if not t.enabled then None else if method_ = `GET && not t.cache_get_requests then None else let key = make_cache_key ~method_ ~url ~headers in (* Try cacheio first *) match t.cacheio with | Some cache -> (* Check for metadata entry *) let metadata_key = key ^ ".meta" in let body_key = key ^ ".body" in if Cacheio.exists cache ~key:metadata_key && Cacheio.exists cache ~key:body_key then Eio.Switch.run @@ fun sw -> (* Read metadata *) let metadata_opt = match Cacheio.get cache ~key:metadata_key ~sw with | Some source -> let buf = Buffer.create 256 in Eio.Flow.copy source (Eio.Flow.buffer_sink buf); deserialize_metadata (Buffer.contents buf) | None -> None in (match metadata_opt with | Some (status, resp_headers) -> (* Read body *) (match Cacheio.get cache ~key:body_key ~sw with | Some source -> let buf = Buffer.create 4096 in Eio.Flow.copy source (Eio.Flow.buffer_sink buf); let body = Buffer.contents buf in Log.debug (fun m -> m "Cache hit for %s" (Uri.to_string url)); Some { status; headers = resp_headers; body } | None -> Log.debug (fun m -> m "Cache body missing for %s" (Uri.to_string url)); None) | None -> Log.debug (fun m -> m "Cache metadata missing for %s" (Uri.to_string url)); None) else (Log.debug (fun m -> m "Cache miss for %s" (Uri.to_string url)); None) | None -> (* Fall back to memory cache *) match Hashtbl.find_opt t.memory_cache key with | Some (response, expiry) when expiry > Unix.gettimeofday () -> Log.debug (fun m -> m "Memory cache hit for %s" (Uri.to_string url)); Some response | _ -> Log.debug (fun m -> m "Cache miss for %s" (Uri.to_string url)); None let get_stream t ~method_ ~url ~headers ~sw = if not t.enabled then None else if method_ = `GET && not t.cache_get_requests then None else let key = make_cache_key ~method_ ~url ~headers in match t.cacheio with | Some cache -> let metadata_key = key ^ ".meta" in let body_key = key ^ ".body" in if Cacheio.exists cache ~key:metadata_key && Cacheio.exists cache ~key:body_key then (* Read metadata first *) let metadata_opt = match Cacheio.get cache ~key:metadata_key ~sw with | Some source -> let buf = Buffer.create 256 in Eio.Flow.copy source (Eio.Flow.buffer_sink buf); deserialize_metadata (Buffer.contents buf) | None -> None in (match metadata_opt with | Some (status, resp_headers) -> (* Return body stream directly *) (match Cacheio.get cache ~key:body_key ~sw with | Some source -> Log.debug (fun m -> m "Streaming cache hit for %s" (Uri.to_string url)); Some (status, resp_headers, source) | None -> None) | None -> None) else None | None -> None let put t ~method_ ~url ~request_headers ~status ~headers ~body = if not t.enabled then () else if is_cacheable ~method_ ~status ~headers then let key = make_cache_key ~method_ ~url ~headers:request_headers in let ttl = parse_max_age headers in Log.debug (fun m -> m "Caching response for %s (ttl: %s)" (Uri.to_string url) (match ttl with Some t -> Printf.sprintf "%.0fs" t | None -> "3600s")); (match t.cacheio with | Some cache -> Eio.Switch.run @@ fun _sw -> let metadata_key = key ^ ".meta" in let metadata = serialize_metadata ~status ~headers in let metadata_source = Eio.Flow.string_source metadata in Cacheio.put cache ~key:metadata_key ~source:metadata_source ~ttl (); let body_key = key ^ ".body" in let body_source = Eio.Flow.string_source body in Cacheio.put cache ~key:body_key ~source:body_source ~ttl () | None -> ()); let cached_resp = { status; headers; body } in let expiry = Unix.gettimeofday () +. Option.value ttl ~default:3600.0 in Hashtbl.replace t.memory_cache key (cached_resp, expiry) let put_stream t ~method_ ~url ~request_headers ~status ~headers ~body_source ~ttl = if not t.enabled then () else if is_cacheable ~method_ ~status ~headers then let key = make_cache_key ~method_ ~url ~headers:request_headers in Log.debug (fun m -> m "Caching streamed response for %s (ttl: %s)" (Uri.to_string url) (match ttl with Some t -> Printf.sprintf "%.0fs" t | None -> "3600s")); match t.cacheio with | Some cache -> Eio.Switch.run @@ fun _sw -> (* Store metadata *) let metadata_key = key ^ ".meta" in let metadata = serialize_metadata ~status ~headers in let metadata_source = Eio.Flow.string_source metadata in Cacheio.put cache ~key:metadata_key ~source:metadata_source ~ttl (); (* Store body directly from source *) let body_key = key ^ ".body" in Cacheio.put cache ~key:body_key ~source:body_source ~ttl () | None -> () module Range = struct type t = { start : int64; end_ : int64 option; (* None means to end of file *) } let of_header header = (* Parse Range: bytes=start-end *) let prefix = "bytes=" in let prefix_len = String.length prefix in if String.length header >= prefix_len && String.sub header 0 prefix_len = prefix then let range_str = String.sub header prefix_len (String.length header - prefix_len) in match String.split_on_char '-' range_str with | [start; ""] -> (* bytes=N- means from N to end *) (try Some { start = Int64.of_string start; end_ = None } with _ -> None) | [start; end_] -> (* bytes=N-M *) (try Some { start = Int64.of_string start; end_ = Some (Int64.of_string end_) } with _ -> None) | _ -> None else None let to_header t = match t.end_ with | None -> Printf.sprintf "bytes=%Ld-" t.start | Some e -> Printf.sprintf "bytes=%Ld-%Ld" t.start e let to_cacheio_range t ~total_size = let end_ = match t.end_ with | None -> Int64.pred total_size | Some e -> min e (Int64.pred total_size) in (* Convert to Cacheio.Range.t *) Cacheio.Range.create ~start:t.start ~end_ end let download_range t ~sw ~url ~range ~on_chunk = let range_header = Range.to_header range in Log.debug (fun m -> m "Range request for %s: %s" (Uri.to_string url) range_header); match t.cacheio with | Some cache -> let key = Uri.to_string url in let cacheio_range = Range.to_cacheio_range range ~total_size:Int64.max_int in (match Cacheio.get_range cache ~key ~range:cacheio_range ~sw with | `Complete source -> let rec read_chunks () = let chunk = Cstruct.create 8192 in try let n = Eio.Flow.single_read source chunk in if n > 0 then begin on_chunk (Cstruct.to_string ~off:0 ~len:n chunk); read_chunks () end with End_of_file -> () in read_chunks (); Some true | `Chunks chunk_sources -> List.iter (fun (_range, source) -> let rec read_chunk () = let chunk = Cstruct.create 8192 in try let n = Eio.Flow.single_read source chunk in if n > 0 then begin on_chunk (Cstruct.to_string ~off:0 ~len:n chunk); read_chunk () end with End_of_file -> () in read_chunk () ) chunk_sources; Some true | `Not_found -> None) | None -> None let put_chunk t ~url ~range ~data = if not t.enabled || not t.cache_range_requests then () else match t.cacheio with | Some cache -> let key = Uri.to_string url in let cacheio_range = Range.to_cacheio_range range ~total_size:Int64.max_int in Eio.Switch.run @@ fun _sw -> let source = Eio.Flow.string_source data in Cacheio.put_chunk cache ~key ~range:cacheio_range ~source () | None -> Log.debug (fun m -> m "Cannot cache chunk for %s: no cacheio backend" (Uri.to_string url)) let has_complete t ~url ~total_size = if not t.enabled then false else match t.cacheio with | Some cache -> let key = Uri.to_string url in Cacheio.has_complete_chunks cache ~key ~total_size | None -> false let missing_ranges t ~url ~total_size = if not t.enabled then [{ Range.start = 0L; end_ = Some (Int64.pred total_size) }] else match t.cacheio with | Some cache -> let key = Uri.to_string url in let cacheio_ranges = Cacheio.missing_ranges cache ~key ~total_size in List.map (fun r -> { Range.start = Cacheio.Range.start r; end_ = Some (Cacheio.Range.end_ r) } ) cacheio_ranges | None -> [{ Range.start = 0L; end_ = Some (Int64.pred total_size) }] let coalesce_chunks t ~url = if not t.enabled then false else match t.cacheio with | Some cache -> let key = Uri.to_string url in let promise = Cacheio.coalesce_chunks cache ~key ~verify:true () in (match Eio.Promise.await promise with | Ok () -> Log.info (fun m -> m "Successfully coalesced chunks for %s" key); true | Error exn -> Log.warn (fun m -> m "Failed to coalesce chunks for %s: %s" key (Printexc.to_string exn)); false) | None -> false let evict t ~url = if not t.enabled then () else let key = make_cache_key ~method_:`GET ~url ~headers:(Cohttp.Header.init ()) in (match t.cacheio with | Some cache -> Cacheio.delete cache ~key:(key ^ ".meta"); Cacheio.delete cache ~key:(key ^ ".body") | None -> ()); Log.debug (fun m -> m "Evicting cache for %s" (Uri.to_string url)); Hashtbl.remove t.memory_cache key let clear t = Log.info (fun m -> m "Clearing entire cache"); (match t.cacheio with | Some cache -> Cacheio.clear cache | None -> ()); Hashtbl.clear t.memory_cache let stats t = let cacheio_stats = match t.cacheio with | Some cache -> let stats = Cacheio.stats cache in `Assoc [ ("total_entries", `Int (Cacheio.Stats.entry_count stats)); ("total_bytes", `Int (Int64.to_int (Cacheio.Stats.total_size stats))); ("expired_entries", `Int (Cacheio.Stats.expired_count stats)); ("pinned_entries", `Int (Cacheio.Stats.pinned_count stats)); ("temporary_entries", `Int (Cacheio.Stats.temporary_count stats)); ] | None -> `Assoc [] in `Assoc [ ("memory_cache_entries", `Int (Hashtbl.length t.memory_cache)); ("cache_backend", `String (if Option.is_some t.cacheio then "cacheio" else "memory")); ("enabled", `Bool t.enabled); ("cache_get_requests", `Bool t.cache_get_requests); ("cache_range_requests", `Bool t.cache_range_requests); ("cacheio_stats", cacheio_stats); ]