# 2013/uns13/resources: spinn.py

File spinn.py, 18.5 KB (added by drasmuss, 5 years ago) |
---|

Line | |
---|---|

1 | #!/usr/bin/python |

2 | """ |

3 | Interface to dump a nengo network in a python file that can be |

4 | imported by PACMAN |

5 | |

6 | |

7 | This Nengo module is called at the end of a Nengo script that needs |

8 | to be translated for spinnaker, and produces a python file containing |

9 | the network specifications as follows :: |

10 | |

11 | pop = {} # a dictionary of populations |

12 | |

13 | |

14 | Each entry in the pop dictionary has as the keyname the name of the population, |

15 | and the following entries: |

16 | - n: number of neurons in the population, int |

17 | - dimensions: dimensions of the encoders/decoders in the population, int |

18 | - bias: bias current values, n x 1 array |

19 | - encoders: encoders for the population, Dxn array |

20 | - decoders: decoders for the population, Dxn array |

21 | - taus: a 4 value array, with synaptic time constants (up to 4 different synapses) |

22 | |

23 | Analogously, a projection dictionary proj{} is created :: |

24 | |

25 | proj = {} # a dictionary of projections |

26 | |

27 | where each entry is a dictionary, with a double key [pre, post] where |

28 | pre and post are the presynaptic and postsynaptic label names, with the |

29 | following entries: |

30 | |

31 | - tau: synaptic time constant |

32 | - w: weight matrix |

33 | - The following structures are also inserted: |

34 | - inputs: a list of population labels, identified as inputs, which can be controlled in the runtime environment |

35 | - outputs: a list of population labels identified as outputs, which can be controlled in the runtime environment |

36 | - robot outputs: a list of population labels identified as robotic outputs, which can be used to control the robot |

37 | - neurons_per_core: a dictionary of {'population label' : number of neurons per core (int)}, |

38 | as an indication to override default options in the splitter on a population basis |

39 | - runtime: the simulation duration |

40 | |

41 | Author: Terry Stewart, Francesco Galluppi, Daniel Rasmussen |

42 | Email: francesco.galluppi@cs.man.ac.uk |

43 | """ |

44 | |

45 | from ca.nengo.model.impl import NetworkImpl, NetworkArrayImpl |

46 | import ca.nengo.model.nef |

47 | import numeric as np |

48 | import optsparse |

49 | from ca.nengo.util import MU |

50 | import nef |

51 | import math |

52 | |

53 | |

54 | def iszero(transform): |

55 | for row in transform: |

56 | for x in row: |

57 | if x!=0: return False |

58 | return True |

59 | |

60 | class Population: |

61 | def __init__(self, n, prefix=''): |

62 | DT = 0.001 |

63 | self.bias = [nn.bias for nn in n.nodes] |

64 | self.encoders = [[nn.scale*n.encoders[i][j] for i,nn in enumerate(n.nodes) ] for j in range(n.dimension) ] |

65 | self.decoders = [[x[j]/DT for x in n.getOrigin('X').decoders ] for j in range(n.dimension)] |

66 | self.name = prefix+n.name |

67 | self.dimension = n.dimension |

68 | self.taus = [] |

69 | |

70 | def generate_text(self): |

71 | spacing = ' '*4 |

72 | text=['{'] |

73 | text.append("%s'%s': %s,"%(spacing, 'dimensions', self.dimension)) |

74 | text.append("%s'%s': %s,"%(spacing, 'bias', self.bias)) |

75 | text.append("%s'%s': %s,"%(spacing, 'encoders', self.encoders)) |

76 | text.append("%s'%s': %s,"%(spacing, 'decoders', self.decoders)) |

77 | while len(self.taus)<4: self.taus.append(2) |

78 | text.append("%s'%s': %s,"%(spacing, 'taus', self.taus)) |

79 | text.append("%s}"%spacing) |

80 | return '\n'.join(text) |

