2011/results/att11/spikevision: LCAsal_UDP.py

File LCAsal_UDP.py, 6.3 KB (added by sshapero, 7 years ago)

NEF Network that takes UDP inputs, finds sparse representation, and attemtps to calculate saliency map. Needs serious tweaking. Takes a long time to load!

Line 
1import nef
2from numeric import *
3import socket, sys, time
4from struct import *
5import shelve
6
7fmt_string = '!iI'
8stride = 8
9y_offset = 8
10y_length = 8
11x_offset = 1
12x_length = 8
13
14HOST = "127.0.0.1"
15PORT = str(9000)
16BUFFSIZE = 2048
17ADDR = (HOST,PORT)
18
19
20class event:
21    def __init__(self,addr,time):
22        self.time = time
23        self.x = 127 - ((addr>>x_offset)&2**(x_length-1)-1)
24        self.y = (addr>>y_offset)&2**(y_length-1)-1
25        self.intensity = 1 - (addr & 1)
26    def __repr__(self):
27        return repr((self.intensity,self.x,self.y,self.time))
28
29# Use the below for command line inputs
30#host = sys.argv[1]
31#textport = sys.argv[2]
32
33host = HOST
34textport = PORT
35
36
37
38def sth(x):
39        th=.6
40        if x>th:
41                return (x-th)
42        elif x<-th:
43                return (x+th)
44        else:
45                return 0
46
47def sthn(x):
48        return [sth(y) for y in x]
49
50def absn(x):
51        return [abs(sth(y)) for y in x]
52
53def rect(x):
54    if x>0:
55        return x
56    else:
57        return 0
58
59def rectn(x):
60    return [rect(y) for y in x]
61
62def zero(m,n):
63    # Create zero matrix
64    new_matrix = [[0 for row in range(n)] for col in range(m)]
65    return new_matrix
66
67def eye(n):
68    new_matrix = [[(row == col) for row in range(n)] for col in range(n)]
69    return new_matrix
70
71def multmin(matrix1):
72        # Matrix multiplication
73        new_matrix = zero(len(matrix1),len(matrix1))
74        for i in range(len(matrix1)):
75                for j in range(len(matrix1)):
76                        if(i != j):
77                                for k in range(len(matrix1[0])):
78                                        new_matrix[i][j] -= matrix1[i][k]*matrix1[j][k]
79
80        return new_matrix
81
82
83def transpose(matrix1):
84        new_matrix = zero(len(matrix1[0]),len(matrix1))
85        for i in range(len(matrix1)):
86                for j in range(len(matrix1[0])):
87                        new_matrix[j][i] = matrix1[i][j]
88
89        return new_matrix
90
91db=shelve.open('data_sal')
92if (db.has_key('dict')):
93    PHI_pre = db['dict']
94    recur_pre = db['recur']
95    salrec_pre = db['salrec']
96else:
97    f = open('thdictionary.txt')
98
99    num_dict = 128*8
100    PHI_pre = []
101    for img in range(num_dict):
102        line = f.readline()
103        dictpre=[]
104        first = 0
105        for char in range(len(line)):
106                if (line[char] == ' '):
107                        last = char
108                        dictpre.append(round(float(line[first:last]),4))
109                        first = char+1
110        #print img     
111        PHI_pre.append(dictpre)
112    f.close()
113
114    f = open('threcur.txt')
115
116    recur_pre = []
117    for img in range(num_dict):
118        line = f.readline()
119        dictpre=[]
120        first = 0
121        for char in range(len(line)):
122                if (line[char] == ' '):
123                        last = char
124                        dictpre.append(round(float(line[first:last]),4))
125                        first = char+1
126               
127        recur_pre.append(dictpre)
128
129    f.close()
130
131    f = open('salrecur.txt')
132
133    salrec_pre = []
134    for img in range(num_dict):
135        line = f.readline()
136        dictpre=[]
137        first = 0
138        for char in range(len(line)):
139                if (line[char] == ' '):
140                        last = char
141                        dictpre.append(round(float(line[first:last]),4))
142                        first = char+1
143               
144        salrec_pre.append(dictpre)
145
146    f.close()
147   
148    db['dict']=PHI_pre
149    db['recur']=recur_pre
150    db['salrec']=salrec_pre
151
152db.close()
153       
154dict = array(PHI_pre)
155recur = array(recur_pre)
156salrec = array(salrec_pre)
157
158
159numNodes = len(dict)
160numInputs = len(dict[0])
161
162
163
164class MyInput(nef.SimpleNode):
165        def origin_value(self):
166                out_pre = zero(1,576)
167                out = out_pre[0][:]
168                compress = 2
169                if(self.t_start < .11):
170                    logtime = 12
171                    self.oldevents = logtime*[[]]
172                else:
173                    try:
174                        data, addr = sock.recvfrom(BUFFSIZE)
175                    except socket.error:
176                        self.oldevents[1:]=self.oldevents[:-1]
177                        self.oldevents[0]=[]
178                    else:
179                        #print addr
180                        #print "sanity check:", addr[1]/4, (len(data)-4)/8, (len(data)-4)%8
181                        #if (((len(data)-4)%8 != 0)or ((len(data)-4)/8 ==0 )):
182                        #print "There is an error in the format of the packets sent"
183                        #sys.exit(1)
184                        tind = unpack("!i",data[0:4])
185                        # print tind
186                        event_list = []   
187                        for iter in range(4,len(data)-4,stride):
188                                message = unpack(fmt_string,data[iter:(iter+stride)])
189                                #print message
190                                event_list.append(event(message[0],message[1]))
191                               
192                                self.oldevents[1:]=self.oldevents[:-1]
193                        self.oldevents[0]=event_list
194                               
195                for evlist in self.oldevents:
196                    for events in evlist:
197                        x = int(events.x/compress)
198                        y = int(events.y/compress)
199                        outind = int(x+24*(24-y-1))
200                        #print x, y, outind
201                        if (x < 24 and y < 24 and x > -1 and y > -1):
202                                #print events.time
203                                out[outind] = 2*int(events.intensity)-1
204                return out
205
206
207lif = array(eye(numInputs))*0.5
208
209net=nef.Network('LCA',quick=True)
210#net.add_to(world)
211
212#input=net.make_input('input',clippre[15])
213myinput=MyInput('input')
214net.add(myinput)
215#infilter=net.make('infilter',1,numInputs,mode='direct')
216neuron=net.make_array('neurons',20,numNodes,intercept=(0,1))
217decoders=net.make('decoders',1,numInputs,mode='direct')
218neuron_value=net.make('neuron_value',1,numNodes,mode='direct')
219salience=net.make_array('salience',20,numNodes,intercept=(0,1))
220sal_decode=net.make('sal_decoder',1,numInputs,mode='direct')
221
222#net.connect(input,neuron,transform=dict)
223#net.connect(myinput.getOrigin('value'),infilter,pstc=.0005)
224#net.connect(infilter,infilter,transform=lif,pstc=.0005)
225net.connect(myinput.getOrigin('value'),neuron,transform=dict,pstc=.002)
226net.connect(neuron,neuron,transform=recur,func=sthn,pstc=.002)
227net.connect(neuron,decoders,transform=dict.T,func=sthn,pstc=.0005)
228net.connect(neuron,neuron_value,func=sthn,pstc=.0005)
229net.connect(neuron,salience,func=absn,pstc=.0005)
230net.connect(salience,salience,transform=salrec,func=rectn,pstc=.0005)
231net.connect(salience,sal_decode,transform=dict.T,func=rectn,pstc=.0002)
232
233sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
234sock.bind((HOST,int(PORT) ))
235sock.settimeout(.001)
236
237net.view(play=0.1)
238
239#sim=net.network.simulator
240#sim.run(0,1,0.001)
241
242#print neuron_value.getOrigin('X').getValues().getValues()
243#print outs.getOrigin('X').getValues().getValues()