Geotessera library for OCaml
1open Eio
2
3let ( / ) = Eio.Path.( / )
4
5type point = { lat : float; lon : float }
6
7let pp_point fmt { lat; lon } = Fmt.pf fmt "{ lat: %.3f, lon: %.3f }" lat lon
8
9module Bbox = struct
10 type t = float array
11
12 let v ~min_lon ~min_lat ~max_lon ~max_lat =
13 [| min_lat; min_lon; max_lat; max_lon |]
14
15 let min_lon t = t.(1)
16 let min_lat t = t.(0)
17 let max_lon t = t.(3)
18 let max_lat t = t.(2)
19end
20
21module Registry = struct
22 type t = {
23 git_url : string;
24 data_url : string;
25 git : Eio.Fs.dir_ty Eio.Path.t;
26 data : Eio.Fs.dir_ty Eio.Path.t;
27 client : Client.t;
28 version : [ `v1 ];
29 }
30
31 let version_to_string = function `v1 -> "v1"
32 let block_size = 5.
33
34 let block_of_point point =
35 {
36 lon = Float.floor (point.lon /. block_size) *. block_size;
37 lat = Float.floor (point.lat /. block_size) *. block_size;
38 }
39
40 let blocks_for_region (bbox : Bbox.t) =
41 let min_block_lon = floor (Bbox.min_lon bbox /. block_size) *. block_size in
42 let min_block_lat = floor (Bbox.min_lat bbox /. block_size) *. block_size in
43 let max_block_lon = floor (Bbox.max_lon bbox /. block_size) *. block_size in
44 let max_block_lat = floor (Bbox.max_lat bbox /. block_size) *. block_size in
45
46 let blocks = ref [] in
47
48 let lon = ref min_block_lon in
49 while !lon <= max_block_lon do
50 let lat = ref min_block_lat in
51 while !lat <= max_block_lat do
52 (* Blocks are referenced by their centre *)
53 blocks := { lat = !lat +. 2.5; lon = !lon +. 2.5 } :: !blocks;
54 lat := !lat +. block_size
55 done;
56 lon := !lon +. block_size
57 done;
58
59 List.rev !blocks
60
61 let clone env ~into url =
62 Eio.Process.run env#process_mgr [ "git"; "clone"; url; into ]
63
64 let with_registry ?(git_url = "https://github.com/ucam-eo/tessera-manifests")
65 ?(data_url = "https://dl-2.tessera.wiki") ?(version = `v1) env fn =
66 let t =
67 Xdg.create
68 ~env:(fun s -> try Some (Unix.getenv s) with Not_found -> None)
69 ()
70 in
71 let cache = Xdg.cache_dir t in
72 let dir = env#fs / cache / "ocaml-geotessera" in
73 let git = dir / "tessera-manifests" in
74 let data = dir / "downloads" in
75 let client = Client.v env#net in
76 let v = { git_url; data_url; git; client; data; version } in
77 (match Eio.Path.kind ~follow:false git with
78 | `Directory -> ()
79 | `Not_found ->
80 Eio.Path.mkdirs ~perm:0o755 data;
81 clone env ~into:(Eio.Path.native_exn git) v.git_url;
82 ()
83 | _ -> Fmt.failwith "%a is not a directory or non-existent" Eio.Path.pp git);
84 fn v
85
86 let with_display ~sw label fn =
87 let display = ref None in
88 let mux = Mutex.create () in
89 let rec get_display () =
90 match !display with
91 | Some d -> d
92 | None ->
93 Mutex.use_rw ~protect:true mux (fun () ->
94 display := Some (Display.init ~sw label));
95 get_display ()
96 in
97 let finalise_display () = Option.iter Display.finalise !display in
98 Fun.protect ~finally:finalise_display (fun () -> fn get_display)
99
100 let extract_lat_lon_from_grid_name s =
101 match String.split_on_char '_' (Filename.basename s) with
102 | [ "grid"; lon; lat_with_ext ] ->
103 let lat = Filename.chop_extension lat_with_ext in
104 { lon = float_of_string lon; lat = float_of_string lat }
105 | [ "grid"; lon; lat; "scales.npy" ] ->
106 { lon = float_of_string lon; lat = float_of_string lat }
107 | _ -> Fmt.invalid_arg "Failed to extract lat/lon from %s" s
108
109 let parse_manifest path =
110 let module R = Eio.Buf_read in
111 Eio.Path.with_open_in path @@ fun f ->
112 let r = R.of_flow ~max_size:max_int f in
113 let lines = R.lines r in
114 Seq.map (String.split_on_char ' ') lines
115 |> Seq.map (function
116 | [ s; h ] ->
117 (extract_lat_lon_from_grid_name s, s, Digestif.SHA256.of_hex h)
118 | _ -> failwith "Malformed manifest")
119 |> List.of_seq
120
121 let find_manifest t ~year point =
122 let b = block_of_point point in
123 let name =
124 Fmt.str "embeddings_%i_lon%i_lat%i.txt" year (Float.to_int b.lon)
125 (Float.to_int b.lat)
126 in
127 let path = t.git / "registry" / "embeddings" / name in
128 parse_manifest path
129
130 let find_landmasks _t point =
131 let name = Fmt.str "grid_%.2f_%.2f.tiff" point.lon point.lat in
132 (point, name, Digestif.SHA256.empty)
133
134 let download_embedding t v =
135 Eio.Switch.run @@ fun sw ->
136 with_display ~sw "embeddings" @@ fun get_display ->
137 Eio.traceln "Downloading and checking embeddings...";
138 let paths =
139 Fiber.List.map ~max_fibers:10
140 (fun (_point, name, _hash) ->
141 let uri =
142 Fmt.str "%s/%s/global_0.1_degree_representation/%s" t.data_url
143 (version_to_string t.version)
144 name
145 in
146 let data_dir = t.data / "embeddings" / Filename.dirname name in
147 Eio.Path.mkdirs ~exists_ok:true ~perm:0o755 data_dir;
148 let path = data_dir / Filename.basename name in
149 let download =
150 try
151 let stat = Eio.Path.stat ~follow:false path in
152 if Optint.Int63.(equal stat.size zero) then `Delete_and_download
153 else `Check_hash stat.kind
154 with Eio.Exn.Io (Eio.Fs.E (Not_found _), _) -> `Download
155 in
156 match download with
157 | `Check_hash `Regular_file ->
158 let _disk_hash =
159 Digestif.SHA256.digest_string @@ Eio.Path.load path
160 in
161 (* assert (Digestif.SHA256.equal disk_hash hash); *)
162 path
163 | `Delete_and_download | `Download ->
164 let () =
165 match download with
166 | `Delete_and_download -> Eio.Path.unlink path
167 | _ -> ()
168 in
169 let display = get_display () in
170 let () =
171 Eio.Path.with_open_out ~create:(`If_missing 0o644) path
172 @@ fun w ->
173 Client.with_body ~default:() t.client display ~uri
174 @@ fun body -> Flow.copy body w
175 in
176 path
177 | `Check_hash _ -> Fmt.failwith "%s exists but is not a file!" name)
178 v
179 in
180 List.fold_left
181 (fun (points, emb, sca) p ->
182 let point = extract_lat_lon_from_grid_name (Eio.Path.native_exn p) in
183 match String.split_on_char '_' (Eio.Path.native_exn p) |> List.rev with
184 | "scales.npy" :: _ -> (points, emb, p :: sca)
185 | _ -> (point :: points, p :: emb, sca))
186 ([], [], []) paths
187 |> fun (ps, s, t) -> List.combine ps (List.combine s t)
188
189 let transformation_matrix e =
190 let scale = Tiff.Ifd.pixel_scale e in
191 let tiepoint = Tiff.Ifd.tiepoint e in
192 let arr = Nx.zeros Nx.float64 [| 6 |] in
193 (* X *)
194 Nx.set_item [ 0 ] tiepoint.(3) arr;
195 (* sX *) Nx.set_item [ 1 ] scale.(0) arr;
196 Nx.set_item [ 3 ] tiepoint.(4) arr;
197 Nx.set_item [ 5 ] (-.scale.(1)) arr;
198 arr
199
200 let crs_and_transform_of_landmask lm =
201 Eio.Path.with_open_in lm @@ fun r ->
202 let ro = Eio.File.pread_exact r in
203 let tiff = Tiff.from_file Tiff.Float32 ro in
204 let ifd = Tiff.ifd tiff in
205 let geos = Tiff.Ifd.geo_key_directory ifd in
206 (Tiff.Ifd.GeoKeys.projected_crs geos, transformation_matrix ifd)
207
208 let download_landmasks t v =
209 Eio.Switch.run @@ fun sw ->
210 with_display ~sw "landmasks" @@ fun get_display ->
211 Eio.traceln "Downloading and checking landmasks...";
212 let paths =
213 Fiber.List.map ~max_fibers:10
214 (fun (_point, name, _hash) ->
215 let uri =
216 Fmt.str "%s/%s/global_0.1_degree_tiff_all/%s" t.data_url
217 (version_to_string t.version)
218 name
219 in
220 let data_dir = t.data / "landmasks" / Filename.dirname name in
221 Eio.Path.mkdirs ~exists_ok:true ~perm:0o755 data_dir;
222 let path = data_dir / Filename.basename name in
223 let download =
224 try
225 let stat = Eio.Path.stat ~follow:false path in
226 if Optint.Int63.(equal stat.size zero) then `Delete_and_download
227 else `Check_hash stat.kind
228 with Eio.Exn.Io (Eio.Fs.E (Not_found _), _) -> `Download
229 in
230 match download with
231 | `Check_hash `Regular_file ->
232 let _disk_hash =
233 Digestif.SHA256.digest_string @@ Eio.Path.load path
234 in
235 (* Eio.traceln "%a\n%a" Digestif.SHA256.pp disk_hash *)
236 (* Digestif.SHA256.pp hash; *)
237 (* assert (Digestif.SHA256.equal disk_hash hash); *)
238 let crs = crs_and_transform_of_landmask path in
239 (crs, path)
240 | `Delete_and_download | `Download ->
241 let () =
242 match download with
243 | `Delete_and_download -> Eio.Path.unlink path
244 | _ -> ()
245 in
246 let display = get_display () in
247 let () =
248 Eio.Path.with_open_out ~create:(`If_missing 0o644) path
249 @@ fun w ->
250 Client.with_body ~default:() t.client display ~uri
251 @@ fun body -> Flow.copy body w
252 in
253 let crs = crs_and_transform_of_landmask path in
254 (crs, path)
255 | `Check_hash _ -> Fmt.failwith "%s exists but is not a file!" name)
256 v
257 in
258 List.fold_left
259 (fun acc (crs, p) ->
260 let point = extract_lat_lon_from_grid_name (Eio.Path.native_exn p) in
261 (point, crs, p) :: acc)
262 [] paths
263 |> List.rev
264end
265
266type 'a env =
267 < fs : Eio.Fs.dir_ty Eio.Path.t
268 ; process_mgr : [> `Generic ] Eio.Process.mgr_ty Eio.Process.mgr
269 ; net : [> `Generic ] Eio.Net.ty Eio.Net.t
270 ; .. >
271 as
272 'a
273
274type embedding = (int, Bigarray.int8_signed_elt) Nx.t
275type scales = (float, Bigarray.float32_elt) Nx.t
276
277let fetch_embedding registry ~year (bbox : Bbox.t) =
278 let blocks = Registry.blocks_for_region bbox in
279 Fiber.List.map
280 (fun point ->
281 let manifest = Registry.find_manifest registry ~year point in
282
283 (* We need to check the tiles available are within the bbox supplied. *)
284 let manifest =
285 List.filter
286 (fun (p, _, _) ->
287 (* Tiles span 0.1 degree *)
288 let tile_min_lon = p.lon -. 0.05 and tile_max_lon = p.lon +. 0.05 in
289 let tile_min_lat = p.lat -. 0.05 and tile_max_lat = p.lat +. 0.05 in
290 tile_min_lon < Bbox.max_lon bbox
291 && tile_min_lat < Bbox.max_lat bbox
292 && tile_max_lat > Bbox.min_lat bbox
293 && tile_max_lon > Bbox.min_lon bbox)
294 manifest
295 in
296
297 let downloads = Registry.download_embedding registry manifest in
298 List.map
299 (fun (point, (emb, scale)) ->
300 let emb =
301 Npy.read_copy (Eio.Path.native_exn emb)
302 |> Npy.to_bigarray C_layout Bigarray.Int8_signed
303 |> Option.get
304 |> Nx.of_bigarray
305 in
306 let scales =
307 Npy.read_copy (Eio.Path.native_exn scale)
308 |> Npy.to_bigarray C_layout Bigarray.Float32
309 |> Option.get
310 |> Nx.of_bigarray
311 in
312 (point, (emb, scales)))
313 downloads)
314 blocks
315 |> List.concat
316
317let fetch_landmask registry points =
318 let landmask_files = List.map (Registry.find_landmasks registry) points in
319 Registry.download_landmasks registry landmask_files
320
321let scale (emb, sca) =
322 match (Array.length (Nx.dims emb), Array.length (Nx.dims sca)) with
323 | 3, 2 ->
324 let dim0 = Nx.dim 0 sca and dim1 = Nx.dim 1 sca in
325 let bscales =
326 Nx.broadcast_to (Nx.dims emb) (Nx.reshape [| dim0; dim1; 1 |] sca)
327 in
328 Nx.mul (Nx.cast Nx.float32 emb) bscales
329 | 3, 3 -> Nx.mul (Nx.cast Nx.float32 emb) sca
330 | a, b -> Fmt.invalid_arg "Wrong dimensions %i %i" a b
331
332let check_points p1 p2 =
333 if p1 <> p2 then
334 Fmt.invalid_arg "Different points: %a and %a" pp_point p1 pp_point p2
335
336let fetch (env : _ env) ~year (bbox : Bbox.t) =
337 Registry.with_registry env @@ fun registry ->
338 let embs = fetch_embedding registry ~year bbox in
339 Eio.traceln "Depickling and scaling embeddings...";
340 let landmask = fetch_landmask registry (List.map fst embs) in
341 List.map2
342 (fun (p1, emb) (p2, crs, lm) ->
343 check_points p1 p2;
344 (crs, p1, lm, emb))
345 embs landmask