Programmatically generate SVG (vector) images, animations, and interactive Jupyter widgets
1 2from io import StringIO 3import base64 4import urllib.parse 5import re 6 7from . import Raster 8from . import elements as elementsModule 9 10 11STRIP_CHARS = ('\x00\x01\x02\x03\x04\x05\x06\x07\x08\x0b\x0c\x0e\x0f\x10\x11' 12 '\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f') 13 14 15class Drawing: 16 ''' A canvas to draw on 17 18 Supports iPython: If a Drawing is the last line of a cell, it will be 19 displayed as an SVG below. ''' 20 def __init__(self, width, height, origin=(0,0), idPrefix='d', 21 displayInline=True, **svgArgs): 22 assert float(width) == width 23 assert float(height) == height 24 self.width = width 25 self.height = height 26 if origin == 'center': 27 self.viewBox = (-width/2, -height/2, width, height) 28 else: 29 origin = tuple(origin) 30 assert len(origin) == 2 31 self.viewBox = origin + (width, height) 32 self.viewBox = (self.viewBox[0], -self.viewBox[1]-self.viewBox[3], 33 self.viewBox[2], self.viewBox[3]) 34 self.elements = [] 35 self.otherDefs = [] 36 self.pixelScale = 1 37 self.renderWidth = None 38 self.renderHeight = None 39 self.idPrefix = str(idPrefix) 40 self.displayInline = displayInline 41 self.svgArgs = {} 42 for k, v in svgArgs.items(): 43 k = k.replace('__', ':') 44 k = k.replace('_', '-') 45 if k[-1] == '-': 46 k = k[:-1] 47 self.svgArgs[k] = v 48 def setRenderSize(self, w=None, h=None): 49 self.renderWidth = w 50 self.renderHeight = h 51 return self 52 def setPixelScale(self, s=1): 53 self.renderWidth = None 54 self.renderHeight = None 55 self.pixelScale = s 56 return self 57 def calcRenderSize(self): 58 if self.renderWidth is None and self.renderHeight is None: 59 return (self.width * self.pixelScale, 60 self.height * self.pixelScale) 61 elif self.renderWidth is None: 62 s = self.renderHeight / self.height 63 return self.width * s, self.renderHeight 64 elif self.renderHeight is None: 65 s = self.renderWidth / self.width 66 return self.renderWidth, self.height * s 67 else: 68 return self.renderWidth, self.renderHeight 69 def draw(self, obj, **kwargs): 70 if not hasattr(obj, 'writeSvgElement'): 71 elements = obj.toDrawables(elements=elementsModule, **kwargs) 72 else: 73 assert len(kwargs) == 0 74 elements = (obj,) 75 self.extend(elements) 76 def append(self, element): 77 self.elements.append(element) 78 def extend(self, iterable): 79 self.elements.extend(iterable) 80 def insert(self, i, element): 81 self.elements.insert(i, element) 82 def remove(self, element): 83 self.elements.remove(element) 84 def clear(self): 85 self.elements.clear() 86 def index(self, *args, **kwargs): 87 self.elements.index(*args, **kwargs) 88 def count(self, element): 89 self.elements.count(element) 90 def reverse(self): 91 self.elements.reverse() 92 def drawDef(self, obj, **kwargs): 93 if not hasattr(obj, 'writeSvgElement'): 94 elements = obj.toDrawables(elements=elementsModule, **kwargs) 95 else: 96 assert len(kwargs) == 0 97 elements = (obj,) 98 self.otherDefs.extend(elements) 99 def appendDef(self, element): 100 self.otherDefs.append(element) 101 def asSvg(self, outputFile=None): 102 returnString = outputFile is None 103 if returnString: 104 outputFile = StringIO() 105 imgWidth, imgHeight = self.calcRenderSize() 106 startStr = '''<?xml version="1.0" encoding="UTF-8"?> 107<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" 108 width="{}" height="{}" viewBox="{} {} {} {}"'''.format( 109 imgWidth, imgHeight, *self.viewBox) 110 endStr = '</svg>' 111 outputFile.write(startStr) 112 elementsModule.writeXmlNodeArgs(self.svgArgs, outputFile) 113 outputFile.write('>\n<defs>\n') 114 # Write definition elements 115 idIndex = 0 116 def idGen(base=''): 117 nonlocal idIndex 118 idStr = self.idPrefix + base + str(idIndex) 119 idIndex += 1 120 return idStr 121 prevSet = set((id(defn) for defn in self.otherDefs)) 122 def isDuplicate(obj): 123 nonlocal prevSet 124 dup = id(obj) in prevSet 125 prevSet.add(id(obj)) 126 return dup 127 for element in self.otherDefs: 128 try: 129 element.writeSvgElement(idGen, isDuplicate, outputFile, False) 130 outputFile.write('\n') 131 except AttributeError: 132 pass 133 for element in self.elements: 134 try: 135 element.writeSvgDefs(idGen, isDuplicate, outputFile, False) 136 except AttributeError: 137 pass 138 outputFile.write('</defs>\n') 139 # Generate ids for normal elements 140 prevDefSet = set(prevSet) 141 for element in self.elements: 142 try: 143 element.writeSvgElement(idGen, isDuplicate, outputFile, True) 144 except AttributeError: 145 pass 146 prevSet = prevDefSet 147 # Write normal elements 148 for element in self.elements: 149 try: 150 element.writeSvgElement(idGen, isDuplicate, outputFile, False) 151 outputFile.write('\n') 152 except AttributeError: 153 pass 154 outputFile.write(endStr) 155 if returnString: 156 return outputFile.getvalue() 157 def saveSvg(self, fname): 158 with open(fname, 'w') as f: 159 self.asSvg(outputFile=f) 160 def savePng(self, fname): 161 self.rasterize(toFile=fname) 162 def rasterize(self, toFile=None): 163 if toFile: 164 return Raster.fromSvgToFile(self.asSvg(), toFile) 165 else: 166 return Raster.fromSvg(self.asSvg()) 167 def _repr_svg_(self): 168 ''' Display in Jupyter notebook ''' 169 if not self.displayInline: 170 return None 171 return self.asSvg() 172 def _repr_html_(self): 173 ''' Display in Jupyter notebook ''' 174 if self.displayInline: 175 return None 176 prefix = b'data:image/svg+xml;base64,' 177 data = base64.b64encode(self.asSvg().encode()) 178 src = (prefix+data).decode() 179 return '<img src="{}">'.format(src) 180 def asDataUri(self, strip_chars=STRIP_CHARS): 181 ''' Returns a data URI with base64 encoding. ''' 182 data = self.asSvg() 183 search = re.compile('|'.join(strip_chars)) 184 data_safe = search.sub(lambda m: '', data) 185 b64 = base64.b64encode(data_safe.encode()) 186 return 'data:image/svg+xml;base64,' + b64.decode(encoding='ascii') 187 def asUtf8DataUri(self, unsafe_chars='"', strip_chars=STRIP_CHARS): 188 ''' Returns a data URI without base64 encoding. 189 190 The characters '#&%' are always escaped. '#' and '&' break parsing 191 of the data URI. If '%' is not escaped, plain text like '%50' will 192 be incorrectly decoded to 'P'. The characters in `strip_chars` 193 cause the SVG not to render even if they are escaped. ''' 194 data = self.asSvg() 195 unsafe_chars = (unsafe_chars or '') + '#&%' 196 replacements = { 197 char: urllib.parse.quote(char, safe='') 198 for char in unsafe_chars 199 } 200 replacements.update({ 201 char: '' 202 for char in strip_chars 203 }) 204 search = re.compile('|'.join(map(re.escape, replacements.keys()))) 205 data_safe = search.sub(lambda m: replacements[m.group(0)], data) 206 return 'data:image/svg+xml;utf8,' + data_safe