home *** CD-ROM | disk | FTP | other *** search
/ Liren Large Software Subsidy 10 / 10.iso / l / l460 / 2.ddi / FUNFUN.DI$ / ODE45.C < prev    next >
Encoding:
C/C++ Source or Header  |  1993-02-01  |  9.7 KB  |  456 lines

  1. /*
  2.  
  3.    ODE45.C    MEX file implementation of:
  4.  
  5.     [tout, yout] = ode45(F, t0, tfinal, y0, tol)
  6.  
  7.    Integrate a system of ordinary differential equations.
  8.  
  9.    INPUT:
  10.    F     - String containing name of user-supplied problem description.
  11.          Call: yprime = fun(t,y) where F = 'fun'.
  12.          t      - Time (scalar).
  13.          y      - Solution column-vector.
  14.          yprime - Returned derivative column-vector; yprime(i) = dy(i)/dt.
  15.    t0    - Initial value of t.
  16.    tfinal- Final value of t.
  17.    y0    - Initial value column-vector.
  18.    tol   - The desired accuracy. (Default: tol = 1.e-6).
  19.  
  20.    OUTPUT:
  21.    tout  - Returned integration time points (row-vector).
  22.    yout  - Returned solution, one solution column-vector per tout-value.
  23.  
  24.    C.B. Moler, 3-25-87.
  25.    Marc Ullman  June 23, 1987
  26.    Copyright (C) 1987  The Mathworks Inc.
  27.    All Rights Reserved
  28. */
  29.  
  30. #include <math.h>
  31. #include "mex.h"
  32.  
  33. /* Input Arguments */
  34.  
  35. #define    F_IN    prhs[0]
  36. #define    T0_IN    prhs[1]
  37. #define    TF_IN    prhs[2]
  38. #define    Y0_IN    prhs[3]
  39. #define    TOL_IN    prhs[4]
  40.  
  41.  
  42. /* Output Arguments */
  43.  
  44. #define    T_OUT    plhs[0]
  45. #define    Y_OUT    plhs[1]
  46.  
  47.  
  48. /* Temporary Variables */
  49.  
  50. #define    OLD_T_OUT    old_plhs[0]
  51. #define    OLD_Y_OUT    old_plhs[1]
  52.  
  53. #define    T_TEMP        tmp_prhs[0]
  54. #define    Y_TEMP        tmp_prhs[1]
  55. #define    Y_OUT_TEMP    tmp_plhs[0]
  56.  
  57.  
  58. #define    MAX(A, B)    ((A) > (B) ? (A) : (B))
  59. #define    MIN(A, B)    ((A) < (B) ? (A) : (B))
  60.  
  61. /* Fehlberg coefficients */
  62.  
  63. static double alpha[5] = { 1.0/4.0,  3.0/8.0,  12.0/13.0,  1.0,  1.0/2.0 };
  64.  
  65. static double beta[5][5] = {
  66.     { 1.0/4.0, 3.0/32.0,  1932.0/2197.0,   8341.0/4104.0,  -6080.0/20520.0 },
  67.     {     0.0, 9.0/32.0, -7200.0/2197.0, -32832.0/4104.0,  41040.0/20520.0 },
  68.     {     0.0,      0.0,  7296.0/2197.0,  29440.0/4104.0, -28352.0/20520.0 },
  69.     {     0.0,      0.0,            0.0,   -845.0/4104.0,   9295.0/20520.0 },
  70.     {     0.0,      0.0,            0.0,             0.0,  -5643.0/20520.0 }
  71. };
  72.  
  73. static double gama[6][2] = {
  74.     {   902880.0/7618050.0,   -2090.0/752400.0 },
  75.     {                  0.0,                0.0 },
  76.     {  3953664.0/7618050.0,   22528.0/752400.0 },
  77.     {  3855735.0/7618050.0,   21970.0/752400.0 },
  78.     { -1371249.0/7618050.0,  -15048.0/752400.0 },
  79.     {   277020.0/7618050.0,  -27360.0/752400.0 }
  80. };
  81.  
  82. #ifdef __STDC__
  83. static double inf_norm(
  84.     double    y[],
  85.     unsigned int M
  86.     )
  87. #else
  88. static double inf_norm(y,M)
  89. double    y[];
  90. unsigned int M;
  91. #endif
  92. {
  93.     register unsigned int m;
  94.     double    temp_max,yabs;
  95.  
  96.     temp_max = fabs(y[0]);
  97.  
  98.     for (m = 1; m < M; m++) {
  99.         yabs = fabs(y[m]);
  100.         temp_max = MAX(temp_max,yabs);
  101.     }
  102.  
  103.     return(temp_max);
  104. }
  105.  
  106.  
  107.  
  108. #ifdef __STDC__
  109. void mexFunction(
  110.     int        nlhs,
  111.     Matrix    *plhs[],
  112.     int        nrhs,
  113.     Matrix    *prhs[]
  114.     )
  115. #else
  116. mexFunction(nlhs, plhs, nrhs, prhs)
  117. int nlhs, nrhs;
  118. Matrix *plhs[], *prhs[];
  119. #endif
  120. {
  121.  
  122.     double    *tout,*yout; 
  123.     double    *toutp,*youtp;
  124.     double    *t,*y;
  125.     double    *y0;
  126.     char    fcn_name[20];
  127.  
  128.     register unsigned int    m,k,kk;
  129.     unsigned int    M,K;
  130.     int        nr, nc;
  131.  
  132.     double    ts;
  133.     double    *s0,*s1,*s2,*s3,*s4,*s5,*s6;
  134.     double    t0;
  135.     double    tfinal;
  136.     double    tol,tau,delta,h,hmax;
  137.     double    ymax;
  138.     double    *ydel;
  139.  
  140.     static double powr = 1.0/5.0;
  141.  
  142.     Matrix    *tmp_plhs[1],*tmp_prhs[2];
  143.     Matrix    *s_matptr[6];
  144.     Matrix    *old_plhs[2];
  145.  
  146.  
  147.     /* Check out the arguments */
  148.  
  149.     if ((nrhs < 4) || (nrhs > 6)) {
  150.         mexErrMsgTxt("Wrong number of input arguments for ODE45");
  151.     } else if (nlhs != 2) {
  152.         mexErrMsgTxt("Wrong number of output arguments for ODE45");
  153.     }
  154.  
  155.  
  156.     /*
  157.      *  Get user function name
  158.      */
  159.  
  160.     if (!mxIsString(F_IN))
  161.         mexErrMsgTxt("String argument expected for function name in ODE23.");
  162.     mxGetString(F_IN, fcn_name, 20);
  163.  
  164.     /*
  165.      * Get Input Arguments
  166.      */
  167.  
  168.     nr = mxGetM(T0_IN);
  169.     nc = mxGetN(T0_IN);
  170.     if (!mxIsNumeric(T0_IN) || mxIsComplex(T0_IN) || 
  171.         !mxIsFull(T0_IN)  || !mxIsDouble(T0_IN) || nr*nc != 1)
  172.         mexErrMsgTxt("Bad t0 input for ODE23.");
  173.     t0 = mxGetScalar(T0_IN);
  174.  
  175.     nr = mxGetM(TF_IN);
  176.     nc = mxGetN(TF_IN);
  177.     if (!mxIsNumeric(TF_IN) || mxIsComplex(TF_IN) || 
  178.         !mxIsFull(TF_IN)  || !mxIsDouble(TF_IN) || nr*nc != 1)
  179.         mexErrMsgTxt("Bad tfinal input for ODE23.");
  180.     tfinal = mxGetScalar(TF_IN);
  181.  
  182.     M = MAX(mxGetM(Y0_IN),mxGetN(Y0_IN));
  183.  
  184.     if (!mxIsNumeric(Y0_IN) || mxIsComplex(Y0_IN) || 
  185.         !mxIsFull(Y0_IN)  || !mxIsDouble(Y0_IN) || !M)
  186.         mexErrMsgTxt("Bad y0 input for ODE23.");
  187.     y0 = mxGetPr(Y0_IN);
  188.  
  189.  
  190.     /*
  191.      * Create and Initialize Return Arguments
  192.      */
  193.  
  194.     T_OUT = mxCreateFull(1, 1, REAL);
  195.     Y_OUT = mxCreateFull(mxGetM(Y0_IN),mxGetN(Y0_IN), REAL);
  196.  
  197.     tout = mxGetPr(T_OUT);
  198.     yout = mxGetPr(Y_OUT);
  199.  
  200.     tout[0] = t0;
  201.     for (m = 0; m <M; m++) {
  202.         yout[m] = y0[m];
  203.     }
  204.  
  205.     /*
  206.      * Create arguments for calling user function
  207.      */
  208.  
  209.     T_TEMP = mxCreateFull(1, 1, REAL);
  210.     Y_TEMP = mxCreateFull(mxGetM(Y0_IN),mxGetN(Y0_IN), REAL);
  211.  
  212.     t = mxGetPr(T_TEMP);
  213.     y = mxGetPr(Y_TEMP);
  214.  
  215.     /*
  216.      * Create an array for ydel
  217.      */
  218.  
  219.     ydel = (double *) mxCalloc(M, sizeof(double));
  220.  
  221.     /*
  222.      * Initialization
  223.      */
  224.  
  225.     if (nrhs < 5) {
  226.         tol = 0.000001;
  227.     } else {
  228.         nr = mxGetM(TOL_IN);
  229.         nc = mxGetN(TOL_IN);
  230.         if (!mxIsNumeric(TOL_IN) || mxIsComplex(TOL_IN) || 
  231.             !mxIsFull(TOL_IN)  || !mxIsDouble(TOL_IN) || nr*nc != 1)
  232.             mexErrMsgTxt("Bad tol input for ODE45.");
  233.         tol = mxGetScalar(TOL_IN);
  234.     }
  235.  
  236.     hmax = (tfinal - t0)/16;
  237.     h = hmax/8;
  238.  
  239.     tout = &t0; 
  240.     yout = y0;
  241.  
  242.     *t = t0;
  243.  
  244.     k = 0;
  245.  
  246.     /*
  247.      * The main loop
  248.      */
  249.  
  250.     while ((*t < tfinal) && ((*t + h) >= *t)) {
  251.  
  252.         if ((*t + h) > tfinal) {
  253.             h = tfinal - *t;
  254.         }
  255.  
  256.         /*
  257.          * Compute the slopes
  258.          */
  259.  
  260.  
  261.         /* ts = t  and  s0 = y */
  262.  
  263.         ts = tout[k];
  264.         s0 = &yout[k*M];
  265.  
  266.         /* s1 = feval(F, t, y); */
  267.  
  268.         *t = ts;
  269.         for  (m = 0; m < M; m++) {
  270.             y[m] = s0[m];
  271.         }
  272.  
  273.         mexCallMATLAB(1,tmp_plhs,2,tmp_prhs,fcn_name);
  274.         if (mxGetM(Y_OUT_TEMP)*mxGetN(Y_OUT_TEMP) != M) {
  275.             goto errorexit;
  276.         }
  277.         s1 = mxGetPr(Y_OUT_TEMP);
  278.         s_matptr[0] = Y_OUT_TEMP;
  279.  
  280.         /* s2 = feval(F, t+h*alpha(1), y+h*s1*beta(1,1)); */
  281.  
  282.         *t = ts + h*alpha[0];
  283.         for  (m = 0; m < M; m++) {
  284.             y[m] = s0[m]+h*s1[m]*beta[0][0];
  285.         }
  286.  
  287.         mexCallMATLAB(1,tmp_plhs,2,tmp_prhs,fcn_name);
  288.         if (mxGetM(Y_OUT_TEMP)*mxGetN(Y_OUT_TEMP) != M) {
  289.             goto errorexit;
  290.         }
  291.         s2 = mxGetPr(Y_OUT_TEMP);
  292.         s_matptr[1] = Y_OUT_TEMP;
  293.  
  294.         /* s3 = feval(F, t+h*alpha(2), y+h*[s1*beta(2,1)+s2*beta(2,2)]; */
  295.  
  296.         *t = ts + h*alpha[1];
  297.         for  (m = 0; m < M; m++) {
  298.             y[m] = s0[m]+h*(s1[m]*beta[0][1] + s2[m]*beta[1][1]);
  299.         }
  300.  
  301.         mexCallMATLAB(1,tmp_plhs,2,tmp_prhs,fcn_name);
  302.         if (mxGetM(Y_OUT_TEMP)*mxGetN(Y_OUT_TEMP) != M) {
  303.             goto errorexit;
  304.         }
  305.         s3 = mxGetPr(Y_OUT_TEMP);
  306.         s_matptr[2] = Y_OUT_TEMP;
  307.  
  308.         /* s4 = feval(F, t+h*alpha(3), y+h*[s1*beta(3,1)+s2*beta(3,2)+...]; */
  309.  
  310.         *t = ts + h*alpha[2];
  311.         for  (m = 0; m < M; m++) {
  312.             y[m] = s0[m]+h*(s1[m]*beta[0][2] + s2[m]*beta[1][2]
  313.                     +s3[m]*beta[2][2]);
  314.         }
  315.  
  316.         mexCallMATLAB(1,tmp_plhs,2,tmp_prhs,fcn_name);
  317.         if (mxGetM(Y_OUT_TEMP)*mxGetN(Y_OUT_TEMP) != M) {
  318.             goto errorexit;
  319.         }
  320.         s4 = mxGetPr(Y_OUT_TEMP);
  321.         s_matptr[3] = Y_OUT_TEMP;
  322.  
  323.         /* s5 = feval(F, t+h*alpha(4), y+h*[s1*beta(4,1)+s2*beta(4,2)+...]; */
  324.  
  325.         *t = ts + h*alpha[3];
  326.         for  (m = 0; m < M; m++) {
  327.             y[m] = s0[m]+h*(s1[m]*beta[0][3] + s2[m]*beta[1][3]
  328.                     +s3[m]*beta[2][3] + s4[m]*beta[3][3]);
  329.         }
  330.  
  331.         mexCallMATLAB(1,tmp_plhs,2,tmp_prhs,fcn_name);
  332.         if (mxGetM(Y_OUT_TEMP)*mxGetN(Y_OUT_TEMP) != M) {
  333.             goto errorexit;
  334.         }
  335.         s5 = mxGetPr(Y_OUT_TEMP);
  336.         s_matptr[4] = Y_OUT_TEMP;
  337.  
  338.         /* s6 = feval(F, t+h*alpha(5), y+h*[s1*beta(5,1)+s2*beta(5,2)+...]; */
  339.  
  340.         *t = ts + h*alpha[4];
  341.         for  (m = 0; m < M; m++) {
  342.             y[m] = s0[m]+h*(s1[m]*beta[0][4] + s2[m]*beta[1][4]
  343.                     +s3[m]*beta[2][4] + s4[m]*beta[3][4]
  344.                      +s5[m]*beta[4][4]);
  345.         }
  346.  
  347.         mexCallMATLAB(1,tmp_plhs,2,tmp_prhs,fcn_name);
  348.         if (mxGetM(Y_OUT_TEMP)*mxGetN(Y_OUT_TEMP) != M) {
  349.             goto errorexit;
  350.         }
  351.         s6 = mxGetPr(Y_OUT_TEMP);
  352.         s_matptr[5] = Y_OUT_TEMP;
  353.  
  354.         /*
  355.          * Estimate the error and the acceptable error
  356.          */
  357.  
  358.         for  (m = 0; m < M; m++) {
  359.             ydel[m] = h*(s1[m]*gama[0][1] + s2[m]*gama[1][1]
  360.                     + s3[m]*gama[2][1] + s4[m]*gama[3][1]
  361.                      + s5[m]*gama[4][1] + s6[m]*gama[5][1]);
  362.         }
  363.         delta = inf_norm(ydel,M);
  364.  
  365.         ymax = inf_norm(&yout[k*M],M);
  366.         tau = tol*MAX(ymax,1.0);
  367.  
  368.         /*
  369.          * Update the solution only if the error is acceptable
  370.          */
  371.  
  372.         if (delta <= tau) {
  373.  
  374.             K = k+1;
  375.  
  376.             OLD_T_OUT = T_OUT;
  377.             OLD_Y_OUT = Y_OUT;
  378.             T_OUT = mxCreateFull(1,K+1,REAL);
  379.             Y_OUT = mxCreateFull(m,K+1,REAL);
  380.  
  381.             toutp = tout;
  382.             youtp = yout;
  383.             tout = mxGetPr(T_OUT);
  384.             yout = mxGetPr(Y_OUT);
  385.  
  386.             for (kk = 0;kk < K; kk++) {
  387.                 tout[kk] = toutp[kk];
  388.                 for(m = 0; m < M; m++) {
  389.                     yout[kk*M+m] = youtp[kk*M+m];
  390.                 }
  391.             }
  392.  
  393.             /* if k <> 0 Free OLD_T_OUT,OLD_Y_OUT */
  394.  
  395.             if (k) {
  396.                 mxFreeMatrix(OLD_T_OUT);
  397.                 mxFreeMatrix(OLD_Y_OUT);
  398.             }
  399.  
  400.             tout[K] = *t = tout[k] + h;
  401.             for (m = 0; m < M; m++) {
  402.                 yout[K*M+m] = yout[k*M+m]
  403.                   + h*(s1[m]*gama[0][0] + s2[m]*gama[1][0]
  404.                      + s3[m]*gama[2][0] + s4[m]*gama[3][0]
  405.                      + s5[m]*gama[4][0] + s6[m]*gama[5][0]);
  406.             }
  407.  
  408.             k++;
  409.         }
  410.  
  411.         /* Free s_matptr[0], s_matptr[1], ... , s_matptr[5] */
  412.  
  413.         for (m = 0; m < 6; m++) {
  414.             mxFreeMatrix(s_matptr[m]);
  415.         }
  416.  
  417.         /*
  418.          * Update the step size
  419.          */
  420.  
  421.         if (delta) {
  422.             h = MIN(hmax, 0.8*h*pow(tau/delta,powr));
  423.         }
  424.  
  425.  
  426.     }
  427.  
  428.     if (*t < tfinal) {
  429. #ifdef THINK_C
  430.         mexPrintf("Singularity likely at t = %f\r",*t);
  431. #else
  432.         mexPrintf("Singularity likely at t = %f\n",*t);
  433. #endif
  434.     }
  435.  
  436.     /*
  437.      * Transpose the outputs
  438.      */
  439.  
  440.     m = mxGetM(T_OUT);
  441.     mxSetM(T_OUT, mxGetN(T_OUT));
  442.     mxSetN(T_OUT, m);
  443.  
  444.     Y_TEMP = Y_OUT;
  445.     mexCallMATLAB(1,&Y_OUT,1,&Y_TEMP,".'");
  446.  
  447.     return;
  448.  
  449. errorexit:
  450.     mexPrintf("Function %s is not returing the correct number of state \
  451. derivatives.",fcn_name);
  452.     mexErrMsgTxt("Error in ODE45.");
  453.  
  454. }
  455.  
  456.