81 | |

82 | |

83 | class Projection: |

84 | def __init__(self, spinn, origin, termination, transform = None): |

85 | if isinstance(termination, ca.nengo.model.nef.impl.DecodedTermination): |

86 | scale = [nn.scale for nn in termination.node.nodes] |

87 | if transform is None: |

88 | transform = termination.transform |

89 | if origin.node.neurons>spinn.max_fan_in: |

90 | w = optsparse.compute_sparse_weights(origin, termination.node, transform, spinn.max_fan_in) |

91 | else: |

92 | w = MU.prod(termination.node.encoders,MU.prod(transform,MU.transpose(origin.decoders))) |

93 | w = MU.prod(w,1.0/termination.node.radii[0]) |

94 | |

95 | for i in range(len(w)): |

96 | for j in range(len(w[i])): |

97 | w[i][j] *= scale[i] / termination.tau |

98 | |

99 | w = MU.transpose(w) |

100 | else: |

101 | #then it is an ensemble termination |

102 | scale = [t.node.scale for t in termination.getNodeTerminations()] |

103 | if transform != None: |

104 | t_weights = transform |

105 | else: |

106 | t_weights = [] |

107 | for t in termination.getNodeTerminations(): |

108 | t_weights += [t.getWeights()] |

109 | w = MU.prod(t_weights, MU.transpose(origin.decoders)) |

110 | |

111 | for i in range(len(w)): |

112 | for j in range(len(w[i])): |

113 | w[i][j] *= float(scale[i]) / termination.tau |

114 | |

115 | w = MU.transpose(w) |

116 | |

117 | self.weights = w |

118 | self.tau = int(round(termination.tau*1000)) |

119 | if self.tau not in spinn.populations[termination.node].taus: |

120 | spinn.populations[termination.node].taus.append(self.tau) |

121 | |

122 | self.pre = spinn.populations[origin.node].name |

123 | self.post = spinn.populations[termination.node].name |

124 | |

125 | def generate_text(self): |

126 | spacing = ' '*4 |

127 | |

128 | weight=[] |

129 | for row in self.weights: |

130 | t=', '.join(['%1.5f'%w for w in row]) |

131 | weight.append('[%s],'%(t)) |

132 | weight='\n '.join(weight) |

133 | |

134 | text=['{'] |

135 | text.append("%s'%s': %d,"%(spacing, 'tau', self.tau)) |

136 | text.append("%s'%s': [%s],"%(spacing, 'w', weight)) |

137 | text.append("%s}"%spacing) |

138 | |

139 | return '\n'.join(text) |

140 | |

141 | |

142 | class SpiNN: |

143 | def __init__(self, network, max_fan_in=50000): |

144 | self.neurons_per_core = {} |

145 | self.robot_outputs = [] |

146 | self.data_outputs = [] |

147 | self.network = network |

148 | self.populations = {} |

149 | self.projections = [] |

150 | self.inputs = [] |

151 | self.max_fan_in = max_fan_in |

152 | self.runtime = 1*60*1000 |

153 | self.io_dim_limit = 2 |

154 | |

155 | self.flatten_relays(self.network) |

156 | |

157 | self.add_spinn_io() |

158 | |

159 | self.process_network(self.network) |

160 | |

161 | def get_taus(self): |

162 | taus = [] |

163 | for p in self.projections: |

164 | if p.tau not in taus: taus.append(p.tau) |

165 | return taus |

166 | |

167 | def set_neurons_per_core(self, name, neurons_per_core): |

168 | self.neurons_per_core[name] = neurons_per_core |

169 | |

170 | def set_robot_output(self, name): |

171 | self.robot_outputs.append(name) |

172 | |

173 | def set_runtime(self, runtime): |

174 | """ |

175 | This function is used to set the duration of the simulation on spinnaker |

176 | """ |

177 | self.runtime = runtime |

178 | |

