1from io import StringIO
2from collections import defaultdict
3import random
4import string
5
6from . import Raster
7from . import elements as elements_module
8from . import jupyter
9
10
11
12SVG_START_FMT = '''<?xml version="1.0" encoding="UTF-8"?>
13<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
14 width="{}" height="{}" viewBox="{} {} {} {}"'''
15SVG_END = '</svg>'
16SVG_CSS_FMT = '<style><![CDATA[{}]]></style>'
17SVG_JS_FMT = '<script><![CDATA[{}]]></script>'
18
19
20class Drawing:
21 '''
22 A vector drawing.
23
24 Append shapes and other elements with `.append()`. The default coordinate
25 system origin is at the top-left corner with x-values increasing to the
26 right and y-values increasing downward.
27
28 Supports iPython: If a Drawing is the last line of a cell, it will be
29 displayed as an SVG below.
30 '''
31 def __init__(self, width, height, origin=(0,0), id_prefix='d', **svg_args):
32 assert float(width) == width
33 assert float(height) == height
34 self.width = width
35 self.height = height
36 if isinstance(origin, str):
37 self.view_box = {
38 'center': (-width/2, -height/2, width, height),
39 'top-left': (0, 0, width, height),
40 'top-right': (-width, 0, width, height),
41 'bottom-left': (0, -height, width, height),
42 'bottom-right': (-width, -height, width, height),
43 }[origin]
44 else:
45 origin = tuple(origin)
46 assert len(origin) == 2
47 self.view_box = origin + (width, height)
48 self.elements = []
49 self.ordered_elements = defaultdict(list)
50 self.other_defs = []
51 self.css_list = []
52 self.js_list = []
53 self.pixel_scale = 1
54 self.render_width = None
55 self.render_height = None
56 self.id_prefix = str(id_prefix)
57 self.svg_args = {}
58 for k, v in svg_args.items():
59 k = k.replace('__', ':')
60 k = k.replace('_', '-')
61 if k[-1] == '-':
62 k = k[:-1]
63 self.svg_args[k] = v
64 def set_render_size(self, w=None, h=None):
65 self.render_width = w
66 self.render_height = h
67 return self
68 def set_pixel_scale(self, s=1):
69 self.render_width = None
70 self.render_height = None
71 self.pixel_scale = s
72 return self
73 def calc_render_size(self):
74 if self.render_width is None and self.render_height is None:
75 return (self.width * self.pixel_scale,
76 self.height * self.pixel_scale)
77 elif self.render_width is None:
78 s = self.render_height / self.height
79 return self.width * s, self.render_height
80 elif self.render_height is None:
81 s = self.render_width / self.width
82 return self.render_width, self.height * s
83 else:
84 return self.render_width, self.render_height
85 def draw(self, obj, *, z=None, **kwargs):
86 '''Add any object that knows how to draw itself to the drawing.
87
88 This object must implement the `to_drawables(**kwargs)` method
89 that returns a `DrawingElement` or list of elements.
90 '''
91 if obj is None:
92 return
93 if not hasattr(obj, 'write_svg_element'):
94 elements = obj.to_drawables(**kwargs)
95 else:
96 assert len(kwargs) == 0
97 elements = obj
98 if hasattr(elements, 'write_svg_element'):
99 self.append(elements, z=z)
100 else:
101 self.extend(elements, z=z)
102 def append(self, element, *, z=None):
103 '''Add any `DrawingElement` to the drawing.
104
105 Do not append a `DrawingDef` referenced by other elements. These are
106 included automatically. Use `.append_def()` for an unreferenced
107 `DrawingDef`.
108 '''
109 if z is not None:
110 self.ordered_elements[z].append(element)
111 else:
112 self.elements.append(element)
113 def extend(self, iterable, *, z=None):
114 if z is not None:
115 self.ordered_elements[z].extend(iterable)
116 else:
117 self.elements.extend(iterable)
118 def insert(self, i, element):
119 self.elements.insert(i, element)
120 def remove(self, element):
121 self.elements.remove(element)
122 def clear(self):
123 self.elements.clear()
124 def index(self, *args, **kwargs):
125 return self.elements.index(*args, **kwargs)
126 def count(self, element):
127 return self.elements.count(element)
128 def reverse(self):
129 self.elements.reverse()
130 def draw_def(self, obj, **kwargs):
131 if not hasattr(obj, 'write_svg_element'):
132 elements = obj.to_drawables(**kwargs)
133 else:
134 assert len(kwargs) == 0
135 elements = obj
136 if hasattr(elements, 'write_svg_element'):
137 self.append_def(elements)
138 else:
139 self.other_defs.extend(elements)
140 def append_def(self, element):
141 self.other_defs.append(element)
142 def append_title(self, text, **kwargs):
143 self.append(elements.Title(text, **kwargs))
144 def append_css(self, css_text):
145 self.css_list.append(css_text)
146 def append_javascriipt(self, js_text, onload=None):
147 if onload:
148 if self.svg_args.get('onload'):
149 self.svg_args['onload'] = f'{self.svg_args["onload"]};{onload}'
150 else:
151 self.svg_args['onload'] = onload
152 self.js_list.append(js_text)
153 def all_elements(self):
154 '''Return self.elements and self.ordered_elements as a single list.'''
155 output = list(self.elements)
156 for z in sorted(self.ordered_elements):
157 output.extend(self.ordered_elements[z])
158 return output
159 def as_svg(self, output_file=None, randomize_ids=False):
160 if output_file is None:
161 with StringIO() as f:
162 self.as_svg(f, randomize_ids=randomize_ids)
163 return f.getvalue()
164 img_width, img_height = self.calc_render_size()
165 start_str = SVG_START_FMT.format(img_width, img_height, *self.view_box)
166 output_file.write(start_str)
167 elements_module.write_xml_node_args(self.svg_args, output_file)
168 output_file.write('>\n')
169 if self.css_list:
170 output_file.write(SVG_CSS_FMT.format('\n'.join(self.css_list)))
171 output_file.write('\n')
172 if self.js_list:
173 output_file.write(SVG_JS_FMT.format('\n'.join(self.js_list)))
174 output_file.write('\n')
175 output_file.write('<defs>\n')
176 # Write definition elements
177 id_prefix = self.id_prefix
178 id_prefix = self._random_id() if randomize_ids else self.id_prefix
179 id_index = 0
180 def id_gen(base=''):
181 nonlocal id_index
182 id_str = f'{id_prefix}{base}{id_index}'
183 id_index += 1
184 return id_str
185 id_map = defaultdict(id_gen)
186 prev_set = set((id(defn) for defn in self.other_defs))
187 def is_duplicate(obj):
188 nonlocal prev_set
189 dup = id(obj) in prev_set
190 prev_set.add(id(obj))
191 return dup
192 for element in self.other_defs:
193 if hasattr(element, 'write_svg_element'):
194 element.write_svg_element(
195 id_map, is_duplicate, output_file, False)
196 output_file.write('\n')
197 all_elements = self.all_elements()
198 for element in all_elements:
199 if hasattr(element, 'write_svg_defs'):
200 element.write_svg_defs(
201 id_map, is_duplicate, output_file, False)
202 output_file.write('</defs>\n')
203 # Generate ids for normal elements
204 prev_def_set = set(prev_set)
205 for element in all_elements:
206 if hasattr(element, 'write_svg_element'):
207 element.write_svg_element(
208 id_map, is_duplicate, output_file, True)
209 prev_set = prev_def_set
210 # Write normal elements
211 for element in all_elements:
212 if hasattr(element, 'write_svg_element'):
213 element.write_svg_element(
214 id_map, is_duplicate, output_file, False)
215 output_file.write('\n')
216 output_file.write(SVG_END)
217 @staticmethod
218 def _random_id(length=8):
219 return (random.choice(string.ascii_letters)
220 + ''.join(random.choices(
221 string.ascii_letters+string.digits, k=length-1)))
222 def save_svg(self, fname, encoding='utf-8'):
223 with open(fname, 'w', encoding=encoding) as f:
224 self.as_svg(output_file=f)
225 def save_png(self, fname):
226 self.rasterize(to_file=fname)
227 def rasterize(self, to_file=None):
228 if to_file:
229 return Raster.from_svg_to_file(self.as_svg(), to_file)
230 else:
231 return Raster.from_svg(self.as_svg())
232 def _repr_svg_(self):
233 '''Display in Jupyter notebook.'''
234 return self.as_svg(randomize_ids=True)
235 def display_inline(self):
236 '''Display inline in the Jupyter web page.'''
237 return jupyter.JupyterSvgInline(self.as_svg(randomize_ids=True))
238 def display_iframe(self):
239 '''Display within an iframe the Jupyter web page.'''
240 w, h = self.calc_render_size()
241 return jupyter.JupyterSvgFrame(self.as_svg(), w, h)
242 def display_image(self):
243 '''Display within an img in the Jupyter web page.'''
244 return jupyter.JupyterSvgImage(self.as_svg())