2013/SKIMonNengo: skimnetwork.py

File skimnetwork.py, 4.8 KB (added by drasmuss, 5 years ago)

code for creating a skim network in nengo

1import nef
3from ca.nengo.math.impl import WeightedCostApproximator, FixedSignalFunction, ConstantFunction, IndicatorPDF
4from ca.nengo.math import Function
5from ca.nengo.util import MU
6from ca.nengo.model.impl import NetworkImpl, FunctionInput
7from ca.nengo.model.nef.impl import DecodedOrigin
8from ca.nengo.model import Units
9from ca.nengo.plot.impl import DefaultPlotter
11class SKIMNetwork(NetworkImpl):
12    def __init__(self, num_dend, inputs, outputs, name="SKIMNetwork", dt=0.001, threshold=0.5):
13        NetworkImpl.__init__(self)
14        self.name = name
16        net = nef.Network(self)
17        num_outputs = len(outputs)
18        num_inputs = len(inputs)
19        pstc_range = IndicatorPDF(0.01,0.1) #range from which pstcs will be drawn for dendrites
20        delay_range = IndicatorPDF(0.001, 0.03) #range from which synaptic delays will be drawn for dendrites
22        #the input signal
23        signal = net.make_input("signal", [FixedSignalFunction(MU.transpose(inputs), i) for i in range(num_inputs)])
25        #the output signal
26        corr_sig = net.make_input("corr_sig", [FixedSignalFunction(MU.transpose(outputs), i) for i in range(num_outputs)])
28        #population of neurons that transforms input signal into spikes
29        input_pop = net.make("input_pop", num_inputs, num_inputs, encoders=MU.diag([1 for _ in range(num_inputs)]),
30                           intercept=(0,0), max_rate=(100,150))
32        net.connect(signal, input_pop)
34        #population of output neurons to detect patterns in input
35        neurons = net.make("neurons", num_outputs, num_outputs, encoders=MU.diag([1 for _ in range(num_outputs)]),
36                           intercept=(threshold,threshold), max_rate=(100,150))
38        #create a population of dendrites for each output neuron
39        dendrites = []
40        for i in range(num_outputs):
41            dendrites += [net.make("dendrites_%i"%i, num_dend, num_inputs, mode="rate", max_rate=(5,10), intercept=(0,0))]
42                #note that we approximate dendrites here with rate-mode LIF neurons. could use sigmoids or anything else instead
43            encoders = dendrites[i].getEncoders()
44            dendrites[i].addTermination("input",
45                                        [[1.0*e for e in encoders[j]] for j,n in enumerate(dendrites[i].getNodes())],
46                                        pstc_range, delay_range, False)
47            net.connect(input_pop, dendrites[i].getTermination("input"))
49        #run the network in order to collect data for calculating decoders
50        activities = [[] for _ in range(num_outputs*num_dend)]
51        for _ in range(len(inputs[0])):
52            net.run(dt)
53            for i,a in enumerate(activities):
54                a += [dendrites[i/num_dend].nodes[i%num_dend].getOrigin("AXON").getValues().getValues()[0]]
56        #calculate decoders
57        for i,dend in enumerate(dendrites):
58            wcost = WeightedCostApproximator.Factory(0.1).getApproximator([[[x for x in sig] for sig in inputs]],
59                                                                     [[[a for a in n]] for n in activities[i*num_dend:(i+1)*num_dend]])
60            coeffs = wcost.findCoefficients(outputs[i])
62            o = dend.addDecodedOrigin(DecodedOrigin(dend, "signalX", dend.nodes, "AXON",
63                                                    [FixedSignalFunction(MU.transpose(outputs[i]), 0, num_inputs)],
64                                                    [[x] for x in coeffs], 0))
66            #connect dendrites up to output neuron
67            net.connect(o, neurons, transform=[[1] if i==n else [0] for n in range(num_outputs)])
69        net.reset()
71    #the basic idea:
72    #random weights from input -> dendrites
73        #actually though we'll just do a normal connection, the randomness will come from the randomness
74        #in the dendrite population
75        #but we want different pstc's for different dendrites (randomly chosen)
76    #decoded weights from dendrites -> neurons (calculated based on specified input/output signals)
77    #no decoding on output of neurons (since only looking at spikes)
78        #basically the decoding is being put into the dendrite connections
80    #calculating decoders
81    #run the network with the given input
82    #save the activity of the dendrites
83        #this becomes the A matrix in the decoder calculations.
84        #since we're feeding the dendrite output into the output population, if we optimize
85        #the dendrite output to produce the desired signal, that should also cause the neurons
86        #to output the desired signal (e.g. when dendrite signal is high, neuron signal will be high).
87        #can put a threshold on the output neurons to do pattern detection instead