179 | def process_network(self, network, prefix=''): |

180 | for n in network.nodes: |

181 | if isinstance(n, ca.nengo.model.nef.NEFEnsemble): |

182 | self.populations[n] = Population(n, prefix) |

183 | elif isinstance(n, ca.nengo.model.Network): |

184 | self.process_network(n, prefix + n.name + '.') |

185 | elif isinstance(n, ca.nengo.model.impl.FunctionInput): |

186 | pass |

187 | else: |

188 | print 'WARNING: Unknown node',n |

189 | |

190 | for p in network.projections: |

191 | print p.origin.node.name, '->', p.termination.node.name |

192 | self.process_projection(p) |

193 | |

194 | def process_projection(self, p): |

195 | origin = p.origin |

196 | termination = p.termination |

197 | |

198 | if hasattr(origin, 'baseOrigin'): origin = origin.baseOrigin |

199 | if hasattr(termination, 'baseTermination'): termination = termination.baseTermination |

200 | |

201 | pre = origin.node |

202 | post = termination.node |

203 | |

204 | if isinstance(pre, ca.nengo.model.impl.NetworkArrayImpl) and isinstance(post, ca.nengo.model.impl.NetworkArrayImpl): |

205 | for term in termination.nodeTerminations: |

206 | transform = np.array(term.transform) |

207 | index = 0 |

208 | for i,n in enumerate(origin.nodeOrigins): |

209 | t = transform[:,index:index+n.dimensions] |

210 | index += n.dimensions |

211 | if not iszero(t): |

212 | self.projections.append(Projection(self, n, term, transform=t)) |

213 | elif isinstance(pre, ca.nengo.model.impl.NetworkArrayImpl): |

214 | if hasattr(termination, "transform"): |

215 | transform = np.array(termination.transform) |

216 | else: |

217 | #then it's an ensemble termination |

218 | transform = [] |

219 | for term in termination.getNodeTerminations(): |

220 | transform += [term.getWeights()] |

221 | transform = np.array(transform) |

222 | index = 0 |

223 | for i,n in enumerate(origin.nodeOrigins): |

224 | t = transform[:,index:index+n.dimensions] |

225 | index += n.dimensions |

226 | if not iszero(t): |

227 | self.projections.append(Projection(self, n, termination, transform=t)) |

228 | elif isinstance(post, ca.nengo.model.impl.NetworkArrayImpl): |

229 | for t in termination.nodeTerminations: |

230 | if isinstance(pre, ca.nengo.model.impl.FunctionInput): |

231 | self.inputs.append(str(self.populations[t.node].name)) |

232 | else: |

233 | self.projections.append(Projection(self, origin, t)) |

234 | elif isinstance(pre, ca.nengo.model.impl.FunctionInput): |

235 | self.inputs.append(str(self.populations[post].name)) |

236 | elif pre in self.populations and post in self.populations: |

237 | self.projections.append(Projection(self, origin, termination)) |

238 | else: |

239 | print 'WARNING: Unknown projection', p |

240 | print ' pre: ',pre |

241 | print ' post: ',post |

242 | |

243 | def flatten_relays(self, network): |

244 | """This function will remove any relays from the network and replace them with |

245 | direct connections. For example, a network configuration such as |

246 | A --\ /-- C |

247 | ---relay--- |

248 | B --/ \-- D |

249 | |

250 | will be changed to |

251 | |

252 | A --\--/-- C |

253 | / |

254 | B --/-\-- D""" |

255 | |

256 | #the basic algorithm: |

257 | #expose all relays in this network |

258 | #recurse on subnetworks |

259 | #find any newly exposed terminations on subnetworks, and connect to them |

260 | |

261 | relays = network.getMetaData("spinn_relays") |

262 | if relays != None: |

263 | for r in relays: |

264 | o = r.getOrigin("X") |

265 | posts = [p.getTermination() for p in network.projections if p.getOrigin() == o] |

