# 2013/SKIMonNengo: skimnetwork.py

File skimnetwork.py, 4.8 KB (added by drasmuss, 5 years ago) |
---|

Line | |
---|---|

1 | import nef |

2 | |

3 | from ca.nengo.math.impl import WeightedCostApproximator, FixedSignalFunction, ConstantFunction, IndicatorPDF |

4 | from ca.nengo.math import Function |

5 | from ca.nengo.util import MU |

6 | from ca.nengo.model.impl import NetworkImpl, FunctionInput |

7 | from ca.nengo.model.nef.impl import DecodedOrigin |

8 | from ca.nengo.model import Units |

9 | from ca.nengo.plot.impl import DefaultPlotter |

10 | |

11 | class SKIMNetwork(NetworkImpl): |

12 | def __init__(self, num_dend, inputs, outputs, name="SKIMNetwork", dt=0.001, threshold=0.5): |

13 | NetworkImpl.__init__(self) |

14 | self.name = name |

15 | |

16 | net = nef.Network(self) |

17 | num_outputs = len(outputs) |

18 | num_inputs = len(inputs) |

19 | pstc_range = IndicatorPDF(0.01,0.1) #range from which pstcs will be drawn for dendrites |

20 | delay_range = IndicatorPDF(0.001, 0.03) #range from which synaptic delays will be drawn for dendrites |

21 | |

22 | #the input signal |

23 | signal = net.make_input("signal", [FixedSignalFunction(MU.transpose(inputs), i) for i in range(num_inputs)]) |

24 | |

25 | #the output signal |

26 | corr_sig = net.make_input("corr_sig", [FixedSignalFunction(MU.transpose(outputs), i) for i in range(num_outputs)]) |

27 | |

28 | #population of neurons that transforms input signal into spikes |

29 | input_pop = net.make("input_pop", num_inputs, num_inputs, encoders=MU.diag([1 for _ in range(num_inputs)]), |

30 | intercept=(0,0), max_rate=(100,150)) |

31 | |

32 | net.connect(signal, input_pop) |

33 | |

34 | #population of output neurons to detect patterns in input |

35 | neurons = net.make("neurons", num_outputs, num_outputs, encoders=MU.diag([1 for _ in range(num_outputs)]), |

36 | intercept=(threshold,threshold), max_rate=(100,150)) |

37 | |

38 | #create a population of dendrites for each output neuron |

39 | dendrites = [] |

40 | for i in range(num_outputs): |

41 | dendrites += [net.make("dendrites_%i"%i, num_dend, num_inputs, mode="rate", max_rate=(5,10), intercept=(0,0))] |

42 | #note that we approximate dendrites here with rate-mode LIF neurons. could use sigmoids or anything else instead |

43 | encoders = dendrites[i].getEncoders() |

44 | dendrites[i].addTermination("input", |

45 | [[1.0*e for e in encoders[j]] for j,n in enumerate(dendrites[i].getNodes())], |

46 | pstc_range, delay_range, False) |

47 | net.connect(input_pop, dendrites[i].getTermination("input")) |

48 | |

49 | #run the network in order to collect data for calculating decoders |

50 | activities = [[] for _ in range(num_outputs*num_dend)] |

51 | for _ in range(len(inputs[0])): |

52 | net.run(dt) |

53 | for i,a in enumerate(activities): |

54 | a += [dendrites[i/num_dend].nodes[i%num_dend].getOrigin("AXON").getValues().getValues()[0]] |

55 | |

56 | #calculate decoders |

57 | for i,dend in enumerate(dendrites): |

58 | wcost = WeightedCostApproximator.Factory(0.1).getApproximator([[[x for x in sig] for sig in inputs]], |

59 | [[[a for a in n]] for n in activities[i*num_dend:(i+1)*num_dend]]) |

60 | coeffs = wcost.findCoefficients(outputs[i]) |

61 | |

62 | o = dend.addDecodedOrigin(DecodedOrigin(dend, "signalX", dend.nodes, "AXON", |

63 | [FixedSignalFunction(MU.transpose(outputs[i]), 0, num_inputs)], |

64 | [[x] for x in coeffs], 0)) |

65 | |

66 | #connect dendrites up to output neuron |

67 | net.connect(o, neurons, transform=[[1] if i==n else [0] for n in range(num_outputs)]) |

68 | |

69 | net.reset() |

70 | |

71 | #the basic idea: |

72 | #random weights from input -> dendrites |

73 | #actually though we'll just do a normal connection, the randomness will come from the randomness |

74 | #in the dendrite population |

75 | #but we want different pstc's for different dendrites (randomly chosen) |

76 | #decoded weights from dendrites -> neurons (calculated based on specified input/output signals) |

77 | #no decoding on output of neurons (since only looking at spikes) |

78 | #basically the decoding is being put into the dendrite connections |

79 | |

80 | #calculating decoders |

81 | #run the network with the given input |

82 | #save the activity of the dendrites |

83 | #this becomes the A matrix in the decoder calculations. |

84 | #since we're feeding the dendrite output into the output population, if we optimize |

85 | #the dendrite output to produce the desired signal, that should also cause the neurons |

86 | #to output the desired signal (e.g. when dendrite signal is high, neuron signal will be high). |

87 | #can put a threshold on the output neurons to do pattern detection instead |