1
2from io import StringIO
3import base64
4import urllib.parse
5import re
6
7from . import Raster
8from . import elements as elementsModule
9
10
11STRIP_CHARS = ('\x00\x01\x02\x03\x04\x05\x06\x07\x08\x0b\x0c\x0e\x0f\x10\x11'
12 '\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f')
13
14
15class Drawing:
16 ''' A canvas to draw on
17
18 Supports iPython: If a Drawing is the last line of a cell, it will be
19 displayed as an SVG below. '''
20 def __init__(self, width, height, origin=(0,0), idPrefix='d',
21 displayInline=True, **svgArgs):
22 assert float(width) == width
23 assert float(height) == height
24 self.width = width
25 self.height = height
26 if origin == 'center':
27 self.viewBox = (-width/2, -height/2, width, height)
28 else:
29 origin = tuple(origin)
30 assert len(origin) == 2
31 self.viewBox = origin + (width, height)
32 self.viewBox = (self.viewBox[0], -self.viewBox[1]-self.viewBox[3],
33 self.viewBox[2], self.viewBox[3])
34 self.elements = []
35 self.otherDefs = []
36 self.pixelScale = 1
37 self.renderWidth = None
38 self.renderHeight = None
39 self.idPrefix = str(idPrefix)
40 self.displayInline = displayInline
41 self.svgArgs = {}
42 for k, v in svgArgs.items():
43 k = k.replace('__', ':')
44 k = k.replace('_', '-')
45 if k[-1] == '-':
46 k = k[:-1]
47 self.svgArgs[k] = v
48 self.idIndex = 0
49 def setRenderSize(self, w=None, h=None):
50 self.renderWidth = w
51 self.renderHeight = h
52 return self
53 def setPixelScale(self, s=1):
54 self.renderWidth = None
55 self.renderHeight = None
56 self.pixelScale = s
57 return self
58 def calcRenderSize(self):
59 if self.renderWidth is None and self.renderHeight is None:
60 return (self.width * self.pixelScale,
61 self.height * self.pixelScale)
62 elif self.renderWidth is None:
63 s = self.renderHeight / self.height
64 return self.width * s, self.renderHeight
65 elif self.renderHeight is None:
66 s = self.renderWidth / self.width
67 return self.renderWidth, self.height * s
68 else:
69 return self.renderWidth, self.renderHeight
70 def draw(self, obj, **kwargs):
71 if obj is None:
72 return
73 if not hasattr(obj, 'writeSvgElement'):
74 elements = obj.toDrawables(elements=elementsModule, **kwargs)
75 else:
76 assert len(kwargs) == 0
77 elements = (obj,)
78 self.extend(elements)
79 def append(self, element):
80 self.elements.append(element)
81 def extend(self, iterable):
82 self.elements.extend(iterable)
83 def insert(self, i, element):
84 self.elements.insert(i, element)
85 def remove(self, element):
86 self.elements.remove(element)
87 def clear(self):
88 self.elements.clear()
89 def index(self, *args, **kwargs):
90 self.elements.index(*args, **kwargs)
91 def count(self, element):
92 self.elements.count(element)
93 def reverse(self):
94 self.elements.reverse()
95 def drawDef(self, obj, **kwargs):
96 if not hasattr(obj, 'writeSvgElement'):
97 elements = obj.toDrawables(elements=elementsModule, **kwargs)
98 else:
99 assert len(kwargs) == 0
100 elements = (obj,)
101 self.otherDefs.extend(elements)
102 def appendDef(self, element):
103 self.otherDefs.append(element)
104 def asSvg(self, outputFile=None):
105 returnString = outputFile is None
106 if returnString:
107 outputFile = StringIO()
108 imgWidth, imgHeight = self.calcRenderSize()
109 startStr = '''<?xml version="1.0" encoding="UTF-8"?>
110<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
111 width="{}" height="{}" viewBox="{} {} {} {}"'''.format(
112 imgWidth, imgHeight, *self.viewBox)
113 endStr = '</svg>'
114 outputFile.write(startStr)
115 elementsModule.writeXmlNodeArgs(self.svgArgs, outputFile)
116 outputFile.write('>\n<defs>\n')
117 # Write definition elements
118 def idGen(base=''):
119 idStr = self.idPrefix + base + str(self.idIndex)
120 self.idIndex += 1
121 return idStr
122 prevSet = set((id(defn) for defn in self.otherDefs))
123 def isDuplicate(obj):
124 nonlocal prevSet
125 dup = id(obj) in prevSet
126 prevSet.add(id(obj))
127 return dup
128 for element in self.otherDefs:
129 try:
130 element.writeSvgElement(idGen, isDuplicate, outputFile, False)
131 outputFile.write('\n')
132 except AttributeError:
133 pass
134 for element in self.elements:
135 try:
136 element.writeSvgDefs(idGen, isDuplicate, outputFile, False)
137 except AttributeError:
138 pass
139 outputFile.write('</defs>\n')
140 # Generate ids for normal elements
141 prevDefSet = set(prevSet)
142 for element in self.elements:
143 try:
144 element.writeSvgElement(idGen, isDuplicate, outputFile, True)
145 except AttributeError:
146 pass
147 prevSet = prevDefSet
148 # Write normal elements
149 for element in self.elements:
150 try:
151 element.writeSvgElement(idGen, isDuplicate, outputFile, False)
152 outputFile.write('\n')
153 except AttributeError:
154 pass
155 outputFile.write(endStr)
156 if returnString:
157 return outputFile.getvalue()
158 def saveSvg(self, fname):
159 with open(fname, 'w') as f:
160 self.asSvg(outputFile=f)
161 def savePng(self, fname):
162 self.rasterize(toFile=fname)
163 def rasterize(self, toFile=None):
164 if toFile:
165 return Raster.fromSvgToFile(self.asSvg(), toFile)
166 else:
167 return Raster.fromSvg(self.asSvg())
168 def _repr_svg_(self):
169 ''' Display in Jupyter notebook '''
170 if not self.displayInline:
171 return None
172 return self.asSvg()
173 def _repr_html_(self):
174 ''' Display in Jupyter notebook '''
175 if self.displayInline:
176 return None
177 prefix = b'data:image/svg+xml;base64,'
178 data = base64.b64encode(self.asSvg().encode())
179 src = (prefix+data).decode()
180 return '<img src="{}">'.format(src)
181 def asDataUri(self, strip_chars=STRIP_CHARS):
182 ''' Returns a data URI with base64 encoding. '''
183 data = self.asSvg()
184 search = re.compile('|'.join(strip_chars))
185 data_safe = search.sub(lambda m: '', data)
186 b64 = base64.b64encode(data_safe.encode())
187 return 'data:image/svg+xml;base64,' + b64.decode(encoding='ascii')
188 def asUtf8DataUri(self, unsafe_chars='"', strip_chars=STRIP_CHARS):
189 ''' Returns a data URI without base64 encoding.
190
191 The characters '#&%' are always escaped. '#' and '&' break parsing
192 of the data URI. If '%' is not escaped, plain text like '%50' will
193 be incorrectly decoded to 'P'. The characters in `strip_chars`
194 cause the SVG not to render even if they are escaped. '''
195 data = self.asSvg()
196 unsafe_chars = (unsafe_chars or '') + '#&%'
197 replacements = {
198 char: urllib.parse.quote(char, safe='')
199 for char in unsafe_chars
200 }
201 replacements.update({
202 char: ''
203 for char in strip_chars
204 })
205 search = re.compile('|'.join(map(re.escape, replacements.keys())))
206 data_safe = search.sub(lambda m: replacements[m.group(0)], data)
207 return 'data:image/svg+xml;utf8,' + data_safe