266 | for i,t in enumerate(posts): |

267 | network.exposeTermination(t, |

268 | network.getExposedTerminationName(r.getTerminations()[0])+"_relay_%i"%i) |

269 | network.removeProjection(t) |

270 | network.removeNode(r.name) |

271 | |

272 | for n in network.nodes: |

273 | if isinstance(n, ca.nengo.model.impl.NetworkImpl): |

274 | self.flatten_relays(n) |

275 | |

276 | old_ts = {} |

277 | ts = [t for t in n.getTerminations() if "_relay_" in t.name] |

278 | for t in ts: |

279 | #find the termination that was originally projected to on the network |

280 | prefix = t.name[:t.name.index("_")] |

281 | old_t = n.getTermination(prefix) |

282 | old_ts[old_t] = True |

283 | |

284 | #send projections to new terminations |

285 | for p in network.projections: |

286 | #find projections to the old termination |

287 | if p.getTermination() == old_t: |

288 | network.addProjection(p.getOrigin(), t) |

289 | |

290 | #remove the old projections |

291 | for t in old_ts.keys(): |

292 | try: |

293 | network.removeProjection(t) |

294 | except ca.nengo.model.StructuralException: |

295 | #then there was no projection on that termination anyway |

296 | pass |

297 | |

298 | def add_spinn_io(self): |

299 | """This method modifies the network to add the input/output populations |

300 | needed to communicate with the Spinnaker I/O. The basic restriction is |

301 | that all inputs and outputs can be at most two dimensions, so this method |

302 | takes any larger inputs/outputs, breaks them up into two dimensional chunks, |

303 | and then combines them into a complete representation internally within the |

304 | network.""" |

305 | |

306 | net = nef.Network(self.network) |

307 | |

308 | #break all inputs down into lower-d chunks |

309 | inputs = [n for n in net.network.nodes if isinstance(n,ca.nengo.model.impl.FunctionInput)] |

310 | for input in inputs: |

311 | dim = input.getOrigin("origin").dimensions |

312 | if dim > self.io_dim_limit: |

313 | |

314 | #find all the places this input projects to |

315 | posts = [p.getTermination() for p in net.network.projections if p.getOrigin() == input.getOrigin("origin")] |

316 | |

317 | #remove the existing projections |

318 | for term in posts: |

319 | net.network.removeProjection(term) |

320 | |

321 | #create sub-inputs |

322 | num_sub = dim/self.io_dim_limit + (1 if dim%self.io_dim_limit!=0 else 0) |

323 | for i in range(num_sub): |

324 | index_range = range(i*self.io_dim_limit, min((i+1)*self.io_dim_limit,dim)) |

325 | sub = net.make("%s_sub_%i"%(input.name,i), 50*len(index_range), len(index_range)) |

326 | net.connect(input, sub, index_pre=index_range) |

327 | |

328 | #connect sub back to original post node |

329 | for term in posts: |

330 | nodename = term.node.name |

331 | if isinstance(term, NetworkImpl.TerminationWrapper): |

332 | #then we know this is a termination on a network |

333 | if not isinstance(term.node, NetworkArrayImpl): |

334 | #we don't connect to the inner populations if it's a networkarray, |

335 | #because networkarrays can handle direct connections |

336 | nodename = nodename + "." + term.getBaseTermination().node.name |

337 | term = term.getBaseTermination() |

338 | |

339 | |

340 | #figure out the original transform |

341 | t = [] |

342 | if isinstance(term, ca.nengo.model.nef.impl.DecodedTermination): |

343 | t = term.transform |

344 | else: |

345 | #then it's an ensemble termination |

346 | t = [] |

347 | for nterm in term.getNodeTerminations(): |

348 | if isinstance(nterm, ca.nengo.model.nef.impl.DecodedTermination): |

349 | t += nterm.transform |

350 | else: |

351 | t += [nterm.weights] |

