home *** CD-ROM | disk | FTP | other *** search
- /*-----------------------------------------------------------------------*
- * Greg Stevens 6/24/93*
- * NNBKPROP.C *
- * [file 6 in a series of 6] *
- * *
- * This file contains the functions for calculating the error and weight *
- * changes for the backpropagation algorithm for the network. Defined *
- * in this file are EPSILON, the weight change increment/coefficient, *
- * and function that updates the weights [UpDateWeightandThresh()]. It *
- * also contains code for a function called GetDerivs(), but this is to *
- * be used in the function UpDateWeightsandDerivs(), not by a main *
- * program. It also contains a function InitOutPatterns, which is *
- * similar to InitPatterns for the input patterns, but takes from a *
- * different file that would contain the corresponding desired output *
- * patterns for the set of input patterns. *
- * *
- *-----------------------------------------------------------------------*/
- #include "nnloadin.c"
-
- #define EPSILON 0.25 /* constant incrementation for weight change */
-
- /* type for holding error values for weight connections */
-
- typedef struct
- {
- float e[ NUMLAYERS ][ MAXNODES ][ MAXNODES ];
- } wERRORtype;
-
- /* type for holding error values for threshhold weights */
-
- typedef struct
- {
- float e[ NUMLAYERS ][ MAXNODES ];
- } tERRORtype;
-
- /* Function Prototypes */
- PATTERNtype InitOutPatterns( void );
- tERRORtype GetDerivs( NNETtype n, PATTERNtype GoalOut, int Pattern );
- NNETtype UpDateWeightandThresh(NNETtype nn, PATTERNtype goal, int p );
-
- /* Function Definitions */
- PATTERNtype InitOutPatterns( void )
- {
- FILE *InFile; /*file w/ pattern*/
- /*data */
- PATTERNtype patns; /*stores patterns*/
- float val; /*pattern value */
- int P,U; /*loop variables */
-
- InFile = fopen( "nnoutput.dat", "rt" ); /*open: read text*/
-
- if ( InFile==NULL ) /* if no file... */
- {
- printf( "File nnoutput.dat does not exist!\n" ); /*error message */
- return( patns ); /*leaves function*/
- }
-
- for (P=0; (P<NUM_PATTERNS); ++P) /* for each pattern.... */
- for (U=0; (U<OUTPUT_LAYER_SIZE); ++U) /* for each unit in it:*/
- {
- fscanf( InFile, "%f", &val );
- patns.p[P][U] = val;
- }
-
- fclose( InFile );
-
- return( patns );
- }
-
-
- tERRORtype GetDerivs(NNETtype n, PATTERNtype GoalOut, int pattern)
- {
- int layer; /* looping variables */
- int node;
- int tonode;
-
- tERRORtype Deriv1; /* for holding dE/dy */
- tERRORtype Deriv2; /* for holding dE/ds */
-
- layer = NUMLAYERS - 1; /* set layer to output layer */
-
- /* calculate dE/dy for output nodes */
- for (node=0; (node<NUMNODES[layer]); ++node) /* for each output node */
- {
- Deriv1.e[layer][node]=GoalOut.p[pattern][node]-n.unit[layer][node].state;
- }
-
- /* calculate dE/ds for output nodes */
- for (node=0; (node<NUMNODES[layer]); ++node)
- {
- if (n.unit[layer][node].actfn==0) /* if it's a linear node... */
- Deriv2.e[layer][node] = Deriv1.e[layer][node];
- else if (n.unit[layer][node].actfn==1) /* if it's a logistic node...*/
- Deriv2.e[layer][node] = Deriv1.e[layer][node] *
- n.unit[layer][node].state *
- (1.0 - n.unit[layer][node].state);
- }
-
- /* calculate dE/dy and dE/ds for hidden layers (backwards from output,*/
- /* not including input layer). */
- for (layer=NUMLAYERS-2; (layer>0); --layer )
- {
- /* calculate dE/dy */
- for (node=0; (node<NUMNODES[layer]); ++node)
- {
- Deriv1.e[layer][node] = 0;
- for (tonode=0; (tonode<NUMNODES[layer+1]); ++tonode)
- {
- Deriv1.e[layer][node] += Deriv2.e[layer+1][tonode] *
- n.unit[layer+1][tonode].weights[node];
- }
- }
-
- /* calculate dE/ds */
- for (node=0; (node<NUMNODES[layer]); ++node)
- {
- if (n.unit[layer][node].actfn==0)
- Deriv2.e[layer][node] = Deriv1.e[layer][node];
- else if (n.unit[layer][node].actfn==1)
- Deriv2.e[layer][node] = Deriv1.e[layer][node] *
- n.unit[layer][node].state *
- (1.0 - n.unit[layer][node].state);
- }
- }
-
- return( Deriv2 ); /* return dE/ds for each layer */
- }
-
- NNETtype UpDateWeightandThresh(NNETtype nn, PATTERNtype goal, int p )
- {
- NNETtype newnet;
- wERRORtype WeightError;
- tERRORtype ThreshError;
- tERRORtype Derivs;
- int layer, unit, inunit;
-
- /* find WeightError and ThreshError */
-
- Derivs = GetDerivs( nn, goal, p );
- for (layer=1; (layer<NUMLAYERS); ++layer )
- {
- for (unit=0; (unit<NUMNODES[ layer ]); ++unit )
- {
- ThreshError.e[ layer ][ unit ] = Derivs.e[ layer ][ unit ];
-
- for (inunit=0; (inunit<NUMNODES[ layer-1 ]); ++inunit)
- {
- WeightError.e[layer][unit][inunit] = Derivs.e[layer][unit] *
- nn.unit[layer-1][inunit].state;
- }
- }
- }
-
- /* Change Weights */
- newnet = nn;
- for (layer=1; (layer<NUMLAYERS); ++layer)
- for (unit=0; (unit<NUMNODES[layer]); ++unit)
- {
- newnet.unit[layer][unit].thresh += EPSILON*ThreshError.e[layer][unit];
-
- for (inunit=0; (inunit<NUMNODES[layer-1]); ++inunit)
- newnet.unit[layer][unit].weights[inunit] += EPSILON*
- WeightError.e[layer][unit][inunit];
- }
-
- return( newnet );
- }
-