1
2from io import StringIO
3import base64
4
5from . import Raster
6from . import elements as elementsModule
7
8
9class Drawing:
10 ''' A canvas to draw on
11
12 Supports iPython: If a Drawing is the last line of a cell, it will be
13 displayed as an SVG below. '''
14 def __init__(self, width, height, origin=(0,0), idPrefix='d',
15 displayInline=True, **svgArgs):
16 assert float(width) == width
17 assert float(height) == height
18 self.width = width
19 self.height = height
20 if origin == 'center':
21 self.viewBox = (-width/2, -height/2, width, height)
22 else:
23 origin = tuple(origin)
24 assert len(origin) == 2
25 self.viewBox = origin + (width, height)
26 self.viewBox = (self.viewBox[0], -self.viewBox[1]-self.viewBox[3],
27 self.viewBox[2], self.viewBox[3])
28 self.elements = []
29 self.otherDefs = []
30 self.pixelScale = 1
31 self.renderWidth = None
32 self.renderHeight = None
33 self.idPrefix = str(idPrefix)
34 self.displayInline = displayInline
35 self.svgArgs = {}
36 for k, v in svgArgs.items():
37 k = k.replace('__', ':')
38 k = k.replace('_', '-')
39 if k[-1] == '-':
40 k = k[:-1]
41 self.svgArgs[k] = v
42 def setRenderSize(self, w=None, h=None):
43 self.renderWidth = w
44 self.renderHeight = h
45 return self
46 def setPixelScale(self, s=1):
47 self.renderWidth = None
48 self.renderHeight = None
49 self.pixelScale = s
50 return self
51 def calcRenderSize(self):
52 if self.renderWidth is None and self.renderHeight is None:
53 return (self.width * self.pixelScale,
54 self.height * self.pixelScale)
55 elif self.renderWidth is None:
56 s = self.renderHeight / self.height
57 return self.width * s, self.renderHeight
58 elif self.renderHeight is None:
59 s = self.renderWidth / self.width
60 return self.renderWidth, self.height * s
61 else:
62 return self.renderWidth, self.renderHeight
63 def draw(self, obj, **kwargs):
64 if not hasattr(obj, 'writeSvgElement'):
65 elements = obj.toDrawables(elements=elementsModule, **kwargs)
66 else:
67 assert len(kwargs) == 0
68 elements = (obj,)
69 self.extend(elements)
70 def append(self, element):
71 self.elements.append(element)
72 def extend(self, iterable):
73 self.elements.extend(iterable)
74 def insert(self, i, element):
75 self.elements.insert(i, element)
76 def remove(self, element):
77 self.elements.remove(element)
78 def clear(self):
79 self.elements.clear()
80 def index(self, *args, **kwargs):
81 self.elements.index(*args, **kwargs)
82 def count(self, element):
83 self.elements.count(element)
84 def reverse(self):
85 self.elements.reverse()
86 def drawDef(self, obj, **kwargs):
87 if not hasattr(obj, 'writeSvgElement'):
88 elements = obj.toDrawables(elements=elementsModule, **kwargs)
89 else:
90 assert len(kwargs) == 0
91 elements = (obj,)
92 self.otherDefs.extend(elements)
93 def appendDef(self, element):
94 self.otherDefs.append(element)
95 def asSvg(self, outputFile=None):
96 returnString = outputFile is None
97 if returnString:
98 outputFile = StringIO()
99 imgWidth, imgHeight = self.calcRenderSize()
100 startStr = '''<?xml version="1.0" encoding="UTF-8"?>
101<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
102 width="{}" height="{}" viewBox="{} {} {} {}"'''.format(
103 imgWidth, imgHeight, *self.viewBox)
104 endStr = '</svg>'
105 outputFile.write(startStr)
106 elementsModule.writeXmlNodeArgs(self.svgArgs, outputFile)
107 outputFile.write('>\n<defs>\n')
108 # Write definition elements
109 idIndex = 0
110 def idGen(base=''):
111 nonlocal idIndex
112 idStr = self.idPrefix + base + str(idIndex)
113 idIndex += 1
114 return idStr
115 prevSet = set((id(defn) for defn in self.otherDefs))
116 def isDuplicate(obj):
117 nonlocal prevSet
118 dup = id(obj) in prevSet
119 prevSet.add(id(obj))
120 return dup
121 for element in self.otherDefs:
122 try:
123 element.writeSvgElement(idGen, isDuplicate, outputFile, False)
124 outputFile.write('\n')
125 except AttributeError:
126 pass
127 for element in self.elements:
128 try:
129 element.writeSvgDefs(idGen, isDuplicate, outputFile, False)
130 except AttributeError:
131 pass
132 outputFile.write('</defs>\n')
133 # Generate ids for normal elements
134 prevDefSet = set(prevSet)
135 for element in self.elements:
136 try:
137 element.writeSvgElement(idGen, isDuplicate, outputFile, True)
138 except AttributeError:
139 pass
140 prevSet = prevDefSet
141 # Write normal elements
142 for element in self.elements:
143 try:
144 element.writeSvgElement(idGen, isDuplicate, outputFile, False)
145 outputFile.write('\n')
146 except AttributeError:
147 pass
148 outputFile.write(endStr)
149 if returnString:
150 return outputFile.getvalue()
151 def saveSvg(self, fname):
152 with open(fname, 'w') as f:
153 self.asSvg(outputFile=f)
154 def savePng(self, fname):
155 self.rasterize(toFile=fname)
156 def rasterize(self, toFile=None):
157 if toFile:
158 return Raster.fromSvgToFile(self.asSvg(), toFile)
159 else:
160 return Raster.fromSvg(self.asSvg())
161 def _repr_svg_(self):
162 ''' Display in Jupyter notebook '''
163 if not self.displayInline:
164 return None
165 return self.asSvg()
166 def _repr_html_(self):
167 ''' Display in Jupyter notebook '''
168 if self.displayInline:
169 return None
170 prefix = b'data:image/svg+xml;base64,'
171 data = base64.b64encode(self.asSvg().encode())
172 src = (prefix+data).decode()
173 return '<img src="{}">'.format(src)