1
2from io import StringIO
3
4from . import Raster
5from . import elements as elementsModule
6
7
8class Drawing:
9 ''' A canvas to draw on
10
11 Supports iPython: If a Drawing is the last line of a cell, it will be
12 displayed as an SVG below. '''
13 def __init__(self, width, height, origin=(0,0)):
14 assert float(width) == width
15 assert float(height) == height
16 self.width = width
17 self.height = height
18 if origin == 'center':
19 self.viewBox = (width/2, height/2, width, height)
20 else:
21 origin = tuple(origin)
22 assert len(origin) == 2
23 self.viewBox = origin + (width, height)
24 self.viewBox = (-self.viewBox[0], self.viewBox[1]-self.viewBox[3],
25 self.viewBox[2], self.viewBox[3])
26 self.elements = []
27 self.pixelScale = 1
28 self.renderWidth = None
29 self.renderHeight = None
30 def setRenderSize(self, w=None, h=None):
31 self.renderWidth = w
32 self.renderHeight = h
33 return self
34 def setPixelScale(self, s=1):
35 self.renderWidth = None
36 self.renderHeight = None
37 self.pixelScale = s
38 return self
39 def calcRenderSize(self):
40 if self.renderWidth is None and self.renderHeight is None:
41 return (self.width * self.pixelScale,
42 self.height * self.pixelScale)
43 elif self.renderWidth is None:
44 s = self.renderHeight / self.height
45 return self.width * s, self.renderHeight
46 elif self.renderHeight is None:
47 s = self.renderWidth / self.width
48 return self.renderWidth, self.height * s
49 else:
50 return self.renderWidth, self.renderHeight
51 def draw(self, obj, **kwargs):
52 if not hasattr(obj, 'writeSvgElement'):
53 elements = obj.toDrawables(elements=elementsModule, **kwargs)
54 else:
55 assert len(kwargs) == 0
56 elements = (obj,)
57 self.extend(elements)
58 def append(self, element):
59 self.elements.append(element)
60 def extend(self, iterable):
61 self.elements.extend(iterable)
62 def insert(self, i, element):
63 self.elements.insert(i, element)
64 def remove(self, element):
65 self.elements.remove(element)
66 def clear(self):
67 self.elements.clear()
68 def index(self, *args, **kwargs):
69 self.elements.index(*args, **kwargs)
70 def count(self, element):
71 self.elements.count(element)
72 def reverse(self):
73 self.elements.reverse()
74 def asSvg(self, outputFile=None):
75 returnString = outputFile is None
76 if returnString:
77 outputFile = StringIO()
78 imgWidth, imgHeight = self.calcRenderSize()
79 startStr = '''<?xml version="1.0" encoding="UTF-8"?>
80<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
81 width="{}" height="{}" viewBox="{} {} {} {}">'''.format(
82 imgWidth, imgHeight, *self.viewBox)
83 endStr = '</svg>'
84 outputFile.write(startStr)
85 outputFile.write('\n')
86 for element in self.elements:
87 try:
88 element.writeSvgElement(outputFile)
89 outputFile.write('\n')
90 except AttributeError:
91 pass
92 outputFile.write(endStr)
93 if returnString:
94 return outputFile.getvalue()
95 def saveSvg(self, fname):
96 with open(fname, 'w') as f:
97 self.asSvg(outputFile=f)
98 def savePng(self, fname):
99 self.rasterize(toFile=fname)
100 def rasterize(self, toFile=None):
101 if toFile:
102 return Raster.fromSvgToFile(self.asSvg(), toFile)
103 else:
104 return Raster.fromSvg(self.asSvg())
105 def _repr_svg_(self):
106 ''' Display in Jupyter notebook '''
107 return self.asSvg()
108