Programmatically generate SVG (vector) images, animations, and interactive Jupyter widgets
1 2from io import StringIO 3import base64 4import urllib.parse 5import re 6 7from . import Raster 8from . import elements as elementsModule 9 10 11STRIP_CHARS = ('\x00\x01\x02\x03\x04\x05\x06\x07\x08\x0b\x0c\x0e\x0f\x10\x11' 12 '\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f') 13 14 15class Drawing: 16 ''' A canvas to draw on 17 18 Supports iPython: If a Drawing is the last line of a cell, it will be 19 displayed as an SVG below. ''' 20 def __init__(self, width, height, origin=(0,0), idPrefix='d', 21 displayInline=True, **svgArgs): 22 assert float(width) == width 23 assert float(height) == height 24 self.width = width 25 self.height = height 26 if origin == 'center': 27 self.viewBox = (-width/2, -height/2, width, height) 28 else: 29 origin = tuple(origin) 30 assert len(origin) == 2 31 self.viewBox = origin + (width, height) 32 self.viewBox = (self.viewBox[0], -self.viewBox[1]-self.viewBox[3], 33 self.viewBox[2], self.viewBox[3]) 34 self.elements = [] 35 self.otherDefs = [] 36 self.pixelScale = 1 37 self.renderWidth = None 38 self.renderHeight = None 39 self.idPrefix = str(idPrefix) 40 self.displayInline = displayInline 41 self.svgArgs = {} 42 for k, v in svgArgs.items(): 43 k = k.replace('__', ':') 44 k = k.replace('_', '-') 45 if k[-1] == '-': 46 k = k[:-1] 47 self.svgArgs[k] = v 48 self.idIndex = 0 49 def setRenderSize(self, w=None, h=None): 50 self.renderWidth = w 51 self.renderHeight = h 52 return self 53 def setPixelScale(self, s=1): 54 self.renderWidth = None 55 self.renderHeight = None 56 self.pixelScale = s 57 return self 58 def calcRenderSize(self): 59 if self.renderWidth is None and self.renderHeight is None: 60 return (self.width * self.pixelScale, 61 self.height * self.pixelScale) 62 elif self.renderWidth is None: 63 s = self.renderHeight / self.height 64 return self.width * s, self.renderHeight 65 elif self.renderHeight is None: 66 s = self.renderWidth / self.width 67 return self.renderWidth, self.height * s 68 else: 69 return self.renderWidth, self.renderHeight 70 def draw(self, obj, **kwargs): 71 if obj is None: 72 return 73 if not hasattr(obj, 'writeSvgElement'): 74 elements = obj.toDrawables(elements=elementsModule, **kwargs) 75 else: 76 assert len(kwargs) == 0 77 elements = (obj,) 78 self.extend(elements) 79 def append(self, element): 80 self.elements.append(element) 81 def extend(self, iterable): 82 self.elements.extend(iterable) 83 def insert(self, i, element): 84 self.elements.insert(i, element) 85 def remove(self, element): 86 self.elements.remove(element) 87 def clear(self): 88 self.elements.clear() 89 def index(self, *args, **kwargs): 90 self.elements.index(*args, **kwargs) 91 def count(self, element): 92 self.elements.count(element) 93 def reverse(self): 94 self.elements.reverse() 95 def drawDef(self, obj, **kwargs): 96 if not hasattr(obj, 'writeSvgElement'): 97 elements = obj.toDrawables(elements=elementsModule, **kwargs) 98 else: 99 assert len(kwargs) == 0 100 elements = (obj,) 101 self.otherDefs.extend(elements) 102 def appendDef(self, element): 103 self.otherDefs.append(element) 104 def asSvg(self, outputFile=None): 105 returnString = outputFile is None 106 if returnString: 107 outputFile = StringIO() 108 imgWidth, imgHeight = self.calcRenderSize() 109 startStr = '''<?xml version="1.0" encoding="UTF-8"?> 110<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" 111 width="{}" height="{}" viewBox="{} {} {} {}"'''.format( 112 imgWidth, imgHeight, *self.viewBox) 113 endStr = '</svg>' 114 outputFile.write(startStr) 115 elementsModule.writeXmlNodeArgs(self.svgArgs, outputFile) 116 outputFile.write('>\n<defs>\n') 117 # Write definition elements 118 def idGen(base=''): 119 idStr = self.idPrefix + base + str(self.idIndex) 120 self.idIndex += 1 121 return idStr 122 prevSet = set((id(defn) for defn in self.otherDefs)) 123 def isDuplicate(obj): 124 nonlocal prevSet 125 dup = id(obj) in prevSet 126 prevSet.add(id(obj)) 127 return dup 128 for element in self.otherDefs: 129 try: 130 element.writeSvgElement(idGen, isDuplicate, outputFile, False) 131 outputFile.write('\n') 132 except AttributeError: 133 pass 134 for element in self.elements: 135 try: 136 element.writeSvgDefs(idGen, isDuplicate, outputFile, False) 137 except AttributeError: 138 pass 139 outputFile.write('</defs>\n') 140 # Generate ids for normal elements 141 prevDefSet = set(prevSet) 142 for element in self.elements: 143 try: 144 element.writeSvgElement(idGen, isDuplicate, outputFile, True) 145 except AttributeError: 146 pass 147 prevSet = prevDefSet 148 # Write normal elements 149 for element in self.elements: 150 try: 151 element.writeSvgElement(idGen, isDuplicate, outputFile, False) 152 outputFile.write('\n') 153 except AttributeError: 154 pass 155 outputFile.write(endStr) 156 if returnString: 157 return outputFile.getvalue() 158 def saveSvg(self, fname): 159 with open(fname, 'w') as f: 160 self.asSvg(outputFile=f) 161 def savePng(self, fname): 162 self.rasterize(toFile=fname) 163 def rasterize(self, toFile=None): 164 if toFile: 165 return Raster.fromSvgToFile(self.asSvg(), toFile) 166 else: 167 return Raster.fromSvg(self.asSvg()) 168 def _repr_svg_(self): 169 ''' Display in Jupyter notebook ''' 170 if not self.displayInline: 171 return None 172 return self.asSvg() 173 def _repr_html_(self): 174 ''' Display in Jupyter notebook ''' 175 if self.displayInline: 176 return None 177 prefix = b'data:image/svg+xml;base64,' 178 data = base64.b64encode(self.asSvg().encode()) 179 src = (prefix+data).decode() 180 return '<img src="{}">'.format(src) 181 def asDataUri(self, strip_chars=STRIP_CHARS): 182 ''' Returns a data URI with base64 encoding. ''' 183 data = self.asSvg() 184 search = re.compile('|'.join(strip_chars)) 185 data_safe = search.sub(lambda m: '', data) 186 b64 = base64.b64encode(data_safe.encode()) 187 return 'data:image/svg+xml;base64,' + b64.decode(encoding='ascii') 188 def asUtf8DataUri(self, unsafe_chars='"', strip_chars=STRIP_CHARS): 189 ''' Returns a data URI without base64 encoding. 190 191 The characters '#&%' are always escaped. '#' and '&' break parsing 192 of the data URI. If '%' is not escaped, plain text like '%50' will 193 be incorrectly decoded to 'P'. The characters in `strip_chars` 194 cause the SVG not to render even if they are escaped. ''' 195 data = self.asSvg() 196 unsafe_chars = (unsafe_chars or '') + '#&%' 197 replacements = { 198 char: urllib.parse.quote(char, safe='') 199 for char in unsafe_chars 200 } 201 replacements.update({ 202 char: '' 203 for char in strip_chars 204 }) 205 search = re.compile('|'.join(map(re.escape, replacements.keys()))) 206 data_safe = search.sub(lambda m: replacements[m.group(0)], data) 207 return 'data:image/svg+xml;utf8,' + data_safe