Programmatically generate SVG (vector) images, animations, and interactive Jupyter widgets
1 2from io import StringIO 3import base64 4import urllib.parse 5import re 6from collections import defaultdict 7 8from . import Raster 9from . import elements as elementsModule 10 11 12STRIP_CHARS = ('\x00\x01\x02\x03\x04\x05\x06\x07\x08\x0b\x0c\x0e\x0f\x10\x11' 13 '\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f') 14 15 16class Drawing: 17 ''' A canvas to draw on 18 19 Supports iPython: If a Drawing is the last line of a cell, it will be 20 displayed as an SVG below. ''' 21 def __init__(self, width, height, origin=(0,0), idPrefix='d', 22 displayInline=True, **svgArgs): 23 assert float(width) == width 24 assert float(height) == height 25 self.width = width 26 self.height = height 27 if origin == 'center': 28 self.viewBox = (-width/2, -height/2, width, height) 29 else: 30 origin = tuple(origin) 31 assert len(origin) == 2 32 self.viewBox = origin + (width, height) 33 self.viewBox = (self.viewBox[0], -self.viewBox[1]-self.viewBox[3], 34 self.viewBox[2], self.viewBox[3]) 35 self.elements = [] 36 self.orderedElements = defaultdict(list) 37 self.otherDefs = [] 38 self.pixelScale = 1 39 self.renderWidth = None 40 self.renderHeight = None 41 self.idPrefix = str(idPrefix) 42 self.displayInline = displayInline 43 self.svgArgs = {} 44 for k, v in svgArgs.items(): 45 k = k.replace('__', ':') 46 k = k.replace('_', '-') 47 if k[-1] == '-': 48 k = k[:-1] 49 self.svgArgs[k] = v 50 self.idIndex = 0 51 def setRenderSize(self, w=None, h=None): 52 self.renderWidth = w 53 self.renderHeight = h 54 return self 55 def setPixelScale(self, s=1): 56 self.renderWidth = None 57 self.renderHeight = None 58 self.pixelScale = s 59 return self 60 def calcRenderSize(self): 61 if self.renderWidth is None and self.renderHeight is None: 62 return (self.width * self.pixelScale, 63 self.height * self.pixelScale) 64 elif self.renderWidth is None: 65 s = self.renderHeight / self.height 66 return self.width * s, self.renderHeight 67 elif self.renderHeight is None: 68 s = self.renderWidth / self.width 69 return self.renderWidth, self.height * s 70 else: 71 return self.renderWidth, self.renderHeight 72 def draw(self, obj, *, z=None, **kwargs): 73 if obj is None: 74 return 75 if not hasattr(obj, 'writeSvgElement'): 76 elements = obj.toDrawables(elements=elementsModule, **kwargs) 77 else: 78 assert len(kwargs) == 0 79 elements = (obj,) 80 self.extend(elements, z=z) 81 def append(self, element, *, z=None): 82 if z is not None: 83 self.orderedElements[z].append(element) 84 else: 85 self.elements.append(element) 86 def extend(self, iterable, *, z=None): 87 if z is not None: 88 self.orderedElements[z].extend(iterable) 89 else: 90 self.elements.extend(iterable) 91 def insert(self, i, element): 92 self.elements.insert(i, element) 93 def remove(self, element): 94 self.elements.remove(element) 95 def clear(self): 96 self.elements.clear() 97 def index(self, *args, **kwargs): 98 return self.elements.index(*args, **kwargs) 99 def count(self, element): 100 return self.elements.count(element) 101 def reverse(self): 102 self.elements.reverse() 103 def drawDef(self, obj, **kwargs): 104 if not hasattr(obj, 'writeSvgElement'): 105 elements = obj.toDrawables(elements=elementsModule, **kwargs) 106 else: 107 assert len(kwargs) == 0 108 elements = (obj,) 109 self.otherDefs.extend(elements) 110 def appendDef(self, element): 111 self.otherDefs.append(element) 112 def allElements(self): 113 ''' Returns self.elements and self.orderedElements as a single list. ''' 114 output = list(self.elements) 115 for z in sorted(self.orderedElements): 116 output.extend(self.orderedElements[z]) 117 return output 118 def asSvg(self, outputFile=None): 119 returnString = outputFile is None 120 if returnString: 121 outputFile = StringIO() 122 imgWidth, imgHeight = self.calcRenderSize() 123 startStr = '''<?xml version="1.0" encoding="UTF-8"?> 124<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" 125 width="{}" height="{}" viewBox="{} {} {} {}"'''.format( 126 imgWidth, imgHeight, *self.viewBox) 127 endStr = '</svg>' 128 outputFile.write(startStr) 129 elementsModule.writeXmlNodeArgs(self.svgArgs, outputFile) 130 outputFile.write('>\n<defs>\n') 131 # Write definition elements 132 def idGen(base=''): 133 idStr = self.idPrefix + base + str(self.idIndex) 134 self.idIndex += 1 135 return idStr 136 prevSet = set((id(defn) for defn in self.otherDefs)) 137 def isDuplicate(obj): 138 nonlocal prevSet 139 dup = id(obj) in prevSet 140 prevSet.add(id(obj)) 141 return dup 142 for element in self.otherDefs: 143 try: 144 element.writeSvgElement(idGen, isDuplicate, outputFile, False) 145 outputFile.write('\n') 146 except AttributeError: 147 pass 148 allElements = self.allElements() 149 for element in allElements: 150 try: 151 element.writeSvgDefs(idGen, isDuplicate, outputFile, False) 152 except AttributeError: 153 pass 154 outputFile.write('</defs>\n') 155 # Generate ids for normal elements 156 prevDefSet = set(prevSet) 157 for element in allElements: 158 try: 159 element.writeSvgElement(idGen, isDuplicate, outputFile, True) 160 except AttributeError: 161 pass 162 prevSet = prevDefSet 163 # Write normal elements 164 for element in allElements: 165 try: 166 element.writeSvgElement(idGen, isDuplicate, outputFile, False) 167 outputFile.write('\n') 168 except AttributeError: 169 pass 170 outputFile.write(endStr) 171 if returnString: 172 return outputFile.getvalue() 173 def saveSvg(self, fname, encoding='utf-8'): 174 with open(fname, 'w', encoding=encoding) as f: 175 self.asSvg(outputFile=f) 176 def savePng(self, fname): 177 self.rasterize(toFile=fname) 178 def rasterize(self, toFile=None): 179 if toFile: 180 return Raster.fromSvgToFile(self.asSvg(), toFile) 181 else: 182 return Raster.fromSvg(self.asSvg()) 183 def _repr_svg_(self): 184 ''' Display in Jupyter notebook ''' 185 if not self.displayInline: 186 return None 187 return self.asSvg() 188 def _repr_html_(self): 189 ''' Display in Jupyter notebook ''' 190 if self.displayInline: 191 return None 192 prefix = b'data:image/svg+xml;base64,' 193 data = base64.b64encode(self.asSvg().encode(encoding='utf-8')) 194 src = (prefix+data).decode(encoding='ascii') 195 return '<img src="{}">'.format(src) 196 def asDataUri(self, strip_chars=STRIP_CHARS): 197 ''' Returns a data URI with base64 encoding. ''' 198 data = self.asSvg() 199 search = re.compile('|'.join(strip_chars)) 200 data_safe = search.sub(lambda m: '', data) 201 b64 = base64.b64encode(data_safe.encode()) 202 return 'data:image/svg+xml;base64,' + b64.decode(encoding='ascii') 203 def asUtf8DataUri(self, unsafe_chars='"', strip_chars=STRIP_CHARS): 204 ''' Returns a data URI without base64 encoding. 205 206 The characters '#&%' are always escaped. '#' and '&' break parsing 207 of the data URI. If '%' is not escaped, plain text like '%50' will 208 be incorrectly decoded to 'P'. The characters in `strip_chars` 209 cause the SVG not to render even if they are escaped. ''' 210 data = self.asSvg() 211 unsafe_chars = (unsafe_chars or '') + '#&%' 212 replacements = { 213 char: urllib.parse.quote(char, safe='') 214 for char in unsafe_chars 215 } 216 replacements.update({ 217 char: '' 218 for char in strip_chars 219 }) 220 search = re.compile('|'.join(map(re.escape, replacements.keys()))) 221 data_safe = search.sub(lambda m: replacements[m.group(0)], data) 222 return 'data:image/svg+xml;utf8,' + data_safe