1
2from io import StringIO
3
4from . import elements as elementsModule
5
6
7class Drawing:
8 ''' A canvas to draw on
9
10 Supports iPython: If a Drawing is the last line of a cell, it will be
11 displayed as an SVG below. '''
12 def __init__(self, width, height, origin=(0,0)):
13 assert float(width) == width
14 assert float(height) == height
15 self.width = width
16 self.height = height
17 if origin == 'center':
18 self.viewBox = (width/2, height/2, width, height)
19 else:
20 origin = tuple(origin)
21 assert len(origin) == 2
22 self.viewBox = origin + (width, height)
23 self.viewBox = (-self.viewBox[0], self.viewBox[1]-self.viewBox[3],
24 self.viewBox[2], self.viewBox[3])
25 self.elements = []
26 self.pixelScale = 1
27 self.renderWidth = None
28 self.renderHeight = None
29 def setRenderSize(self, w=None, h=None):
30 self.renderWidth = w
31 self.renderHeight = h
32 return self
33 def setPixelScale(self, s=1):
34 self.renderWidth = None
35 self.renderHeight = None
36 self.pixelScale = s
37 return self
38 def calcRenderSize(self):
39 if self.renderWidth is None and self.renderHeight is None:
40 return (self.width * self.pixelScale,
41 self.height * self.pixelScale)
42 elif self.renderWidth is None:
43 s = self.renderHeight / self.height
44 return self.width * s, self.renderHeight
45 elif self.renderHeight is None:
46 s = self.renderWidth / self.width
47 return self.renderWidth, self.height * s
48 else:
49 return self.renderWidth, self.renderHeight
50 def draw(self, obj, **kwargs):
51 if not hasattr(obj, 'writeSvgElement'):
52 elements = obj.toDrawables(elements=elementsModule, **kwargs)
53 else:
54 assert len(kwargs) == 0
55 elements = (obj,)
56 self.extend(elements)
57 def append(self, element):
58 self.elements.append(element)
59 def extend(self, iterable):
60 self.elements.extend(iterable)
61 def insert(self, i, element):
62 self.elements.insert(i, element)
63 def remove(self, element):
64 self.elements.remove(element)
65 def clear(self):
66 self.elements.clear()
67 def index(self, *args, **kwargs):
68 self.elements.index(*args, **kwargs)
69 def count(self, element):
70 self.elements.count(element)
71 def reverse(self):
72 self.elements.reverse()
73 def asSvg(self, outputFile=None):
74 returnString = outputFile is None
75 if returnString:
76 outputFile = StringIO()
77 imgWidth, imgHeight = self.calcRenderSize()
78 startStr = '''<?xml version="1.0" encoding="UTF-8"?>
79<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
80 width="{}" height="{}" viewBox="{} {} {} {}">'''.format(
81 imgWidth, imgHeight, *self.viewBox)
82 endStr = '</svg>'
83 outputFile.write(startStr)
84 outputFile.write('\n')
85 for element in self.elements:
86 try:
87 element.writeSvgElement(outputFile)
88 outputFile.write('\n')
89 except AttributeError:
90 pass
91 outputFile.write(endStr)
92 if returnString:
93 return outputFile.getvalue()
94 def saveSvg(self, fname):
95 with open(fname, 'w') as f:
96 self.asSvg(outputFile=f)
97 def _repr_svg_(self):
98 ''' Display in Jupyter notebook '''
99 return self.asSvg()
100