2013/uns13/resources: spinn.py

File spinn.py, 18.5 KB (added by drasmuss, 5 years ago)

updated spinn module

Line 
1#!/usr/bin/python
2"""
3Interface to dump a nengo network in a python file that can be
4imported by PACMAN
5
6
7This Nengo module is called at the end of a Nengo script that needs
8to be translated for spinnaker, and produces a python file containing
9the network specifications as follows ::
10
11   pop = {} # a dictionary of populations
12
13
14Each entry in the pop dictionary has as the keyname the name of the population,
15and the following entries:
16 - n: number of neurons in the population, int
17 - dimensions: dimensions of the encoders/decoders in the population, int
18 - bias: bias current values, n x 1 array
19 - encoders: encoders for the population, Dxn array
20 - decoders: decoders for the population, Dxn array
21 - taus: a 4 value array, with synaptic time constants (up to 4 different synapses)
22
23Analogously, a projection dictionary proj{} is created ::
24
25   proj = {} # a dictionary of projections
26   
27where each entry is a dictionary, with a double key [pre, post] where
28pre and post are the presynaptic and postsynaptic label names, with the
29following entries:
30
31 - tau: synaptic time constant
32 - w: weight matrix
33 - The following structures are also inserted:
34 - inputs: a list of population labels, identified as inputs, which can be controlled in the runtime environment
35 - outputs: a list of population labels identified as outputs, which can be controlled in the runtime environment
36 - robot outputs: a list of population labels identified as robotic outputs, which can be used to control the robot
37 - neurons_per_core: a dictionary of {'population label' : number of neurons per core (int)},
38 as an indication to override default options in the splitter on a population basis
39 - runtime: the simulation duration
40
41Author: Terry Stewart, Francesco Galluppi, Daniel Rasmussen
42Email:  francesco.galluppi@cs.man.ac.uk
43"""
44
45from ca.nengo.model.impl import NetworkImpl, NetworkArrayImpl
46import ca.nengo.model.nef
47import numeric as np
48import optsparse
49from ca.nengo.util import MU
50import nef
51import math
52
53
54def iszero(transform):
55    for row in transform:
56        for x in row:
57            if x!=0: return False
58    return True
59
60class Population:
61    def __init__(self, n, prefix=''):
62        DT = 0.001
63        self.bias = [nn.bias for nn in n.nodes]
64        self.encoders = [[nn.scale*n.encoders[i][j] for i,nn in enumerate(n.nodes) ] for j in range(n.dimension) ]
65        self.decoders = [[x[j]/DT for x in n.getOrigin('X').decoders ] for j in range(n.dimension)]
66        self.name = prefix+n.name       
67        self.dimension = n.dimension
68        self.taus = []
69
70    def generate_text(self):
71        spacing = ' '*4
72        text=['{']
73        text.append("%s'%s': %s,"%(spacing, 'dimensions', self.dimension))
74        text.append("%s'%s': %s,"%(spacing, 'bias', self.bias))
75        text.append("%s'%s': %s,"%(spacing, 'encoders', self.encoders))
76        text.append("%s'%s': %s,"%(spacing, 'decoders', self.decoders))
77        while len(self.taus)<4: self.taus.append(2)
78        text.append("%s'%s': %s,"%(spacing, 'taus', self.taus))
79        text.append("%s}"%spacing)
80        return '\n'.join(text)
81       
82       
83class Projection:   
84    def __init__(self, spinn, origin, termination, transform = None):
85        if isinstance(termination, ca.nengo.model.nef.impl.DecodedTermination):
86            scale = [nn.scale for nn in termination.node.nodes]
87            if transform is None:
88                transform = termination.transform
89            if origin.node.neurons>spinn.max_fan_in:
90                w = optsparse.compute_sparse_weights(origin, termination.node, transform, spinn.max_fan_in)
91            else:   
92                w = MU.prod(termination.node.encoders,MU.prod(transform,MU.transpose(origin.decoders)))
93                w = MU.prod(w,1.0/termination.node.radii[0])
94               
95            for i in range(len(w)):
96                for j in range(len(w[i])):
97                    w[i][j] *= scale[i] / termination.tau
98               
99            w = MU.transpose(w)
100        else:
101            #then it is an ensemble termination
102            scale = [t.node.scale for t in termination.getNodeTerminations()]
103            if transform != None:
104                t_weights = transform
105            else:
106                t_weights = []
107                for t in termination.getNodeTerminations():
108                    t_weights += [t.getWeights()]
109            w = MU.prod(t_weights, MU.transpose(origin.decoders))
110           
111            for i in range(len(w)):
112                for j in range(len(w[i])):
113                    w[i][j] *= float(scale[i]) / termination.tau
114           
115            w = MU.transpose(w)
116       
117        self.weights = w
118        self.tau = int(round(termination.tau*1000))
119        if self.tau not in spinn.populations[termination.node].taus:
120            spinn.populations[termination.node].taus.append(self.tau)
121       
122        self.pre = spinn.populations[origin.node].name
123        self.post = spinn.populations[termination.node].name
124       
125    def generate_text(self):
126        spacing = ' '*4
127
128        weight=[]
129        for row in self.weights:
130            t=', '.join(['%1.5f'%w for w in row])
131            weight.append('[%s],'%(t))
132        weight='\n          '.join(weight)
133
134        text=['{']
135        text.append("%s'%s': %d,"%(spacing, 'tau', self.tau))
136        text.append("%s'%s': [%s],"%(spacing, 'w', weight))
137        text.append("%s}"%spacing)
138       
139        return '\n'.join(text)
140       
141       
142class SpiNN:
143    def __init__(self, network, max_fan_in=50000):
144        self.neurons_per_core = {}
145        self.robot_outputs = []
146        self.data_outputs = []
147        self.network = network
148        self.populations = {}
149        self.projections = []
150        self.inputs = []       
151        self.max_fan_in = max_fan_in
152        self.runtime = 1*60*1000
153        self.io_dim_limit = 2
154       
155        self.flatten_relays(self.network)
156
157        self.add_spinn_io()
158
159        self.process_network(self.network)
160       
161    def get_taus(self):
162        taus = []
163        for p in self.projections:
164            if p.tau not in taus: taus.append(p.tau)
165        return taus
166   
167    def set_neurons_per_core(self, name, neurons_per_core):
168        self.neurons_per_core[name] = neurons_per_core
169       
170    def set_robot_output(self, name):
171        self.robot_outputs.append(name)   
172       
173    def set_runtime(self, runtime):
174        """
175        This function is used to set the duration of the simulation on spinnaker
176        """
177        self.runtime = runtime
178       
179    def process_network(self, network, prefix=''):     
180        for n in network.nodes:
181            if isinstance(n, ca.nengo.model.nef.NEFEnsemble):
182                self.populations[n] = Population(n, prefix)
183            elif isinstance(n, ca.nengo.model.Network):
184                self.process_network(n, prefix + n.name + '.')
185            elif isinstance(n, ca.nengo.model.impl.FunctionInput):
186                pass
187            else:
188                print 'WARNING: Unknown node',n
189       
190        for p in network.projections:
191            print p.origin.node.name, '->', p.termination.node.name
192            self.process_projection(p)
193       
194    def process_projection(self, p):
195        origin = p.origin
196        termination = p.termination
197       
198        if hasattr(origin, 'baseOrigin'): origin = origin.baseOrigin
199        if hasattr(termination, 'baseTermination'): termination = termination.baseTermination
200   
201        pre = origin.node
202        post = termination.node
203       
204        if isinstance(pre, ca.nengo.model.impl.NetworkArrayImpl) and isinstance(post, ca.nengo.model.impl.NetworkArrayImpl):
205            for term in termination.nodeTerminations:
206                transform = np.array(term.transform)
207                index = 0
208                for i,n in enumerate(origin.nodeOrigins):
209                    t = transform[:,index:index+n.dimensions]
210                    index += n.dimensions
211                    if not iszero(t):
212                        self.projections.append(Projection(self, n, term, transform=t))       
213        elif isinstance(pre, ca.nengo.model.impl.NetworkArrayImpl):
214            if hasattr(termination, "transform"):
215                transform = np.array(termination.transform)
216            else:
217                #then it's an ensemble termination
218                transform = []
219                for term in termination.getNodeTerminations():
220                    transform += [term.getWeights()]
221                transform = np.array(transform)
222            index = 0
223            for i,n in enumerate(origin.nodeOrigins):
224                t = transform[:,index:index+n.dimensions]
225                index += n.dimensions
226                if not iszero(t):
227                    self.projections.append(Projection(self, n, termination, transform=t))
228        elif isinstance(post, ca.nengo.model.impl.NetworkArrayImpl):
229            for t in termination.nodeTerminations:
230                if isinstance(pre, ca.nengo.model.impl.FunctionInput):
231                    self.inputs.append(str(self.populations[t.node].name))
232                else:
233                    self.projections.append(Projection(self, origin, t))
234        elif isinstance(pre, ca.nengo.model.impl.FunctionInput):
235            self.inputs.append(str(self.populations[post].name))
236        elif pre in self.populations and post in self.populations:
237            self.projections.append(Projection(self, origin, termination))
238        else:
239            print 'WARNING: Unknown projection', p
240            print '         pre: ',pre
241            print '         post: ',post
242           
243    def flatten_relays(self, network):
244        """This function will remove any relays from the network and replace them with
245        direct connections.  For example, a network configuration such as
246         A --\           /-- C
247              ---relay---
248         B --/           \-- D
249         
250         will be changed to
251         
252         A --\--/-- C
253              /
254         B --/-\-- D"""
255         
256        #the basic algorithm:
257        #expose all relays in this network
258        #recurse on subnetworks
259        #find any newly exposed terminations on subnetworks, and connect to them
260       
261        relays = network.getMetaData("spinn_relays")
262        if relays != None:
263            for r in relays:
264                o = r.getOrigin("X")
265                posts = [p.getTermination() for p in network.projections if p.getOrigin() == o]
266                for i,t in enumerate(posts):
267                    network.exposeTermination(t,
268                                              network.getExposedTerminationName(r.getTerminations()[0])+"_relay_%i"%i)
269                    network.removeProjection(t)
270                network.removeNode(r.name)
271                   
272        for n in network.nodes:
273            if isinstance(n, ca.nengo.model.impl.NetworkImpl):
274                self.flatten_relays(n)
275               
276                old_ts = {}
277                ts = [t for t in n.getTerminations() if "_relay_" in t.name]
278                for t in ts:
279                    #find the termination that was originally projected to on the network
280                    prefix = t.name[:t.name.index("_")]
281                    old_t = n.getTermination(prefix)
282                    old_ts[old_t] = True
283                   
284                    #send projections to new terminations
285                    for p in network.projections:
286                        #find projections to the old termination
287                        if p.getTermination() == old_t:
288                            network.addProjection(p.getOrigin(), t)
289                           
290                #remove the old projections
291                for t in old_ts.keys():
292                    try:
293                        network.removeProjection(t)
294                    except ca.nengo.model.StructuralException:
295                        #then there was no projection on that termination anyway
296                        pass
297           
298    def add_spinn_io(self):
299        """This method modifies the network to add the input/output populations
300        needed to communicate with the Spinnaker I/O.  The basic restriction is
301        that all inputs and outputs can be at most two dimensions, so this method
302        takes any larger inputs/outputs, breaks them up into two dimensional chunks,
303        and then combines them into a complete representation internally within the
304        network."""
305       
306        net = nef.Network(self.network)
307       
308        #break all inputs down into lower-d chunks
309        inputs = [n for n in net.network.nodes if isinstance(n,ca.nengo.model.impl.FunctionInput)]
310        for input in inputs:
311            dim = input.getOrigin("origin").dimensions
312            if dim > self.io_dim_limit:
313               
314                #find all the places this input projects to
315                posts = [p.getTermination() for p in net.network.projections if p.getOrigin() == input.getOrigin("origin")]
316               
317                #remove the existing projections
318                for term in posts:
319                    net.network.removeProjection(term)
320                   
321                #create sub-inputs
322                num_sub = dim/self.io_dim_limit + (1 if dim%self.io_dim_limit!=0 else 0)
323                for i in range(num_sub):
324                    index_range = range(i*self.io_dim_limit, min((i+1)*self.io_dim_limit,dim))
325                    sub = net.make("%s_sub_%i"%(input.name,i), 50*len(index_range), len(index_range))
326                    net.connect(input, sub, index_pre=index_range)
327                   
328                    #connect sub back to original post node
329                    for term in posts:
330                        nodename = term.node.name
331                        if isinstance(term, NetworkImpl.TerminationWrapper):
332                            #then we know this is a termination on a network
333                            if not isinstance(term.node, NetworkArrayImpl):
334                                #we don't connect to the inner populations if it's a networkarray,
335                                #because networkarrays can handle direct connections
336                                nodename = nodename + "." + term.getBaseTermination().node.name
337                            term = term.getBaseTermination()
338                           
339                           
340                        #figure out the original transform
341                        t = []
342                        if isinstance(term, ca.nengo.model.nef.impl.DecodedTermination):
343                            t = term.transform
344                        else:
345                            #then it's an ensemble termination
346                            t = []
347                            for nterm in term.getNodeTerminations():
348                                if isinstance(nterm, ca.nengo.model.nef.impl.DecodedTermination):
349                                    t += nterm.transform
350                                else:
351                                    t += [nterm.weights]
352                       
353                        #connect sub to post with the appropriate part of transform matrix
354                        t = MU.transpose(t) #this makes it easier to select the appropriate column
355                        net.connect(sub, nodename,
356                                    transform=MU.transpose([t[j] for j in index_range]))
357               
358        #break all outputs up into lower-d chunks
359        #assuming here that all outputs are nodes, and we will always use
360        #the default origin as output
361        outputs = net.network.getMetaData("spinn_outputs")
362        if outputs == None:
363            outputs = []
364           
365        #detect all nodes with no outgoing projections, and assume they are outputs
366        for n in net.network.nodes:
367            if not n in outputs and len([p for o in n.getOrigins() for p in net.network.projections if p.getOrigin() == o]) == 0:
368                outputs += [n]
369               
370        for o in outputs:
371            nodename = o.name
372            try:
373                origin = o.getOrigin("X") #note: assume all output is from "X"
374            except ca.nengo.model.StructuralException:
375                #then it has no X origin
376                continue
377           
378            if isinstance(origin, NetworkImpl.OriginWrapper) and not isinstance(o, NetworkArrayImpl):
379                origin = origin.getBaseOrigin()
380                nodename = nodename + "." + origin.node.name 
381           
382            dim = net.get(nodename).dimension
383               
384            #check whether the output node is the target of an input (since a population can't be
385            #both an input and an output)
386            isinput = False
387            for proj in net.network.projections:
388                if proj.getOrigin().node in inputs:
389                    if isinstance(proj.getTermination(), NetworkImpl.TerminationWrapper):
390                        tnode = proj.getTermination().getBaseTermination().node
391                    else:
392                        tnode = proj.getTermination().node
393                    if tnode == net.get(nodename):
394                        isinput = True
395               
396            if dim <= self.io_dim_limit and not isinput:
397                self.data_outputs += [net.get(nodename)]
398            else:
399                num_sub = dim/self.io_dim_limit + (1 if dim%self.io_dim_limit!=0 else 0)
400                for i in range(num_sub):
401                    index_range=range(i*self.io_dim_limit, min((i+1)*self.io_dim_limit, dim))
402                    sub = net.make("%s_sub_%i"%(o.name,i), 50*len(index_range), len(index_range))
403                    net.connect(nodename, sub, index_pre=index_range)
404                    self.data_outputs += [sub]
405           
406    def print_info(self):
407        print self.get_taus()
408        for p in sorted(self.populations.values(), key=lambda a: a.name): print p.name
409        for p in self.projections: print '%s -> %s'%(p.pre, p.post)
410       
411    def generate_text(self):
412        text = []
413        text.append('pop = {}')
414        for p in sorted(self.populations.values(), key=lambda p:p.name):
415            text.append("pop['%s'] = %s"%(p.name, p.generate_text()))
416        text.append('proj = {}')   
417        try:
418            for p in sorted(self.projections, key=lambda p:(p.pre, p.post)):
419                text.append("proj['%s', '%s'] = %s"%(p.pre, p.post, p.generate_text()))
420        except:
421            print p
422        text.append("inputs = %s"%self.inputs)
423        text.append("robot_outputs = %s"%self.robot_outputs)
424        text.append("neurons_per_core = %s"%self.neurons_per_core)
425        text.append("runtime = %s\n"%self.runtime)
426       
427       
428        return '\n'.join(text)
429       
430       
431    def write_to_file(self, filename):
432        f=open(filename,'w')
433        f.write(self.generate_text())
434        f.close()
435