Geotessera library for OCaml
at main 12 kB view raw
1open Eio 2 3let ( / ) = Eio.Path.( / ) 4 5type point = { lat : float; lon : float } 6 7let pp_point fmt { lat; lon } = Fmt.pf fmt "{ lat: %.3f, lon: %.3f }" lat lon 8 9module Bbox = struct 10 type t = float array 11 12 let v ~min_lon ~min_lat ~max_lon ~max_lat = 13 [| min_lat; min_lon; max_lat; max_lon |] 14 15 let min_lon t = t.(1) 16 let min_lat t = t.(0) 17 let max_lon t = t.(3) 18 let max_lat t = t.(2) 19end 20 21module Registry = struct 22 type t = { 23 git_url : string; 24 data_url : string; 25 git : Eio.Fs.dir_ty Eio.Path.t; 26 data : Eio.Fs.dir_ty Eio.Path.t; 27 client : Client.t; 28 version : [ `v1 ]; 29 } 30 31 let version_to_string = function `v1 -> "v1" 32 let block_size = 5. 33 34 let block_of_point point = 35 { 36 lon = Float.floor (point.lon /. block_size) *. block_size; 37 lat = Float.floor (point.lat /. block_size) *. block_size; 38 } 39 40 let blocks_for_region (bbox : Bbox.t) = 41 let min_block_lon = floor (Bbox.min_lon bbox /. block_size) *. block_size in 42 let min_block_lat = floor (Bbox.min_lat bbox /. block_size) *. block_size in 43 let max_block_lon = floor (Bbox.max_lon bbox /. block_size) *. block_size in 44 let max_block_lat = floor (Bbox.max_lat bbox /. block_size) *. block_size in 45 46 let blocks = ref [] in 47 48 let lon = ref min_block_lon in 49 while !lon <= max_block_lon do 50 let lat = ref min_block_lat in 51 while !lat <= max_block_lat do 52 (* Blocks are referenced by their centre *) 53 blocks := { lat = !lat +. 2.5; lon = !lon +. 2.5 } :: !blocks; 54 lat := !lat +. block_size 55 done; 56 lon := !lon +. block_size 57 done; 58 59 List.rev !blocks 60 61 let clone env ~into url = 62 Eio.Process.run env#process_mgr [ "git"; "clone"; url; into ] 63 64 let with_registry ?(git_url = "https://github.com/ucam-eo/tessera-manifests") 65 ?(data_url = "https://dl-2.tessera.wiki") ?(version = `v1) env fn = 66 let t = 67 Xdg.create 68 ~env:(fun s -> try Some (Unix.getenv s) with Not_found -> None) 69 () 70 in 71 let cache = Xdg.cache_dir t in 72 let dir = env#fs / cache / "ocaml-geotessera" in 73 let git = dir / "tessera-manifests" in 74 let data = dir / "downloads" in 75 let client = Client.v env#net in 76 let v = { git_url; data_url; git; client; data; version } in 77 (match Eio.Path.kind ~follow:false git with 78 | `Directory -> () 79 | `Not_found -> 80 Eio.Path.mkdirs ~perm:0o755 data; 81 clone env ~into:(Eio.Path.native_exn git) v.git_url; 82 () 83 | _ -> Fmt.failwith "%a is not a directory or non-existent" Eio.Path.pp git); 84 fn v 85 86 let with_display ~sw label fn = 87 let display = ref None in 88 let mux = Mutex.create () in 89 let rec get_display () = 90 match !display with 91 | Some d -> d 92 | None -> 93 Mutex.use_rw ~protect:true mux (fun () -> 94 display := Some (Display.init ~sw label)); 95 get_display () 96 in 97 let finalise_display () = Option.iter Display.finalise !display in 98 Fun.protect ~finally:finalise_display (fun () -> fn get_display) 99 100 let extract_lat_lon_from_grid_name s = 101 match String.split_on_char '_' (Filename.basename s) with 102 | [ "grid"; lon; lat_with_ext ] -> 103 let lat = Filename.chop_extension lat_with_ext in 104 { lon = float_of_string lon; lat = float_of_string lat } 105 | [ "grid"; lon; lat; "scales.npy" ] -> 106 { lon = float_of_string lon; lat = float_of_string lat } 107 | _ -> Fmt.invalid_arg "Failed to extract lat/lon from %s" s 108 109 let parse_manifest path = 110 let module R = Eio.Buf_read in 111 Eio.Path.with_open_in path @@ fun f -> 112 let r = R.of_flow ~max_size:max_int f in 113 let lines = R.lines r in 114 Seq.map (String.split_on_char ' ') lines 115 |> Seq.map (function 116 | [ s; h ] -> 117 (extract_lat_lon_from_grid_name s, s, Digestif.SHA256.of_hex h) 118 | _ -> failwith "Malformed manifest") 119 |> List.of_seq 120 121 let find_manifest t ~year point = 122 let b = block_of_point point in 123 let name = 124 Fmt.str "embeddings_%i_lon%i_lat%i.txt" year (Float.to_int b.lon) 125 (Float.to_int b.lat) 126 in 127 let path = t.git / "registry" / "embeddings" / name in 128 parse_manifest path 129 130 let find_landmasks _t point = 131 let name = Fmt.str "grid_%.2f_%.2f.tiff" point.lon point.lat in 132 (point, name, Digestif.SHA256.empty) 133 134 let download_embedding t v = 135 Eio.Switch.run @@ fun sw -> 136 with_display ~sw "embeddings" @@ fun get_display -> 137 Eio.traceln "Downloading and checking embeddings..."; 138 let paths = 139 Fiber.List.map ~max_fibers:10 140 (fun (_point, name, _hash) -> 141 let uri = 142 Fmt.str "%s/%s/global_0.1_degree_representation/%s" t.data_url 143 (version_to_string t.version) 144 name 145 in 146 let data_dir = t.data / "embeddings" / Filename.dirname name in 147 Eio.Path.mkdirs ~exists_ok:true ~perm:0o755 data_dir; 148 let path = data_dir / Filename.basename name in 149 let download = 150 try 151 let stat = Eio.Path.stat ~follow:false path in 152 if Optint.Int63.(equal stat.size zero) then `Delete_and_download 153 else `Check_hash stat.kind 154 with Eio.Exn.Io (Eio.Fs.E (Not_found _), _) -> `Download 155 in 156 match download with 157 | `Check_hash `Regular_file -> 158 let _disk_hash = 159 Digestif.SHA256.digest_string @@ Eio.Path.load path 160 in 161 (* assert (Digestif.SHA256.equal disk_hash hash); *) 162 path 163 | `Delete_and_download | `Download -> 164 let () = 165 match download with 166 | `Delete_and_download -> Eio.Path.unlink path 167 | _ -> () 168 in 169 let display = get_display () in 170 let () = 171 Eio.Path.with_open_out ~create:(`If_missing 0o644) path 172 @@ fun w -> 173 Client.with_body ~default:() t.client display ~uri 174 @@ fun body -> Flow.copy body w 175 in 176 path 177 | `Check_hash _ -> Fmt.failwith "%s exists but is not a file!" name) 178 v 179 in 180 List.fold_left 181 (fun (points, emb, sca) p -> 182 let point = extract_lat_lon_from_grid_name (Eio.Path.native_exn p) in 183 match String.split_on_char '_' (Eio.Path.native_exn p) |> List.rev with 184 | "scales.npy" :: _ -> (points, emb, p :: sca) 185 | _ -> (point :: points, p :: emb, sca)) 186 ([], [], []) paths 187 |> fun (ps, s, t) -> List.combine ps (List.combine s t) 188 189 let transformation_matrix e = 190 let scale = Tiff.Ifd.pixel_scale e in 191 let tiepoint = Tiff.Ifd.tiepoint e in 192 let arr = Nx.zeros Nx.float64 [| 6 |] in 193 (* X *) 194 Nx.set_item [ 0 ] tiepoint.(3) arr; 195 (* sX *) Nx.set_item [ 1 ] scale.(0) arr; 196 Nx.set_item [ 3 ] tiepoint.(4) arr; 197 Nx.set_item [ 5 ] (-.scale.(1)) arr; 198 arr 199 200 let crs_and_transform_of_landmask lm = 201 Eio.Path.with_open_in lm @@ fun r -> 202 let ro = Eio.File.pread_exact r in 203 let tiff = Tiff.from_file Tiff.Float32 ro in 204 let ifd = Tiff.ifd tiff in 205 let geos = Tiff.Ifd.geo_key_directory ifd in 206 (Tiff.Ifd.GeoKeys.projected_crs geos, transformation_matrix ifd) 207 208 let download_landmasks t v = 209 Eio.Switch.run @@ fun sw -> 210 with_display ~sw "landmasks" @@ fun get_display -> 211 Eio.traceln "Downloading and checking landmasks..."; 212 let paths = 213 Fiber.List.map ~max_fibers:10 214 (fun (_point, name, _hash) -> 215 let uri = 216 Fmt.str "%s/%s/global_0.1_degree_tiff_all/%s" t.data_url 217 (version_to_string t.version) 218 name 219 in 220 let data_dir = t.data / "landmasks" / Filename.dirname name in 221 Eio.Path.mkdirs ~exists_ok:true ~perm:0o755 data_dir; 222 let path = data_dir / Filename.basename name in 223 let download = 224 try 225 let stat = Eio.Path.stat ~follow:false path in 226 if Optint.Int63.(equal stat.size zero) then `Delete_and_download 227 else `Check_hash stat.kind 228 with Eio.Exn.Io (Eio.Fs.E (Not_found _), _) -> `Download 229 in 230 match download with 231 | `Check_hash `Regular_file -> 232 let _disk_hash = 233 Digestif.SHA256.digest_string @@ Eio.Path.load path 234 in 235 (* Eio.traceln "%a\n%a" Digestif.SHA256.pp disk_hash *) 236 (* Digestif.SHA256.pp hash; *) 237 (* assert (Digestif.SHA256.equal disk_hash hash); *) 238 let crs = crs_and_transform_of_landmask path in 239 (crs, path) 240 | `Delete_and_download | `Download -> 241 let () = 242 match download with 243 | `Delete_and_download -> Eio.Path.unlink path 244 | _ -> () 245 in 246 let display = get_display () in 247 let () = 248 Eio.Path.with_open_out ~create:(`If_missing 0o644) path 249 @@ fun w -> 250 Client.with_body ~default:() t.client display ~uri 251 @@ fun body -> Flow.copy body w 252 in 253 let crs = crs_and_transform_of_landmask path in 254 (crs, path) 255 | `Check_hash _ -> Fmt.failwith "%s exists but is not a file!" name) 256 v 257 in 258 List.fold_left 259 (fun acc (crs, p) -> 260 let point = extract_lat_lon_from_grid_name (Eio.Path.native_exn p) in 261 (point, crs, p) :: acc) 262 [] paths 263 |> List.rev 264end 265 266type 'a env = 267 < fs : Eio.Fs.dir_ty Eio.Path.t 268 ; process_mgr : [> `Generic ] Eio.Process.mgr_ty Eio.Process.mgr 269 ; net : [> `Generic ] Eio.Net.ty Eio.Net.t 270 ; .. > 271 as 272 'a 273 274type embedding = (int, Bigarray.int8_signed_elt) Nx.t 275type scales = (float, Bigarray.float32_elt) Nx.t 276 277let fetch_embedding registry ~year (bbox : Bbox.t) = 278 let blocks = Registry.blocks_for_region bbox in 279 Fiber.List.map 280 (fun point -> 281 let manifest = Registry.find_manifest registry ~year point in 282 283 (* We need to check the tiles available are within the bbox supplied. *) 284 let manifest = 285 List.filter 286 (fun (p, _, _) -> 287 (* Tiles span 0.1 degree *) 288 let tile_min_lon = p.lon -. 0.05 and tile_max_lon = p.lon +. 0.05 in 289 let tile_min_lat = p.lat -. 0.05 and tile_max_lat = p.lat +. 0.05 in 290 tile_min_lon < Bbox.max_lon bbox 291 && tile_min_lat < Bbox.max_lat bbox 292 && tile_max_lat > Bbox.min_lat bbox 293 && tile_max_lon > Bbox.min_lon bbox) 294 manifest 295 in 296 297 let downloads = Registry.download_embedding registry manifest in 298 List.map 299 (fun (point, (emb, scale)) -> 300 let emb = 301 Npy.read_copy (Eio.Path.native_exn emb) 302 |> Npy.to_bigarray C_layout Bigarray.Int8_signed 303 |> Option.get 304 |> Nx.of_bigarray 305 in 306 let scales = 307 Npy.read_copy (Eio.Path.native_exn scale) 308 |> Npy.to_bigarray C_layout Bigarray.Float32 309 |> Option.get 310 |> Nx.of_bigarray 311 in 312 (point, (emb, scales))) 313 downloads) 314 blocks 315 |> List.concat 316 317let fetch_landmask registry points = 318 let landmask_files = List.map (Registry.find_landmasks registry) points in 319 Registry.download_landmasks registry landmask_files 320 321let scale (emb, sca) = 322 match (Array.length (Nx.dims emb), Array.length (Nx.dims sca)) with 323 | 3, 2 -> 324 let dim0 = Nx.dim 0 sca and dim1 = Nx.dim 1 sca in 325 let bscales = 326 Nx.broadcast_to (Nx.dims emb) (Nx.reshape [| dim0; dim1; 1 |] sca) 327 in 328 Nx.mul (Nx.cast Nx.float32 emb) bscales 329 | 3, 3 -> Nx.mul (Nx.cast Nx.float32 emb) sca 330 | a, b -> Fmt.invalid_arg "Wrong dimensions %i %i" a b 331 332let check_points p1 p2 = 333 if p1 <> p2 then 334 Fmt.invalid_arg "Different points: %a and %a" pp_point p1 pp_point p2 335 336let fetch (env : _ env) ~year (bbox : Bbox.t) = 337 Registry.with_registry env @@ fun registry -> 338 let embs = fetch_embedding registry ~year bbox in 339 Eio.traceln "Depickling and scaling embeddings..."; 340 let landmask = fetch_landmask registry (List.map fst embs) in 341 List.map2 342 (fun (p1, emb) (p2, crs, lm) -> 343 check_points p1 p2; 344 (crs, p1, lm, emb)) 345 embs landmask