2013/uns13/resources: spinn_viewer_example.py

File spinn_viewer_example.py, 7.2 KB (added by drasmuss, 5 years ago)

example of new viewer for nengo models on spinnaker

Line 
1import nef
2import java
3import jarray
4import struct
5import numeric
6import random
7import nengo_spinnaker_interface.packet as packet
8import robotsim
9import gradientutils
10import os
11import timeview
12import boxworldwatch
13import types
14
15
16#the input_group is a list of strings, where each
17#string contains a newline separated list of populations
18#along with their location on spinnaker
19#e.g., in order to create two four dimensional inputs,
20#this would take the form
21#input_groups=["""input0.0 chip_x chip_y core
22#input0.1 chip_x chip_y core""",
23#"""input1.0 chip_x chip_y core
24#input1.1 chip_x chip_y core"""]
25input_groups=[
26"""target_sub_1 0 0 3
27target_sub_0 0 0 5""",
28"""place_sub_0 0 0 7
29place_sub_1 0 0 8""",
30"""robot_move 1 0 5""",
31"""inhib_pre 1 0 9"""]
32
33#the same format applies for outputs
34output_groups=[
35"""decoder_GradientAction.action_ 0 0 4""",
36"""decoder_GradientAction.actions 0 0 1""",
37"""decoder_inhib_post 0 0 9"""]
38
39def parse_map(groups):   
40    result = [None for _ in range(len(groups))]
41    for i,group in enumerate(groups):
42        result[i] = [g.split() for g in group.split("\n")]
43        result[i].sort()
44        result[i] = [{"x":int(v[1]), "y":int(v[2]), "p":int(v[3])} for v in result[i] if len(v) > 0]
45       
46    return result
47
48INPUTS = parse_map(input_groups)
49
50OUTPUTS = parse_map(output_groups)
51
52class NullOutputStream(java.io.OutputStream):
53    def write(*args):
54        pass
55real_out=java.lang.System.out
56java.lang.System.setOut(java.io.PrintStream(NullOutputStream()))
57
58from org.apache.log4j import Logger,Level
59logger=Logger.getLogger(ca.nengo.util.Memory)
60logger.setLevel(Level.FATAL)
61
62
63
64class SpikeSender(nef.SimpleNode):
65    def __init__(self,name,ensemble):
66        nef.SimpleNode.__init__(self,name)
67        self.ensemble=ensemble
68        self.datastream=java.io.DataOutputStream(real_out)
69    def tick(self):
70        data=self.ensemble.getOrigin('AXON').getValues().getValues()
71        for i,spike in enumerate(data):
72            if spike:
73                self.datastream.writeBytes(struct.pack('<I',i+2048))       
74        self.datastream.writeInt(0xFFFFFFFF)
75
76class UDPValueSender(nef.SimpleNode):
77    pstc=0
78    def __init__(self,name,address,port):
79        nef.SimpleNode.__init__(self,name)
80        self.socket=java.net.DatagramSocket()
81        self.address=java.net.InetAddress.getByName(address)
82        self.port=port
83        self.header=packet.Packet()
84        self.last_value=[[None for _ in range(len(INPUTS[i]))] for i in range(len(INPUTS))]
85        self.header.tx_sdp_header=dict(flags=7,ip_tag=255,dst_port=(2<<5)+1,src_port=255,dst_chip=0,src_chip=0,cmd=257,arg0=1,arg1=0,arg2=0)
86       
87        self.data = [None for _ in range(len(INPUTS))]
88       
89        for i,in_set in enumerate(INPUTS):
90            class FuncWrapper:
91                #this wrapper is necessary to ensure we get unique functions,
92                #rather than a bunch of pointers to the same function
93                def __init__(self, index, data):
94                    self.index = index
95                    self.d = data
96                def func(self, x):
97                    self.d[self.index] = x
98
99            self.create_termination("input_%i"%i, FuncWrapper(i,self.data).func)
100            self.getTermination("input_%i"%i).setDimensions(len(in_set)*2)
101       
102    def tick(self):
103        for j,d in enumerate(self.data):
104            if d != None:
105                data=[int(x*256) for x in d]
106               
107                for i,curr_in in enumerate(INPUTS[j]):
108                    if data[i] != self.last_value[j][i]:
109                        self.header.tx_sdp_header['arg0']=1
110                        self.header.tx_sdp_header['dst_port'] = (4<<5) + curr_in["p"]
111                        self.header.tx_sdp_header["dst_chip"] = (curr_in["x"]<<8) | curr_in["y"]
112                        start=struct.pack('BB',0x01,00)
113                        msg=start+self.header.pack_sdp_header()
114                        for d in data[i*2:(i+1)*2]:
115                            msg+=struct.pack('<I',d)
116                        packet=java.net.DatagramPacket(msg,len(msg),self.address,self.port)
117                        self.socket.send(packet)
118                        self.last_value[j][i] = data[i]
119
120class UDPValueReceiver(nef.SimpleNode):
121    def __init__(self,name,dimensions=2,port=54321):
122        self.socket=java.net.DatagramSocket(port)
123        maxLength=65535
124        self.buffer=jarray.zeros(maxLength,'b')
125        self.packet=java.net.DatagramPacket(self.buffer,maxLength)
126        self.dimensions=dimensions
127               
128        self.outputs = [[0 for _ in range(dimensions*len(o))] for o in OUTPUTS]
129        nef.SimpleNode.__init__(self,name)
130       
131        for i,o in enumerate(self.outputs):
132            class FuncWrapper:
133                #this wrapper is necessary to ensure we get unique functions,
134                #rather than a bunch of pointers to the same function
135                def __init__(self, index, data):
136                    self.index = index
137                    self.d = data
138                def func(self):
139                    return self.d[self.index]
140            self.create_origin("output_%i"%i, FuncWrapper(i,self.outputs).func)
141             
142    def tick(self):
143        if self.t_start>0:
144            self.socket.receive(self.packet)
145            d=java.io.DataInputStream(java.io.ByteArrayInputStream(self.packet.getData()))
146            d.readByte()
147            d.readByte()
148            d.readByte()
149            d.readByte()
150            d.readByte()
151            p=d.readByte()
152            x=d.readByte()
153            y=d.readByte()
154            d.readShort()
155            int0=struct.unpack('<I',struct.pack('>I',d.readInt()))[0]
156            int1=struct.unpack('<I',struct.pack('>I',d.readInt()))[0]
157            d.readInt()
158            d.readInt()
159            int4=struct.unpack('<i',struct.pack('>i',d.readInt()))[0]
160            int5=struct.unpack('<i',struct.pack('>i',d.readInt()))[0]
161           
162            for i,group in enumerate(OUTPUTS):
163                for j,o in enumerate(group):
164                    if x == o["x"] and y == o["y"] and (p & 31) == o["p"] and int0==258 and int1==1:   
165                        self.outputs[i][j*2]=int4/256.0
166                        self.outputs[i][j*2+1]=int5/256.0   
167
168#create network           
169net=nef.Network('spinn_viewer_example')
170
171#set up packet i/o
172
173#this node will have inputs as specified in the
174#input_groups at the top of this file.  connections
175#to this node will be redirected to the spinnaker
176#cores specified in that list.
177uvs=UDPValueSender('uvs', 'spinn4', 17893)
178net.add(uvs)
179
180#this node will receive data from spinnaker, and make
181#it available as outputs on this node (again, as specified
182#in output_groups above)
183uvr=UDPValueReceiver('uvr')
184net.add(uvr)
185
186#connect up inputs to uvs
187#this will depend on what inputs you need in your network
188input1 = net.make_input("input1", [0,0])
189net.connect(input, uvs.getTermination("input_0"))
190input2 = net.make_input("input2", [0,0])
191net.connect(input, uvs.getTermination("input_1"))
192
193#open up a viewer
194#outputs can be displayed by right clicking on the "uvr" node
195net.view()