Programmatically generate SVG (vector) images, animations, and interactive Jupyter widgets
1 2from io import StringIO 3import base64 4 5from . import Raster 6from . import elements as elementsModule 7 8 9class Drawing: 10 ''' A canvas to draw on 11 12 Supports iPython: If a Drawing is the last line of a cell, it will be 13 displayed as an SVG below. ''' 14 def __init__(self, width, height, origin=(0,0), idPrefix='d', 15 displayInline=True, **svgArgs): 16 assert float(width) == width 17 assert float(height) == height 18 self.width = width 19 self.height = height 20 if origin == 'center': 21 self.viewBox = (-width/2, -height/2, width, height) 22 else: 23 origin = tuple(origin) 24 assert len(origin) == 2 25 self.viewBox = origin + (width, height) 26 self.viewBox = (self.viewBox[0], -self.viewBox[1]-self.viewBox[3], 27 self.viewBox[2], self.viewBox[3]) 28 self.elements = [] 29 self.otherDefs = [] 30 self.pixelScale = 1 31 self.renderWidth = None 32 self.renderHeight = None 33 self.idPrefix = str(idPrefix) 34 self.displayInline = displayInline 35 self.svgArgs = {} 36 for k, v in svgArgs.items(): 37 k = k.replace('__', ':') 38 k = k.replace('_', '-') 39 if k[-1] == '-': 40 k = k[:-1] 41 self.svgArgs[k] = v 42 def setRenderSize(self, w=None, h=None): 43 self.renderWidth = w 44 self.renderHeight = h 45 return self 46 def setPixelScale(self, s=1): 47 self.renderWidth = None 48 self.renderHeight = None 49 self.pixelScale = s 50 return self 51 def calcRenderSize(self): 52 if self.renderWidth is None and self.renderHeight is None: 53 return (self.width * self.pixelScale, 54 self.height * self.pixelScale) 55 elif self.renderWidth is None: 56 s = self.renderHeight / self.height 57 return self.width * s, self.renderHeight 58 elif self.renderHeight is None: 59 s = self.renderWidth / self.width 60 return self.renderWidth, self.height * s 61 else: 62 return self.renderWidth, self.renderHeight 63 def draw(self, obj, **kwargs): 64 if not hasattr(obj, 'writeSvgElement'): 65 elements = obj.toDrawables(elements=elementsModule, **kwargs) 66 else: 67 assert len(kwargs) == 0 68 elements = (obj,) 69 self.extend(elements) 70 def append(self, element): 71 self.elements.append(element) 72 def extend(self, iterable): 73 self.elements.extend(iterable) 74 def insert(self, i, element): 75 self.elements.insert(i, element) 76 def remove(self, element): 77 self.elements.remove(element) 78 def clear(self): 79 self.elements.clear() 80 def index(self, *args, **kwargs): 81 self.elements.index(*args, **kwargs) 82 def count(self, element): 83 self.elements.count(element) 84 def reverse(self): 85 self.elements.reverse() 86 def drawDef(self, obj, **kwargs): 87 if not hasattr(obj, 'writeSvgElement'): 88 elements = obj.toDrawables(elements=elementsModule, **kwargs) 89 else: 90 assert len(kwargs) == 0 91 elements = (obj,) 92 self.otherDefs.extend(elements) 93 def appendDef(self, element): 94 self.otherDefs.append(element) 95 def asSvg(self, outputFile=None): 96 returnString = outputFile is None 97 if returnString: 98 outputFile = StringIO() 99 imgWidth, imgHeight = self.calcRenderSize() 100 startStr = '''<?xml version="1.0" encoding="UTF-8"?> 101<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" 102 width="{}" height="{}" viewBox="{} {} {} {}"'''.format( 103 imgWidth, imgHeight, *self.viewBox) 104 endStr = '</svg>' 105 outputFile.write(startStr) 106 elementsModule.writeXmlNodeArgs(self.svgArgs, outputFile) 107 outputFile.write('>\n<defs>\n') 108 # Write definition elements 109 idIndex = 0 110 def idGen(base=''): 111 nonlocal idIndex 112 idStr = self.idPrefix + base + str(idIndex) 113 idIndex += 1 114 return idStr 115 prevSet = set((id(defn) for defn in self.otherDefs)) 116 def isDuplicate(obj): 117 nonlocal prevSet 118 dup = id(obj) in prevSet 119 prevSet.add(id(obj)) 120 return dup 121 for element in self.otherDefs: 122 try: 123 element.writeSvgElement(idGen, isDuplicate, outputFile, False) 124 outputFile.write('\n') 125 except AttributeError: 126 pass 127 for element in self.elements: 128 try: 129 element.writeSvgDefs(idGen, isDuplicate, outputFile, False) 130 except AttributeError: 131 pass 132 outputFile.write('</defs>\n') 133 # Generate ids for normal elements 134 prevDefSet = set(prevSet) 135 for element in self.elements: 136 try: 137 element.writeSvgElement(idGen, isDuplicate, outputFile, True) 138 except AttributeError: 139 pass 140 prevSet = prevDefSet 141 # Write normal elements 142 for element in self.elements: 143 try: 144 element.writeSvgElement(idGen, isDuplicate, outputFile, False) 145 outputFile.write('\n') 146 except AttributeError: 147 pass 148 outputFile.write(endStr) 149 if returnString: 150 return outputFile.getvalue() 151 def saveSvg(self, fname): 152 with open(fname, 'w') as f: 153 self.asSvg(outputFile=f) 154 def savePng(self, fname): 155 self.rasterize(toFile=fname) 156 def rasterize(self, toFile=None): 157 if toFile: 158 return Raster.fromSvgToFile(self.asSvg(), toFile) 159 else: 160 return Raster.fromSvg(self.asSvg()) 161 def _repr_svg_(self): 162 ''' Display in Jupyter notebook ''' 163 if not self.displayInline: 164 return None 165 return self.asSvg() 166 def _repr_html_(self): 167 ''' Display in Jupyter notebook ''' 168 if self.displayInline: 169 return None 170 prefix = b'data:image/svg+xml;base64,' 171 data = base64.b64encode(self.asSvg().encode()) 172 src = (prefix+data).decode() 173 return '<img src="{}">'.format(src)