UESMANN CPP  1.0
Reference implementation of UESMANN
/home/travis/build/jimfinnis/uesmanncpp/netFactory.hpp
Go to the documentation of this file.
1 
9 #ifndef __NETFACTORY_HPP
10 #define __NETFACTORY_HPP
11 
12 #include "bpnet.hpp"
13 #include "obnet.hpp"
14 #include "hinet.hpp"
15 #include "uesnet.hpp"
16 
17 
24 class NetFactory { // not a namespace because Doxygen gets confused.
25 public:
32  static Net *makeNet(NetType t,ExampleSet &e,int hnodes){
33  Net *net;
34 
35  int layers[3];
36  layers[0] = e.getInputCount();
37  layers[1] = hnodes;
38  layers[2] = e.getOutputCount();
39 
40  return makeNet(t,3,layers);
41  }
42 
43  static Net *makeNet(NetType t,int layercount, int *layers){
44  switch(t){
45  case NetType::PLAIN:
46  return new BPNet(layercount,layers);
48  return new OutputBlendingNet(layercount,layers);
49  case NetType::HINPUT:
50  return new HInputNet(layercount,layers);
51  case NetType::UESMANN:
52  return new UESNet(layercount,layers);
53  default:break;
54  }
55  }
56 
61  inline static Net *load(const char *fn){
62  FILE *a = fopen(fn,"rb");
63  if(!a)
64  throw new std::runtime_error("cannot open file");
65 
66  // get type
67  uint32_t magic;
68  if(!fread(&magic,sizeof(uint32_t),1,a)){
69  fclose(a);
70  throw new std::runtime_error("bad net save file");
71  }
72 
73  NetType t = static_cast<NetType>(magic);
74 
75  // build layer specification reading the layer count and then
76  // the layer sizes
77  uint32_t layercount,tmp;
78  if(!fread(&layercount,sizeof(uint32_t),1,a)){
79  fclose(a);
80  throw new std::runtime_error("bad net save file");
81  }
82  int *layers = new int[layercount];
83  for(int i=0;i<layercount;i++){
84  if(!fread(&tmp,sizeof(uint32_t),1,a)){
85  delete [] layers;
86  fclose(a);
87  throw new std::runtime_error("bad net save file");
88  }
89  layers[i]=tmp;
90  }
91 
92  // build the net
93  Net *n = makeNet(t,layercount,layers);
94 
95  // get the parameter data
96  int size = n->getDataSize();
97  double *buf = new double[size];
98  // and read it
99 // printf("loading %d doubles\n",size);
100  int readData = fread(buf,sizeof(double),size,a);
101  if(readData!=size){
102  delete [] buf;
103  delete [] layers;
104  fclose(a);
105  throw new std::runtime_error("bad net save file");
106  }
107  n->load(buf);
108 
109  delete [] buf;
110  delete [] layers;
111  fclose(a);
112  return n;
113  }
114 
119  inline static void save(const char *fn,Net *n) {
120  FILE *a = fopen(fn,"wb");
121  if(!a)
122  throw new std::runtime_error("cannot open file");
123 
124  // get and write the magic number
125  uint32_t magic=static_cast<uint32_t>(n->type); // magic number
126  fwrite(&magic,sizeof(uint32_t),1,a);
127 
128  // write the layer count and layer sizes, all as 32-bit.
129  uint32_t layercount = n->getLayerCount();
130  fwrite(&layercount,sizeof(uint32_t),1,a);
131  for(int i=0;i<layercount;i++){
132  uint32_t layersize = n->getLayerSize(i);
133  fwrite(&layersize,sizeof(uint32_t),1,a);
134  }
135 
136  // get the parameter data
137  int size = n->getDataSize();
138 // printf("saving %d doubles\n",size);
139  double *buf = new double[size];
140  n->save(buf);
141  // and write it
142  fwrite(buf,sizeof(double),size,a);
143  delete [] buf;
144 
145  fclose(a);
146  }
147 
148 };
149 
150 
151 
152 #endif /* __NETFACTORY_HPP */
NetType
The different types of network - each has an associated integer for saving/loading file data...
Definition: netType.hpp:15
static Net * load(const char *fn)
Load a network of any type from a file - note, endianness not checked!
Definition: netFactory.hpp:61
output blending
This class - really a namespace - contains functions which create, load or save networks of all types...
Definition: netFactory.hpp:24
virtual int getLayerCount() const =0
Get the number of layers.
A modulatory network architecture which uses two plain backprop networks, each of which is trained se...
Definition: obnet.hpp:18
virtual void save(double *buf) const =0
Serialize the data (not including any network type magic number or layer/node counts) to the given me...
A modulatory network architecture which uses a plain backprop network with an extra input to carry th...
Definition: hinet.hpp:17
The "basic" back-propagation network using a logistic sigmoid, as described by Rumelhart, Hinton and Williams (and many others). This class is used by output blending and h-as-input networks.
Definition: bpnet.hpp:18
plain back-propagation
int getInputCount() const
get the number of inputs in all examples
Definition: data.hpp:311
static Net * makeNet(NetType t, int layercount, int *layers)
Definition: netFactory.hpp:43
virtual int getLayerSize(int n) const =0
Get the number of nodes in a given layer.
h-as-input network - only
This implements a plain backprop network.
static void save(const char *fn, Net *n)
Save a net of any type to a file - note, endianness not checked!
Definition: netFactory.hpp:119
static Net * makeNet(NetType t, ExampleSet &e, int hnodes)
Construct a single hidden layer network of a given type which conforms to the example set...
Definition: netFactory.hpp:32
This file contains the implementation of the UESMANN network itself - at least, those parts which are...
int getOutputCount() const
get the number of outputs in all examples
Definition: data.hpp:319
virtual int getDataSize() const =0
Get the length of the serialised data block for this network.
virtual void load(double *buf)=0
Given that the pointer points to a data block of the correct size for the current network...
NetType type
type of the network, used for load/save
Definition: net.hpp:49
h-as-input
The UESMANN network, which it itself based on the BPNet code as it has the same architecture as the p...
Definition: uesnet.hpp:17
Output blending network - only works with 2 h-levels, 0 and 1, and only with SGD. ...
The abstract network type upon which all others are based. It&#39;s not pure virtual, in that it encapsul...
Definition: net.hpp:39
A set of example data. Each datum consists of hormone (i.e. modulator value), inputs and outputs...
Definition: data.hpp:57