1
2from io import StringIO
3
4from . import Raster
5from . import elements as elementsModule
6
7
8class Drawing:
9 ''' A canvas to draw on
10
11 Supports iPython: If a Drawing is the last line of a cell, it will be
12 displayed as an SVG below. '''
13 def __init__(self, width, height, origin=(0,0), **svgArgs):
14 assert float(width) == width
15 assert float(height) == height
16 self.width = width
17 self.height = height
18 if origin == 'center':
19 self.viewBox = (width/2, height/2, width, height)
20 else:
21 origin = tuple(origin)
22 assert len(origin) == 2
23 self.viewBox = origin + (width, height)
24 self.viewBox = (-self.viewBox[0], self.viewBox[1]-self.viewBox[3],
25 self.viewBox[2], self.viewBox[3])
26 self.elements = []
27 self.otherDefs = []
28 self.pixelScale = 1
29 self.renderWidth = None
30 self.renderHeight = None
31 self.svgArgs = svgArgs
32 def setRenderSize(self, w=None, h=None):
33 self.renderWidth = w
34 self.renderHeight = h
35 return self
36 def setPixelScale(self, s=1):
37 self.renderWidth = None
38 self.renderHeight = None
39 self.pixelScale = s
40 return self
41 def calcRenderSize(self):
42 if self.renderWidth is None and self.renderHeight is None:
43 return (self.width * self.pixelScale,
44 self.height * self.pixelScale)
45 elif self.renderWidth is None:
46 s = self.renderHeight / self.height
47 return self.width * s, self.renderHeight
48 elif self.renderHeight is None:
49 s = self.renderWidth / self.width
50 return self.renderWidth, self.height * s
51 else:
52 return self.renderWidth, self.renderHeight
53 def draw(self, obj, **kwargs):
54 if not hasattr(obj, 'writeSvgElement'):
55 elements = obj.toDrawables(elements=elementsModule, **kwargs)
56 else:
57 assert len(kwargs) == 0
58 elements = (obj,)
59 self.extend(elements)
60 def append(self, element):
61 self.elements.append(element)
62 def extend(self, iterable):
63 self.elements.extend(iterable)
64 def insert(self, i, element):
65 self.elements.insert(i, element)
66 def remove(self, element):
67 self.elements.remove(element)
68 def clear(self):
69 self.elements.clear()
70 def index(self, *args, **kwargs):
71 self.elements.index(*args, **kwargs)
72 def count(self, element):
73 self.elements.count(element)
74 def reverse(self):
75 self.elements.reverse()
76 def drawDef(self, obj, **kwargs):
77 if not hasattr(obj, 'writeSvgElement'):
78 elements = obj.toDrawables(elements=elementsModule, **kwargs)
79 else:
80 assert len(kwargs) == 0
81 elements = (obj,)
82 self.otherDefs.extend(elements)
83 def appendDef(self, element):
84 self.otherDefs.append(element)
85 def asSvg(self, outputFile=None):
86 returnString = outputFile is None
87 if returnString:
88 outputFile = StringIO()
89 imgWidth, imgHeight = self.calcRenderSize()
90 startStr = '''<?xml version="1.0" encoding="UTF-8"?>
91<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
92 width="{}" height="{}" viewBox="{} {} {} {}"'''.format(
93 imgWidth, imgHeight, *self.viewBox)
94 endStr = '</svg>'
95 outputFile.write(startStr)
96 elementsModule.writeXmlNodeArgs(self.svgArgs, outputFile)
97 outputFile.write('>\n<defs>\n')
98 # Write definition elements
99 idIndex = 0
100 def idGen(base='d'):
101 nonlocal idIndex
102 idStr = base + str(idIndex)
103 idIndex += 1
104 return idStr
105 prevSet = set((id(defn) for defn in self.otherDefs))
106 def isDuplicate(obj):
107 nonlocal prevSet
108 dup = id(obj) in prevSet
109 prevSet.add(id(obj))
110 return dup
111 for element in self.otherDefs:
112 try:
113 element.writeSvgElement(outputFile)
114 outputFile.write('\n')
115 except AttributeError:
116 pass
117 for element in self.elements:
118 try:
119 element.writeSvgDefs(idGen, isDuplicate, outputFile)
120 except AttributeError:
121 pass
122 outputFile.write('</defs>\n')
123 # Write normal elements
124 for element in self.elements:
125 try:
126 element.writeSvgElement(outputFile)
127 outputFile.write('\n')
128 except AttributeError:
129 pass
130 outputFile.write(endStr)
131 if returnString:
132 return outputFile.getvalue()
133 def saveSvg(self, fname):
134 with open(fname, 'w') as f:
135 self.asSvg(outputFile=f)
136 def savePng(self, fname):
137 self.rasterize(toFile=fname)
138 def rasterize(self, toFile=None):
139 if toFile:
140 return Raster.fromSvgToFile(self.asSvg(), toFile)
141 else:
142 return Raster.fromSvg(self.asSvg())
143 def _repr_svg_(self):
144 ''' Display in Jupyter notebook '''
145 return self.asSvg()
146