16 #include <arpa/inet.h> 35 MNIST(
const char *labelFile,
const char *imgFile,
int start=0,
int len=0){
38 FILE *a = fopen(labelFile,
"rb");
40 printf(
"Error opening label file %s: %s\n",labelFile,strerror(errno));;
41 throw std::runtime_error(
"cannot open label file: " + std::string(labelFile));
45 rd=fread(&magic,
sizeof(uint32_t),1,a);
48 printf(
"incorrect magic number in label file %s: %x\n",labelFile,magic);
49 throw std::runtime_error(
"bad magic number in label file");
51 rd=fread(&ct,
sizeof(uint32_t),1,a);
54 printf(
"unfeasibly large count in label file %s: %x\n",labelFile,ct);
55 throw std::runtime_error(
"bad count in label file");
60 printf(
"specified range [%d-%d], only %d in file %s\n",start,start+len,ct,labelFile);
61 throw std::runtime_error(
"bad range in label file");
63 fseek(a,start*
sizeof(uint8_t),SEEK_CUR);
65 labels =
new uint8_t[len];
66 rd = fread(labels,
sizeof(uint8_t),len,a);
68 printf(
"not enough items in label file %s: %d\n",labelFile,rd);
69 throw std::runtime_error(
"not enough elements in label file");
74 a = fopen(imgFile,
"rb");
76 printf(
"Error opening image file %s: %s\n",imgFile,strerror(errno));
77 throw std::runtime_error(
"cannot open image file: " + std::string(imgFile));
79 rd=fread(&magic,
sizeof(uint32_t),1,a);
82 printf(
"incorrect magic number in image file %s: %d\n",imgFile,magic);
83 throw std::runtime_error(
"bad magic in image file");
86 rd=fread(&ct2,
sizeof(uint32_t),1,a);
89 printf(
"image file count does not agree with label file count:\n" 91 imgFile,ct2,labelFile,ct);
92 throw std::runtime_error(
"bad count in image file");
95 rd=fread(&rows,
sizeof(uint32_t),1,a);
97 rd=fread(&cols,
sizeof(uint32_t),1,a);
99 if(rows > 128 || cols > 128){
100 printf(
"Bad dimensions in image file %s: %dx%d\n",imgFile,rows,cols);
101 throw std::runtime_error(
"bad dimensions in image file");
104 fseek(a,start*
sizeof(uint8_t)*rows*cols,SEEK_CUR);
105 imgs =
new uint8_t[len*rows*cols];
106 rd = fread(imgs,
sizeof(uint8_t),rows*cols*len,a);
107 if(rd!=len*rows*cols){
108 printf(
"wrong amount of pixels in image file %s: %d\n",imgFile,rd);
109 throw std::runtime_error(
"bad filesize in image file");
116 for(
int i=0;i<ct;i++){
176 return imgs+rows*cols*n;
184 uint8_t
getPix(
int n,
int x,
int y)
const {
194 printf(
"Out of range\n");
198 for(
int x=0;x<
r();x++){
199 for(
int y=0;y<
c();y++){
200 uint8_t qq = *d++ / 25;
202 putchar(qq ? qq+
'0':
'.');
uint8_t getMaxLabel() const
get the maximum label value (0 to 9 in the original data but different in other tests) ...
int r() const
returns the number of rows in each image
void dump(int i) const
dump the image data to standard out
This class encapsulates and loads data in the standard MNIST format. The data resides in two files...
uint8_t getLabel(int n) const
get the label for a given example
uint8_t getPix(int n, int x, int y) const
get a pixel for a given example
int c() const
returns the number of columns in each image
MNIST(const char *labelFile, const char *imgFile, int start=0, int len=0)
constructor which loads the data from the given file, and can load only part of the data in a file...
uint8_t * getImg(int n) const
get the bitmap for a given example
int getCount() const
returns the number of examples