352 | |

353 | #connect sub to post with the appropriate part of transform matrix |

354 | t = MU.transpose(t) #this makes it easier to select the appropriate column |

355 | net.connect(sub, nodename, |

356 | transform=MU.transpose([t[j] for j in index_range])) |

357 | |

358 | #break all outputs up into lower-d chunks |

359 | #assuming here that all outputs are nodes, and we will always use |

360 | #the default origin as output |

361 | outputs = net.network.getMetaData("spinn_outputs") |

362 | if outputs == None: |

363 | outputs = [] |

364 | |

365 | #detect all nodes with no outgoing projections, and assume they are outputs |

366 | for n in net.network.nodes: |

367 | if not n in outputs and len([p for o in n.getOrigins() for p in net.network.projections if p.getOrigin() == o]) == 0: |

368 | outputs += [n] |

369 | |

370 | for o in outputs: |

371 | nodename = o.name |

372 | try: |

373 | origin = o.getOrigin("X") #note: assume all output is from "X" |

374 | except ca.nengo.model.StructuralException: |

375 | #then it has no X origin |

376 | continue |

377 | |

378 | if isinstance(origin, NetworkImpl.OriginWrapper) and not isinstance(o, NetworkArrayImpl): |

379 | origin = origin.getBaseOrigin() |

380 | nodename = nodename + "." + origin.node.name |

381 | |

382 | dim = net.get(nodename).dimension |

383 | |

384 | #check whether the output node is the target of an input (since a population can't be |

385 | #both an input and an output) |

386 | isinput = False |

387 | for proj in net.network.projections: |

388 | if proj.getOrigin().node in inputs: |

389 | if isinstance(proj.getTermination(), NetworkImpl.TerminationWrapper): |

390 | tnode = proj.getTermination().getBaseTermination().node |

391 | else: |

392 | tnode = proj.getTermination().node |

393 | if tnode == net.get(nodename): |

394 | isinput = True |

395 | |

396 | if dim <= self.io_dim_limit and not isinput: |

397 | self.data_outputs += [net.get(nodename)] |

398 | else: |

399 | num_sub = dim/self.io_dim_limit + (1 if dim%self.io_dim_limit!=0 else 0) |

400 | for i in range(num_sub): |

401 | index_range=range(i*self.io_dim_limit, min((i+1)*self.io_dim_limit, dim)) |

402 | sub = net.make("%s_sub_%i"%(o.name,i), 50*len(index_range), len(index_range)) |

403 | net.connect(nodename, sub, index_pre=index_range) |

404 | self.data_outputs += [sub] |

405 | |

406 | def print_info(self): |

407 | print self.get_taus() |

408 | for p in sorted(self.populations.values(), key=lambda a: a.name): print p.name |

409 | for p in self.projections: print '%s -> %s'%(p.pre, p.post) |

410 | |

411 | def generate_text(self): |

412 | text = [] |

413 | text.append('pop = {}') |

414 | for p in sorted(self.populations.values(), key=lambda p:p.name): |

415 | text.append("pop['%s'] = %s"%(p.name, p.generate_text())) |

416 | text.append('proj = {}') |

417 | try: |

418 | for p in sorted(self.projections, key=lambda p:(p.pre, p.post)): |

419 | text.append("proj['%s', '%s'] = %s"%(p.pre, p.post, p.generate_text())) |

420 | except: |

421 | print p |

422 | text.append("inputs = %s"%self.inputs) |

423 | text.append("robot_outputs = %s"%self.robot_outputs) |

424 | text.append("neurons_per_core = %s"%self.neurons_per_core) |

425 | text.append("runtime = %s\n"%self.runtime) |

426 | |

427 | |

428 | return '\n'.join(text) |

429 | |

430 | |

431 | def write_to_file(self, filename): |

432 | f=open(filename,'w') |

433 | f.write(self.generate_text()) |

434 | f.close() |

435 |