1
2from io import StringIO
3
4from . import Raster
5from . import elements as elementsModule
6
7
8class Drawing:
9 ''' A canvas to draw on
10
11 Supports iPython: If a Drawing is the last line of a cell, it will be
12 displayed as an SVG below. '''
13 def __init__(self, width, height, origin=(0,0), **svgArgs):
14 assert float(width) == width
15 assert float(height) == height
16 self.width = width
17 self.height = height
18 if origin == 'center':
19 self.viewBox = (width/2, height/2, width, height)
20 else:
21 origin = tuple(origin)
22 assert len(origin) == 2
23 self.viewBox = origin + (width, height)
24 self.viewBox = (-self.viewBox[0], self.viewBox[1]-self.viewBox[3],
25 self.viewBox[2], self.viewBox[3])
26 self.elements = []
27 self.pixelScale = 1
28 self.renderWidth = None
29 self.renderHeight = None
30 self.svgArgs = svgArgs
31 def setRenderSize(self, w=None, h=None):
32 self.renderWidth = w
33 self.renderHeight = h
34 return self
35 def setPixelScale(self, s=1):
36 self.renderWidth = None
37 self.renderHeight = None
38 self.pixelScale = s
39 return self
40 def calcRenderSize(self):
41 if self.renderWidth is None and self.renderHeight is None:
42 return (self.width * self.pixelScale,
43 self.height * self.pixelScale)
44 elif self.renderWidth is None:
45 s = self.renderHeight / self.height
46 return self.width * s, self.renderHeight
47 elif self.renderHeight is None:
48 s = self.renderWidth / self.width
49 return self.renderWidth, self.height * s
50 else:
51 return self.renderWidth, self.renderHeight
52 def draw(self, obj, **kwargs):
53 if not hasattr(obj, 'writeSvgElement'):
54 elements = obj.toDrawables(elements=elementsModule, **kwargs)
55 else:
56 assert len(kwargs) == 0
57 elements = (obj,)
58 self.extend(elements)
59 def append(self, element):
60 self.elements.append(element)
61 def extend(self, iterable):
62 self.elements.extend(iterable)
63 def insert(self, i, element):
64 self.elements.insert(i, element)
65 def remove(self, element):
66 self.elements.remove(element)
67 def clear(self):
68 self.elements.clear()
69 def index(self, *args, **kwargs):
70 self.elements.index(*args, **kwargs)
71 def count(self, element):
72 self.elements.count(element)
73 def reverse(self):
74 self.elements.reverse()
75 def asSvg(self, outputFile=None):
76 returnString = outputFile is None
77 if returnString:
78 outputFile = StringIO()
79 imgWidth, imgHeight = self.calcRenderSize()
80 startStr = '''<?xml version="1.0" encoding="UTF-8"?>
81<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
82 width="{}" height="{}" viewBox="{} {} {} {}"'''.format(
83 imgWidth, imgHeight, *self.viewBox)
84 endStr = '</svg>'
85 outputFile.write(startStr)
86 elementsModule.writeXmlNodeArgs(self.svgArgs, outputFile)
87 outputFile.write('>\n<defs>\n')
88 # Write definition elements
89 idIndex = 0
90 def idGen(base='d'):
91 nonlocal idIndex
92 idStr = base + str(idIndex)
93 idIndex += 1
94 return idStr
95 prevSet = set()
96 def isDuplicate(obj):
97 nonlocal prevSet
98 dup = id(obj) in prevSet
99 prevSet.add(id(obj))
100 return dup
101 for element in self.elements:
102 try:
103 element.writeSvgDefs(idGen, isDuplicate, outputFile)
104 except AttributeError:
105 pass
106 outputFile.write('</defs>\n')
107 # Write normal elements
108 for element in self.elements:
109 try:
110 element.writeSvgElement(outputFile)
111 outputFile.write('\n')
112 except AttributeError:
113 pass
114 outputFile.write(endStr)
115 if returnString:
116 return outputFile.getvalue()
117 def saveSvg(self, fname):
118 with open(fname, 'w') as f:
119 self.asSvg(outputFile=f)
120 def savePng(self, fname):
121 self.rasterize(toFile=fname)
122 def rasterize(self, toFile=None):
123 if toFile:
124 return Raster.fromSvgToFile(self.asSvg(), toFile)
125 else:
126 return Raster.fromSvg(self.asSvg())
127 def _repr_svg_(self):
128 ''' Display in Jupyter notebook '''
129 return self.asSvg()
130