Programmatically generate SVG (vector) images, animations, and interactive Jupyter widgets
1from io import StringIO 2from collections import defaultdict 3import random 4import string 5 6from . import Raster 7from . import elements as elements_module 8from . import jupyter 9 10 11 12SVG_START_FMT = '''<?xml version="1.0" encoding="UTF-8"?> 13<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" 14 width="{}" height="{}" viewBox="{} {} {} {}"''' 15SVG_END = '</svg>' 16SVG_CSS_FMT = '<style><![CDATA[{}]]></style>' 17SVG_JS_FMT = '<script><![CDATA[{}]]></script>' 18 19 20class Drawing: 21 ''' 22 A vector drawing. 23 24 Append shapes and other elements with `.append()`. The default coordinate 25 system origin is at the top-left corner with x-values increasing to the 26 right and y-values increasing downward. 27 28 Supports iPython: If a Drawing is the last line of a cell, it will be 29 displayed as an SVG below. 30 ''' 31 def __init__(self, width, height, origin=(0,0), id_prefix='d', **svg_args): 32 assert float(width) == width 33 assert float(height) == height 34 self.width = width 35 self.height = height 36 if isinstance(origin, str): 37 self.view_box = { 38 'center': (-width/2, -height/2, width, height), 39 'top-left': (0, 0, width, height), 40 'top-right': (-width, 0, width, height), 41 'bottom-left': (0, -height, width, height), 42 'bottom-right': (-width, -height, width, height), 43 }[origin] 44 else: 45 origin = tuple(origin) 46 assert len(origin) == 2 47 self.view_box = origin + (width, height) 48 self.elements = [] 49 self.ordered_elements = defaultdict(list) 50 self.other_defs = [] 51 self.css_list = [] 52 self.js_list = [] 53 self.pixel_scale = 1 54 self.render_width = None 55 self.render_height = None 56 self.id_prefix = str(id_prefix) 57 self.svg_args = {} 58 for k, v in svg_args.items(): 59 k = k.replace('__', ':') 60 k = k.replace('_', '-') 61 if k[-1] == '-': 62 k = k[:-1] 63 self.svg_args[k] = v 64 def set_render_size(self, w=None, h=None): 65 self.render_width = w 66 self.render_height = h 67 return self 68 def set_pixel_scale(self, s=1): 69 self.render_width = None 70 self.render_height = None 71 self.pixel_scale = s 72 return self 73 def calc_render_size(self): 74 if self.render_width is None and self.render_height is None: 75 return (self.width * self.pixel_scale, 76 self.height * self.pixel_scale) 77 elif self.render_width is None: 78 s = self.render_height / self.height 79 return self.width * s, self.render_height 80 elif self.render_height is None: 81 s = self.render_width / self.width 82 return self.render_width, self.height * s 83 else: 84 return self.render_width, self.render_height 85 def draw(self, obj, *, z=None, **kwargs): 86 '''Add any object that knows how to draw itself to the drawing. 87 88 This object must implement the `to_drawables(**kwargs)` method 89 that returns a `DrawingElement` or list of elements. 90 ''' 91 if obj is None: 92 return 93 if not hasattr(obj, 'write_svg_element'): 94 elements = obj.to_drawables(**kwargs) 95 else: 96 assert len(kwargs) == 0 97 elements = obj 98 if hasattr(elements, 'write_svg_element'): 99 self.append(elements, z=z) 100 else: 101 self.extend(elements, z=z) 102 def append(self, element, *, z=None): 103 '''Add any `DrawingElement` to the drawing. 104 105 Do not append a `DrawingDef` referenced by other elements. These are 106 included automatically. Use `.append_def()` for an unreferenced 107 `DrawingDef`. 108 ''' 109 if z is not None: 110 self.ordered_elements[z].append(element) 111 else: 112 self.elements.append(element) 113 def extend(self, iterable, *, z=None): 114 if z is not None: 115 self.ordered_elements[z].extend(iterable) 116 else: 117 self.elements.extend(iterable) 118 def insert(self, i, element): 119 self.elements.insert(i, element) 120 def remove(self, element): 121 self.elements.remove(element) 122 def clear(self): 123 self.elements.clear() 124 def index(self, *args, **kwargs): 125 return self.elements.index(*args, **kwargs) 126 def count(self, element): 127 return self.elements.count(element) 128 def reverse(self): 129 self.elements.reverse() 130 def draw_def(self, obj, **kwargs): 131 if not hasattr(obj, 'write_svg_element'): 132 elements = obj.to_drawables(**kwargs) 133 else: 134 assert len(kwargs) == 0 135 elements = obj 136 if hasattr(elements, 'write_svg_element'): 137 self.append_def(elements) 138 else: 139 self.other_defs.extend(elements) 140 def append_def(self, element): 141 self.other_defs.append(element) 142 def append_title(self, text, **kwargs): 143 self.append(elements.Title(text, **kwargs)) 144 def append_css(self, css_text): 145 self.css_list.append(css_text) 146 def append_javascriipt(self, js_text, onload=None): 147 if onload: 148 if self.svg_args.get('onload'): 149 self.svg_args['onload'] = f'{self.svg_args["onload"]};{onload}' 150 else: 151 self.svg_args['onload'] = onload 152 self.js_list.append(js_text) 153 def all_elements(self): 154 '''Return self.elements and self.ordered_elements as a single list.''' 155 output = list(self.elements) 156 for z in sorted(self.ordered_elements): 157 output.extend(self.ordered_elements[z]) 158 return output 159 def as_svg(self, output_file=None, randomize_ids=False): 160 if output_file is None: 161 with StringIO() as f: 162 self.as_svg(f, randomize_ids=randomize_ids) 163 return f.getvalue() 164 img_width, img_height = self.calc_render_size() 165 start_str = SVG_START_FMT.format(img_width, img_height, *self.view_box) 166 output_file.write(start_str) 167 elements_module.write_xml_node_args(self.svg_args, output_file) 168 output_file.write('>\n') 169 if self.css_list: 170 output_file.write(SVG_CSS_FMT.format('\n'.join(self.css_list))) 171 output_file.write('\n') 172 if self.js_list: 173 output_file.write(SVG_JS_FMT.format('\n'.join(self.js_list))) 174 output_file.write('\n') 175 output_file.write('<defs>\n') 176 # Write definition elements 177 id_prefix = self.id_prefix 178 id_prefix = self._random_id() if randomize_ids else self.id_prefix 179 id_index = 0 180 def id_gen(base=''): 181 nonlocal id_index 182 id_str = f'{id_prefix}{base}{id_index}' 183 id_index += 1 184 return id_str 185 id_map = defaultdict(id_gen) 186 prev_set = set((id(defn) for defn in self.other_defs)) 187 def is_duplicate(obj): 188 nonlocal prev_set 189 dup = id(obj) in prev_set 190 prev_set.add(id(obj)) 191 return dup 192 for element in self.other_defs: 193 if hasattr(element, 'write_svg_element'): 194 element.write_svg_element( 195 id_map, is_duplicate, output_file, False) 196 output_file.write('\n') 197 all_elements = self.all_elements() 198 for element in all_elements: 199 if hasattr(element, 'write_svg_defs'): 200 element.write_svg_defs( 201 id_map, is_duplicate, output_file, False) 202 output_file.write('</defs>\n') 203 # Generate ids for normal elements 204 prev_def_set = set(prev_set) 205 for element in all_elements: 206 if hasattr(element, 'write_svg_element'): 207 element.write_svg_element( 208 id_map, is_duplicate, output_file, True) 209 prev_set = prev_def_set 210 # Write normal elements 211 for element in all_elements: 212 if hasattr(element, 'write_svg_element'): 213 element.write_svg_element( 214 id_map, is_duplicate, output_file, False) 215 output_file.write('\n') 216 output_file.write(SVG_END) 217 @staticmethod 218 def _random_id(length=8): 219 return (random.choice(string.ascii_letters) 220 + ''.join(random.choices( 221 string.ascii_letters+string.digits, k=length-1))) 222 def save_svg(self, fname, encoding='utf-8'): 223 with open(fname, 'w', encoding=encoding) as f: 224 self.as_svg(output_file=f) 225 def save_png(self, fname): 226 self.rasterize(to_file=fname) 227 def rasterize(self, to_file=None): 228 if to_file: 229 return Raster.from_svg_to_file(self.as_svg(), to_file) 230 else: 231 return Raster.from_svg(self.as_svg()) 232 def _repr_svg_(self): 233 '''Display in Jupyter notebook.''' 234 return self.as_svg(randomize_ids=True) 235 def display_inline(self): 236 '''Display inline in the Jupyter web page.''' 237 return jupyter.JupyterSvgInline(self.as_svg(randomize_ids=True)) 238 def display_iframe(self): 239 '''Display within an iframe the Jupyter web page.''' 240 w, h = self.calc_render_size() 241 return jupyter.JupyterSvgFrame(self.as_svg(), w, h) 242 def display_image(self): 243 '''Display within an img in the Jupyter web page.''' 244 return jupyter.JupyterSvgImage(self.as_svg())