Programmatically generate SVG (vector) images, animations, and interactive Jupyter widgets
1 2from io import StringIO 3 4from . import Raster 5from . import elements as elementsModule 6 7 8class Drawing: 9 ''' A canvas to draw on 10 11 Supports iPython: If a Drawing is the last line of a cell, it will be 12 displayed as an SVG below. ''' 13 def __init__(self, width, height, origin=(0,0), **svgArgs): 14 assert float(width) == width 15 assert float(height) == height 16 self.width = width 17 self.height = height 18 if origin == 'center': 19 self.viewBox = (width/2, height/2, width, height) 20 else: 21 origin = tuple(origin) 22 assert len(origin) == 2 23 self.viewBox = origin + (width, height) 24 self.viewBox = (-self.viewBox[0], self.viewBox[1]-self.viewBox[3], 25 self.viewBox[2], self.viewBox[3]) 26 self.elements = [] 27 self.pixelScale = 1 28 self.renderWidth = None 29 self.renderHeight = None 30 self.svgArgs = svgArgs 31 def setRenderSize(self, w=None, h=None): 32 self.renderWidth = w 33 self.renderHeight = h 34 return self 35 def setPixelScale(self, s=1): 36 self.renderWidth = None 37 self.renderHeight = None 38 self.pixelScale = s 39 return self 40 def calcRenderSize(self): 41 if self.renderWidth is None and self.renderHeight is None: 42 return (self.width * self.pixelScale, 43 self.height * self.pixelScale) 44 elif self.renderWidth is None: 45 s = self.renderHeight / self.height 46 return self.width * s, self.renderHeight 47 elif self.renderHeight is None: 48 s = self.renderWidth / self.width 49 return self.renderWidth, self.height * s 50 else: 51 return self.renderWidth, self.renderHeight 52 def draw(self, obj, **kwargs): 53 if not hasattr(obj, 'writeSvgElement'): 54 elements = obj.toDrawables(elements=elementsModule, **kwargs) 55 else: 56 assert len(kwargs) == 0 57 elements = (obj,) 58 self.extend(elements) 59 def append(self, element): 60 self.elements.append(element) 61 def extend(self, iterable): 62 self.elements.extend(iterable) 63 def insert(self, i, element): 64 self.elements.insert(i, element) 65 def remove(self, element): 66 self.elements.remove(element) 67 def clear(self): 68 self.elements.clear() 69 def index(self, *args, **kwargs): 70 self.elements.index(*args, **kwargs) 71 def count(self, element): 72 self.elements.count(element) 73 def reverse(self): 74 self.elements.reverse() 75 def asSvg(self, outputFile=None): 76 returnString = outputFile is None 77 if returnString: 78 outputFile = StringIO() 79 imgWidth, imgHeight = self.calcRenderSize() 80 startStr = '''<?xml version="1.0" encoding="UTF-8"?> 81<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" 82 width="{}" height="{}" viewBox="{} {} {} {}"'''.format( 83 imgWidth, imgHeight, *self.viewBox) 84 endStr = '</svg>' 85 outputFile.write(startStr) 86 elementsModule.writeXmlNodeArgs(self.svgArgs, outputFile) 87 outputFile.write('>\n<defs>\n') 88 # Write definition elements 89 idIndex = 0 90 def idGen(base='d'): 91 nonlocal idIndex 92 idStr = base + str(idIndex) 93 idIndex += 1 94 return idStr 95 prevSet = set() 96 def isDuplicate(obj): 97 nonlocal prevSet 98 dup = id(obj) in prevSet 99 prevSet.add(id(obj)) 100 return dup 101 for element in self.elements: 102 try: 103 element.writeSvgDefs(idGen, isDuplicate, outputFile) 104 except AttributeError: 105 pass 106 outputFile.write('</defs>\n') 107 # Write normal elements 108 for element in self.elements: 109 try: 110 element.writeSvgElement(outputFile) 111 outputFile.write('\n') 112 except AttributeError: 113 pass 114 outputFile.write(endStr) 115 if returnString: 116 return outputFile.getvalue() 117 def saveSvg(self, fname): 118 with open(fname, 'w') as f: 119 self.asSvg(outputFile=f) 120 def savePng(self, fname): 121 self.rasterize(toFile=fname) 122 def rasterize(self, toFile=None): 123 if toFile: 124 return Raster.fromSvgToFile(self.asSvg(), toFile) 125 else: 126 return Raster.fromSvg(self.asSvg()) 127 def _repr_svg_(self): 128 ''' Display in Jupyter notebook ''' 129 return self.asSvg() 130