1
2from io import StringIO
3import base64
4import urllib.parse
5import re
6from collections import defaultdict
7
8from . import Raster
9from . import elements as elementsModule
10
11
12STRIP_CHARS = ('\x00\x01\x02\x03\x04\x05\x06\x07\x08\x0b\x0c\x0e\x0f\x10\x11'
13 '\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f')
14
15
16class Drawing:
17 ''' A canvas to draw on
18
19 Supports iPython: If a Drawing is the last line of a cell, it will be
20 displayed as an SVG below. '''
21 def __init__(self, width, height, origin=(0,0), idPrefix='d',
22 displayInline=True, **svgArgs):
23 assert float(width) == width
24 assert float(height) == height
25 self.width = width
26 self.height = height
27 if origin == 'center':
28 self.viewBox = (-width/2, -height/2, width, height)
29 else:
30 origin = tuple(origin)
31 assert len(origin) == 2
32 self.viewBox = origin + (width, height)
33 self.viewBox = (self.viewBox[0], -self.viewBox[1]-self.viewBox[3],
34 self.viewBox[2], self.viewBox[3])
35 self.elements = []
36 self.orderedElements = defaultdict(list)
37 self.otherDefs = []
38 self.pixelScale = 1
39 self.renderWidth = None
40 self.renderHeight = None
41 self.idPrefix = str(idPrefix)
42 self.displayInline = displayInline
43 self.svgArgs = {}
44 for k, v in svgArgs.items():
45 k = k.replace('__', ':')
46 k = k.replace('_', '-')
47 if k[-1] == '-':
48 k = k[:-1]
49 self.svgArgs[k] = v
50 self.idIndex = 0
51 def setRenderSize(self, w=None, h=None):
52 self.renderWidth = w
53 self.renderHeight = h
54 return self
55 def setPixelScale(self, s=1):
56 self.renderWidth = None
57 self.renderHeight = None
58 self.pixelScale = s
59 return self
60 def calcRenderSize(self):
61 if self.renderWidth is None and self.renderHeight is None:
62 return (self.width * self.pixelScale,
63 self.height * self.pixelScale)
64 elif self.renderWidth is None:
65 s = self.renderHeight / self.height
66 return self.width * s, self.renderHeight
67 elif self.renderHeight is None:
68 s = self.renderWidth / self.width
69 return self.renderWidth, self.height * s
70 else:
71 return self.renderWidth, self.renderHeight
72 def draw(self, obj, *, z=None, **kwargs):
73 if obj is None:
74 return
75 if not hasattr(obj, 'writeSvgElement'):
76 elements = obj.toDrawables(elements=elementsModule, **kwargs)
77 else:
78 assert len(kwargs) == 0
79 elements = (obj,)
80 self.extend(elements, z=z)
81 def append(self, element, *, z=None):
82 if z is not None:
83 self.orderedElements[z].append(element)
84 else:
85 self.elements.append(element)
86 def extend(self, iterable, *, z=None):
87 if z is not None:
88 self.orderedElements[z].extend(iterable)
89 else:
90 self.elements.extend(iterable)
91 def insert(self, i, element):
92 self.elements.insert(i, element)
93 def remove(self, element):
94 self.elements.remove(element)
95 def clear(self):
96 self.elements.clear()
97 def index(self, *args, **kwargs):
98 return self.elements.index(*args, **kwargs)
99 def count(self, element):
100 return self.elements.count(element)
101 def reverse(self):
102 self.elements.reverse()
103 def drawDef(self, obj, **kwargs):
104 if not hasattr(obj, 'writeSvgElement'):
105 elements = obj.toDrawables(elements=elementsModule, **kwargs)
106 else:
107 assert len(kwargs) == 0
108 elements = (obj,)
109 self.otherDefs.extend(elements)
110 def appendDef(self, element):
111 self.otherDefs.append(element)
112 def allElements(self):
113 ''' Returns self.elements and self.orderedElements as a single list. '''
114 output = list(self.elements)
115 for z in sorted(self.orderedElements):
116 output.extend(self.orderedElements[z])
117 return output
118 def asSvg(self, outputFile=None):
119 returnString = outputFile is None
120 if returnString:
121 outputFile = StringIO()
122 imgWidth, imgHeight = self.calcRenderSize()
123 startStr = '''<?xml version="1.0" encoding="UTF-8"?>
124<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
125 width="{}" height="{}" viewBox="{} {} {} {}"'''.format(
126 imgWidth, imgHeight, *self.viewBox)
127 endStr = '</svg>'
128 outputFile.write(startStr)
129 elementsModule.writeXmlNodeArgs(self.svgArgs, outputFile)
130 outputFile.write('>\n<defs>\n')
131 # Write definition elements
132 def idGen(base=''):
133 idStr = self.idPrefix + base + str(self.idIndex)
134 self.idIndex += 1
135 return idStr
136 prevSet = set((id(defn) for defn in self.otherDefs))
137 def isDuplicate(obj):
138 nonlocal prevSet
139 dup = id(obj) in prevSet
140 prevSet.add(id(obj))
141 return dup
142 for element in self.otherDefs:
143 try:
144 element.writeSvgElement(idGen, isDuplicate, outputFile, False)
145 outputFile.write('\n')
146 except AttributeError:
147 pass
148 allElements = self.allElements()
149 for element in allElements:
150 try:
151 element.writeSvgDefs(idGen, isDuplicate, outputFile, False)
152 except AttributeError:
153 pass
154 outputFile.write('</defs>\n')
155 # Generate ids for normal elements
156 prevDefSet = set(prevSet)
157 for element in allElements:
158 try:
159 element.writeSvgElement(idGen, isDuplicate, outputFile, True)
160 except AttributeError:
161 pass
162 prevSet = prevDefSet
163 # Write normal elements
164 for element in allElements:
165 try:
166 element.writeSvgElement(idGen, isDuplicate, outputFile, False)
167 outputFile.write('\n')
168 except AttributeError:
169 pass
170 outputFile.write(endStr)
171 if returnString:
172 return outputFile.getvalue()
173 def saveSvg(self, fname, encoding='utf-8'):
174 with open(fname, 'w', encoding=encoding) as f:
175 self.asSvg(outputFile=f)
176 def savePng(self, fname):
177 self.rasterize(toFile=fname)
178 def rasterize(self, toFile=None):
179 if toFile:
180 return Raster.fromSvgToFile(self.asSvg(), toFile)
181 else:
182 return Raster.fromSvg(self.asSvg())
183 def _repr_svg_(self):
184 ''' Display in Jupyter notebook '''
185 if not self.displayInline:
186 return None
187 return self.asSvg()
188 def _repr_html_(self):
189 ''' Display in Jupyter notebook '''
190 if self.displayInline:
191 return None
192 prefix = b'data:image/svg+xml;base64,'
193 data = base64.b64encode(self.asSvg().encode(encoding='utf-8'))
194 src = (prefix+data).decode(encoding='ascii')
195 return '<img src="{}">'.format(src)
196 def asDataUri(self, strip_chars=STRIP_CHARS):
197 ''' Returns a data URI with base64 encoding. '''
198 data = self.asSvg()
199 search = re.compile('|'.join(strip_chars))
200 data_safe = search.sub(lambda m: '', data)
201 b64 = base64.b64encode(data_safe.encode())
202 return 'data:image/svg+xml;base64,' + b64.decode(encoding='ascii')
203 def asUtf8DataUri(self, unsafe_chars='"', strip_chars=STRIP_CHARS):
204 ''' Returns a data URI without base64 encoding.
205
206 The characters '#&%' are always escaped. '#' and '&' break parsing
207 of the data URI. If '%' is not escaped, plain text like '%50' will
208 be incorrectly decoded to 'P'. The characters in `strip_chars`
209 cause the SVG not to render even if they are escaped. '''
210 data = self.asSvg()
211 unsafe_chars = (unsafe_chars or '') + '#&%'
212 replacements = {
213 char: urllib.parse.quote(char, safe='')
214 for char in unsafe_chars
215 }
216 replacements.update({
217 char: ''
218 for char in strip_chars
219 })
220 search = re.compile('|'.join(map(re.escape, replacements.keys())))
221 data_safe = search.sub(lambda m: replacements[m.group(0)], data)
222 return 'data:image/svg+xml;utf8,' + data_safe