Programmatically generate SVG (vector) images, animations, and interactive Jupyter widgets
at 1.0.1 5.2 kB view raw
1 2from io import StringIO 3 4from . import Raster 5from . import elements as elementsModule 6 7 8class Drawing: 9 ''' A canvas to draw on 10 11 Supports iPython: If a Drawing is the last line of a cell, it will be 12 displayed as an SVG below. ''' 13 def __init__(self, width, height, origin=(0,0), **svgArgs): 14 assert float(width) == width 15 assert float(height) == height 16 self.width = width 17 self.height = height 18 if origin == 'center': 19 self.viewBox = (width/2, height/2, width, height) 20 else: 21 origin = tuple(origin) 22 assert len(origin) == 2 23 self.viewBox = origin + (width, height) 24 self.viewBox = (self.viewBox[0], -self.viewBox[1]-self.viewBox[3], 25 self.viewBox[2], self.viewBox[3]) 26 self.elements = [] 27 self.otherDefs = [] 28 self.pixelScale = 1 29 self.renderWidth = None 30 self.renderHeight = None 31 self.svgArgs = svgArgs 32 def setRenderSize(self, w=None, h=None): 33 self.renderWidth = w 34 self.renderHeight = h 35 return self 36 def setPixelScale(self, s=1): 37 self.renderWidth = None 38 self.renderHeight = None 39 self.pixelScale = s 40 return self 41 def calcRenderSize(self): 42 if self.renderWidth is None and self.renderHeight is None: 43 return (self.width * self.pixelScale, 44 self.height * self.pixelScale) 45 elif self.renderWidth is None: 46 s = self.renderHeight / self.height 47 return self.width * s, self.renderHeight 48 elif self.renderHeight is None: 49 s = self.renderWidth / self.width 50 return self.renderWidth, self.height * s 51 else: 52 return self.renderWidth, self.renderHeight 53 def draw(self, obj, **kwargs): 54 if not hasattr(obj, 'writeSvgElement'): 55 elements = obj.toDrawables(elements=elementsModule, **kwargs) 56 else: 57 assert len(kwargs) == 0 58 elements = (obj,) 59 self.extend(elements) 60 def append(self, element): 61 self.elements.append(element) 62 def extend(self, iterable): 63 self.elements.extend(iterable) 64 def insert(self, i, element): 65 self.elements.insert(i, element) 66 def remove(self, element): 67 self.elements.remove(element) 68 def clear(self): 69 self.elements.clear() 70 def index(self, *args, **kwargs): 71 self.elements.index(*args, **kwargs) 72 def count(self, element): 73 self.elements.count(element) 74 def reverse(self): 75 self.elements.reverse() 76 def drawDef(self, obj, **kwargs): 77 if not hasattr(obj, 'writeSvgElement'): 78 elements = obj.toDrawables(elements=elementsModule, **kwargs) 79 else: 80 assert len(kwargs) == 0 81 elements = (obj,) 82 self.otherDefs.extend(elements) 83 def appendDef(self, element): 84 self.otherDefs.append(element) 85 def asSvg(self, outputFile=None): 86 returnString = outputFile is None 87 if returnString: 88 outputFile = StringIO() 89 imgWidth, imgHeight = self.calcRenderSize() 90 startStr = '''<?xml version="1.0" encoding="UTF-8"?> 91<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" 92 width="{}" height="{}" viewBox="{} {} {} {}"'''.format( 93 imgWidth, imgHeight, *self.viewBox) 94 endStr = '</svg>' 95 outputFile.write(startStr) 96 elementsModule.writeXmlNodeArgs(self.svgArgs, outputFile) 97 outputFile.write('>\n<defs>\n') 98 # Write definition elements 99 idIndex = 0 100 def idGen(base='d'): 101 nonlocal idIndex 102 idStr = base + str(idIndex) 103 idIndex += 1 104 return idStr 105 prevSet = set((id(defn) for defn in self.otherDefs)) 106 def isDuplicate(obj): 107 nonlocal prevSet 108 dup = id(obj) in prevSet 109 prevSet.add(id(obj)) 110 return dup 111 for element in self.otherDefs: 112 try: 113 element.writeSvgElement(outputFile) 114 outputFile.write('\n') 115 except AttributeError: 116 pass 117 for element in self.elements: 118 try: 119 element.writeSvgDefs(idGen, isDuplicate, outputFile) 120 except AttributeError: 121 pass 122 outputFile.write('</defs>\n') 123 # Write normal elements 124 for element in self.elements: 125 try: 126 element.writeSvgElement(outputFile) 127 outputFile.write('\n') 128 except AttributeError: 129 pass 130 outputFile.write(endStr) 131 if returnString: 132 return outputFile.getvalue() 133 def saveSvg(self, fname): 134 with open(fname, 'w') as f: 135 self.asSvg(outputFile=f) 136 def savePng(self, fname): 137 self.rasterize(toFile=fname) 138 def rasterize(self, toFile=None): 139 if toFile: 140 return Raster.fromSvgToFile(self.asSvg(), toFile) 141 else: 142 return Raster.fromSvg(self.asSvg()) 143 def _repr_svg_(self): 144 ''' Display in Jupyter notebook ''' 145 return self.asSvg() 146