1
2from io import StringIO
3import base64
4import urllib.parse
5import re
6
7from . import Raster
8from . import elements as elementsModule
9
10
11STRIP_CHARS = ('\x00\x01\x02\x03\x04\x05\x06\x07\x08\x0b\x0c\x0e\x0f\x10\x11'
12 '\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f')
13
14
15class Drawing:
16 ''' A canvas to draw on
17
18 Supports iPython: If a Drawing is the last line of a cell, it will be
19 displayed as an SVG below. '''
20 def __init__(self, width, height, origin=(0,0), idPrefix='d',
21 displayInline=True, **svgArgs):
22 assert float(width) == width
23 assert float(height) == height
24 self.width = width
25 self.height = height
26 if origin == 'center':
27 self.viewBox = (-width/2, -height/2, width, height)
28 else:
29 origin = tuple(origin)
30 assert len(origin) == 2
31 self.viewBox = origin + (width, height)
32 self.viewBox = (self.viewBox[0], -self.viewBox[1]-self.viewBox[3],
33 self.viewBox[2], self.viewBox[3])
34 self.elements = []
35 self.otherDefs = []
36 self.pixelScale = 1
37 self.renderWidth = None
38 self.renderHeight = None
39 self.idPrefix = str(idPrefix)
40 self.displayInline = displayInline
41 self.svgArgs = {}
42 for k, v in svgArgs.items():
43 k = k.replace('__', ':')
44 k = k.replace('_', '-')
45 if k[-1] == '-':
46 k = k[:-1]
47 self.svgArgs[k] = v
48 def setRenderSize(self, w=None, h=None):
49 self.renderWidth = w
50 self.renderHeight = h
51 return self
52 def setPixelScale(self, s=1):
53 self.renderWidth = None
54 self.renderHeight = None
55 self.pixelScale = s
56 return self
57 def calcRenderSize(self):
58 if self.renderWidth is None and self.renderHeight is None:
59 return (self.width * self.pixelScale,
60 self.height * self.pixelScale)
61 elif self.renderWidth is None:
62 s = self.renderHeight / self.height
63 return self.width * s, self.renderHeight
64 elif self.renderHeight is None:
65 s = self.renderWidth / self.width
66 return self.renderWidth, self.height * s
67 else:
68 return self.renderWidth, self.renderHeight
69 def draw(self, obj, **kwargs):
70 if not hasattr(obj, 'writeSvgElement'):
71 elements = obj.toDrawables(elements=elementsModule, **kwargs)
72 else:
73 assert len(kwargs) == 0
74 elements = (obj,)
75 self.extend(elements)
76 def append(self, element):
77 self.elements.append(element)
78 def extend(self, iterable):
79 self.elements.extend(iterable)
80 def insert(self, i, element):
81 self.elements.insert(i, element)
82 def remove(self, element):
83 self.elements.remove(element)
84 def clear(self):
85 self.elements.clear()
86 def index(self, *args, **kwargs):
87 self.elements.index(*args, **kwargs)
88 def count(self, element):
89 self.elements.count(element)
90 def reverse(self):
91 self.elements.reverse()
92 def drawDef(self, obj, **kwargs):
93 if not hasattr(obj, 'writeSvgElement'):
94 elements = obj.toDrawables(elements=elementsModule, **kwargs)
95 else:
96 assert len(kwargs) == 0
97 elements = (obj,)
98 self.otherDefs.extend(elements)
99 def appendDef(self, element):
100 self.otherDefs.append(element)
101 def asSvg(self, outputFile=None):
102 returnString = outputFile is None
103 if returnString:
104 outputFile = StringIO()
105 imgWidth, imgHeight = self.calcRenderSize()
106 startStr = '''<?xml version="1.0" encoding="UTF-8"?>
107<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
108 width="{}" height="{}" viewBox="{} {} {} {}"'''.format(
109 imgWidth, imgHeight, *self.viewBox)
110 endStr = '</svg>'
111 outputFile.write(startStr)
112 elementsModule.writeXmlNodeArgs(self.svgArgs, outputFile)
113 outputFile.write('>\n<defs>\n')
114 # Write definition elements
115 idIndex = 0
116 def idGen(base=''):
117 nonlocal idIndex
118 idStr = self.idPrefix + base + str(idIndex)
119 idIndex += 1
120 return idStr
121 prevSet = set((id(defn) for defn in self.otherDefs))
122 def isDuplicate(obj):
123 nonlocal prevSet
124 dup = id(obj) in prevSet
125 prevSet.add(id(obj))
126 return dup
127 for element in self.otherDefs:
128 try:
129 element.writeSvgElement(idGen, isDuplicate, outputFile, False)
130 outputFile.write('\n')
131 except AttributeError:
132 pass
133 for element in self.elements:
134 try:
135 element.writeSvgDefs(idGen, isDuplicate, outputFile, False)
136 except AttributeError:
137 pass
138 outputFile.write('</defs>\n')
139 # Generate ids for normal elements
140 prevDefSet = set(prevSet)
141 for element in self.elements:
142 try:
143 element.writeSvgElement(idGen, isDuplicate, outputFile, True)
144 except AttributeError:
145 pass
146 prevSet = prevDefSet
147 # Write normal elements
148 for element in self.elements:
149 try:
150 element.writeSvgElement(idGen, isDuplicate, outputFile, False)
151 outputFile.write('\n')
152 except AttributeError:
153 pass
154 outputFile.write(endStr)
155 if returnString:
156 return outputFile.getvalue()
157 def saveSvg(self, fname):
158 with open(fname, 'w') as f:
159 self.asSvg(outputFile=f)
160 def savePng(self, fname):
161 self.rasterize(toFile=fname)
162 def rasterize(self, toFile=None):
163 if toFile:
164 return Raster.fromSvgToFile(self.asSvg(), toFile)
165 else:
166 return Raster.fromSvg(self.asSvg())
167 def _repr_svg_(self):
168 ''' Display in Jupyter notebook '''
169 if not self.displayInline:
170 return None
171 return self.asSvg()
172 def _repr_html_(self):
173 ''' Display in Jupyter notebook '''
174 if self.displayInline:
175 return None
176 prefix = b'data:image/svg+xml;base64,'
177 data = base64.b64encode(self.asSvg().encode())
178 src = (prefix+data).decode()
179 return '<img src="{}">'.format(src)
180 def asDataUri(self, strip_chars=STRIP_CHARS):
181 ''' Returns a data URI with base64 encoding. '''
182 data = self.asSvg()
183 search = re.compile('|'.join(strip_chars))
184 data_safe = search.sub(lambda m: '', data)
185 b64 = base64.b64encode(data_safe.encode())
186 return 'data:image/svg+xml;base64,' + b64.decode(encoding='ascii')
187 def asUtf8DataUri(self, unsafe_chars='"', strip_chars=STRIP_CHARS):
188 ''' Returns a data URI without base64 encoding.
189
190 The characters '#&%' are always escaped. '#' and '&' break parsing
191 of the data URI. If '%' is not escaped, plain text like '%50' will
192 be incorrectly decoded to 'P'. The characters in `strip_chars`
193 cause the SVG not to render even if they are escaped. '''
194 data = self.asSvg()
195 unsafe_chars = (unsafe_chars or '') + '#&%'
196 replacements = {
197 char: urllib.parse.quote(char, safe='')
198 for char in unsafe_chars
199 }
200 replacements.update({
201 char: ''
202 for char in strip_chars
203 })
204 search = re.compile('|'.join(map(re.escape, replacements.keys())))
205 data_safe = search.sub(lambda m: replacements[m.group(0)], data)
206 return 'data:image/svg+xml;utf8,' + data_safe